* The network library guarantees that a single thread will call these methods at a time, but
* different call may be made by different threads.
diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java
similarity index 86%
rename from network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java
rename to common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java
index 02230a00e69f..b0e85bae7c30 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java
@@ -30,13 +30,18 @@
*/
class StreamInterceptor implements TransportFrameDecoder.Interceptor {
+ private final TransportResponseHandler handler;
private final String streamId;
private final long byteCount;
private final StreamCallback callback;
+ private long bytesRead;
- private volatile long bytesRead;
-
- StreamInterceptor(String streamId, long byteCount, StreamCallback callback) {
+ StreamInterceptor(
+ TransportResponseHandler handler,
+ String streamId,
+ long byteCount,
+ StreamCallback callback) {
+ this.handler = handler;
this.streamId = streamId;
this.byteCount = byteCount;
this.callback = callback;
@@ -45,11 +50,13 @@ class StreamInterceptor implements TransportFrameDecoder.Interceptor {
@Override
public void exceptionCaught(Throwable cause) throws Exception {
+ handler.deactivateStream();
callback.onFailure(streamId, cause);
}
@Override
public void channelInactive() throws Exception {
+ handler.deactivateStream();
callback.onFailure(streamId, new ClosedChannelException());
}
@@ -65,8 +72,10 @@ public boolean handle(ByteBuf buf) throws Exception {
RuntimeException re = new IllegalStateException(String.format(
"Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead));
callback.onFailure(streamId, re);
+ handler.deactivateStream();
throw re;
} else if (bytesRead == byteCount) {
+ handler.deactivateStream();
callback.onComplete(streamId);
}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java
new file mode 100644
index 000000000000..8f354ad78bba
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -0,0 +1,322 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.net.SocketAddress;
+import java.nio.ByteBuffer;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Objects;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import com.google.common.util.concurrent.SettableFuture;
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.protocol.ChunkFetchRequest;
+import org.apache.spark.network.protocol.OneWayMessage;
+import org.apache.spark.network.protocol.RpcRequest;
+import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.protocol.StreamRequest;
+import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
+
+/**
+ * Client for fetching consecutive chunks of a pre-negotiated stream. This API is intended to allow
+ * efficient transfer of a large amount of data, broken up into chunks with size ranging from
+ * hundreds of KB to a few MB.
+ *
+ * Note that while this client deals with the fetching of chunks from a stream (i.e., data plane),
+ * the actual setup of the streams is done outside the scope of the transport layer. The convenience
+ * method "sendRPC" is provided to enable control plane communication between the client and server
+ * to perform this setup.
+ *
+ * For example, a typical workflow might be:
+ * client.sendRPC(new OpenFile("/foo")) --> returns StreamId = 100
+ * client.fetchChunk(streamId = 100, chunkIndex = 0, callback)
+ * client.fetchChunk(streamId = 100, chunkIndex = 1, callback)
+ * ...
+ * client.sendRPC(new CloseStream(100))
+ *
+ * Construct an instance of TransportClient using {@link TransportClientFactory}. A single
+ * TransportClient may be used for multiple streams, but any given stream must be restricted to a
+ * single client, in order to avoid out-of-order responses.
+ *
+ * NB: This class is used to make requests to the server, while {@link TransportResponseHandler} is
+ * responsible for handling responses from the server.
+ *
+ * Concurrency: thread safe and can be called from multiple threads.
+ */
+public class TransportClient implements Closeable {
+ private static final Logger logger = LoggerFactory.getLogger(TransportClient.class);
+
+ private final Channel channel;
+ private final TransportResponseHandler handler;
+ @Nullable private String clientId;
+ private volatile boolean timedOut;
+
+ public TransportClient(Channel channel, TransportResponseHandler handler) {
+ this.channel = Preconditions.checkNotNull(channel);
+ this.handler = Preconditions.checkNotNull(handler);
+ this.timedOut = false;
+ }
+
+ public Channel getChannel() {
+ return channel;
+ }
+
+ public boolean isActive() {
+ return !timedOut && (channel.isOpen() || channel.isActive());
+ }
+
+ public SocketAddress getSocketAddress() {
+ return channel.remoteAddress();
+ }
+
+ /**
+ * Returns the ID used by the client to authenticate itself when authentication is enabled.
+ *
+ * @return The client ID, or null if authentication is disabled.
+ */
+ public String getClientId() {
+ return clientId;
+ }
+
+ /**
+ * Sets the authenticated client ID. This is meant to be used by the authentication layer.
+ *
+ * Trying to set a different client ID after it's been set will result in an exception.
+ */
+ public void setClientId(String id) {
+ Preconditions.checkState(clientId == null, "Client ID has already been set.");
+ this.clientId = id;
+ }
+
+ /**
+ * Requests a single chunk from the remote side, from the pre-negotiated streamId.
+ *
+ * Chunk indices go from 0 onwards. It is valid to request the same chunk multiple times, though
+ * some streams may not support this.
+ *
+ * Multiple fetchChunk requests may be outstanding simultaneously, and the chunks are guaranteed
+ * to be returned in the same order that they were requested, assuming only a single
+ * TransportClient is used to fetch the chunks.
+ *
+ * @param streamId Identifier that refers to a stream in the remote StreamManager. This should
+ * be agreed upon by client and server beforehand.
+ * @param chunkIndex 0-based index of the chunk to fetch
+ * @param callback Callback invoked upon successful receipt of chunk, or upon any failure.
+ */
+ public void fetchChunk(
+ long streamId,
+ int chunkIndex,
+ ChunkReceivedCallback callback) {
+ long startTime = System.currentTimeMillis();
+ if (logger.isDebugEnabled()) {
+ logger.debug("Sending fetch chunk request {} to {}", chunkIndex, getRemoteAddress(channel));
+ }
+
+ StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex);
+ handler.addFetchRequest(streamChunkId, callback);
+
+ channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(future -> {
+ if (future.isSuccess()) {
+ long timeTaken = System.currentTimeMillis() - startTime;
+ if (logger.isTraceEnabled()) {
+ logger.trace("Sending request {} to {} took {} ms", streamChunkId,
+ getRemoteAddress(channel), timeTaken);
+ }
+ } else {
+ String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId,
+ getRemoteAddress(channel), future.cause());
+ logger.error(errorMsg, future.cause());
+ handler.removeFetchRequest(streamChunkId);
+ channel.close();
+ try {
+ callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause()));
+ } catch (Exception e) {
+ logger.error("Uncaught exception in RPC response callback handler!", e);
+ }
+ }
+ });
+ }
+
+ /**
+ * Request to stream the data with the given stream ID from the remote end.
+ *
+ * @param streamId The stream to fetch.
+ * @param callback Object to call with the stream data.
+ */
+ public void stream(String streamId, StreamCallback callback) {
+ long startTime = System.currentTimeMillis();
+ if (logger.isDebugEnabled()) {
+ logger.debug("Sending stream request for {} to {}", streamId, getRemoteAddress(channel));
+ }
+
+ // Need to synchronize here so that the callback is added to the queue and the RPC is
+ // written to the socket atomically, so that callbacks are called in the right order
+ // when responses arrive.
+ synchronized (this) {
+ handler.addStreamCallback(streamId, callback);
+ channel.writeAndFlush(new StreamRequest(streamId)).addListener(future -> {
+ if (future.isSuccess()) {
+ long timeTaken = System.currentTimeMillis() - startTime;
+ if (logger.isTraceEnabled()) {
+ logger.trace("Sending request for {} to {} took {} ms", streamId,
+ getRemoteAddress(channel), timeTaken);
+ }
+ } else {
+ String errorMsg = String.format("Failed to send request for %s to %s: %s", streamId,
+ getRemoteAddress(channel), future.cause());
+ logger.error(errorMsg, future.cause());
+ channel.close();
+ try {
+ callback.onFailure(streamId, new IOException(errorMsg, future.cause()));
+ } catch (Exception e) {
+ logger.error("Uncaught exception in RPC response callback handler!", e);
+ }
+ }
+ });
+ }
+ }
+
+ /**
+ * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked
+ * with the server's response or upon any failure.
+ *
+ * @param message The message to send.
+ * @param callback Callback to handle the RPC's reply.
+ * @return The RPC's id.
+ */
+ public long sendRpc(ByteBuffer message, RpcResponseCallback callback) {
+ long startTime = System.currentTimeMillis();
+ if (logger.isTraceEnabled()) {
+ logger.trace("Sending RPC to {}", getRemoteAddress(channel));
+ }
+
+ long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits());
+ handler.addRpcRequest(requestId, callback);
+
+ channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message)))
+ .addListener(future -> {
+ if (future.isSuccess()) {
+ long timeTaken = System.currentTimeMillis() - startTime;
+ if (logger.isTraceEnabled()) {
+ logger.trace("Sending request {} to {} took {} ms", requestId,
+ getRemoteAddress(channel), timeTaken);
+ }
+ } else {
+ String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId,
+ getRemoteAddress(channel), future.cause());
+ logger.error(errorMsg, future.cause());
+ handler.removeRpcRequest(requestId);
+ channel.close();
+ try {
+ callback.onFailure(new IOException(errorMsg, future.cause()));
+ } catch (Exception e) {
+ logger.error("Uncaught exception in RPC response callback handler!", e);
+ }
+ }
+ });
+
+ return requestId;
+ }
+
+ /**
+ * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to
+ * a specified timeout for a response.
+ */
+ public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) {
+ final SettableFuture
Like {@code java.util.Optional} in Java 8, {@code scala.Option} in Scala, and + * {@code com.google.common.base.Optional} in Google Guava, this class represents a + * value of a given type that may or may not exist. It is used in methods that wish + * to optionally return a value, in preference to returning {@code null}.
+ * + *In fact, the class here is a reimplementation of the essential API of both + * {@code java.util.Optional} and {@code com.google.common.base.Optional}. From + * {@code java.util.Optional}, it implements:
+ * + *From {@code com.google.common.base.Optional} it implements:
+ * + *{@code java.util.Optional} itself was not used because at the time, the + * project did not require Java 8. Using {@code com.google.common.base.Optional} + * has in the past caused serious library version conflicts with Guava that can't + * be resolved by shading. Hence this work-alike clone.
+ * + * @param* This write path is inefficient for shuffles with large numbers of reduce partitions because it @@ -61,7 +60,7 @@ * {@link SortShuffleManager} only selects this write path when *
spark.shuffle.sort.bypassMergeThreshold.@@ -28,7 +26,7 @@ * * This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that * our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the - * 13-bit page numbers assigned by {@link TaskMemoryManager}), this + * 13-bit page numbers assigned by {@link org.apache.spark.memory.TaskMemoryManager}), this * implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task. *
* Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this
@@ -44,6 +42,16 @@ final class PackedRecordPointer {
*/
static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1; // 16777215
+ /**
+ * The index of the first byte of the partition id, counting from the least significant byte.
+ */
+ static final int PARTITION_ID_START_BYTE_INDEX = 5;
+
+ /**
+ * The index of the last byte of the partition id, counting from the least significant byte.
+ */
+ static final int PARTITION_ID_END_BYTE_INDEX = 7;
+
/** Bit mask for the lower 40 bits of a long. */
private static final long MASK_LONG_LOWER_40_BITS = (1L << 40) - 1;
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index 400d8520019b..c33d1e33f030 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -37,8 +37,10 @@
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.DiskBlockObjectWriter;
+import org.apache.spark.storage.FileSegment;
import org.apache.spark.storage.TempShuffleBlockId;
import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.Utils;
@@ -60,7 +62,7 @@
*/
final class ShuffleExternalSorter extends MemoryConsumer {
- private final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class);
+ private static final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class);
@VisibleForTesting
static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
@@ -71,7 +73,10 @@ final class ShuffleExternalSorter extends MemoryConsumer {
private final TaskContext taskContext;
private final ShuffleWriteMetrics writeMetrics;
- /** Force this sorter to spill when there are this many elements in memory. For testing only */
+ /**
+ * Force this sorter to spill when there are this many elements in memory. The default value is
+ * 1024 * 1024 * 1024, which allows the maximum size of the pointer array to be 8G.
+ */
private final long numElementsForSpillThreshold;
/** The buffer size to use when writing spills using DiskBlockObjectWriter */
@@ -83,9 +88,9 @@ final class ShuffleExternalSorter extends MemoryConsumer {
* this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager
* itself).
*/
- private final LinkedList
* It is only valid to call this method immediately after calling `lookup()` using the same key.
*
* After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length` - * will return information on the data stored by this `putNewKey` call. + * will return information on the data stored by this `append` call. *
** As an example usage, here's the proper way to store a new key: @@ -620,7 +678,7 @@ public int getValueLength() { *
* Location loc = map.lookup(keyBase, keyOffset, keyLength);
* if (!loc.isDefined()) {
- * if (!loc.putNewKey(keyBase, keyOffset, keyLength, ...)) {
+ * if (!loc.append(keyBase, keyOffset, keyLength, ...)) {
* // handle failure to grow map (by spilling, for example)
* }
* }
@@ -632,28 +690,26 @@ public int getValueLength() {
* @return true if the put() was successful and false if the put() failed because memory could
* not be acquired.
*/
- public boolean putNewKey(Object keyBase, long keyOffset, int keyLength,
- Object valueBase, long valueOffset, int valueLength) {
- assert (!isDefined) : "Can only set value once for a key";
- assert (keyLength % 8 == 0);
- assert (valueLength % 8 == 0);
- assert(longArray != null);
-
+ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff, int vlen) {
+ assert (klen % 8 == 0);
+ assert (vlen % 8 == 0);
+ assert (longArray != null);
- if (numElements == MAX_CAPACITY
+ if (numKeys == MAX_CAPACITY
// The map could be reused from last spill (because of no enough memory to grow),
// then we don't try to grow again if hit the `growthThreshold`.
- || !canGrowArray && numElements > growthThreshold) {
+ || !canGrowArray && numKeys >= growthThreshold) {
return false;
}
// Here, we'll copy the data into our data pages. Because we only store a relative offset from
// the key address instead of storing the absolute address of the value, the key and value
// must be stored in the same memory page.
- // (8 byte key length) (key) (value)
- final long recordLength = 8 + keyLength + valueLength;
+ // (8 byte key length) (key) (value) (8 byte pointer to next value)
+ int uaoSize = UnsafeAlignedOffset.getUaoSize();
+ final long recordLength = (2 * uaoSize) + klen + vlen + 8;
if (currentPage == null || currentPage.size() - pageCursor < recordLength) {
- if (!acquireNewPage(recordLength + 4L)) {
+ if (!acquireNewPage(recordLength + uaoSize)) {
return false;
}
}
@@ -662,30 +718,36 @@ public boolean putNewKey(Object keyBase, long keyOffset, int keyLength,
final Object base = currentPage.getBaseObject();
long offset = currentPage.getBaseOffset() + pageCursor;
final long recordOffset = offset;
- Platform.putInt(base, offset, keyLength + valueLength + 4);
- Platform.putInt(base, offset + 4, keyLength);
- offset += 8;
- Platform.copyMemory(keyBase, keyOffset, base, offset, keyLength);
- offset += keyLength;
- Platform.copyMemory(valueBase, valueOffset, base, offset, valueLength);
-
- // --- Update bookkeeping data structures -----------------------------------------------------
+ UnsafeAlignedOffset.putSize(base, offset, klen + vlen + uaoSize);
+ UnsafeAlignedOffset.putSize(base, offset + uaoSize, klen);
+ offset += (2 * uaoSize);
+ Platform.copyMemory(kbase, koff, base, offset, klen);
+ offset += klen;
+ Platform.copyMemory(vbase, voff, base, offset, vlen);
+ offset += vlen;
+ // put this value at the beginning of the list
+ Platform.putLong(base, offset, isDefined ? longArray.get(pos * 2) : 0);
+
+ // --- Update bookkeeping data structures ----------------------------------------------------
offset = currentPage.getBaseOffset();
- Platform.putInt(base, offset, Platform.getInt(base, offset) + 1);
+ UnsafeAlignedOffset.putSize(base, offset, UnsafeAlignedOffset.getSize(base, offset) + 1);
pageCursor += recordLength;
- numElements++;
final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset(
currentPage, recordOffset);
longArray.set(pos * 2, storedKeyAddress);
- longArray.set(pos * 2 + 1, keyHashcode);
updateAddressesAndSizes(storedKeyAddress);
- isDefined = true;
+ numValues++;
+ if (!isDefined) {
+ numKeys++;
+ longArray.set(pos * 2 + 1, keyHashcode);
+ isDefined = true;
- if (numElements > growthThreshold && longArray.size() < MAX_CAPACITY) {
- try {
- growAndRehash();
- } catch (OutOfMemoryError oom) {
- canGrowArray = false;
+ if (numKeys >= growthThreshold && longArray.size() < MAX_CAPACITY) {
+ try {
+ growAndRehash();
+ } catch (OutOfMemoryError oom) {
+ canGrowArray = false;
+ }
}
}
return true;
@@ -703,8 +765,8 @@ private boolean acquireNewPage(long required) {
return false;
}
dataPages.add(currentPage);
- Platform.putInt(currentPage.getBaseObject(), currentPage.getBaseOffset(), 0);
- pageCursor = 4;
+ UnsafeAlignedOffset.putSize(currentPage.getBaseObject(), currentPage.getBaseOffset(), 0);
+ pageCursor = UnsafeAlignedOffset.getUaoSize();
return true;
}
@@ -724,11 +786,10 @@ public long spill(long size, MemoryConsumer trigger) throws IOException {
*/
private void allocate(int capacity) {
assert (capacity >= 0);
- // The capacity needs to be divisible by 64 so that our bit set can be sized properly
capacity = Math.max((int) Math.min(MAX_CAPACITY, ByteArrayMethods.nextPowerOf2(capacity)), 64);
assert (capacity <= MAX_CAPACITY);
- acquireMemory(capacity * 16);
- longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2]));
+ longArray = allocateArray(capacity * 2);
+ longArray.zeroOut();
this.growthThreshold = (int) (capacity * loadFactor);
this.mask = capacity - 1;
@@ -743,9 +804,8 @@ private void allocate(int capacity) {
public void free() {
updatePeakMemoryUsed();
if (longArray != null) {
- long used = longArray.memoryBlock().size();
+ freeArray(longArray);
longArray = null;
- releaseMemory(used);
}
Iterator dataPagesIterator = dataPages.iterator();
while (dataPagesIterator.hasNext()) {
@@ -834,21 +894,24 @@ public int getNumDataPages() {
/**
* Returns the underline long[] of longArray.
*/
- public long[] getArray() {
+ public LongArray getArray() {
assert(longArray != null);
- return (long[]) longArray.memoryBlock().getBaseObject();
+ return longArray;
}
/**
* Reset this map to initialized state.
*/
public void reset() {
- numElements = 0;
- Arrays.fill(getArray(), 0);
+ numKeys = 0;
+ numValues = 0;
+ freeArray(longArray);
while (dataPages.size() > 0) {
MemoryBlock dataPage = dataPages.removeLast();
freePage(dataPage);
}
+ allocate(initialCapacity);
+ canGrowArray = true;
currentPage = null;
pageCursor = 0;
}
@@ -887,7 +950,7 @@ void growAndRehash() {
longArray.set(newPos * 2, keyPointer);
longArray.set(newPos * 2 + 1, hashcode);
}
- releaseMemory(oldLongArray.memoryBlock().size());
+ freeArray(oldLongArray);
if (enablePerfMetrics) {
timeSpentResizingNs += System.nanoTime() - resizeStartTime;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
index d2bf297c6c17..0910db22af00 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
@@ -20,97 +20,153 @@
import com.google.common.primitives.UnsignedLongs;
import org.apache.spark.annotation.Private;
-import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.types.ByteArray;
import org.apache.spark.unsafe.types.UTF8String;
-import org.apache.spark.util.Utils;
@Private
public class PrefixComparators {
private PrefixComparators() {}
- public static final StringPrefixComparator STRING = new StringPrefixComparator();
- public static final StringPrefixComparatorDesc STRING_DESC = new StringPrefixComparatorDesc();
- public static final BinaryPrefixComparator BINARY = new BinaryPrefixComparator();
- public static final BinaryPrefixComparatorDesc BINARY_DESC = new BinaryPrefixComparatorDesc();
- public static final LongPrefixComparator LONG = new LongPrefixComparator();
- public static final LongPrefixComparatorDesc LONG_DESC = new LongPrefixComparatorDesc();
- public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator();
- public static final DoublePrefixComparatorDesc DOUBLE_DESC = new DoublePrefixComparatorDesc();
-
- public static final class StringPrefixComparator extends PrefixComparator {
- @Override
- public int compare(long aPrefix, long bPrefix) {
- return UnsignedLongs.compare(aPrefix, bPrefix);
- }
-
+ public static final PrefixComparator STRING = new UnsignedPrefixComparator();
+ public static final PrefixComparator STRING_DESC = new UnsignedPrefixComparatorDesc();
+ public static final PrefixComparator STRING_NULLS_LAST = new UnsignedPrefixComparatorNullsLast();
+ public static final PrefixComparator STRING_DESC_NULLS_FIRST =
+ new UnsignedPrefixComparatorDescNullsFirst();
+
+ public static final PrefixComparator BINARY = new UnsignedPrefixComparator();
+ public static final PrefixComparator BINARY_DESC = new UnsignedPrefixComparatorDesc();
+ public static final PrefixComparator BINARY_NULLS_LAST = new UnsignedPrefixComparatorNullsLast();
+ public static final PrefixComparator BINARY_DESC_NULLS_FIRST =
+ new UnsignedPrefixComparatorDescNullsFirst();
+
+ public static final PrefixComparator LONG = new SignedPrefixComparator();
+ public static final PrefixComparator LONG_DESC = new SignedPrefixComparatorDesc();
+ public static final PrefixComparator LONG_NULLS_LAST = new SignedPrefixComparatorNullsLast();
+ public static final PrefixComparator LONG_DESC_NULLS_FIRST =
+ new SignedPrefixComparatorDescNullsFirst();
+
+ public static final PrefixComparator DOUBLE = new UnsignedPrefixComparator();
+ public static final PrefixComparator DOUBLE_DESC = new UnsignedPrefixComparatorDesc();
+ public static final PrefixComparator DOUBLE_NULLS_LAST = new UnsignedPrefixComparatorNullsLast();
+ public static final PrefixComparator DOUBLE_DESC_NULLS_FIRST =
+ new UnsignedPrefixComparatorDescNullsFirst();
+
+ public static final class StringPrefixComparator {
public static long computePrefix(UTF8String value) {
return value == null ? 0L : value.getPrefix();
}
}
- public static final class StringPrefixComparatorDesc extends PrefixComparator {
- @Override
- public int compare(long bPrefix, long aPrefix) {
+ public static final class BinaryPrefixComparator {
+ public static long computePrefix(byte[] bytes) {
+ return ByteArray.getPrefix(bytes);
+ }
+ }
+
+ public static final class DoublePrefixComparator {
+ /**
+ * Converts the double into a value that compares correctly as an unsigned long. For more
+ * details see http://stereopsis.com/radix.html.
+ */
+ public static long computePrefix(double value) {
+ // Java's doubleToLongBits already canonicalizes all NaN values to the smallest possible
+ // positive NaN, so there's nothing special we need to do for NaNs.
+ long bits = Double.doubleToLongBits(value);
+ // Negative floats compare backwards due to their sign-magnitude representation, so flip
+ // all the bits in this case.
+ long mask = -(bits >>> 63) | 0x8000000000000000L;
+ return bits ^ mask;
+ }
+ }
+
+ /**
+ * Provides radix sort parameters. Comparators implementing this also are indicating that the
+ * ordering they define is compatible with radix sort.
+ */
+ public abstract static class RadixSortSupport extends PrefixComparator {
+ /** @return Whether the sort should be descending in binary sort order. */
+ public abstract boolean sortDescending();
+
+ /** @return Whether the sort should take into account the sign bit. */
+ public abstract boolean sortSigned();
+
+ /** @return Whether the sort should put nulls first or last. */
+ public abstract boolean nullsFirst();
+ }
+
+ //
+ // Standard prefix comparator implementations
+ //
+
+ public static final class UnsignedPrefixComparator extends RadixSortSupport {
+ @Override public boolean sortDescending() { return false; }
+ @Override public boolean sortSigned() { return false; }
+ @Override public boolean nullsFirst() { return true; }
+ public int compare(long aPrefix, long bPrefix) {
return UnsignedLongs.compare(aPrefix, bPrefix);
}
}
- public static final class BinaryPrefixComparator extends PrefixComparator {
- @Override
+ public static final class UnsignedPrefixComparatorNullsLast extends RadixSortSupport {
+ @Override public boolean sortDescending() { return false; }
+ @Override public boolean sortSigned() { return false; }
+ @Override public boolean nullsFirst() { return false; }
public int compare(long aPrefix, long bPrefix) {
return UnsignedLongs.compare(aPrefix, bPrefix);
}
+ }
- public static long computePrefix(byte[] bytes) {
- return ByteArray.getPrefix(bytes);
+ public static final class UnsignedPrefixComparatorDescNullsFirst extends RadixSortSupport {
+ @Override public boolean sortDescending() { return true; }
+ @Override public boolean sortSigned() { return false; }
+ @Override public boolean nullsFirst() { return true; }
+ public int compare(long bPrefix, long aPrefix) {
+ return UnsignedLongs.compare(aPrefix, bPrefix);
}
}
- public static final class BinaryPrefixComparatorDesc extends PrefixComparator {
- @Override
+ public static final class UnsignedPrefixComparatorDesc extends RadixSortSupport {
+ @Override public boolean sortDescending() { return true; }
+ @Override public boolean sortSigned() { return false; }
+ @Override public boolean nullsFirst() { return false; }
public int compare(long bPrefix, long aPrefix) {
return UnsignedLongs.compare(aPrefix, bPrefix);
}
}
- public static final class LongPrefixComparator extends PrefixComparator {
- @Override
+ public static final class SignedPrefixComparator extends RadixSortSupport {
+ @Override public boolean sortDescending() { return false; }
+ @Override public boolean sortSigned() { return true; }
+ @Override public boolean nullsFirst() { return true; }
public int compare(long a, long b) {
return (a < b) ? -1 : (a > b) ? 1 : 0;
}
}
- public static final class LongPrefixComparatorDesc extends PrefixComparator {
- @Override
- public int compare(long b, long a) {
+ public static final class SignedPrefixComparatorNullsLast extends RadixSortSupport {
+ @Override public boolean sortDescending() { return false; }
+ @Override public boolean sortSigned() { return true; }
+ @Override public boolean nullsFirst() { return false; }
+ public int compare(long a, long b) {
return (a < b) ? -1 : (a > b) ? 1 : 0;
}
}
- public static final class DoublePrefixComparator extends PrefixComparator {
- @Override
- public int compare(long aPrefix, long bPrefix) {
- double a = Double.longBitsToDouble(aPrefix);
- double b = Double.longBitsToDouble(bPrefix);
- return Utils.nanSafeCompareDoubles(a, b);
- }
-
- public static long computePrefix(double value) {
- return Double.doubleToLongBits(value);
+ public static final class SignedPrefixComparatorDescNullsFirst extends RadixSortSupport {
+ @Override public boolean sortDescending() { return true; }
+ @Override public boolean sortSigned() { return true; }
+ @Override public boolean nullsFirst() { return true; }
+ public int compare(long b, long a) {
+ return (a < b) ? -1 : (a > b) ? 1 : 0;
}
}
- public static final class DoublePrefixComparatorDesc extends PrefixComparator {
- @Override
- public int compare(long bPrefix, long aPrefix) {
- double a = Double.longBitsToDouble(aPrefix);
- double b = Double.longBitsToDouble(bPrefix);
- return Utils.nanSafeCompareDoubles(a, b);
- }
-
- public static long computePrefix(double value) {
- return Double.doubleToLongBits(value);
+ public static final class SignedPrefixComparatorDesc extends RadixSortSupport {
+ @Override public boolean sortDescending() { return true; }
+ @Override public boolean sortSigned() { return true; }
+ @Override public boolean nullsFirst() { return false; }
+ public int compare(long b, long a) {
+ return (a < b) ? -1 : (a > b) ? 1 : 0;
}
}
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
new file mode 100644
index 000000000000..3dd318471008
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import com.google.common.primitives.Ints;
+
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
+
+public class RadixSort {
+
+ /**
+ * Sorts a given array of longs using least-significant-digit radix sort. This routine assumes
+ * you have extra space at the end of the array at least equal to the number of records. The
+ * sort is destructive and may relocate the data positioned within the array.
+ *
+ * @param array array of long elements followed by at least that many empty slots.
+ * @param numRecords number of data records in the array.
+ * @param startByteIndex the first byte (in range [0, 7]) to sort each long by, counting from the
+ * least significant byte.
+ * @param endByteIndex the last byte (in range [0, 7]) to sort each long by, counting from the
+ * least significant byte. Must be greater than startByteIndex.
+ * @param desc whether this is a descending (binary-order) sort.
+ * @param signed whether this is a signed (two's complement) sort.
+ *
+ * @return The starting index of the sorted data within the given array. We return this instead
+ * of always copying the data back to position zero for efficiency.
+ */
+ public static int sort(
+ LongArray array, long numRecords, int startByteIndex, int endByteIndex,
+ boolean desc, boolean signed) {
+ assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0";
+ assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
+ assert endByteIndex > startByteIndex;
+ assert numRecords * 2 <= array.size();
+ long inIndex = 0;
+ long outIndex = numRecords;
+ if (numRecords > 0) {
+ long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex);
+ for (int i = startByteIndex; i <= endByteIndex; i++) {
+ if (counts[i] != null) {
+ sortAtByte(
+ array, numRecords, counts[i], i, inIndex, outIndex,
+ desc, signed && i == endByteIndex);
+ long tmp = inIndex;
+ inIndex = outIndex;
+ outIndex = tmp;
+ }
+ }
+ }
+ return Ints.checkedCast(inIndex);
+ }
+
+ /**
+ * Performs a partial sort by copying data into destination offsets for each byte value at the
+ * specified byte offset.
+ *
+ * @param array array to partially sort.
+ * @param numRecords number of data records in the array.
+ * @param counts counts for each byte value. This routine destructively modifies this array.
+ * @param byteIdx the byte in a long to sort at, counting from the least significant byte.
+ * @param inIndex the starting index in the array where input data is located.
+ * @param outIndex the starting index where sorted output data should be written.
+ * @param desc whether this is a descending (binary-order) sort.
+ * @param signed whether this is a signed (two's complement) sort (only applies to last byte).
+ */
+ private static void sortAtByte(
+ LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex,
+ boolean desc, boolean signed) {
+ assert counts.length == 256;
+ long[] offsets = transformCountsToOffsets(
+ counts, numRecords, array.getBaseOffset() + outIndex * 8L, 8, desc, signed);
+ Object baseObject = array.getBaseObject();
+ long baseOffset = array.getBaseOffset() + inIndex * 8L;
+ long maxOffset = baseOffset + numRecords * 8L;
+ for (long offset = baseOffset; offset < maxOffset; offset += 8) {
+ long value = Platform.getLong(baseObject, offset);
+ int bucket = (int)((value >>> (byteIdx * 8)) & 0xff);
+ Platform.putLong(baseObject, offsets[bucket], value);
+ offsets[bucket] += 8;
+ }
+ }
+
+ /**
+ * Computes a value histogram for each byte in the given array.
+ *
+ * @param array array to count records in.
+ * @param numRecords number of data records in the array.
+ * @param startByteIndex the first byte to compute counts for (the prior are skipped).
+ * @param endByteIndex the last byte to compute counts for.
+ *
+ * @return an array of eight 256-byte count arrays, one for each byte starting from the least
+ * significant byte. If the byte does not need sorting the array will be null.
+ */
+ private static long[][] getCounts(
+ LongArray array, long numRecords, int startByteIndex, int endByteIndex) {
+ long[][] counts = new long[8][];
+ // Optimization: do a fast pre-pass to determine which byte indices we can skip for sorting.
+ // If all the byte values at a particular index are the same we don't need to count it.
+ long bitwiseMax = 0;
+ long bitwiseMin = -1L;
+ long maxOffset = array.getBaseOffset() + numRecords * 8L;
+ Object baseObject = array.getBaseObject();
+ for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) {
+ long value = Platform.getLong(baseObject, offset);
+ bitwiseMax |= value;
+ bitwiseMin &= value;
+ }
+ long bitsChanged = bitwiseMin ^ bitwiseMax;
+ // Compute counts for each byte index.
+ for (int i = startByteIndex; i <= endByteIndex; i++) {
+ if (((bitsChanged >>> (i * 8)) & 0xff) != 0) {
+ counts[i] = new long[256];
+ // TODO(ekl) consider computing all the counts in one pass.
+ for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) {
+ counts[i][(int)((Platform.getLong(baseObject, offset) >>> (i * 8)) & 0xff)]++;
+ }
+ }
+ }
+ return counts;
+ }
+
+ /**
+ * Transforms counts into the proper unsafe output offsets for the sort type.
+ *
+ * @param counts counts for each byte value. This routine destructively modifies this array.
+ * @param numRecords number of data records in the original data array.
+ * @param outputOffset output offset in bytes from the base array object.
+ * @param bytesPerRecord size of each record (8 for plain sort, 16 for key-prefix sort).
+ * @param desc whether this is a descending (binary-order) sort.
+ * @param signed whether this is a signed (two's complement) sort.
+ *
+ * @return the input counts array.
+ */
+ private static long[] transformCountsToOffsets(
+ long[] counts, long numRecords, long outputOffset, long bytesPerRecord,
+ boolean desc, boolean signed) {
+ assert counts.length == 256;
+ int start = signed ? 128 : 0; // output the negative records first (values 129-255).
+ if (desc) {
+ long pos = numRecords;
+ for (int i = start; i < start + 256; i++) {
+ pos -= counts[i & 0xff];
+ counts[i & 0xff] = outputOffset + pos * bytesPerRecord;
+ }
+ } else {
+ long pos = 0;
+ for (int i = start; i < start + 256; i++) {
+ long tmp = counts[i & 0xff];
+ counts[i & 0xff] = outputOffset + pos * bytesPerRecord;
+ pos += tmp;
+ }
+ }
+ return counts;
+ }
+
+ /**
+ * Specialization of sort() for key-prefix arrays. In this type of array, each record consists
+ * of two longs, only the second of which is sorted on.
+ *
+ * @param startIndex starting index in the array to sort from. This parameter is not supported
+ * in the plain sort() implementation.
+ */
+ public static int sortKeyPrefixArray(
+ LongArray array,
+ long startIndex,
+ long numRecords,
+ int startByteIndex,
+ int endByteIndex,
+ boolean desc,
+ boolean signed) {
+ assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0";
+ assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
+ assert endByteIndex > startByteIndex;
+ assert numRecords * 4 <= array.size();
+ long inIndex = startIndex;
+ long outIndex = startIndex + numRecords * 2L;
+ if (numRecords > 0) {
+ long[][] counts = getKeyPrefixArrayCounts(
+ array, startIndex, numRecords, startByteIndex, endByteIndex);
+ for (int i = startByteIndex; i <= endByteIndex; i++) {
+ if (counts[i] != null) {
+ sortKeyPrefixArrayAtByte(
+ array, numRecords, counts[i], i, inIndex, outIndex,
+ desc, signed && i == endByteIndex);
+ long tmp = inIndex;
+ inIndex = outIndex;
+ outIndex = tmp;
+ }
+ }
+ }
+ return Ints.checkedCast(inIndex);
+ }
+
+ /**
+ * Specialization of getCounts() for key-prefix arrays. We could probably combine this with
+ * getCounts with some added parameters but that seems to hurt in benchmarks.
+ */
+ private static long[][] getKeyPrefixArrayCounts(
+ LongArray array, long startIndex, long numRecords, int startByteIndex, int endByteIndex) {
+ long[][] counts = new long[8][];
+ long bitwiseMax = 0;
+ long bitwiseMin = -1L;
+ long baseOffset = array.getBaseOffset() + startIndex * 8L;
+ long limit = baseOffset + numRecords * 16L;
+ Object baseObject = array.getBaseObject();
+ for (long offset = baseOffset; offset < limit; offset += 16) {
+ long value = Platform.getLong(baseObject, offset + 8);
+ bitwiseMax |= value;
+ bitwiseMin &= value;
+ }
+ long bitsChanged = bitwiseMin ^ bitwiseMax;
+ for (int i = startByteIndex; i <= endByteIndex; i++) {
+ if (((bitsChanged >>> (i * 8)) & 0xff) != 0) {
+ counts[i] = new long[256];
+ for (long offset = baseOffset; offset < limit; offset += 16) {
+ counts[i][(int)((Platform.getLong(baseObject, offset + 8) >>> (i * 8)) & 0xff)]++;
+ }
+ }
+ }
+ return counts;
+ }
+
+ /**
+ * Specialization of sortAtByte() for key-prefix arrays.
+ */
+ private static void sortKeyPrefixArrayAtByte(
+ LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex,
+ boolean desc, boolean signed) {
+ assert counts.length == 256;
+ long[] offsets = transformCountsToOffsets(
+ counts, numRecords, array.getBaseOffset() + outIndex * 8L, 16, desc, signed);
+ Object baseObject = array.getBaseObject();
+ long baseOffset = array.getBaseOffset() + inIndex * 8L;
+ long maxOffset = baseOffset + numRecords * 16L;
+ for (long offset = baseOffset; offset < maxOffset; offset += 16) {
+ long key = Platform.getLong(baseObject, offset);
+ long prefix = Platform.getLong(baseObject, offset + 8);
+ int bucket = (int)((prefix >>> (byteIdx * 8)) & 0xff);
+ long dest = offsets[bucket];
+ Platform.putLong(baseObject, dest, key);
+ Platform.putLong(baseObject, dest + 8, prefix);
+ offsets[bucket] += 16;
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
index 09e425879220..02b5de8e128c 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
@@ -32,6 +32,8 @@ public abstract class RecordComparator {
public abstract int compare(
Object leftBaseObject,
long leftBaseOffset,
+ int leftBaseLength,
Object rightBaseObject,
- long rightBaseOffset);
+ long rightBaseOffset,
+ int rightBaseLength);
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java
index dbf6770e0739..e9571aa8bb05 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java
@@ -17,11 +17,9 @@
package org.apache.spark.util.collection.unsafe.sort;
-import org.apache.spark.memory.TaskMemoryManager;
-
-final class RecordPointerAndKeyPrefix {
+public final class RecordPointerAndKeyPrefix {
/**
- * A pointer to a record; see {@link TaskMemoryManager} for a
+ * A pointer to a record; see {@link org.apache.spark.memory.TaskMemoryManager} for a
* description of how these addresses are encoded.
*/
public long recordPointer;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index cba043bc48cc..e1ba58871bbe 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -21,6 +21,7 @@
import java.io.File;
import java.io.IOException;
import java.util.LinkedList;
+import java.util.Queue;
import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
@@ -30,10 +31,12 @@
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.UnsafeAlignedOffset;
+import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.util.TaskCompletionListener;
import org.apache.spark.util.Utils;
/**
@@ -41,18 +44,28 @@
*/
public final class UnsafeExternalSorter extends MemoryConsumer {
- private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class);
+ private static final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class);
+ @Nullable
private final PrefixComparator prefixComparator;
+ @Nullable
private final RecordComparator recordComparator;
private final TaskMemoryManager taskMemoryManager;
private final BlockManager blockManager;
+ private final SerializerManager serializerManager;
private final TaskContext taskContext;
private ShuffleWriteMetrics writeMetrics;
/** The buffer size to use when writing spills using DiskBlockObjectWriter */
private final int fileBufferSizeBytes;
+ /**
+ * Force this sorter to spill when there are this many elements in memory. The default value is
+ * 1024 * 1024 * 1024 / 2 which allows the maximum size of the pointer array to be 8G.
+ */
+ public static final long DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD = 1024 * 1024 * 1024 / 2;
+
+ private final long numElementsForSpillThreshold;
/**
* Memory pages that hold the records being sorted. The pages in this list are freed when
* spilling, although in principle we could recycle these pages across spills (on the other hand,
@@ -69,19 +82,24 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
private MemoryBlock currentPage = null;
private long pageCursor = -1;
private long peakMemoryUsedBytes = 0;
+ private long totalSpillBytes = 0L;
+ private long totalSortTimeNanos = 0L;
private volatile SpillableIterator readingIterator = null;
public static UnsafeExternalSorter createWithExistingInMemorySorter(
TaskMemoryManager taskMemoryManager,
BlockManager blockManager,
+ SerializerManager serializerManager,
TaskContext taskContext,
RecordComparator recordComparator,
PrefixComparator prefixComparator,
int initialSize,
long pageSizeBytes,
+ long numElementsForSpillThreshold,
UnsafeInMemorySorter inMemorySorter) throws IOException {
UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager,
- taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter);
+ serializerManager, taskContext, recordComparator, prefixComparator, initialSize,
+ numElementsForSpillThreshold, pageSizeBytes, inMemorySorter, false /* ignored */);
sorter.spill(Long.MAX_VALUE, sorter);
// The external sorter will be used to insert records, in-memory sorter is not needed.
sorter.inMemSorter = null;
@@ -91,57 +109,61 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter(
public static UnsafeExternalSorter create(
TaskMemoryManager taskMemoryManager,
BlockManager blockManager,
+ SerializerManager serializerManager,
TaskContext taskContext,
RecordComparator recordComparator,
PrefixComparator prefixComparator,
int initialSize,
- long pageSizeBytes) {
- return new UnsafeExternalSorter(taskMemoryManager, blockManager,
- taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null);
+ long pageSizeBytes,
+ long numElementsForSpillThreshold,
+ boolean canUseRadixSort) {
+ return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager,
+ taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes,
+ numElementsForSpillThreshold, null, canUseRadixSort);
}
private UnsafeExternalSorter(
TaskMemoryManager taskMemoryManager,
BlockManager blockManager,
+ SerializerManager serializerManager,
TaskContext taskContext,
RecordComparator recordComparator,
PrefixComparator prefixComparator,
int initialSize,
long pageSizeBytes,
- @Nullable UnsafeInMemorySorter existingInMemorySorter) {
- super(taskMemoryManager, pageSizeBytes);
+ long numElementsForSpillThreshold,
+ @Nullable UnsafeInMemorySorter existingInMemorySorter,
+ boolean canUseRadixSort) {
+ super(taskMemoryManager, pageSizeBytes, taskMemoryManager.getTungstenMemoryMode());
this.taskMemoryManager = taskMemoryManager;
this.blockManager = blockManager;
+ this.serializerManager = serializerManager;
this.taskContext = taskContext;
this.recordComparator = recordComparator;
this.prefixComparator = prefixComparator;
// Use getSizeAsKb (not bytes) to maintain backwards compatibility for units
- // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+ // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024
this.fileBufferSizeBytes = 32 * 1024;
- // TODO: metrics tracking + integration with shuffle write metrics
- // need to connect the write metrics to task metrics so we count the spill IO somewhere.
+ // The spill metrics are stored in a new ShuffleWriteMetrics,
+ // and then discarded (this fixes SPARK-16827).
+ // TODO: Instead, separate spill metrics should be stored and reported (tracked in SPARK-3577).
this.writeMetrics = new ShuffleWriteMetrics();
if (existingInMemorySorter == null) {
- this.inMemSorter =
- new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize);
- acquireMemory(inMemSorter.getMemoryUsage());
+ this.inMemSorter = new UnsafeInMemorySorter(
+ this, taskMemoryManager, recordComparator, prefixComparator, initialSize, canUseRadixSort);
} else {
this.inMemSorter = existingInMemorySorter;
}
this.peakMemoryUsedBytes = getMemoryUsage();
+ this.numElementsForSpillThreshold = numElementsForSpillThreshold;
// Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
// the end of the task. This is necessary to avoid memory leaks in when the downstream operator
// does not fully consume the sorter's output (e.g. sort followed by limit).
- taskContext.addTaskCompletionListener(
- new TaskCompletionListener() {
- @Override
- public void onTaskCompletion(TaskContext context) {
- cleanupResources();
- }
- }
- );
+ taskContext.addTaskCompletionListener(context -> {
+ cleanupResources();
+ });
}
/**
@@ -192,16 +214,19 @@ public long spill(long size, MemoryConsumer trigger) throws IOException {
spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix());
}
spillWriter.close();
-
- inMemSorter.reset();
}
final long spillSize = freeMemory();
// Note that this is more-or-less going to be a multiple of the page size, so wasted space in
// pages will currently be counted as memory spilled even though that space isn't actually
// written to disk. This also counts the space needed to store the sorter's pointer array.
- taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
+ inMemSorter.reset();
+ // Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the
+ // records. Otherwise, if the task is over allocated memory, then without freeing the memory
+ // pages, we might not be able to get memory for the pointer array.
+ taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
+ totalSpillBytes += spillSize;
return spillSize;
}
@@ -232,6 +257,24 @@ public long getPeakMemoryUsedBytes() {
return peakMemoryUsedBytes;
}
+ /**
+ * @return the total amount of time spent sorting data (in-memory only).
+ */
+ public long getSortTimeNanos() {
+ UnsafeInMemorySorter sorter = inMemSorter;
+ if (sorter != null) {
+ return sorter.getSortTimeNanos();
+ }
+ return totalSortTimeNanos;
+ }
+
+ /**
+ * Return the total number of bytes that has been spilled into disk so far.
+ */
+ public long getSpillSize() {
+ return totalSpillBytes;
+ }
+
@VisibleForTesting
public int getNumberOfAllocatedPages() {
return allocatedPages.size();
@@ -277,9 +320,8 @@ public void cleanupResources() {
deleteSpillFiles();
freeMemory();
if (inMemSorter != null) {
- long used = inMemSorter.getMemoryUsage();
+ inMemSorter.free();
inMemSorter = null;
- releaseMemory(used);
}
}
}
@@ -293,26 +335,23 @@ private void growPointerArrayIfNecessary() throws IOException {
assert(inMemSorter != null);
if (!inMemSorter.hasSpaceForAnotherRecord()) {
long used = inMemSorter.getMemoryUsage();
- long needed = used + inMemSorter.getMemoryToExpand();
+ LongArray array;
try {
- acquireMemory(needed); // could trigger spilling
+ // could trigger spilling
+ array = allocateArray(used / 8 * 2);
} catch (OutOfMemoryError e) {
// should have trigger spilling
- assert(inMemSorter.hasSpaceForAnotherRecord());
+ if (!inMemSorter.hasSpaceForAnotherRecord()) {
+ logger.error("Unable to grow the pointer array");
+ throw e;
+ }
return;
}
// check if spilling is triggered or not
if (inMemSorter.hasSpaceForAnotherRecord()) {
- releaseMemory(needed);
+ freeArray(array);
} else {
- try {
- inMemSorter.expandPointerArray();
- releaseMemory(used);
- } catch (OutOfMemoryError oom) {
- // Just in case that JVM had run out of memory
- releaseMemory(needed);
- spill();
- }
+ inMemSorter.expandPointerArray(array);
}
}
}
@@ -339,22 +378,30 @@ private void acquireNewPageIfNecessary(int required) {
/**
* Write a record to the sorter.
*/
- public void insertRecord(Object recordBase, long recordOffset, int length, long prefix)
+ public void insertRecord(
+ Object recordBase, long recordOffset, int length, long prefix, boolean prefixIsNull)
throws IOException {
+ assert(inMemSorter != null);
+ if (inMemSorter.numRecords() >= numElementsForSpillThreshold) {
+ logger.info("Spilling data because number of spilledRecords crossed the threshold " +
+ numElementsForSpillThreshold);
+ spill();
+ }
+
growPointerArrayIfNecessary();
+ int uaoSize = UnsafeAlignedOffset.getUaoSize();
// Need 4 bytes to store the record length.
- final int required = length + 4;
+ final int required = length + uaoSize;
acquireNewPageIfNecessary(required);
final Object base = currentPage.getBaseObject();
final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor);
- Platform.putInt(base, pageCursor, length);
- pageCursor += 4;
+ UnsafeAlignedOffset.putSize(base, pageCursor, length);
+ pageCursor += uaoSize;
Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
pageCursor += length;
- assert(inMemSorter != null);
- inMemSorter.insertRecord(recordAddress, prefix);
+ inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull);
}
/**
@@ -366,26 +413,27 @@ public void insertRecord(Object recordBase, long recordOffset, int length, long
* record length = key length + value length + 4
*/
public void insertKVRecord(Object keyBase, long keyOffset, int keyLen,
- Object valueBase, long valueOffset, int valueLen, long prefix)
+ Object valueBase, long valueOffset, int valueLen, long prefix, boolean prefixIsNull)
throws IOException {
growPointerArrayIfNecessary();
- final int required = keyLen + valueLen + 4 + 4;
+ int uaoSize = UnsafeAlignedOffset.getUaoSize();
+ final int required = keyLen + valueLen + (2 * uaoSize);
acquireNewPageIfNecessary(required);
final Object base = currentPage.getBaseObject();
final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor);
- Platform.putInt(base, pageCursor, keyLen + valueLen + 4);
- pageCursor += 4;
- Platform.putInt(base, pageCursor, keyLen);
- pageCursor += 4;
+ UnsafeAlignedOffset.putSize(base, pageCursor, keyLen + valueLen + uaoSize);
+ pageCursor += uaoSize;
+ UnsafeAlignedOffset.putSize(base, pageCursor, keyLen);
+ pageCursor += uaoSize;
Platform.copyMemory(keyBase, keyOffset, base, pageCursor, keyLen);
pageCursor += keyLen;
Platform.copyMemory(valueBase, valueOffset, base, pageCursor, valueLen);
pageCursor += valueLen;
assert(inMemSorter != null);
- inMemSorter.insertRecord(recordAddress, prefix);
+ inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull);
}
/**
@@ -406,6 +454,7 @@ public void merge(UnsafeExternalSorter other) throws IOException {
* after consuming this iterator.
*/
public UnsafeSorterIterator getSortedIterator() throws IOException {
+ assert(recordComparator != null);
if (spillWriters.isEmpty()) {
assert(inMemSorter != null);
readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
@@ -414,7 +463,7 @@ public UnsafeSorterIterator getSortedIterator() throws IOException {
final UnsafeSorterSpillMerger spillMerger =
new UnsafeSorterSpillMerger(recordComparator, prefixComparator, spillWriters.size());
for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
- spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager));
+ spillMerger.addSpillIfNotEmpty(spillWriter.getReader(serializerManager));
}
if (inMemSorter != null) {
readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
@@ -424,6 +473,10 @@ public UnsafeSorterIterator getSortedIterator() throws IOException {
}
}
+ @VisibleForTesting boolean hasSpaceForAnotherRecord() {
+ return inMemSorter.hasSpaceForAnotherRecord();
+ }
+
/**
* An UnsafeSorterIterator that support spilling.
*/
@@ -434,9 +487,13 @@ class SpillableIterator extends UnsafeSorterIterator {
private boolean loaded = false;
private int numRecords = 0;
- public SpillableIterator(UnsafeInMemorySorter.SortedIterator inMemIterator) {
+ SpillableIterator(UnsafeSorterIterator inMemIterator) {
this.upstream = inMemIterator;
- this.numRecords = inMemIterator.numRecordsLeft();
+ this.numRecords = inMemIterator.getNumRecords();
+ }
+
+ public int getNumRecords() {
+ return numRecords;
}
public long spill() throws IOException {
@@ -449,6 +506,7 @@ public long spill() throws IOException {
UnsafeInMemorySorter.SortedIterator inMemIterator =
((UnsafeInMemorySorter.SortedIterator) upstream).clone();
+ // Iterate over the records that have not been returned and spill them.
final UnsafeSorterSpillWriter spillWriter =
new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords);
while (inMemIterator.hasNext()) {
@@ -460,13 +518,16 @@ public long spill() throws IOException {
}
spillWriter.close();
spillWriters.add(spillWriter);
- nextUpstream = spillWriter.getReader(blockManager);
+ nextUpstream = spillWriter.getReader(serializerManager);
long released = 0L;
synchronized (UnsafeExternalSorter.this) {
- // release the pages except the one that is used
+ // release the pages except the one that is used. There can still be a caller that
+ // is accessing the current record. We free this page in that caller's next loadNext()
+ // call.
for (MemoryBlock page : allocatedPages) {
- if (!loaded || page.getBaseObject() != inMemIterator.getBaseObject()) {
+ if (!loaded || page.pageNumber !=
+ ((UnsafeInMemorySorter.SortedIterator)upstream).getCurrentPageNumber()) {
released += page.size();
freePage(page);
} else {
@@ -475,6 +536,15 @@ public long spill() throws IOException {
}
allocatedPages.clear();
}
+
+ // in-memory sorter will not be used after spilling
+ assert(inMemSorter != null);
+ released += inMemSorter.getMemoryUsage();
+ totalSortTimeNanos += inMemSorter.getSortTimeNanos();
+ inMemSorter.free();
+ inMemSorter = null;
+ taskContext.taskMetrics().incMemoryBytesSpilled(released);
+ totalSpillBytes += released;
return released;
}
}
@@ -496,11 +566,6 @@ public void loadNext() throws IOException {
}
upstream = nextUpstream;
nextUpstream = null;
-
- assert(inMemSorter != null);
- long used = inMemSorter.getMemoryUsage();
- inMemSorter = null;
- releaseMemory(used);
}
numRecords--;
upstream.loadNext();
@@ -527,4 +592,81 @@ public long getKeyPrefix() {
return upstream.getKeyPrefix();
}
}
+
+ /**
+ * Returns a iterator, which will return the rows in the order as inserted.
+ *
+ * It is the caller's responsibility to call `cleanupResources()`
+ * after consuming this iterator.
+ *
+ * TODO: support forced spilling
+ */
+ public UnsafeSorterIterator getIterator() throws IOException {
+ if (spillWriters.isEmpty()) {
+ assert(inMemSorter != null);
+ return inMemSorter.getSortedIterator();
+ } else {
+ LinkedList queue = new LinkedList<>();
+ for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
+ queue.add(spillWriter.getReader(serializerManager));
+ }
+ if (inMemSorter != null) {
+ queue.add(inMemSorter.getSortedIterator());
+ }
+ return new ChainedIterator(queue);
+ }
+ }
+
+ /**
+ * Chain multiple UnsafeSorterIterator together as single one.
+ */
+ static class ChainedIterator extends UnsafeSorterIterator {
+
+ private final Queue iterators;
+ private UnsafeSorterIterator current;
+ private int numRecords;
+
+ ChainedIterator(Queue iterators) {
+ assert iterators.size() > 0;
+ this.numRecords = 0;
+ for (UnsafeSorterIterator iter: iterators) {
+ this.numRecords += iter.getNumRecords();
+ }
+ this.iterators = iterators;
+ this.current = iterators.remove();
+ }
+
+ @Override
+ public int getNumRecords() {
+ return numRecords;
+ }
+
+ @Override
+ public boolean hasNext() {
+ while (!current.hasNext() && !iterators.isEmpty()) {
+ current = iterators.remove();
+ }
+ return current.hasNext();
+ }
+
+ @Override
+ public void loadNext() throws IOException {
+ while (!current.hasNext() && !iterators.isEmpty()) {
+ current = iterators.remove();
+ }
+ current.loadNext();
+ }
+
+ @Override
+ public Object getBaseObject() { return current.getBaseObject(); }
+
+ @Override
+ public long getBaseOffset() { return current.getBaseOffset(); }
+
+ @Override
+ public int getRecordLength() { return current.getRecordLength(); }
+
+ @Override
+ public long getKeyPrefix() { return current.getKeyPrefix(); }
+ }
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index d57213b9b8bf..b02581199295 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -18,9 +18,17 @@
package org.apache.spark.util.collection.unsafe.sort;
import java.util.Comparator;
+import java.util.LinkedList;
+import org.apache.avro.reflect.Nullable;
+
+import org.apache.spark.TaskContext;
+import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.UnsafeAlignedOffset;
+import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.collection.Sorter;
/**
@@ -50,54 +58,137 @@ private static final class SortComparator implements Comparator sorter;
+ @Nullable
private final Comparator sortComparator;
/**
- * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at
+ * If non-null, specifies the radix sort parameters and that radix sort will be used.
+ */
+ @Nullable
+ private final PrefixComparators.RadixSortSupport radixSortSupport;
+
+ /**
+ * Within this buffer, position {@code 2 * i} holds a pointer to the record at
* index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
+ *
+ * Only part of the array will be used to store the pointers, the rest part is preserved as
+ * temporary buffer for sorting.
*/
- private long[] array;
+ private LongArray array;
/**
* The position in the sort buffer where new records can be inserted.
*/
private int pos = 0;
+ /**
+ * If sorting with radix sort, specifies the starting position in the sort buffer where records
+ * with non-null prefixes are kept. Positions [0..nullBoundaryPos) will contain null-prefixed
+ * records, and positions [nullBoundaryPos..pos) non-null prefixed records. This lets us avoid
+ * radix sorting over null values.
+ */
+ private int nullBoundaryPos = 0;
+
+ /*
+ * How many records could be inserted, because part of the array should be left for sorting.
+ */
+ private int usableCapacity = 0;
+
+ private long initialSize;
+
+ private long totalSortTimeNanos = 0L;
+
public UnsafeInMemorySorter(
+ final MemoryConsumer consumer,
final TaskMemoryManager memoryManager,
final RecordComparator recordComparator,
final PrefixComparator prefixComparator,
- int initialSize) {
- this(memoryManager, recordComparator, prefixComparator, new long[initialSize * 2]);
+ int initialSize,
+ boolean canUseRadixSort) {
+ this(consumer, memoryManager, recordComparator, prefixComparator,
+ consumer.allocateArray(initialSize * 2), canUseRadixSort);
}
public UnsafeInMemorySorter(
+ final MemoryConsumer consumer,
final TaskMemoryManager memoryManager,
final RecordComparator recordComparator,
final PrefixComparator prefixComparator,
- long[] array) {
- this.array = array;
+ LongArray array,
+ boolean canUseRadixSort) {
+ this.consumer = consumer;
this.memoryManager = memoryManager;
- this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
- this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
+ this.initialSize = array.size();
+ if (recordComparator != null) {
+ this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
+ if (canUseRadixSort && prefixComparator instanceof PrefixComparators.RadixSortSupport) {
+ this.radixSortSupport = (PrefixComparators.RadixSortSupport)prefixComparator;
+ } else {
+ this.radixSortSupport = null;
+ }
+ } else {
+ this.sortComparator = null;
+ this.radixSortSupport = null;
+ }
+ this.array = array;
+ this.usableCapacity = getUsableCapacity();
+ }
+
+ private int getUsableCapacity() {
+ // Radix sort requires same amount of used memory as buffer, Tim sort requires
+ // half of the used memory as buffer.
+ return (int) (array.size() / (radixSortSupport != null ? 2 : 1.5));
+ }
+
+ /**
+ * Free the memory used by pointer array.
+ */
+ public void free() {
+ if (consumer != null) {
+ if (array != null) {
+ consumer.freeArray(array);
+ }
+ array = null;
+ }
}
public void reset() {
+ if (consumer != null) {
+ consumer.freeArray(array);
+ // the call to consumer.allocateArray may trigger a spill
+ // which in turn access this instance and eventually re-enter this method
+ // and try to free the array again.
+ // By setting the array to null and its length to 0
+ // we effectively make the spill code-path a no-op.
+ // Setting the array to null also indicates that it has already been
+ // de-allocated which prevents a double de-allocation in free().
+ array = null;
+ usableCapacity = 0;
+ pos = 0;
+ nullBoundaryPos = 0;
+ array = consumer.allocateArray(initialSize);
+ usableCapacity = getUsableCapacity();
+ }
pos = 0;
+ nullBoundaryPos = 0;
}
/**
@@ -107,26 +198,34 @@ public int numRecords() {
return pos / 2;
}
- private int newLength() {
- return array.length < Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE;
- }
-
- public long getMemoryToExpand() {
- return (long) (newLength() - array.length) * 8L;
+ /**
+ * @return the total amount of time spent sorting data (in-memory only).
+ */
+ public long getSortTimeNanos() {
+ return totalSortTimeNanos;
}
public long getMemoryUsage() {
- return array.length * 8L;
+ return array.size() * 8;
}
public boolean hasSpaceForAnotherRecord() {
- return pos + 2 <= array.length;
+ return pos + 1 < usableCapacity;
}
- public void expandPointerArray() {
- final long[] oldArray = array;
- array = new long[newLength()];
- System.arraycopy(oldArray, 0, array, 0, oldArray.length);
+ public void expandPointerArray(LongArray newArray) {
+ if (newArray.size() < array.size()) {
+ throw new OutOfMemoryError("Not enough memory to grow pointer array");
+ }
+ Platform.copyMemory(
+ array.getBaseObject(),
+ array.getBaseOffset(),
+ newArray.getBaseObject(),
+ newArray.getBaseOffset(),
+ pos * 8L);
+ consumer.freeArray(array);
+ array = newArray;
+ usableCapacity = getUsableCapacity();
}
/**
@@ -136,63 +235,87 @@ public void expandPointerArray() {
* @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}.
* @param keyPrefix a user-defined key prefix
*/
- public void insertRecord(long recordPointer, long keyPrefix) {
+ public void insertRecord(long recordPointer, long keyPrefix, boolean prefixIsNull) {
if (!hasSpaceForAnotherRecord()) {
- expandPointerArray();
+ throw new IllegalStateException("There is no space for new record");
+ }
+ if (prefixIsNull && radixSortSupport != null) {
+ // Swap forward a non-null record to make room for this one at the beginning of the array.
+ array.set(pos, array.get(nullBoundaryPos));
+ pos++;
+ array.set(pos, array.get(nullBoundaryPos + 1));
+ pos++;
+ // Place this record in the vacated position.
+ array.set(nullBoundaryPos, recordPointer);
+ nullBoundaryPos++;
+ array.set(nullBoundaryPos, keyPrefix);
+ nullBoundaryPos++;
+ } else {
+ array.set(pos, recordPointer);
+ pos++;
+ array.set(pos, keyPrefix);
+ pos++;
}
- array[pos] = recordPointer;
- pos++;
- array[pos] = keyPrefix;
- pos++;
}
- public static final class SortedIterator extends UnsafeSorterIterator {
+ public final class SortedIterator extends UnsafeSorterIterator implements Cloneable {
- private final TaskMemoryManager memoryManager;
- private final int sortBufferInsertPosition;
- private final long[] sortBuffer;
- private int position = 0;
+ private final int numRecords;
+ private int position;
+ private int offset;
private Object baseObject;
private long baseOffset;
private long keyPrefix;
private int recordLength;
+ private long currentPageNumber;
+ private final TaskContext taskContext = TaskContext.get();
- private SortedIterator(
- TaskMemoryManager memoryManager,
- int sortBufferInsertPosition,
- long[] sortBuffer) {
- this.memoryManager = memoryManager;
- this.sortBufferInsertPosition = sortBufferInsertPosition;
- this.sortBuffer = sortBuffer;
+ private SortedIterator(int numRecords, int offset) {
+ this.numRecords = numRecords;
+ this.position = 0;
+ this.offset = offset;
}
- public SortedIterator clone () {
- SortedIterator iter = new SortedIterator(memoryManager, sortBufferInsertPosition, sortBuffer);
+ public SortedIterator clone() {
+ SortedIterator iter = new SortedIterator(numRecords, offset);
iter.position = position;
iter.baseObject = baseObject;
iter.baseOffset = baseOffset;
iter.keyPrefix = keyPrefix;
iter.recordLength = recordLength;
+ iter.currentPageNumber = currentPageNumber;
return iter;
}
@Override
- public boolean hasNext() {
- return position < sortBufferInsertPosition;
+ public int getNumRecords() {
+ return numRecords;
}
- public int numRecordsLeft() {
- return (sortBufferInsertPosition - position) / 2;
+ @Override
+ public boolean hasNext() {
+ return position / 2 < numRecords;
}
@Override
public void loadNext() {
+ // Kill the task in case it has been marked as killed. This logic is from
+ // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
+ // to avoid performance overhead. This check is added here in `loadNext()` instead of in
+ // `hasNext()` because it's technically possible for the caller to be relying on
+ // `getNumRecords()` instead of `hasNext()` to know when to stop.
+ if (taskContext != null) {
+ taskContext.killTaskIfInterrupted();
+ }
// This pointer points to a 4-byte record length, followed by the record's bytes
- final long recordPointer = sortBuffer[position];
+ final long recordPointer = array.get(offset + position);
+ currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer);
+ int uaoSize = UnsafeAlignedOffset.getUaoSize();
baseObject = memoryManager.getPage(recordPointer);
- baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length
- recordLength = Platform.getInt(baseObject, baseOffset - 4);
- keyPrefix = sortBuffer[position + 1];
+ // Skip over record length
+ baseOffset = memoryManager.getOffsetInPage(recordPointer) + uaoSize;
+ recordLength = UnsafeAlignedOffset.getSize(baseObject, baseOffset - uaoSize);
+ keyPrefix = array.get(offset + position + 1);
position += 2;
}
@@ -202,6 +325,10 @@ public void loadNext() {
@Override
public long getBaseOffset() { return baseOffset; }
+ public long getCurrentPageNumber() {
+ return currentPageNumber;
+ }
+
@Override
public int getRecordLength() { return recordLength; }
@@ -213,8 +340,41 @@ public void loadNext() {
* Return an iterator over record pointers in sorted order. For efficiency, all calls to
* {@code next()} will return the same mutable object.
*/
- public SortedIterator getSortedIterator() {
- sorter.sort(array, 0, pos / 2, sortComparator);
- return new SortedIterator(memoryManager, pos, array);
+ public UnsafeSorterIterator getSortedIterator() {
+ int offset = 0;
+ long start = System.nanoTime();
+ if (sortComparator != null) {
+ if (this.radixSortSupport != null) {
+ offset = RadixSort.sortKeyPrefixArray(
+ array, nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0, 7,
+ radixSortSupport.sortDescending(), radixSortSupport.sortSigned());
+ } else {
+ MemoryBlock unused = new MemoryBlock(
+ array.getBaseObject(),
+ array.getBaseOffset() + pos * 8L,
+ (array.size() - pos) * 8L);
+ LongArray buffer = new LongArray(unused);
+ Sorter sorter =
+ new Sorter<>(new UnsafeSortDataFormat(buffer));
+ sorter.sort(array, 0, pos / 2, sortComparator);
+ }
+ }
+ totalSortTimeNanos += System.nanoTime() - start;
+ if (nullBoundaryPos > 0) {
+ assert radixSortSupport != null : "Nulls are only stored separately with radix sort";
+ LinkedList queue = new LinkedList<>();
+
+ // The null order is either LAST or FIRST, regardless of sorting direction (ASC|DESC)
+ if (radixSortSupport.nullsFirst()) {
+ queue.add(new SortedIterator(nullBoundaryPos / 2, 0));
+ queue.add(new SortedIterator((pos - nullBoundaryPos) / 2, offset));
+ } else {
+ queue.add(new SortedIterator((pos - nullBoundaryPos) / 2, offset));
+ queue.add(new SortedIterator(nullBoundaryPos / 2, 0));
+ }
+ return new UnsafeExternalSorter.ChainedIterator(queue);
+ } else {
+ return new SortedIterator(pos / 2, offset);
+ }
}
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
index d09c728a7a63..d9f84d10e905 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
@@ -17,23 +17,28 @@
package org.apache.spark.util.collection.unsafe.sort;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.util.collection.SortDataFormat;
/**
* Supports sorting an array of (record pointer, key prefix) pairs.
* Used in {@link UnsafeInMemorySorter}.
*
- * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at
+ * Within each long[] buffer, position {@code 2 * i} holds a pointer to the record at
* index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
*/
-final class UnsafeSortDataFormat extends SortDataFormat {
+public final class UnsafeSortDataFormat
+ extends SortDataFormat {
- public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat();
+ private final LongArray buffer;
- private UnsafeSortDataFormat() { }
+ public UnsafeSortDataFormat(LongArray buffer) {
+ this.buffer = buffer;
+ }
@Override
- public RecordPointerAndKeyPrefix getKey(long[] data, int pos) {
+ public RecordPointerAndKeyPrefix getKey(LongArray data, int pos) {
// Since we re-use keys, this method shouldn't be called.
throw new UnsupportedOperationException();
}
@@ -44,37 +49,44 @@ public RecordPointerAndKeyPrefix newKey() {
}
@Override
- public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix reuse) {
- reuse.recordPointer = data[pos * 2];
- reuse.keyPrefix = data[pos * 2 + 1];
+ public RecordPointerAndKeyPrefix getKey(LongArray data, int pos,
+ RecordPointerAndKeyPrefix reuse) {
+ reuse.recordPointer = data.get(pos * 2);
+ reuse.keyPrefix = data.get(pos * 2 + 1);
return reuse;
}
@Override
- public void swap(long[] data, int pos0, int pos1) {
- long tempPointer = data[pos0 * 2];
- long tempKeyPrefix = data[pos0 * 2 + 1];
- data[pos0 * 2] = data[pos1 * 2];
- data[pos0 * 2 + 1] = data[pos1 * 2 + 1];
- data[pos1 * 2] = tempPointer;
- data[pos1 * 2 + 1] = tempKeyPrefix;
+ public void swap(LongArray data, int pos0, int pos1) {
+ long tempPointer = data.get(pos0 * 2);
+ long tempKeyPrefix = data.get(pos0 * 2 + 1);
+ data.set(pos0 * 2, data.get(pos1 * 2));
+ data.set(pos0 * 2 + 1, data.get(pos1 * 2 + 1));
+ data.set(pos1 * 2, tempPointer);
+ data.set(pos1 * 2 + 1, tempKeyPrefix);
}
@Override
- public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
- dst[dstPos * 2] = src[srcPos * 2];
- dst[dstPos * 2 + 1] = src[srcPos * 2 + 1];
+ public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) {
+ dst.set(dstPos * 2, src.get(srcPos * 2));
+ dst.set(dstPos * 2 + 1, src.get(srcPos * 2 + 1));
}
@Override
- public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
- System.arraycopy(src, srcPos * 2, dst, dstPos * 2, length * 2);
+ public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) {
+ Platform.copyMemory(
+ src.getBaseObject(),
+ src.getBaseOffset() + srcPos * 16L,
+ dst.getBaseObject(),
+ dst.getBaseOffset() + dstPos * 16L,
+ length * 16L);
}
@Override
- public long[] allocate(int length) {
- assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large";
- return new long[length * 2];
+ public LongArray allocate(int length) {
+ assert (length * 2 <= buffer.size()) :
+ "the buffer is smaller than required: " + buffer.size() + " < " + (length * 2);
+ return buffer;
}
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
index 16ac2e8d821b..1b3167fcc250 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
@@ -32,4 +32,6 @@ public abstract class UnsafeSorterIterator {
public abstract int getRecordLength();
public abstract long getKeyPrefix();
+
+ public abstract int getNumRecords();
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
index 3874a9f9cbdb..ff0dcc259a4a 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
@@ -23,28 +23,25 @@
final class UnsafeSorterSpillMerger {
+ private int numRecords = 0;
private final PriorityQueue priorityQueue;
- public UnsafeSorterSpillMerger(
- final RecordComparator recordComparator,
- final PrefixComparator prefixComparator,
- final int numSpills) {
- final Comparator comparator = new Comparator() {
-
- @Override
- public int compare(UnsafeSorterIterator left, UnsafeSorterIterator right) {
- final int prefixComparisonResult =
- prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix());
- if (prefixComparisonResult == 0) {
- return recordComparator.compare(
- left.getBaseObject(), left.getBaseOffset(),
- right.getBaseObject(), right.getBaseOffset());
- } else {
- return prefixComparisonResult;
- }
+ UnsafeSorterSpillMerger(
+ RecordComparator recordComparator,
+ PrefixComparator prefixComparator,
+ int numSpills) {
+ Comparator comparator = (left, right) -> {
+ int prefixComparisonResult =
+ prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix());
+ if (prefixComparisonResult == 0) {
+ return recordComparator.compare(
+ left.getBaseObject(), left.getBaseOffset(), left.getRecordLength(),
+ right.getBaseObject(), right.getBaseOffset(), right.getRecordLength());
+ } else {
+ return prefixComparisonResult;
}
};
- priorityQueue = new PriorityQueue(numSpills, comparator);
+ priorityQueue = new PriorityQueue<>(numSpills, comparator);
}
/**
@@ -56,9 +53,10 @@ public void addSpillIfNotEmpty(UnsafeSorterIterator spillReader) throws IOExcept
// make sure the hasNext method of UnsafeSorterIterator returned by getSortedIterator
// does not return wrong result because hasNext will returns true
// at least priorityQueue.size() times. If we allow n spillReaders in the
- // priorityQueue, we will have n extra empty records in the result of the UnsafeSorterIterator.
+ // priorityQueue, we will have n extra empty records in the result of UnsafeSorterIterator.
spillReader.loadNext();
priorityQueue.add(spillReader);
+ numRecords += spillReader.getNumRecords();
}
}
@@ -67,6 +65,11 @@ public UnsafeSorterIterator getSortedIterator() throws IOException {
private UnsafeSorterIterator spillReader;
+ @Override
+ public int getNumRecords() {
+ return numRecords;
+ }
+
@Override
public boolean hasNext() {
return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext());
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
index 039e940a357e..9521ab86a12d 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -20,43 +20,73 @@
import java.io.*;
import com.google.common.io.ByteStreams;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+import com.google.common.io.Closeables;
+import org.apache.spark.SparkEnv;
+import org.apache.spark.TaskContext;
+import org.apache.spark.io.NioBufferedFileInputStream;
+import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.storage.BlockId;
-import org.apache.spark.storage.BlockManager;
import org.apache.spark.unsafe.Platform;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
/**
* Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description
* of the file format).
*/
-public final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
+public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implements Closeable {
private static final Logger logger = LoggerFactory.getLogger(UnsafeSorterSpillReader.class);
+ private static final int DEFAULT_BUFFER_SIZE_BYTES = 1024 * 1024; // 1 MB
+ private static final int MAX_BUFFER_SIZE_BYTES = 16777216; // 16 mb
- private final File file;
private InputStream in;
private DataInputStream din;
// Variables that change with every record read:
private int recordLength;
private long keyPrefix;
+ private int numRecords;
private int numRecordsRemaining;
private byte[] arr = new byte[1024 * 1024];
private Object baseObject = arr;
private final long baseOffset = Platform.BYTE_ARRAY_OFFSET;
+ private final TaskContext taskContext = TaskContext.get();
public UnsafeSorterSpillReader(
- BlockManager blockManager,
+ SerializerManager serializerManager,
File file,
BlockId blockId) throws IOException {
assert (file.length() > 0);
- this.file = file;
- final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file));
- this.in = blockManager.wrapForCompression(blockId, bs);
- this.din = new DataInputStream(this.in);
- numRecordsRemaining = din.readInt();
+ long bufferSizeBytes =
+ SparkEnv.get() == null ?
+ DEFAULT_BUFFER_SIZE_BYTES:
+ SparkEnv.get().conf().getSizeAsBytes("spark.unsafe.sorter.spill.reader.buffer.size",
+ DEFAULT_BUFFER_SIZE_BYTES);
+ if (bufferSizeBytes > MAX_BUFFER_SIZE_BYTES || bufferSizeBytes < DEFAULT_BUFFER_SIZE_BYTES) {
+ // fall back to a sane default value
+ logger.warn("Value of config \"spark.unsafe.sorter.spill.reader.buffer.size\" = {} not in " +
+ "allowed range [{}, {}). Falling back to default value : {} bytes", bufferSizeBytes,
+ DEFAULT_BUFFER_SIZE_BYTES, MAX_BUFFER_SIZE_BYTES, DEFAULT_BUFFER_SIZE_BYTES);
+ bufferSizeBytes = DEFAULT_BUFFER_SIZE_BYTES;
+ }
+
+ final InputStream bs =
+ new NioBufferedFileInputStream(file, (int) bufferSizeBytes);
+ try {
+ this.in = serializerManager.wrapStream(blockId, bs);
+ this.din = new DataInputStream(this.in);
+ numRecords = numRecordsRemaining = din.readInt();
+ } catch (IOException e) {
+ Closeables.close(bs, /* swallowIOException = */ true);
+ throw e;
+ }
+ }
+
+ @Override
+ public int getNumRecords() {
+ return numRecords;
}
@Override
@@ -66,6 +96,14 @@ public boolean hasNext() {
@Override
public void loadNext() throws IOException {
+ // Kill the task in case it has been marked as killed. This logic is from
+ // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
+ // to avoid performance overhead. This check is added here in `loadNext()` instead of in
+ // `hasNext()` because it's technically possible for the caller to be relying on
+ // `getNumRecords()` instead of `hasNext()` to know when to stop.
+ if (taskContext != null) {
+ taskContext.killTaskIfInterrupted();
+ }
recordLength = din.readInt();
keyPrefix = din.readLong();
if (recordLength > arr.length) {
@@ -75,12 +113,7 @@ public void loadNext() throws IOException {
ByteStreams.readFully(in, arr, 0, recordLength);
numRecordsRemaining--;
if (numRecordsRemaining == 0) {
- in.close();
- if (!file.delete() && file.exists()) {
- logger.warn("Unable to delete spill file {}", file.getPath());
- }
- in = null;
- din = null;
+ close();
}
}
@@ -103,4 +136,16 @@ public int getRecordLength() {
public long getKeyPrefix() {
return keyPrefix;
}
+
+ @Override
+ public void close() throws IOException {
+ if (in != null) {
+ try {
+ in.close();
+ } finally {
+ in = null;
+ din = null;
+ }
+ }
+ }
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
index 234e21140a1d..164b9d70b79d 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
@@ -20,6 +20,7 @@
import java.io.File;
import java.io.IOException;
+import org.apache.spark.serializer.SerializerManager;
import scala.Tuple2;
import org.apache.spark.executor.ShuffleWriteMetrics;
@@ -135,7 +136,8 @@ public void write(
}
public void close() throws IOException {
- writer.commitAndClose();
+ writer.commitAndGet();
+ writer.close();
writer = null;
writeBuffer = null;
}
@@ -144,7 +146,7 @@ public File getFile() {
return file;
}
- public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException {
- return new UnsafeSorterSpillReader(blockManager, file, blockId);
+ public UnsafeSorterSpillReader getReader(SerializerManager serializerManager) throws IOException {
+ return new UnsafeSorterSpillReader(serializerManager, file, blockId);
}
}
diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties
deleted file mode 100644
index c85abc35b93b..000000000000
--- a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties
+++ /dev/null
@@ -1,33 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-# Set everything to be logged to the console
-log4j.rootCategory=WARN, console
-log4j.appender.console=org.apache.log4j.ConsoleAppender
-log4j.appender.console.target=System.err
-log4j.appender.console.layout=org.apache.log4j.PatternLayout
-log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
-
-# Settings to quiet third party logs that are too verbose
-log4j.logger.org.spark-project.jetty=WARN
-log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR
-log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
-log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
-
-# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support
-log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL
-log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR
diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties
index d44cc85dcbd8..277010015072 100644
--- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties
+++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties
@@ -22,12 +22,21 @@ log4j.appender.console.target=System.err
log4j.appender.console.layout=org.apache.log4j.PatternLayout
log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
+# Set the default spark-shell log level to WARN. When running the spark-shell, the
+# log level for this class is used to overwrite the root logger's log level, so that
+# the user can have different defaults for the shell and regular Spark apps.
+log4j.logger.org.apache.spark.repl.Main=WARN
+
# Settings to quiet third party logs that are too verbose
-log4j.logger.org.spark-project.jetty=WARN
-log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR
+log4j.logger.org.spark_project.jetty=WARN
+log4j.logger.org.spark_project.jetty.util.component.AbstractLifeCycle=ERROR
log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support
log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL
log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR
+
+# Parquet related logging
+log4j.logger.org.apache.parquet.CorruptStatistics=ERROR
+log4j.logger.parquet.CorruptStatistics=ERROR
diff --git a/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js b/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js
index 2d9262b972a5..6fe8136c87ae 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js
@@ -1,4 +1,5 @@
-/* This is a custom version of dagre-d3 on top of v0.4.3. The full list of commits can be found at http://github.com/andrewor14/dagre-d3/ */!function(e){if("object"==typeof exports&&"undefined"!=typeof module)module.exports=e();else if("function"==typeof define&&define.amd)define([],e);else{var f;"undefined"!=typeof window?f=window:"undefined"!=typeof global?f=global:"undefined"!=typeof self&&(f=self),f.dagreD3=e()}}(function(){var define,module,exports;return function e(t,n,r){function s(o,u){if(!n[o]){if(!t[o]){var a=typeof require=="function"&&require;if(!u&&a)return a(o,!0);if(i)return i(o,!0);var f=new Error("Cannot find module '"+o+"'");throw f.code="MODULE_NOT_FOUND",f}var l=n[o]={exports:{}};t[o][0].call(l.exports,function(e){var n=t[o][1][e];return s(n?n:e)},l,l.exports,e,t,n,r)}return n[o].exports}var i=typeof require=="function"&&require;for(var o=0;o0}},{}],14:[function(require,module,exports){module.exports=intersectNode;function intersectNode(node,point){return node.intersect(point)}},{}],15:[function(require,module,exports){var intersectLine=require("./intersect-line");module.exports=intersectPolygon;function intersectPolygon(node,polyPoints,point){var x1=node.x;var y1=node.y;var intersections=[];var minX=Number.POSITIVE_INFINITY,minY=Number.POSITIVE_INFINITY;polyPoints.forEach(function(entry){minX=Math.min(minX,entry.x);minY=Math.min(minY,entry.y)});var left=x1-node.width/2-minX;var top=y1-node.height/2-minY;for(var i=0;i1){intersections.sort(function(p,q){var pdx=p.x-point.x,pdy=p.y-point.y,distp=Math.sqrt(pdx*pdx+pdy*pdy),qdx=q.x-point.x,qdy=q.y-point.y,distq=Math.sqrt(qdx*qdx+qdy*qdy);return distpMath.abs(dx)*h){if(dy<0){h=-h}sx=dy===0?0:h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=dx===0?0:w*dy/dx}return{x:x+sx,y:y+sy}}},{}],17:[function(require,module,exports){var util=require("../util");module.exports=addHtmlLabel;function addHtmlLabel(root,node){var fo=root.append("foreignObject").attr("width","100000");var div=fo.append("xhtml:div");var label=node.label;switch(typeof label){case"function":div.insert(label);break;case"object":div.insert(function(){return label});break;default:div.html(label)}util.applyStyle(div,node.labelStyle);div.style("display","inline-block");div.style("white-space","nowrap");var w,h;div.each(function(){w=this.clientWidth;h=this.clientHeight});fo.attr("width",w).attr("height",h);return fo}},{"../util":25}],18:[function(require,module,exports){var addTextLabel=require("./add-text-label"),addHtmlLabel=require("./add-html-label");module.exports=addLabel;function addLabel(root,node){var label=node.label;var labelSvg=root.append("g");if(typeof label!=="string"||node.labelType==="html"){addHtmlLabel(labelSvg,node)}else{addTextLabel(labelSvg,node)}var labelBBox=labelSvg.node().getBBox();labelSvg.attr("transform","translate("+-labelBBox.width/2+","+-labelBBox.height/2+")");return labelSvg}},{"./add-html-label":17,"./add-text-label":19}],19:[function(require,module,exports){var util=require("../util");module.exports=addTextLabel;function addTextLabel(root,node){var domNode=root.append("text");var lines=processEscapeSequences(node.label).split("\n");for(var i=0;imaxPadding){maxPadding=child.paddingTop}}return maxPadding}function getRank(g,v){var maxRank=0;var children=g.children(v);for(var i=0;imaxRank){maxRank=thisRank}}return maxRank}function orderByRank(g,nodes){return nodes.sort(function(x,y){return getRank(g,x)-getRank(g,y)})}function edgeToId(e){return escapeId(e.v)+":"+escapeId(e.w)+":"+escapeId(e.name)}var ID_DELIM=/:/g;function escapeId(str){return str?String(str).replace(ID_DELIM,"\\:"):""}function applyStyle(dom,styleFn){if(styleFn){dom.attr("style",styleFn)}}function applyClass(dom,classFn,otherClasses){if(classFn){dom.attr("class",classFn).attr("class",otherClasses+" "+dom.attr("class"))}}function applyTransition(selection,g){var graph=g.graph();if(_.isPlainObject(graph)){var transition=graph.transition;if(_.isFunction(transition)){return transition(selection)}}return selection}},{"./lodash":20}],26:[function(require,module,exports){module.exports="0.4.4-pre"},{}],27:[function(require,module,exports){module.exports={graphlib:require("./lib/graphlib"),layout:require("./lib/layout"),debug:require("./lib/debug"),util:{time:require("./lib/util").time,notime:require("./lib/util").notime},version:require("./lib/version")}},{"./lib/debug":32,"./lib/graphlib":33,"./lib/layout":35,"./lib/util":55,"./lib/version":56}],28:[function(require,module,exports){"use strict";var _=require("./lodash"),greedyFAS=require("./greedy-fas");module.exports={run:run,undo:undo};function run(g){var fas=g.graph().acyclicer==="greedy"?greedyFAS(g,weightFn(g)):dfsFAS(g);_.each(fas,function(e){var label=g.edge(e);g.removeEdge(e);label.forwardName=e.name;label.reversed=true;g.setEdge(e.w,e.v,label,_.uniqueId("rev"))});function weightFn(g){return function(e){return g.edge(e).weight}}}function dfsFAS(g){var fas=[],stack={},visited={};function dfs(v){if(_.has(visited,v)){return}visited[v]=true;stack[v]=true;_.each(g.outEdges(v),function(e){if(_.has(stack,e.w)){fas.push(e)}else{dfs(e.w)}});delete stack[v]}_.each(g.nodes(),dfs);return fas}function undo(g){_.each(g.edges(),function(e){var label=g.edge(e);if(label.reversed){g.removeEdge(e);var forwardName=label.forwardName;delete label.reversed;delete label.forwardName;g.setEdge(e.w,e.v,label,forwardName)}})}},{"./greedy-fas":34,"./lodash":36}],29:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports=addBorderSegments;function addBorderSegments(g){function dfs(v){var children=g.children(v),node=g.node(v);if(children.length){_.each(children,dfs)}if(_.has(node,"minRank")){node.borderLeft=[];node.borderRight=[];for(var rank=node.minRank,maxRank=node.maxRank+1;rank0;--i){entry=buckets[i].dequeue();if(entry){results=results.concat(removeNode(g,buckets,zeroIdx,entry,true));break}}}}return results}function removeNode(g,buckets,zeroIdx,entry,collectPredecessors){var results=collectPredecessors?[]:undefined;_.each(g.inEdges(entry.v),function(edge){var weight=g.edge(edge),uEntry=g.node(edge.v);if(collectPredecessors){results.push({v:edge.v,w:edge.w})}uEntry.out-=weight;assignBucket(buckets,zeroIdx,uEntry)});_.each(g.outEdges(entry.v),function(edge){var weight=g.edge(edge),w=edge.w,wEntry=g.node(w);wEntry["in"]-=weight;assignBucket(buckets,zeroIdx,wEntry)});g.removeNode(entry.v);return results}function buildState(g,weightFn){var fasGraph=new Graph,maxIn=0,maxOut=0;_.each(g.nodes(),function(v){fasGraph.setNode(v,{v:v,"in":0,out:0})});_.each(g.edges(),function(e){var prevWeight=fasGraph.edge(e.v,e.w)||0,weight=weightFn(e),edgeWeight=prevWeight+weight;fasGraph.setEdge(e.v,e.w,edgeWeight);maxOut=Math.max(maxOut,fasGraph.node(e.v).out+=weight);maxIn=Math.max(maxIn,fasGraph.node(e.w)["in"]+=weight)});var buckets=_.range(maxOut+maxIn+3).map(function(){return new List});var zeroIdx=maxIn+1;_.each(fasGraph.nodes(),function(v){assignBucket(buckets,zeroIdx,fasGraph.node(v))});return{graph:fasGraph,buckets:buckets,zeroIdx:zeroIdx}}function assignBucket(buckets,zeroIdx,entry){if(!entry.out){buckets[0].enqueue(entry)}else if(!entry["in"]){buckets[buckets.length-1].enqueue(entry)}else{buckets[entry.out-entry["in"]+zeroIdx].enqueue(entry)}}},{"./data/list":31,"./graphlib":33,"./lodash":36}],35:[function(require,module,exports){"use strict";var _=require("./lodash"),acyclic=require("./acyclic"),normalize=require("./normalize"),rank=require("./rank"),normalizeRanks=require("./util").normalizeRanks,parentDummyChains=require("./parent-dummy-chains"),removeEmptyRanks=require("./util").removeEmptyRanks,nestingGraph=require("./nesting-graph"),addBorderSegments=require("./add-border-segments"),coordinateSystem=require("./coordinate-system"),order=require("./order"),position=require("./position"),util=require("./util"),Graph=require("./graphlib").Graph;module.exports=layout;function layout(g,opts){var time=opts&&opts.debugTiming?util.time:util.notime;time("layout",function(){var layoutGraph=time(" buildLayoutGraph",function(){return buildLayoutGraph(g)});time(" runLayout",function(){runLayout(layoutGraph,time)});time(" updateInputGraph",function(){updateInputGraph(g,layoutGraph)})})}function runLayout(g,time){time(" makeSpaceForEdgeLabels",function(){makeSpaceForEdgeLabels(g)});time(" removeSelfEdges",function(){removeSelfEdges(g)});time(" acyclic",function(){acyclic.run(g)});time(" nestingGraph.run",function(){nestingGraph.run(g)});time(" rank",function(){rank(util.asNonCompoundGraph(g))});time(" injectEdgeLabelProxies",function(){injectEdgeLabelProxies(g)});time(" removeEmptyRanks",function(){removeEmptyRanks(g)});time(" nestingGraph.cleanup",function(){nestingGraph.cleanup(g)});time(" normalizeRanks",function(){normalizeRanks(g)});time(" assignRankMinMax",function(){assignRankMinMax(g)});time(" removeEdgeLabelProxies",function(){removeEdgeLabelProxies(g)});time(" normalize.run",function(){normalize.run(g)});time(" parentDummyChains",function(){
-parentDummyChains(g)});time(" addBorderSegments",function(){addBorderSegments(g)});time(" order",function(){order(g)});time(" insertSelfEdges",function(){insertSelfEdges(g)});time(" adjustCoordinateSystem",function(){coordinateSystem.adjust(g)});time(" position",function(){position(g)});time(" positionSelfEdges",function(){positionSelfEdges(g)});time(" removeBorderNodes",function(){removeBorderNodes(g)});time(" normalize.undo",function(){normalize.undo(g)});time(" fixupEdgeLabelCoords",function(){fixupEdgeLabelCoords(g)});time(" undoCoordinateSystem",function(){coordinateSystem.undo(g)});time(" translateGraph",function(){translateGraph(g)});time(" assignNodeIntersects",function(){assignNodeIntersects(g)});time(" reversePoints",function(){reversePointsForReversedEdges(g)});time(" acyclic.undo",function(){acyclic.undo(g)})}function updateInputGraph(inputGraph,layoutGraph){_.each(inputGraph.nodes(),function(v){var inputLabel=inputGraph.node(v),layoutLabel=layoutGraph.node(v);if(inputLabel){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y;if(layoutGraph.children(v).length){inputLabel.width=layoutLabel.width;inputLabel.height=layoutLabel.height}}});_.each(inputGraph.edges(),function(e){var inputLabel=inputGraph.edge(e),layoutLabel=layoutGraph.edge(e);inputLabel.points=layoutLabel.points;if(_.has(layoutLabel,"x")){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y}});inputGraph.graph().width=layoutGraph.graph().width;inputGraph.graph().height=layoutGraph.graph().height}var graphNumAttrs=["nodesep","edgesep","ranksep","marginx","marginy"],graphDefaults={ranksep:50,edgesep:20,nodesep:50,rankdir:"tb"},graphAttrs=["acyclicer","ranker","rankdir","align"],nodeNumAttrs=["width","height"],nodeDefaults={width:0,height:0},edgeNumAttrs=["minlen","weight","width","height","labeloffset"],edgeDefaults={minlen:1,weight:1,width:0,height:0,labeloffset:10,labelpos:"r"},edgeAttrs=["labelpos"];function buildLayoutGraph(inputGraph){var g=new Graph({multigraph:true,compound:true}),graph=canonicalize(inputGraph.graph());g.setGraph(_.merge({},graphDefaults,selectNumberAttrs(graph,graphNumAttrs),_.pick(graph,graphAttrs)));_.each(inputGraph.nodes(),function(v){var node=canonicalize(inputGraph.node(v));g.setNode(v,_.defaults(selectNumberAttrs(node,nodeNumAttrs),nodeDefaults));g.setParent(v,inputGraph.parent(v))});_.each(inputGraph.edges(),function(e){var edge=canonicalize(inputGraph.edge(e));g.setEdge(e,_.merge({},edgeDefaults,selectNumberAttrs(edge,edgeNumAttrs),_.pick(edge,edgeAttrs)))});return g}function makeSpaceForEdgeLabels(g){var graph=g.graph();graph.ranksep/=2;_.each(g.edges(),function(e){var edge=g.edge(e);edge.minlen*=2;if(edge.labelpos.toLowerCase()!=="c"){if(graph.rankdir==="TB"||graph.rankdir==="BT"){edge.width+=edge.labeloffset}else{edge.height+=edge.labeloffset}}})}function injectEdgeLabelProxies(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.width&&edge.height){var v=g.node(e.v),w=g.node(e.w),label={rank:(w.rank-v.rank)/2+v.rank,e:e};util.addDummyNode(g,"edge-proxy",label,"_ep")}})}function assignRankMinMax(g){var maxRank=0;_.each(g.nodes(),function(v){var node=g.node(v);if(node.borderTop){node.minRank=g.node(node.borderTop).rank;node.maxRank=g.node(node.borderBottom).rank;maxRank=_.max(maxRank,node.maxRank)}});g.graph().maxRank=maxRank}function removeEdgeLabelProxies(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="edge-proxy"){g.edge(node.e).labelRank=node.rank;g.removeNode(v)}})}function translateGraph(g){var minX=Number.POSITIVE_INFINITY,maxX=0,minY=Number.POSITIVE_INFINITY,maxY=0,graphLabel=g.graph(),marginX=graphLabel.marginx||0,marginY=graphLabel.marginy||0;function getExtremes(attrs){var x=attrs.x,y=attrs.y,w=attrs.width,h=attrs.height;minX=Math.min(minX,x-w/2);maxX=Math.max(maxX,x+w/2);minY=Math.min(minY,y-h/2);maxY=Math.max(maxY,y+h/2)}_.each(g.nodes(),function(v){getExtremes(g.node(v))});_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){getExtremes(edge)}});minX-=marginX;minY-=marginY;_.each(g.nodes(),function(v){var node=g.node(v);node.x-=minX;node.y-=minY});_.each(g.edges(),function(e){var edge=g.edge(e);_.each(edge.points,function(p){p.x-=minX;p.y-=minY});if(_.has(edge,"x")){edge.x-=minX}if(_.has(edge,"y")){edge.y-=minY}});graphLabel.width=maxX-minX+marginX;graphLabel.height=maxY-minY+marginY}function assignNodeIntersects(g){_.each(g.edges(),function(e){var edge=g.edge(e),nodeV=g.node(e.v),nodeW=g.node(e.w),p1,p2;if(!edge.points){edge.points=[];p1=nodeW;p2=nodeV}else{p1=edge.points[0];p2=edge.points[edge.points.length-1]}edge.points.unshift(util.intersectRect(nodeV,p1));edge.points.push(util.intersectRect(nodeW,p2))})}function fixupEdgeLabelCoords(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){if(edge.labelpos==="l"||edge.labelpos==="r"){edge.width-=edge.labeloffset}switch(edge.labelpos){case"l":edge.x-=edge.width/2+edge.labeloffset;break;case"r":edge.x+=edge.width/2+edge.labeloffset;break}}})}function reversePointsForReversedEdges(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.reversed){edge.points.reverse()}})}function removeBorderNodes(g){_.each(g.nodes(),function(v){if(g.children(v).length){var node=g.node(v),t=g.node(node.borderTop),b=g.node(node.borderBottom),l=g.node(_.last(node.borderLeft)),r=g.node(_.last(node.borderRight));node.width=Math.abs(r.x-l.x);node.height=Math.abs(b.y-t.y);node.x=l.x+node.width/2;node.y=t.y+node.height/2}});_.each(g.nodes(),function(v){if(g.node(v).dummy==="border"){g.removeNode(v)}})}function removeSelfEdges(g){_.each(g.edges(),function(e){if(e.v===e.w){var node=g.node(e.v);if(!node.selfEdges){node.selfEdges=[]}node.selfEdges.push({e:e,label:g.edge(e)});g.removeEdge(e)}})}function insertSelfEdges(g){var layers=util.buildLayerMatrix(g);_.each(layers,function(layer){var orderShift=0;_.each(layer,function(v,i){var node=g.node(v);node.order=i+orderShift;_.each(node.selfEdges,function(selfEdge){util.addDummyNode(g,"selfedge",{width:selfEdge.label.width,height:selfEdge.label.height,rank:node.rank,order:i+ ++orderShift,e:selfEdge.e,label:selfEdge.label},"_se")});delete node.selfEdges})})}function positionSelfEdges(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="selfedge"){var selfNode=g.node(node.e.v),x=selfNode.x+selfNode.width/2,y=selfNode.y,dx=node.x-x,dy=selfNode.height/2;g.setEdge(node.e,node.label);g.removeNode(v);node.label.points=[{x:x+2*dx/3,y:y-dy},{x:x+5*dx/6,y:y-dy},{x:x+dx,y:y},{x:x+5*dx/6,y:y+dy},{x:x+2*dx/3,y:y+dy}];node.label.x=node.x;node.label.y=node.y}})}function selectNumberAttrs(obj,attrs){return _.mapValues(_.pick(obj,attrs),Number)}function canonicalize(attrs){var newAttrs={};_.each(attrs,function(v,k){newAttrs[k.toLowerCase()]=v});return newAttrs}},{"./acyclic":28,"./add-border-segments":29,"./coordinate-system":30,"./graphlib":33,"./lodash":36,"./nesting-graph":37,"./normalize":38,"./order":43,"./parent-dummy-chains":48,"./position":50,"./rank":52,"./util":55}],36:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],37:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports={run:run,cleanup:cleanup};function run(g){var root=util.addDummyNode(g,"root",{},"_root"),depths=treeDepths(g),height=_.max(depths)-1,nodeSep=2*height+1;g.graph().nestingRoot=root;_.each(g.edges(),function(e){g.edge(e).minlen*=nodeSep});var weight=sumWeights(g)+1;_.each(g.children(),function(child){dfs(g,root,nodeSep,weight,height,depths,child)});g.graph().nodeRankFactor=nodeSep}function dfs(g,root,nodeSep,weight,height,depths,v){var children=g.children(v);if(!children.length){if(v!==root){g.setEdge(root,v,{weight:0,minlen:nodeSep})}return}var top=util.addBorderNode(g,"_bt"),bottom=util.addBorderNode(g,"_bb"),label=g.node(v);g.setParent(top,v);label.borderTop=top;g.setParent(bottom,v);label.borderBottom=bottom;_.each(children,function(child){dfs(g,root,nodeSep,weight,height,depths,child);var childNode=g.node(child),childTop=childNode.borderTop?childNode.borderTop:child,childBottom=childNode.borderBottom?childNode.borderBottom:child,thisWeight=childNode.borderTop?weight:2*weight,minlen=childTop!==childBottom?1:height-depths[v]+1;g.setEdge(top,childTop,{weight:thisWeight,minlen:minlen,nestingEdge:true});g.setEdge(childBottom,bottom,{weight:thisWeight,minlen:minlen,nestingEdge:true})});if(!g.parent(v)){g.setEdge(root,top,{weight:0,minlen:height+depths[v]})}}function treeDepths(g){var depths={};function dfs(v,depth){var children=g.children(v);if(children&&children.length){_.each(children,function(child){dfs(child,depth+1)})}depths[v]=depth}_.each(g.children(),function(v){dfs(v,1)});return depths}function sumWeights(g){return _.reduce(g.edges(),function(acc,e){return acc+g.edge(e).weight},0)}function cleanup(g){var graphLabel=g.graph();g.removeNode(graphLabel.nestingRoot);delete graphLabel.nestingRoot;_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.nestingEdge){g.removeEdge(e)}})}},{"./lodash":36,"./util":55}],38:[function(require,module,exports){"use strict";var _=require("./lodash"),util=require("./util");module.exports={run:run,undo:undo};function run(g){g.graph().dummyChains=[];_.each(g.edges(),function(edge){normalizeEdge(g,edge)})}function normalizeEdge(g,e){var v=e.v,vRank=g.node(v).rank,w=e.w,wRank=g.node(w).rank,name=e.name,edgeLabel=g.edge(e),labelRank=edgeLabel.labelRank;if(wRank===vRank+1)return;g.removeEdge(e);var dummy,attrs,i;for(i=0,++vRank;vRank0){if(index%2){weightSum+=tree[index+1]}index=index-1>>1;tree[index]+=entry.weight}cc+=entry.weight*weightSum}));return cc}},{"../lodash":36}],43:[function(require,module,exports){"use strict";var _=require("../lodash"),initOrder=require("./init-order"),crossCount=require("./cross-count"),sortSubgraph=require("./sort-subgraph"),buildLayerGraph=require("./build-layer-graph"),addSubgraphConstraints=require("./add-subgraph-constraints"),Graph=require("../graphlib").Graph,util=require("../util");module.exports=order;function order(g){var maxRank=util.maxRank(g),downLayerGraphs=buildLayerGraphs(g,_.range(1,maxRank+1),"inEdges"),upLayerGraphs=buildLayerGraphs(g,_.range(maxRank-1,-1,-1),"outEdges");var layering=initOrder(g);assignOrder(g,layering);var bestCC=Number.POSITIVE_INFINITY,best;for(var i=0,lastBest=0;lastBest<4;++i,++lastBest){sweepLayerGraphs(i%2?downLayerGraphs:upLayerGraphs,i%4>=2);layering=util.buildLayerMatrix(g);var cc=crossCount(g,layering);if(cc=vEntry.barycenter){mergeEntries(vEntry,uEntry)}}}function handleOut(vEntry){return function(wEntry){wEntry["in"].push(vEntry);if(--wEntry.indegree===0){sourceSet.push(wEntry)}}}while(sourceSet.length){var entry=sourceSet.pop();entries.push(entry);_.each(entry["in"].reverse(),handleIn(entry));_.each(entry.out,handleOut(entry))}return _.chain(entries).filter(function(entry){return!entry.merged}).map(function(entry){return _.pick(entry,["vs","i","barycenter","weight"])}).value()}function mergeEntries(target,source){var sum=0,weight=0;if(target.weight){sum+=target.barycenter*target.weight;weight+=target.weight}if(source.weight){sum+=source.barycenter*source.weight;weight+=source.weight}target.vs=source.vs.concat(target.vs);target.barycenter=sum/weight;target.weight=weight;target.i=Math.min(source.i,target.i);source.merged=true}},{"../lodash":36}],46:[function(require,module,exports){var _=require("../lodash"),barycenter=require("./barycenter"),resolveConflicts=require("./resolve-conflicts"),sort=require("./sort");module.exports=sortSubgraph;function sortSubgraph(g,v,cg,biasRight){var movable=g.children(v),node=g.node(v),bl=node?node.borderLeft:undefined,br=node?node.borderRight:undefined,subgraphs={};if(bl){movable=_.filter(movable,function(w){return w!==bl&&w!==br})}var barycenters=barycenter(g,movable);_.each(barycenters,function(entry){if(g.children(entry.v).length){var subgraphResult=sortSubgraph(g,entry.v,cg,biasRight);subgraphs[entry.v]=subgraphResult;if(_.has(subgraphResult,"barycenter")){mergeBarycenters(entry,subgraphResult)}}});var entries=resolveConflicts(barycenters,cg);expandSubgraphs(entries,subgraphs);var result=sort(entries,biasRight);if(bl){result.vs=_.flatten([bl,result.vs,br],true);if(g.predecessors(bl).length){var blPred=g.node(g.predecessors(bl)[0]),brPred=g.node(g.predecessors(br)[0]);if(!_.has(result,"barycenter")){result.barycenter=0;result.weight=0}result.barycenter=(result.barycenter*result.weight+blPred.order+brPred.order)/(result.weight+2);result.weight+=2}}return result}function expandSubgraphs(entries,subgraphs){_.each(entries,function(entry){entry.vs=_.flatten(entry.vs.map(function(v){if(subgraphs[v]){return subgraphs[v].vs}return v}),true)})}function mergeBarycenters(target,other){if(!_.isUndefined(target.barycenter)){target.barycenter=(target.barycenter*target.weight+other.barycenter*other.weight)/(target.weight+other.weight);target.weight+=other.weight}else{target.barycenter=other.barycenter;target.weight=other.weight}}},{"../lodash":36,"./barycenter":40,"./resolve-conflicts":45,"./sort":47}],47:[function(require,module,exports){var _=require("../lodash"),util=require("../util");module.exports=sort;function sort(entries,biasRight){var parts=util.partition(entries,function(entry){return _.has(entry,"barycenter")});var sortable=parts.lhs,unsortable=_.sortBy(parts.rhs,function(entry){return-entry.i}),vs=[],sum=0,weight=0,vsIndex=0;sortable.sort(compareWithBias(!!biasRight));vsIndex=consumeUnsortable(vs,unsortable,vsIndex);_.each(sortable,function(entry){vsIndex+=entry.vs.length;vs.push(entry.vs);sum+=entry.barycenter*entry.weight;weight+=entry.weight;vsIndex=consumeUnsortable(vs,unsortable,vsIndex)});var result={vs:_.flatten(vs,true)};if(weight){result.barycenter=sum/weight;result.weight=weight}return result}function consumeUnsortable(vs,unsortable,index){var last;while(unsortable.length&&(last=_.last(unsortable)).i<=index){unsortable.pop();vs.push(last.vs);index++}return index}function compareWithBias(bias){return function(entryV,entryW){if(entryV.barycenterentryW.barycenter){return 1}return!bias?entryV.i-entryW.i:entryW.i-entryV.i}}},{"../lodash":36,"../util":55}],48:[function(require,module,exports){var _=require("./lodash");module.exports=parentDummyChains;function parentDummyChains(g){var postorderNums=postorder(g);_.each(g.graph().dummyChains,function(v){var node=g.node(v),edgeObj=node.edgeObj,pathData=findPath(g,postorderNums,edgeObj.v,edgeObj.w),path=pathData.path,lca=pathData.lca,pathIdx=0,pathV=path[pathIdx],ascending=true;while(v!==edgeObj.w){node=g.node(v);if(ascending){while((pathV=path[pathIdx])!==lca&&g.node(pathV).maxRanklow||lim>postorderNums[parent].lim));lca=parent;parent=w;while((parent=g.parent(parent))!==lca){wPath.push(parent)}return{path:vPath.concat(wPath.reverse()),lca:lca}}function postorder(g){var result={},lim=0;function dfs(v){var low=lim;_.each(g.children(v),dfs);result[v]={low:low,lim:lim++}}_.each(g.children(),dfs);return result}},{"./lodash":36}],49:[function(require,module,exports){"use strict";var _=require("../lodash"),Graph=require("../graphlib").Graph,util=require("../util");module.exports={positionX:positionX,findType1Conflicts:findType1Conflicts,findType2Conflicts:findType2Conflicts,addConflict:addConflict,hasConflict:hasConflict,verticalAlignment:verticalAlignment,horizontalCompaction:horizontalCompaction,alignCoordinates:alignCoordinates,findSmallestWidthAlignment:findSmallestWidthAlignment,balance:balance};function findType1Conflicts(g,layering){var conflicts={};function visitLayer(prevLayer,layer){var k0=0,scanPos=0,prevLayerLength=prevLayer.length,lastNode=_.last(layer);_.each(layer,function(v,i){var w=findOtherInnerSegmentNode(g,v),k1=w?g.node(w).order:prevLayerLength;if(w||v===lastNode){_.each(layer.slice(scanPos,i+1),function(scanNode){_.each(g.predecessors(scanNode),function(u){var uLabel=g.node(u),uPos=uLabel.order;if((uPosnextNorthBorder)){addConflict(conflicts,u,v)}})}})}function visitLayer(north,south){var prevNorthPos=-1,nextNorthPos,southPos=0;_.each(south,function(v,southLookahead){if(g.node(v).dummy==="border"){var predecessors=g.predecessors(v);if(predecessors.length){nextNorthPos=g.node(predecessors[0]).order;scan(south,southPos,southLookahead,prevNorthPos,nextNorthPos);southPos=southLookahead;prevNorthPos=nextNorthPos}}scan(south,southPos,south.length,nextNorthPos,north.length)});return south}_.reduce(layering,visitLayer);return conflicts}function findOtherInnerSegmentNode(g,v){if(g.node(v).dummy){return _.find(g.predecessors(v),function(u){return g.node(u).dummy})}}function addConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}var conflictsV=conflicts[v];if(!conflictsV){conflicts[v]=conflictsV={}}conflictsV[w]=true}function hasConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}return _.has(conflicts[v],w)}function verticalAlignment(g,layering,conflicts,neighborFn){var root={},align={},pos={};_.each(layering,function(layer){_.each(layer,function(v,order){root[v]=v;align[v]=v;pos[v]=order})});_.each(layering,function(layer){var prevIdx=-1;_.each(layer,function(v){var ws=neighborFn(v);if(ws.length){ws=_.sortBy(ws,function(w){return pos[w]});var mp=(ws.length-1)/2;for(var i=Math.floor(mp),il=Math.ceil(mp);i<=il;++i){var w=ws[i];if(align[v]===v&&prevIdx0}},{}],14:[function(require,module,exports){module.exports=intersectNode;function intersectNode(node,point){return node.intersect(point)}},{}],15:[function(require,module,exports){var intersectLine=require("./intersect-line");module.exports=intersectPolygon;function intersectPolygon(node,polyPoints,point){var x1=node.x;var y1=node.y;var intersections=[];var minX=Number.POSITIVE_INFINITY,minY=Number.POSITIVE_INFINITY;polyPoints.forEach(function(entry){minX=Math.min(minX,entry.x);minY=Math.min(minY,entry.y)});var left=x1-node.width/2-minX;var top=y1-node.height/2-minY;for(var i=0;i1){intersections.sort(function(p,q){var pdx=p.x-point.x,pdy=p.y-point.y,distp=Math.sqrt(pdx*pdx+pdy*pdy),qdx=q.x-point.x,qdy=q.y-point.y,distq=Math.sqrt(qdx*qdx+qdy*qdy);return distpMath.abs(dx)*h){if(dy<0){h=-h}sx=dy===0?0:h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=dx===0?0:w*dy/dx}return{x:x+sx,y:y+sy}}},{}],17:[function(require,module,exports){var util=require("../util");module.exports=addHtmlLabel;function addHtmlLabel(root,node){var fo=root.append("foreignObject").attr("width","100000");var div=fo.append("xhtml:div");var label=node.label;switch(typeof label){case"function":div.insert(label);break;case"object":div.insert(function(){return label});break;default:div.html(label)}util.applyStyle(div,node.labelStyle);div.style("display","inline-block");div.style("white-space","nowrap");var w,h;div.each(function(){w=this.clientWidth;h=this.clientHeight});fo.attr("width",w).attr("height",h);return fo}},{"../util":25}],18:[function(require,module,exports){var addTextLabel=require("./add-text-label"),addHtmlLabel=require("./add-html-label");module.exports=addLabel;function addLabel(root,node){var label=node.label;var labelSvg=root.append("g");if(typeof label!=="string"||node.labelType==="html"){addHtmlLabel(labelSvg,node)}else{addTextLabel(labelSvg,node)}var labelBBox=labelSvg.node().getBBox();labelSvg.attr("transform","translate("+-labelBBox.width/2+","+-labelBBox.height/2+")");return labelSvg}},{"./add-html-label":17,"./add-text-label":19}],19:[function(require,module,exports){var util=require("../util");module.exports=addTextLabel;function addTextLabel(root,node){var domNode=root.append("text");var lines=processEscapeSequences(node.label).split("\n");for(var i=0;imaxPadding){maxPadding=child.paddingTop}}return maxPadding}function getRank(g,v){var maxRank=0;var children=g.children(v);for(var i=0;imaxRank){maxRank=thisRank}}return maxRank}function orderByRank(g,nodes){return nodes.sort(function(x,y){return getRank(g,x)-getRank(g,y)})}function edgeToId(e){return escapeId(e.v)+":"+escapeId(e.w)+":"+escapeId(e.name)}var ID_DELIM=/:/g;function escapeId(str){return str?String(str).replace(ID_DELIM,"\\:"):""}function applyStyle(dom,styleFn){if(styleFn){dom.attr("style",styleFn)}}function applyClass(dom,classFn,otherClasses){if(classFn){dom.attr("class",classFn).attr("class",otherClasses+" "+dom.attr("class"))}}function applyTransition(selection,g){var graph=g.graph();if(_.isPlainObject(graph)){var transition=graph.transition;if(_.isFunction(transition)){return transition(selection)}}return selection}},{"./lodash":20}],26:[function(require,module,exports){module.exports="0.4.4-pre"},{}],27:[function(require,module,exports){module.exports={graphlib:require("./lib/graphlib"),layout:require("./lib/layout"),debug:require("./lib/debug"),util:{time:require("./lib/util").time,notime:require("./lib/util").notime},version:require("./lib/version")}},{"./lib/debug":32,"./lib/graphlib":33,"./lib/layout":35,"./lib/util":55,"./lib/version":56}],28:[function(require,module,exports){"use strict";var _=require("./lodash"),greedyFAS=require("./greedy-fas");module.exports={run:run,undo:undo};function run(g){var fas=g.graph().acyclicer==="greedy"?greedyFAS(g,weightFn(g)):dfsFAS(g);_.each(fas,function(e){var label=g.edge(e);g.removeEdge(e);label.forwardName=e.name;label.reversed=true;g.setEdge(e.w,e.v,label,_.uniqueId("rev"))});function weightFn(g){return function(e){return g.edge(e).weight}}}function dfsFAS(g){var fas=[],stack={},visited={};function dfs(v){if(_.has(visited,v)){return}visited[v]=true;stack[v]=true;_.each(g.outEdges(v),function(e){if(_.has(stack,e.w)){fas.push(e)}else{dfs(e.w)}});delete stack[v]}_.each(g.nodes(),dfs);return fas}function undo(g){_.each(g.edges(),function(e){var label=g.edge(e);if(label.reversed){g.removeEdge(e);var forwardName=label.forwardName;delete label.reversed;delete label.forwardName;g.setEdge(e.w,e.v,label,forwardName)}})}},{"./greedy-fas":34,"./lodash":36}],29:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports=addBorderSegments;function addBorderSegments(g){function dfs(v){var children=g.children(v),node=g.node(v);if(children.length){_.each(children,dfs)}if(_.has(node,"minRank")){node.borderLeft=[];node.borderRight=[];for(var rank=node.minRank,maxRank=node.maxRank+1;rank0;--i){entry=buckets[i].dequeue();if(entry){results=results.concat(removeNode(g,buckets,zeroIdx,entry,true));break}}}}return results}function removeNode(g,buckets,zeroIdx,entry,collectPredecessors){var results=collectPredecessors?[]:undefined;_.each(g.inEdges(entry.v),function(edge){var weight=g.edge(edge),uEntry=g.node(edge.v);if(collectPredecessors){results.push({v:edge.v,w:edge.w})}uEntry.out-=weight;assignBucket(buckets,zeroIdx,uEntry)});_.each(g.outEdges(entry.v),function(edge){var weight=g.edge(edge),w=edge.w,wEntry=g.node(w);wEntry["in"]-=weight;assignBucket(buckets,zeroIdx,wEntry)});g.removeNode(entry.v);return results}function buildState(g,weightFn){var fasGraph=new Graph,maxIn=0,maxOut=0;_.each(g.nodes(),function(v){fasGraph.setNode(v,{v:v,"in":0,out:0})});_.each(g.edges(),function(e){var prevWeight=fasGraph.edge(e.v,e.w)||0,weight=weightFn(e),edgeWeight=prevWeight+weight;fasGraph.setEdge(e.v,e.w,edgeWeight);maxOut=Math.max(maxOut,fasGraph.node(e.v).out+=weight);maxIn=Math.max(maxIn,fasGraph.node(e.w)["in"]+=weight)});var buckets=_.range(maxOut+maxIn+3).map(function(){return new List});var zeroIdx=maxIn+1;_.each(fasGraph.nodes(),function(v){assignBucket(buckets,zeroIdx,fasGraph.node(v))});return{graph:fasGraph,buckets:buckets,zeroIdx:zeroIdx}}function assignBucket(buckets,zeroIdx,entry){if(!entry.out){buckets[0].enqueue(entry)}else if(!entry["in"]){buckets[buckets.length-1].enqueue(entry)}else{buckets[entry.out-entry["in"]+zeroIdx].enqueue(entry)}}},{"./data/list":31,"./graphlib":33,"./lodash":36}],35:[function(require,module,exports){"use strict";var _=require("./lodash"),acyclic=require("./acyclic"),normalize=require("./normalize"),rank=require("./rank"),normalizeRanks=require("./util").normalizeRanks,parentDummyChains=require("./parent-dummy-chains"),removeEmptyRanks=require("./util").removeEmptyRanks,nestingGraph=require("./nesting-graph"),addBorderSegments=require("./add-border-segments"),coordinateSystem=require("./coordinate-system"),order=require("./order"),position=require("./position"),util=require("./util"),Graph=require("./graphlib").Graph;module.exports=layout;function layout(g,opts){var time=opts&&opts.debugTiming?util.time:util.notime;time("layout",function(){var layoutGraph=time(" buildLayoutGraph",function(){return buildLayoutGraph(g)});time(" runLayout",function(){runLayout(layoutGraph,time)});time(" updateInputGraph",function(){updateInputGraph(g,layoutGraph)})})}function runLayout(g,time){time(" makeSpaceForEdgeLabels",function(){makeSpaceForEdgeLabels(g)});time(" removeSelfEdges",function(){removeSelfEdges(g)});time(" acyclic",function(){acyclic.run(g)});time(" nestingGraph.run",function(){nestingGraph.run(g)});time(" rank",function(){rank(util.asNonCompoundGraph(g))});time(" injectEdgeLabelProxies",function(){injectEdgeLabelProxies(g)});time(" removeEmptyRanks",function(){removeEmptyRanks(g)});time(" nestingGraph.cleanup",function(){nestingGraph.cleanup(g)});time(" normalizeRanks",function(){normalizeRanks(g)});time(" assignRankMinMax",function(){assignRankMinMax(g)});time(" removeEdgeLabelProxies",function(){removeEdgeLabelProxies(g)});time(" normalize.run",function(){
+normalize.run(g)});time(" parentDummyChains",function(){parentDummyChains(g)});time(" addBorderSegments",function(){addBorderSegments(g)});time(" order",function(){order(g)});time(" insertSelfEdges",function(){insertSelfEdges(g)});time(" adjustCoordinateSystem",function(){coordinateSystem.adjust(g)});time(" position",function(){position(g)});time(" positionSelfEdges",function(){positionSelfEdges(g)});time(" removeBorderNodes",function(){removeBorderNodes(g)});time(" normalize.undo",function(){normalize.undo(g)});time(" fixupEdgeLabelCoords",function(){fixupEdgeLabelCoords(g)});time(" undoCoordinateSystem",function(){coordinateSystem.undo(g)});time(" translateGraph",function(){translateGraph(g)});time(" assignNodeIntersects",function(){assignNodeIntersects(g)});time(" reversePoints",function(){reversePointsForReversedEdges(g)});time(" acyclic.undo",function(){acyclic.undo(g)})}function updateInputGraph(inputGraph,layoutGraph){_.each(inputGraph.nodes(),function(v){var inputLabel=inputGraph.node(v),layoutLabel=layoutGraph.node(v);if(inputLabel){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y;if(layoutGraph.children(v).length){inputLabel.width=layoutLabel.width;inputLabel.height=layoutLabel.height}}});_.each(inputGraph.edges(),function(e){var inputLabel=inputGraph.edge(e),layoutLabel=layoutGraph.edge(e);inputLabel.points=layoutLabel.points;if(_.has(layoutLabel,"x")){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y}});inputGraph.graph().width=layoutGraph.graph().width;inputGraph.graph().height=layoutGraph.graph().height}var graphNumAttrs=["nodesep","edgesep","ranksep","marginx","marginy"],graphDefaults={ranksep:50,edgesep:20,nodesep:50,rankdir:"tb"},graphAttrs=["acyclicer","ranker","rankdir","align"],nodeNumAttrs=["width","height"],nodeDefaults={width:0,height:0},edgeNumAttrs=["minlen","weight","width","height","labeloffset"],edgeDefaults={minlen:1,weight:1,width:0,height:0,labeloffset:10,labelpos:"r"},edgeAttrs=["labelpos"];function buildLayoutGraph(inputGraph){var g=new Graph({multigraph:true,compound:true}),graph=canonicalize(inputGraph.graph());g.setGraph(_.merge({},graphDefaults,selectNumberAttrs(graph,graphNumAttrs),_.pick(graph,graphAttrs)));_.each(inputGraph.nodes(),function(v){var node=canonicalize(inputGraph.node(v));g.setNode(v,_.defaults(selectNumberAttrs(node,nodeNumAttrs),nodeDefaults));g.setParent(v,inputGraph.parent(v))});_.each(inputGraph.edges(),function(e){var edge=canonicalize(inputGraph.edge(e));g.setEdge(e,_.merge({},edgeDefaults,selectNumberAttrs(edge,edgeNumAttrs),_.pick(edge,edgeAttrs)))});return g}function makeSpaceForEdgeLabels(g){var graph=g.graph();graph.ranksep/=2;_.each(g.edges(),function(e){var edge=g.edge(e);edge.minlen*=2;if(edge.labelpos.toLowerCase()!=="c"){if(graph.rankdir==="TB"||graph.rankdir==="BT"){edge.width+=edge.labeloffset}else{edge.height+=edge.labeloffset}}})}function injectEdgeLabelProxies(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.width&&edge.height){var v=g.node(e.v),w=g.node(e.w),label={rank:(w.rank-v.rank)/2+v.rank,e:e};util.addDummyNode(g,"edge-proxy",label,"_ep")}})}function assignRankMinMax(g){var maxRank=0;_.each(g.nodes(),function(v){var node=g.node(v);if(node.borderTop){node.minRank=g.node(node.borderTop).rank;node.maxRank=g.node(node.borderBottom).rank;maxRank=_.max(maxRank,node.maxRank)}});g.graph().maxRank=maxRank}function removeEdgeLabelProxies(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="edge-proxy"){g.edge(node.e).labelRank=node.rank;g.removeNode(v)}})}function translateGraph(g){var minX=Number.POSITIVE_INFINITY,maxX=0,minY=Number.POSITIVE_INFINITY,maxY=0,graphLabel=g.graph(),marginX=graphLabel.marginx||0,marginY=graphLabel.marginy||0;function getExtremes(attrs){var x=attrs.x,y=attrs.y,w=attrs.width,h=attrs.height;minX=Math.min(minX,x-w/2);maxX=Math.max(maxX,x+w/2);minY=Math.min(minY,y-h/2);maxY=Math.max(maxY,y+h/2)}_.each(g.nodes(),function(v){getExtremes(g.node(v))});_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){getExtremes(edge)}});minX-=marginX;minY-=marginY;_.each(g.nodes(),function(v){var node=g.node(v);node.x-=minX;node.y-=minY});_.each(g.edges(),function(e){var edge=g.edge(e);_.each(edge.points,function(p){p.x-=minX;p.y-=minY});if(_.has(edge,"x")){edge.x-=minX}if(_.has(edge,"y")){edge.y-=minY}});graphLabel.width=maxX-minX+marginX;graphLabel.height=maxY-minY+marginY}function assignNodeIntersects(g){_.each(g.edges(),function(e){var edge=g.edge(e),nodeV=g.node(e.v),nodeW=g.node(e.w),p1,p2;if(!edge.points){edge.points=[];p1=nodeW;p2=nodeV}else{p1=edge.points[0];p2=edge.points[edge.points.length-1]}edge.points.unshift(util.intersectRect(nodeV,p1));edge.points.push(util.intersectRect(nodeW,p2))})}function fixupEdgeLabelCoords(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){if(edge.labelpos==="l"||edge.labelpos==="r"){edge.width-=edge.labeloffset}switch(edge.labelpos){case"l":edge.x-=edge.width/2+edge.labeloffset;break;case"r":edge.x+=edge.width/2+edge.labeloffset;break}}})}function reversePointsForReversedEdges(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.reversed){edge.points.reverse()}})}function removeBorderNodes(g){_.each(g.nodes(),function(v){if(g.children(v).length){var node=g.node(v),t=g.node(node.borderTop),b=g.node(node.borderBottom),l=g.node(_.last(node.borderLeft)),r=g.node(_.last(node.borderRight));node.width=Math.abs(r.x-l.x);node.height=Math.abs(b.y-t.y);node.x=l.x+node.width/2;node.y=t.y+node.height/2}});_.each(g.nodes(),function(v){if(g.node(v).dummy==="border"){g.removeNode(v)}})}function removeSelfEdges(g){_.each(g.edges(),function(e){if(e.v===e.w){var node=g.node(e.v);if(!node.selfEdges){node.selfEdges=[]}node.selfEdges.push({e:e,label:g.edge(e)});g.removeEdge(e)}})}function insertSelfEdges(g){var layers=util.buildLayerMatrix(g);_.each(layers,function(layer){var orderShift=0;_.each(layer,function(v,i){var node=g.node(v);node.order=i+orderShift;_.each(node.selfEdges,function(selfEdge){util.addDummyNode(g,"selfedge",{width:selfEdge.label.width,height:selfEdge.label.height,rank:node.rank,order:i+ ++orderShift,e:selfEdge.e,label:selfEdge.label},"_se")});delete node.selfEdges})})}function positionSelfEdges(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="selfedge"){var selfNode=g.node(node.e.v),x=selfNode.x+selfNode.width/2,y=selfNode.y,dx=node.x-x,dy=selfNode.height/2;g.setEdge(node.e,node.label);g.removeNode(v);node.label.points=[{x:x+2*dx/3,y:y-dy},{x:x+5*dx/6,y:y-dy},{x:x+dx,y:y},{x:x+5*dx/6,y:y+dy},{x:x+2*dx/3,y:y+dy}];node.label.x=node.x;node.label.y=node.y}})}function selectNumberAttrs(obj,attrs){return _.mapValues(_.pick(obj,attrs),Number)}function canonicalize(attrs){var newAttrs={};_.each(attrs,function(v,k){newAttrs[k.toLowerCase()]=v});return newAttrs}},{"./acyclic":28,"./add-border-segments":29,"./coordinate-system":30,"./graphlib":33,"./lodash":36,"./nesting-graph":37,"./normalize":38,"./order":43,"./parent-dummy-chains":48,"./position":50,"./rank":52,"./util":55}],36:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],37:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports={run:run,cleanup:cleanup};function run(g){var root=util.addDummyNode(g,"root",{},"_root"),depths=treeDepths(g),height=_.max(depths)-1,nodeSep=2*height+1;g.graph().nestingRoot=root;_.each(g.edges(),function(e){g.edge(e).minlen*=nodeSep});var weight=sumWeights(g)+1;_.each(g.children(),function(child){dfs(g,root,nodeSep,weight,height,depths,child)});g.graph().nodeRankFactor=nodeSep}function dfs(g,root,nodeSep,weight,height,depths,v){var children=g.children(v);if(!children.length){if(v!==root){g.setEdge(root,v,{weight:0,minlen:nodeSep})}return}var top=util.addBorderNode(g,"_bt"),bottom=util.addBorderNode(g,"_bb"),label=g.node(v);g.setParent(top,v);label.borderTop=top;g.setParent(bottom,v);label.borderBottom=bottom;_.each(children,function(child){dfs(g,root,nodeSep,weight,height,depths,child);var childNode=g.node(child),childTop=childNode.borderTop?childNode.borderTop:child,childBottom=childNode.borderBottom?childNode.borderBottom:child,thisWeight=childNode.borderTop?weight:2*weight,minlen=childTop!==childBottom?1:height-depths[v]+1;g.setEdge(top,childTop,{weight:thisWeight,minlen:minlen,nestingEdge:true});g.setEdge(childBottom,bottom,{weight:thisWeight,minlen:minlen,nestingEdge:true})});if(!g.parent(v)){g.setEdge(root,top,{weight:0,minlen:height+depths[v]})}}function treeDepths(g){var depths={};function dfs(v,depth){var children=g.children(v);if(children&&children.length){_.each(children,function(child){dfs(child,depth+1)})}depths[v]=depth}_.each(g.children(),function(v){dfs(v,1)});return depths}function sumWeights(g){return _.reduce(g.edges(),function(acc,e){return acc+g.edge(e).weight},0)}function cleanup(g){var graphLabel=g.graph();g.removeNode(graphLabel.nestingRoot);delete graphLabel.nestingRoot;_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.nestingEdge){g.removeEdge(e)}})}},{"./lodash":36,"./util":55}],38:[function(require,module,exports){"use strict";var _=require("./lodash"),util=require("./util");module.exports={run:run,undo:undo};function run(g){g.graph().dummyChains=[];_.each(g.edges(),function(edge){normalizeEdge(g,edge)})}function normalizeEdge(g,e){var v=e.v,vRank=g.node(v).rank,w=e.w,wRank=g.node(w).rank,name=e.name,edgeLabel=g.edge(e),labelRank=edgeLabel.labelRank;if(wRank===vRank+1)return;g.removeEdge(e);var dummy,attrs,i;for(i=0,++vRank;vRank0){if(index%2){weightSum+=tree[index+1]}index=index-1>>1;tree[index]+=entry.weight}cc+=entry.weight*weightSum}));return cc}},{"../lodash":36}],43:[function(require,module,exports){"use strict";var _=require("../lodash"),initOrder=require("./init-order"),crossCount=require("./cross-count"),sortSubgraph=require("./sort-subgraph"),buildLayerGraph=require("./build-layer-graph"),addSubgraphConstraints=require("./add-subgraph-constraints"),Graph=require("../graphlib").Graph,util=require("../util");module.exports=order;function order(g){var maxRank=util.maxRank(g),downLayerGraphs=buildLayerGraphs(g,_.range(1,maxRank+1),"inEdges"),upLayerGraphs=buildLayerGraphs(g,_.range(maxRank-1,-1,-1),"outEdges");var layering=initOrder(g);assignOrder(g,layering);var bestCC=Number.POSITIVE_INFINITY,best;for(var i=0,lastBest=0;lastBest<4;++i,++lastBest){sweepLayerGraphs(i%2?downLayerGraphs:upLayerGraphs,i%4>=2);layering=util.buildLayerMatrix(g);var cc=crossCount(g,layering);if(cc=vEntry.barycenter){mergeEntries(vEntry,uEntry)}}}function handleOut(vEntry){return function(wEntry){wEntry["in"].push(vEntry);if(--wEntry.indegree===0){sourceSet.push(wEntry)}}}while(sourceSet.length){var entry=sourceSet.pop();entries.push(entry);_.each(entry["in"].reverse(),handleIn(entry));_.each(entry.out,handleOut(entry))}return _.chain(entries).filter(function(entry){return!entry.merged}).map(function(entry){return _.pick(entry,["vs","i","barycenter","weight"])}).value()}function mergeEntries(target,source){var sum=0,weight=0;if(target.weight){sum+=target.barycenter*target.weight;weight+=target.weight}if(source.weight){sum+=source.barycenter*source.weight;weight+=source.weight}target.vs=source.vs.concat(target.vs);target.barycenter=sum/weight;target.weight=weight;target.i=Math.min(source.i,target.i);source.merged=true}},{"../lodash":36}],46:[function(require,module,exports){var _=require("../lodash"),barycenter=require("./barycenter"),resolveConflicts=require("./resolve-conflicts"),sort=require("./sort");module.exports=sortSubgraph;function sortSubgraph(g,v,cg,biasRight){var movable=g.children(v),node=g.node(v),bl=node?node.borderLeft:undefined,br=node?node.borderRight:undefined,subgraphs={};if(bl){movable=_.filter(movable,function(w){return w!==bl&&w!==br})}var barycenters=barycenter(g,movable);_.each(barycenters,function(entry){if(g.children(entry.v).length){var subgraphResult=sortSubgraph(g,entry.v,cg,biasRight);subgraphs[entry.v]=subgraphResult;if(_.has(subgraphResult,"barycenter")){mergeBarycenters(entry,subgraphResult)}}});var entries=resolveConflicts(barycenters,cg);expandSubgraphs(entries,subgraphs);var result=sort(entries,biasRight);if(bl){result.vs=_.flatten([bl,result.vs,br],true);if(g.predecessors(bl).length){var blPred=g.node(g.predecessors(bl)[0]),brPred=g.node(g.predecessors(br)[0]);if(!_.has(result,"barycenter")){result.barycenter=0;result.weight=0}result.barycenter=(result.barycenter*result.weight+blPred.order+brPred.order)/(result.weight+2);result.weight+=2}}return result}function expandSubgraphs(entries,subgraphs){_.each(entries,function(entry){entry.vs=_.flatten(entry.vs.map(function(v){if(subgraphs[v]){return subgraphs[v].vs}return v}),true)})}function mergeBarycenters(target,other){if(!_.isUndefined(target.barycenter)){target.barycenter=(target.barycenter*target.weight+other.barycenter*other.weight)/(target.weight+other.weight);target.weight+=other.weight}else{target.barycenter=other.barycenter;target.weight=other.weight}}},{"../lodash":36,"./barycenter":40,"./resolve-conflicts":45,"./sort":47}],47:[function(require,module,exports){var _=require("../lodash"),util=require("../util");module.exports=sort;function sort(entries,biasRight){var parts=util.partition(entries,function(entry){return _.has(entry,"barycenter")});var sortable=parts.lhs,unsortable=_.sortBy(parts.rhs,function(entry){return-entry.i}),vs=[],sum=0,weight=0,vsIndex=0;sortable.sort(compareWithBias(!!biasRight));vsIndex=consumeUnsortable(vs,unsortable,vsIndex);_.each(sortable,function(entry){vsIndex+=entry.vs.length;vs.push(entry.vs);sum+=entry.barycenter*entry.weight;weight+=entry.weight;vsIndex=consumeUnsortable(vs,unsortable,vsIndex)});var result={vs:_.flatten(vs,true)};if(weight){result.barycenter=sum/weight;result.weight=weight}return result}function consumeUnsortable(vs,unsortable,index){var last;while(unsortable.length&&(last=_.last(unsortable)).i<=index){unsortable.pop();vs.push(last.vs);index++}return index}function compareWithBias(bias){return function(entryV,entryW){if(entryV.barycenterentryW.barycenter){return 1}return!bias?entryV.i-entryW.i:entryW.i-entryV.i}}},{"../lodash":36,"../util":55}],48:[function(require,module,exports){var _=require("./lodash");module.exports=parentDummyChains;function parentDummyChains(g){var postorderNums=postorder(g);_.each(g.graph().dummyChains,function(v){var node=g.node(v),edgeObj=node.edgeObj,pathData=findPath(g,postorderNums,edgeObj.v,edgeObj.w),path=pathData.path,lca=pathData.lca,pathIdx=0,pathV=path[pathIdx],ascending=true;while(v!==edgeObj.w){node=g.node(v);if(ascending){while((pathV=path[pathIdx])!==lca&&g.node(pathV).maxRanklow||lim>postorderNums[parent].lim));lca=parent;parent=w;while((parent=g.parent(parent))!==lca){wPath.push(parent)}return{path:vPath.concat(wPath.reverse()),lca:lca}}function postorder(g){var result={},lim=0;function dfs(v){var low=lim;_.each(g.children(v),dfs);result[v]={low:low,lim:lim++}}_.each(g.children(),dfs);return result}},{"./lodash":36}],49:[function(require,module,exports){"use strict";var _=require("../lodash"),Graph=require("../graphlib").Graph,util=require("../util");module.exports={positionX:positionX,findType1Conflicts:findType1Conflicts,findType2Conflicts:findType2Conflicts,addConflict:addConflict,hasConflict:hasConflict,verticalAlignment:verticalAlignment,horizontalCompaction:horizontalCompaction,alignCoordinates:alignCoordinates,findSmallestWidthAlignment:findSmallestWidthAlignment,balance:balance};function findType1Conflicts(g,layering){var conflicts={};function visitLayer(prevLayer,layer){var k0=0,scanPos=0,prevLayerLength=prevLayer.length,lastNode=_.last(layer);_.each(layer,function(v,i){var w=findOtherInnerSegmentNode(g,v),k1=w?g.node(w).order:prevLayerLength;if(w||v===lastNode){_.each(layer.slice(scanPos,i+1),function(scanNode){_.each(g.predecessors(scanNode),function(u){var uLabel=g.node(u),uPos=uLabel.order;if((uPosnextNorthBorder)){addConflict(conflicts,u,v)}})}})}function visitLayer(north,south){var prevNorthPos=-1,nextNorthPos,southPos=0;_.each(south,function(v,southLookahead){if(g.node(v).dummy==="border"){var predecessors=g.predecessors(v);if(predecessors.length){nextNorthPos=g.node(predecessors[0]).order;scan(south,southPos,southLookahead,prevNorthPos,nextNorthPos);southPos=southLookahead;prevNorthPos=nextNorthPos}}scan(south,southPos,south.length,nextNorthPos,north.length)});return south}_.reduce(layering,visitLayer);return conflicts}function findOtherInnerSegmentNode(g,v){if(g.node(v).dummy){return _.find(g.predecessors(v),function(u){return g.node(u).dummy})}}function addConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}var conflictsV=conflicts[v];if(!conflictsV){conflicts[v]=conflictsV={}}conflictsV[w]=true}function hasConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}return _.has(conflicts[v],w)}function verticalAlignment(g,layering,conflicts,neighborFn){var root={},align={},pos={};_.each(layering,function(layer){_.each(layer,function(v,order){root[v]=v;align[v]=v;pos[v]=order})});_.each(layering,function(layer){var prevIdx=-1;_.each(layer,function(v){var ws=neighborFn(v);if(ws.length){ws=_.sortBy(ws,function(w){return pos[w]});var mp=(ws.length-1)/2;for(var i=Math.floor(mp),il=Math.ceil(mp);i<=il;++i){var w=ws[i];if(align[v]===v&&prevIdxwLabel.lim){tailLabel=wLabel;flip=true}var candidates=_.filter(g.edges(),function(edge){return flip===isDescendant(t,t.node(edge.v),tailLabel)&&flip!==isDescendant(t,t.node(edge.w),tailLabel)});return _.min(candidates,function(edge){return slack(g,edge)})}function exchangeEdges(t,g,e,f){var v=e.v,w=e.w;t.removeEdge(v,w);t.setEdge(f.v,f.w,{});initLowLimValues(t);initCutValues(t,g);updateRanks(t,g)}function updateRanks(t,g){var root=_.find(t.nodes(),function(v){return!g.node(v).parent}),vs=preorder(t,root);vs=vs.slice(1);_.each(vs,function(v){var parent=t.node(v).parent,edge=g.edge(v,parent),flipped=false;if(!edge){edge=g.edge(parent,v);flipped=true}g.node(v).rank=g.node(parent).rank+(flipped?edge.minlen:-edge.minlen)})}function isTreeEdge(tree,u,v){return tree.hasEdge(u,v)}function isDescendant(tree,vLabel,rootLabel){return rootLabel.low<=vLabel.lim&&vLabel.lim<=rootLabel.lim}},{"../graphlib":33,"../lodash":36,"../util":55,"./feasible-tree":51,"./util":54}],54:[function(require,module,exports){"use strict";var _=require("../lodash");module.exports={longestPath:longestPath,slack:slack};function longestPath(g){var visited={};function dfs(v){var label=g.node(v);if(_.has(visited,v)){return label.rank}visited[v]=true;var rank=_.min(_.map(g.outEdges(v),function(e){return dfs(e.w)-g.edge(e).minlen}));if(rank===Number.POSITIVE_INFINITY){rank=0}return label.rank=rank}_.each(g.sources(),dfs)}function slack(g,e){return g.node(e.w).rank-g.node(e.v).rank-g.edge(e).minlen}},{"../lodash":36}],55:[function(require,module,exports){"use strict";var _=require("./lodash"),Graph=require("./graphlib").Graph;module.exports={addDummyNode:addDummyNode,simplify:simplify,asNonCompoundGraph:asNonCompoundGraph,successorWeights:successorWeights,predecessorWeights:predecessorWeights,intersectRect:intersectRect,buildLayerMatrix:buildLayerMatrix,normalizeRanks:normalizeRanks,removeEmptyRanks:removeEmptyRanks,addBorderNode:addBorderNode,maxRank:maxRank,partition:partition,time:time,notime:notime};function addDummyNode(g,type,attrs,name){var v;do{v=_.uniqueId(name)}while(g.hasNode(v));attrs.dummy=type;g.setNode(v,attrs);return v}function simplify(g){var simplified=(new Graph).setGraph(g.graph());_.each(g.nodes(),function(v){simplified.setNode(v,g.node(v))});_.each(g.edges(),function(e){var simpleLabel=simplified.edge(e.v,e.w)||{weight:0,minlen:1},label=g.edge(e);simplified.setEdge(e.v,e.w,{weight:simpleLabel.weight+label.weight,minlen:Math.max(simpleLabel.minlen,label.minlen)})});return simplified}function asNonCompoundGraph(g){var simplified=new Graph({multigraph:g.isMultigraph()}).setGraph(g.graph());_.each(g.nodes(),function(v){if(!g.children(v).length){simplified.setNode(v,g.node(v))}});_.each(g.edges(),function(e){simplified.setEdge(e,g.edge(e))});return simplified}function successorWeights(g){var weightMap=_.map(g.nodes(),function(v){var sucs={};_.each(g.outEdges(v),function(e){sucs[e.w]=(sucs[e.w]||0)+g.edge(e).weight});return sucs});return _.zipObject(g.nodes(),weightMap)}function predecessorWeights(g){var weightMap=_.map(g.nodes(),function(v){var preds={};_.each(g.inEdges(v),function(e){preds[e.v]=(preds[e.v]||0)+g.edge(e).weight});return preds});return _.zipObject(g.nodes(),weightMap)}function intersectRect(rect,point){var x=rect.x;var y=rect.y;var dx=point.x-x;var dy=point.y-y;var w=rect.width/2;var h=rect.height/2;if(!dx&&!dy){throw new Error("Not possible to find intersection inside of the rectangle")}var sx,sy;if(Math.abs(dy)*w>Math.abs(dx)*h){if(dy<0){h=-h}sx=h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=w*dy/dx}return{x:x+sx,y:y+sy}}function buildLayerMatrix(g){var layering=_.map(_.range(maxRank(g)+1),function(){return[]});_.each(g.nodes(),function(v){var node=g.node(v),rank=node.rank;if(!_.isUndefined(rank)){layering[rank][node.order]=v}});return layering}function normalizeRanks(g){var min=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));_.each(g.nodes(),function(v){var node=g.node(v);if(_.has(node,"rank")){node.rank-=min}})}function removeEmptyRanks(g){var offset=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));var layers=[];_.each(g.nodes(),function(v){var rank=g.node(v).rank-offset;if(!_.has(layers,rank)){layers[rank]=[]}layers[rank].push(v)});var delta=0,nodeRankFactor=g.graph().nodeRankFactor;_.each(layers,function(vs,i){if(_.isUndefined(vs)&&i%nodeRankFactor!==0){--delta}else if(delta){_.each(vs,function(v){g.node(v).rank+=delta})}})}function addBorderNode(g,prefix,rank,order){var node={width:0,height:0};if(arguments.length>=4){node.rank=rank;node.order=order}return addDummyNode(g,"border",node,prefix)}function maxRank(g){return _.max(_.map(g.nodes(),function(v){var rank=g.node(v).rank;if(!_.isUndefined(rank)){return rank}}))}function partition(collection,fn){var result={lhs:[],rhs:[]};_.each(collection,function(value){if(fn(value)){result.lhs.push(value)}else{result.rhs.push(value)}});return result}function time(name,fn){var start=_.now();try{return fn()}finally{console.log(name+" time: "+(_.now()-start)+"ms")}}function notime(name,fn){return fn()}},{"./graphlib":33,"./lodash":36}],56:[function(require,module,exports){module.exports="0.7.1"},{}],57:[function(require,module,exports){var lib=require("./lib");module.exports={Graph:lib.Graph,json:require("./lib/json"),alg:require("./lib/alg"),version:lib.version}},{"./lib":73,"./lib/alg":64,"./lib/json":74}],58:[function(require,module,exports){var _=require("../lodash");module.exports=components;function components(g){var visited={},cmpts=[],cmpt;function dfs(v){if(_.has(visited,v))return;visited[v]=true;cmpt.push(v);_.each(g.successors(v),dfs);_.each(g.predecessors(v),dfs)}_.each(g.nodes(),function(v){cmpt=[];dfs(v);if(cmpt.length){cmpts.push(cmpt)}});return cmpts}},{"../lodash":75}],59:[function(require,module,exports){var _=require("../lodash");module.exports=dfs;function dfs(g,vs,order){if(!_.isArray(vs)){vs=[vs]}var acc=[],visited={};_.each(vs,function(v){if(!g.hasNode(v)){throw new Error("Graph does not have node: "+v)}doDfs(g,v,order==="post",visited,acc)});return acc}function doDfs(g,v,postorder,visited,acc){if(!_.has(visited,v)){visited[v]=true;if(!postorder){acc.push(v)}_.each(g.neighbors(v),function(w){doDfs(g,w,postorder,visited,acc)});if(postorder){acc.push(v)}}}},{"../lodash":75}],60:[function(require,module,exports){var dijkstra=require("./dijkstra"),_=require("../lodash");module.exports=dijkstraAll;function dijkstraAll(g,weightFunc,edgeFunc){return _.transform(g.nodes(),function(acc,v){acc[v]=dijkstra(g,v,weightFunc,edgeFunc)},{})}},{"../lodash":75,"./dijkstra":61}],61:[function(require,module,exports){var _=require("../lodash"),PriorityQueue=require("../data/priority-queue");module.exports=dijkstra;var DEFAULT_WEIGHT_FUNC=_.constant(1);function dijkstra(g,source,weightFn,edgeFn){return runDijkstra(g,String(source),weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runDijkstra(g,source,weightFn,edgeFn){var results={},pq=new PriorityQueue,v,vEntry;var updateNeighbors=function(edge){var w=edge.v!==v?edge.v:edge.w,wEntry=results[w],weight=weightFn(edge),distance=vEntry.distance+weight;if(weight<0){throw new Error("dijkstra does not allow negative edge weights. "+"Bad edge: "+edge+" Weight: "+weight)}if(distance0){v=pq.removeMin();vEntry=results[v];if(vEntry.distance===Number.POSITIVE_INFINITY){break}edgeFn(v).forEach(updateNeighbors)}return results}},{"../data/priority-queue":71,"../lodash":75}],62:[function(require,module,exports){var _=require("../lodash"),tarjan=require("./tarjan");module.exports=findCycles;function findCycles(g){return _.filter(tarjan(g),function(cmpt){return cmpt.length>1})}},{"../lodash":75,"./tarjan":69}],63:[function(require,module,exports){var _=require("../lodash");module.exports=floydWarshall;var DEFAULT_WEIGHT_FUNC=_.constant(1);function floydWarshall(g,weightFn,edgeFn){return runFloydWarshall(g,weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runFloydWarshall(g,weightFn,edgeFn){var results={},nodes=g.nodes();nodes.forEach(function(v){results[v]={};results[v][v]={distance:0};nodes.forEach(function(w){if(v!==w){results[v][w]={distance:Number.POSITIVE_INFINITY}}});edgeFn(v).forEach(function(edge){var w=edge.v===v?edge.w:edge.v,d=weightFn(edge);results[v][w]={distance:d,predecessor:v}})});nodes.forEach(function(k){var rowK=results[k];nodes.forEach(function(i){var rowI=results[i];nodes.forEach(function(j){var ik=rowI[k];var kj=rowK[j];var ij=rowI[j];var altDistance=ik.distance+kj.distance;if(altDistance0){v=pq.removeMin();if(_.has(parents,v)){result.setEdge(v,parents[v])}else if(init){throw new Error("Input graph is not connected: "+g)}else{init=true}g.nodeEdges(v).forEach(updateNeighbors)}return result}},{"../data/priority-queue":71,"../graph":72,"../lodash":75}],69:[function(require,module,exports){var _=require("../lodash");module.exports=tarjan;function tarjan(g){var index=0,stack=[],visited={},results=[];function dfs(v){var entry=visited[v]={onStack:true,lowlink:index,index:index++};stack.push(v);g.successors(v).forEach(function(w){if(!_.has(visited,w)){dfs(w);entry.lowlink=Math.min(entry.lowlink,visited[w].lowlink)}else if(visited[w].onStack){entry.lowlink=Math.min(entry.lowlink,visited[w].index)}});if(entry.lowlink===entry.index){var cmpt=[],w;do{w=stack.pop();visited[w].onStack=false;cmpt.push(w)}while(v!==w);results.push(cmpt)}}g.nodes().forEach(function(v){if(!_.has(visited,v)){dfs(v)}});return results}},{"../lodash":75}],70:[function(require,module,exports){var _=require("../lodash");module.exports=topsort;topsort.CycleException=CycleException;function topsort(g){var visited={},stack={},results=[];function visit(node){if(_.has(stack,node)){throw new CycleException}if(!_.has(visited,node)){stack[node]=true;visited[node]=true;_.each(g.predecessors(node),visit);delete stack[node];results.push(node)}}_.each(g.sinks(),visit);if(_.size(visited)!==g.nodeCount()){throw new CycleException}return results}function CycleException(){}},{"../lodash":75}],71:[function(require,module,exports){var _=require("../lodash");module.exports=PriorityQueue;function PriorityQueue(){this._arr=[];this._keyIndices={}}PriorityQueue.prototype.size=function(){return this._arr.length};PriorityQueue.prototype.keys=function(){return this._arr.map(function(x){return x.key})};PriorityQueue.prototype.has=function(key){return _.has(this._keyIndices,key)};PriorityQueue.prototype.priority=function(key){var index=this._keyIndices[key];if(index!==undefined){return this._arr[index].priority}};PriorityQueue.prototype.min=function(){if(this.size()===0){throw new Error("Queue underflow")}return this._arr[0].key};PriorityQueue.prototype.add=function(key,priority){var keyIndices=this._keyIndices;key=String(key);if(!_.has(keyIndices,key)){var arr=this._arr;var index=arr.length;keyIndices[key]=index;arr.push({key:key,priority:priority});this._decrease(index);return true}return false};PriorityQueue.prototype.removeMin=function(){this._swap(0,this._arr.length-1);var min=this._arr.pop();delete this._keyIndices[min.key];this._heapify(0);return min.key};PriorityQueue.prototype.decrease=function(key,priority){var index=this._keyIndices[key];if(priority>this._arr[index].priority){throw new Error("New priority is greater than current priority. "+"Key: "+key+" Old: "+this._arr[index].priority+" New: "+priority)}this._arr[index].priority=priority;this._decrease(index)};PriorityQueue.prototype._heapify=function(i){var arr=this._arr;var l=2*i,r=l+1,largest=i;if(l>1;if(arr[parent].priority1){this.setNode(v,value)}else{this.setNode(v)}},this);return this};Graph.prototype.setNode=function(v,value){if(_.has(this._nodes,v)){if(arguments.length>1){this._nodes[v]=value}return this}this._nodes[v]=arguments.length>1?value:this._defaultNodeLabelFn(v);if(this._isCompound){this._parent[v]=GRAPH_NODE;this._children[v]={};this._children[GRAPH_NODE][v]=true}this._in[v]={};this._preds[v]={};this._out[v]={};this._sucs[v]={};++this._nodeCount;return this};Graph.prototype.node=function(v){return this._nodes[v]};Graph.prototype.hasNode=function(v){return _.has(this._nodes,v)};Graph.prototype.removeNode=function(v){var self=this;if(_.has(this._nodes,v)){var removeEdge=function(e){self.removeEdge(self._edgeObjs[e])};delete this._nodes[v];if(this._isCompound){this._removeFromParentsChildList(v);delete this._parent[v];_.each(this.children(v),function(child){this.setParent(child)},this);delete this._children[v]}_.each(_.keys(this._in[v]),removeEdge);delete this._in[v];delete this._preds[v];_.each(_.keys(this._out[v]),removeEdge);delete this._out[v];delete this._sucs[v];--this._nodeCount}return this};Graph.prototype.setParent=function(v,parent){if(!this._isCompound){throw new Error("Cannot set parent in a non-compound graph")}if(_.isUndefined(parent)){parent=GRAPH_NODE}else{for(var ancestor=parent;!_.isUndefined(ancestor);ancestor=this.parent(ancestor)){if(ancestor===v){throw new Error("Setting "+parent+" as parent of "+v+" would create create a cycle")}}this.setNode(parent)}this.setNode(v);this._removeFromParentsChildList(v);this._parent[v]=parent;this._children[parent][v]=true;return this};Graph.prototype._removeFromParentsChildList=function(v){delete this._children[this._parent[v]][v]};Graph.prototype.parent=function(v){if(this._isCompound){var parent=this._parent[v];if(parent!==GRAPH_NODE){return parent}}};Graph.prototype.children=function(v){if(_.isUndefined(v)){v=GRAPH_NODE}if(this._isCompound){var children=this._children[v];if(children){return _.keys(children)}}else if(v===GRAPH_NODE){return this.nodes()}else if(this.hasNode(v)){return[]}};Graph.prototype.predecessors=function(v){var predsV=this._preds[v];if(predsV){return _.keys(predsV)}};Graph.prototype.successors=function(v){var sucsV=this._sucs[v];if(sucsV){return _.keys(sucsV)}};Graph.prototype.neighbors=function(v){var preds=this.predecessors(v);if(preds){return _.union(preds,this.successors(v))}};Graph.prototype.setDefaultEdgeLabel=function(newDefault){if(!_.isFunction(newDefault)){newDefault=_.constant(newDefault)}this._defaultEdgeLabelFn=newDefault;return this};Graph.prototype.edgeCount=function(){return this._edgeCount};Graph.prototype.edges=function(){return _.values(this._edgeObjs)};Graph.prototype.setPath=function(vs,value){var self=this,args=arguments;_.reduce(vs,function(v,w){if(args.length>1){self.setEdge(v,w,value)}else{self.setEdge(v,w)}return w});return this};Graph.prototype.setEdge=function(){var v,w,name,value,valueSpecified=false;if(_.isPlainObject(arguments[0])){v=arguments[0].v;w=arguments[0].w;name=arguments[0].name;if(arguments.length===2){value=arguments[1];valueSpecified=true}}else{v=arguments[0];w=arguments[1];name=arguments[3];if(arguments.length>2){value=arguments[2];valueSpecified=true}}v=""+v;w=""+w;if(!_.isUndefined(name)){name=""+name}var e=edgeArgsToId(this._isDirected,v,w,name);if(_.has(this._edgeLabels,e)){if(valueSpecified){this._edgeLabels[e]=value}return this}if(!_.isUndefined(name)&&!this._isMultigraph){throw new Error("Cannot set a named edge when isMultigraph = false")}this.setNode(v);this.setNode(w);this._edgeLabels[e]=valueSpecified?value:this._defaultEdgeLabelFn(v,w,name);var edgeObj=edgeArgsToObj(this._isDirected,v,w,name);v=edgeObj.v;w=edgeObj.w;Object.freeze(edgeObj);this._edgeObjs[e]=edgeObj;incrementOrInitEntry(this._preds[w],v);incrementOrInitEntry(this._sucs[v],w);this._in[w][e]=edgeObj;this._out[v][e]=edgeObj;this._edgeCount++;return this};Graph.prototype.edge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return this._edgeLabels[e]};Graph.prototype.hasEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return _.has(this._edgeLabels,e)};Graph.prototype.removeEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name),edge=this._edgeObjs[e];if(edge){v=edge.v;w=edge.w;delete this._edgeLabels[e];delete this._edgeObjs[e];decrementOrRemoveEntry(this._preds[w],v);decrementOrRemoveEntry(this._sucs[v],w);delete this._in[w][e];delete this._out[v][e];this._edgeCount--}return this};Graph.prototype.inEdges=function(v,u){var inV=this._in[v];if(inV){var edges=_.values(inV);if(!u){return edges}return _.filter(edges,function(edge){return edge.v===u})}};Graph.prototype.outEdges=function(v,w){var outV=this._out[v];if(outV){var edges=_.values(outV);if(!w){return edges}return _.filter(edges,function(edge){return edge.w===w})}};Graph.prototype.nodeEdges=function(v,w){var inEdges=this.inEdges(v,w);if(inEdges){return inEdges.concat(this.outEdges(v,w))}};function incrementOrInitEntry(map,k){if(_.has(map,k)){map[k]++}else{map[k]=1}}function decrementOrRemoveEntry(map,k){if(!--map[k]){delete map[k]}}function edgeArgsToId(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}return v+EDGE_KEY_DELIM+w+EDGE_KEY_DELIM+(_.isUndefined(name)?DEFAULT_EDGE_NAME:name)}function edgeArgsToObj(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}var edgeObj={v:v,w:w};if(name){edgeObj.name=name}return edgeObj}function edgeObjToId(isDirected,edgeObj){return edgeArgsToId(isDirected,edgeObj.v,edgeObj.w,edgeObj.name)}},{"./lodash":75}],73:[function(require,module,exports){module.exports={Graph:require("./graph"),version:require("./version")}},{"./graph":72,"./version":76}],74:[function(require,module,exports){var _=require("./lodash"),Graph=require("./graph");module.exports={write:write,read:read};function write(g){var json={options:{directed:g.isDirected(),multigraph:g.isMultigraph(),compound:g.isCompound()},nodes:writeNodes(g),edges:writeEdges(g)};if(!_.isUndefined(g.graph())){json.value=_.clone(g.graph())}return json}function writeNodes(g){return _.map(g.nodes(),function(v){var nodeValue=g.node(v),parent=g.parent(v),node={v:v};if(!_.isUndefined(nodeValue)){node.value=nodeValue}if(!_.isUndefined(parent)){node.parent=parent}return node})}function writeEdges(g){return _.map(g.edges(),function(e){var edgeValue=g.edge(e),edge={v:e.v,w:e.w};if(!_.isUndefined(e.name)){edge.name=e.name}if(!_.isUndefined(edgeValue)){edge.value=edgeValue}return edge})}function read(json){var g=new Graph(json.options).setGraph(json.value);_.each(json.nodes,function(entry){g.setNode(entry.v,entry.value);if(entry.parent){g.setParent(entry.v,entry.parent)}});_.each(json.edges,function(entry){g.setEdge({v:entry.v,w:entry.w,name:entry.name},entry.value)});return g}},{"./graph":72,"./lodash":75}],75:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],76:[function(require,module,exports){module.exports="1.0.1"},{}],77:[function(require,module,exports){(function(global){(function(){var undefined;var arrayPool=[],objectPool=[];var idCounter=0;var keyPrefix=+new Date+"";var largeArraySize=75;var maxPoolSize=40;var whitespace=" \f \ufeff"+"\n\r\u2028\u2029"+" ";var reEmptyStringLeading=/\b__p \+= '';/g,reEmptyStringMiddle=/\b(__p \+=) '' \+/g,reEmptyStringTrailing=/(__e\(.*?\)|\b__t\)) \+\n'';/g;var reEsTemplate=/\$\{([^\\}]*(?:\\.[^\\}]*)*)\}/g;var reFlags=/\w*$/;var reFuncName=/^\s*function[ \n\r\t]+\w/;var reInterpolate=/<%=([\s\S]+?)%>/g;var reLeadingSpacesAndZeros=RegExp("^["+whitespace+"]*0+(?=.$)");var reNoMatch=/($^)/;var reThis=/\bthis\b/;var reUnescapedString=/['\n\r\t\u2028\u2029\\]/g;var contextProps=["Array","Boolean","Date","Function","Math","Number","Object","RegExp","String","_","attachEvent","clearTimeout","isFinite","isNaN","parseInt","setTimeout"];var templateCounter=0;var argsClass="[object Arguments]",arrayClass="[object Array]",boolClass="[object Boolean]",dateClass="[object Date]",funcClass="[object Function]",numberClass="[object Number]",objectClass="[object Object]",regexpClass="[object RegExp]",stringClass="[object String]";var cloneableClasses={};cloneableClasses[funcClass]=false;cloneableClasses[argsClass]=cloneableClasses[arrayClass]=cloneableClasses[boolClass]=cloneableClasses[dateClass]=cloneableClasses[numberClass]=cloneableClasses[objectClass]=cloneableClasses[regexpClass]=cloneableClasses[stringClass]=true;var debounceOptions={leading:false,maxWait:0,trailing:false};var descriptor={configurable:false,enumerable:false,value:null,writable:false};var objectTypes={"boolean":false,"function":true,object:true,number:false,string:false,undefined:false};var stringEscapes={"\\":"\\","'":"'","\n":"n","\r":"r"," ":"t","\u2028":"u2028","\u2029":"u2029"};var root=objectTypes[typeof window]&&window||this;var freeExports=objectTypes[typeof exports]&&exports&&!exports.nodeType&&exports;var freeModule=objectTypes[typeof module]&&module&&!module.nodeType&&module;var moduleExports=freeModule&&freeModule.exports===freeExports&&freeExports;var freeGlobal=objectTypes[typeof global]&&global;if(freeGlobal&&(freeGlobal.global===freeGlobal||freeGlobal.window===freeGlobal)){root=freeGlobal}function baseIndexOf(array,value,fromIndex){var index=(fromIndex||0)-1,length=array?array.length:0;while(++index-1?0:-1:cache?0:-1}function cachePush(value){var cache=this.cache,type=typeof value;if(type=="boolean"||value==null){cache[value]=true}else{if(type!="number"&&type!="string"){type="object"}var key=type=="number"?value:keyPrefix+value,typeCache=cache[type]||(cache[type]={});if(type=="object"){(typeCache[key]||(typeCache[key]=[])).push(value)}else{typeCache[key]=true}}}function charAtCallback(value){return value.charCodeAt(0)}function compareAscending(a,b){var ac=a.criteria,bc=b.criteria,index=-1,length=ac.length;while(++indexother||typeof value=="undefined"){return 1}if(value/g,evaluate:/<%([\s\S]+?)%>/g,interpolate:reInterpolate,variable:"",imports:{_:lodash}};function baseBind(bindData){var func=bindData[0],partialArgs=bindData[2],thisArg=bindData[4];function bound(){if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(this instanceof bound){var thisBinding=baseCreate(func.prototype),result=func.apply(thisBinding,args||arguments);return isObject(result)?result:thisBinding}return func.apply(thisArg,args||arguments)}setBindData(bound,bindData);return bound}function baseClone(value,isDeep,callback,stackA,stackB){if(callback){var result=callback(value);if(typeof result!="undefined"){return result}}var isObj=isObject(value);if(isObj){var className=toString.call(value);if(!cloneableClasses[className]){return value}var ctor=ctorByClass[className];switch(className){case boolClass:case dateClass:return new ctor(+value);case numberClass:case stringClass:return new ctor(value);case regexpClass:result=ctor(value.source,reFlags.exec(value));result.lastIndex=value.lastIndex;return result}}else{return value}var isArr=isArray(value);if(isDeep){var initedStack=!stackA;stackA||(stackA=getArray());stackB||(stackB=getArray());var length=stackA.length;while(length--){if(stackA[length]==value){return stackB[length]}}result=isArr?ctor(value.length):{}}else{result=isArr?slice(value):assign({},value)}if(isArr){if(hasOwnProperty.call(value,"index")){result.index=value.index}if(hasOwnProperty.call(value,"input")){result.input=value.input}}if(!isDeep){return result}stackA.push(value);stackB.push(result);(isArr?forEach:forOwn)(value,function(objValue,key){result[key]=baseClone(objValue,isDeep,callback,stackA,stackB)});if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseCreate(prototype,properties){return isObject(prototype)?nativeCreate(prototype):{};
+})}function enterEdge(t,g,edge){var v=edge.v,w=edge.w;if(!g.hasEdge(v,w)){v=edge.w;w=edge.v}var vLabel=t.node(v),wLabel=t.node(w),tailLabel=vLabel,flip=false;if(vLabel.lim>wLabel.lim){tailLabel=wLabel;flip=true}var candidates=_.filter(g.edges(),function(edge){return flip===isDescendant(t,t.node(edge.v),tailLabel)&&flip!==isDescendant(t,t.node(edge.w),tailLabel)});return _.min(candidates,function(edge){return slack(g,edge)})}function exchangeEdges(t,g,e,f){var v=e.v,w=e.w;t.removeEdge(v,w);t.setEdge(f.v,f.w,{});initLowLimValues(t);initCutValues(t,g);updateRanks(t,g)}function updateRanks(t,g){var root=_.find(t.nodes(),function(v){return!g.node(v).parent}),vs=preorder(t,root);vs=vs.slice(1);_.each(vs,function(v){var parent=t.node(v).parent,edge=g.edge(v,parent),flipped=false;if(!edge){edge=g.edge(parent,v);flipped=true}g.node(v).rank=g.node(parent).rank+(flipped?edge.minlen:-edge.minlen)})}function isTreeEdge(tree,u,v){return tree.hasEdge(u,v)}function isDescendant(tree,vLabel,rootLabel){return rootLabel.low<=vLabel.lim&&vLabel.lim<=rootLabel.lim}},{"../graphlib":33,"../lodash":36,"../util":55,"./feasible-tree":51,"./util":54}],54:[function(require,module,exports){"use strict";var _=require("../lodash");module.exports={longestPath:longestPath,slack:slack};function longestPath(g){var visited={};function dfs(v){var label=g.node(v);if(_.has(visited,v)){return label.rank}visited[v]=true;var rank=_.min(_.map(g.outEdges(v),function(e){return dfs(e.w)-g.edge(e).minlen}));if(rank===Number.POSITIVE_INFINITY){rank=0}return label.rank=rank}_.each(g.sources(),dfs)}function slack(g,e){return g.node(e.w).rank-g.node(e.v).rank-g.edge(e).minlen}},{"../lodash":36}],55:[function(require,module,exports){"use strict";var _=require("./lodash"),Graph=require("./graphlib").Graph;module.exports={addDummyNode:addDummyNode,simplify:simplify,asNonCompoundGraph:asNonCompoundGraph,successorWeights:successorWeights,predecessorWeights:predecessorWeights,intersectRect:intersectRect,buildLayerMatrix:buildLayerMatrix,normalizeRanks:normalizeRanks,removeEmptyRanks:removeEmptyRanks,addBorderNode:addBorderNode,maxRank:maxRank,partition:partition,time:time,notime:notime};function addDummyNode(g,type,attrs,name){var v;do{v=_.uniqueId(name)}while(g.hasNode(v));attrs.dummy=type;g.setNode(v,attrs);return v}function simplify(g){var simplified=(new Graph).setGraph(g.graph());_.each(g.nodes(),function(v){simplified.setNode(v,g.node(v))});_.each(g.edges(),function(e){var simpleLabel=simplified.edge(e.v,e.w)||{weight:0,minlen:1},label=g.edge(e);simplified.setEdge(e.v,e.w,{weight:simpleLabel.weight+label.weight,minlen:Math.max(simpleLabel.minlen,label.minlen)})});return simplified}function asNonCompoundGraph(g){var simplified=new Graph({multigraph:g.isMultigraph()}).setGraph(g.graph());_.each(g.nodes(),function(v){if(!g.children(v).length){simplified.setNode(v,g.node(v))}});_.each(g.edges(),function(e){simplified.setEdge(e,g.edge(e))});return simplified}function successorWeights(g){var weightMap=_.map(g.nodes(),function(v){var sucs={};_.each(g.outEdges(v),function(e){sucs[e.w]=(sucs[e.w]||0)+g.edge(e).weight});return sucs});return _.zipObject(g.nodes(),weightMap)}function predecessorWeights(g){var weightMap=_.map(g.nodes(),function(v){var preds={};_.each(g.inEdges(v),function(e){preds[e.v]=(preds[e.v]||0)+g.edge(e).weight});return preds});return _.zipObject(g.nodes(),weightMap)}function intersectRect(rect,point){var x=rect.x;var y=rect.y;var dx=point.x-x;var dy=point.y-y;var w=rect.width/2;var h=rect.height/2;if(!dx&&!dy){throw new Error("Not possible to find intersection inside of the rectangle")}var sx,sy;if(Math.abs(dy)*w>Math.abs(dx)*h){if(dy<0){h=-h}sx=h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=w*dy/dx}return{x:x+sx,y:y+sy}}function buildLayerMatrix(g){var layering=_.map(_.range(maxRank(g)+1),function(){return[]});_.each(g.nodes(),function(v){var node=g.node(v),rank=node.rank;if(!_.isUndefined(rank)){layering[rank][node.order]=v}});return layering}function normalizeRanks(g){var min=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));_.each(g.nodes(),function(v){var node=g.node(v);if(_.has(node,"rank")){node.rank-=min}})}function removeEmptyRanks(g){var offset=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));var layers=[];_.each(g.nodes(),function(v){var rank=g.node(v).rank-offset;if(!_.has(layers,rank)){layers[rank]=[]}layers[rank].push(v)});var delta=0,nodeRankFactor=g.graph().nodeRankFactor;_.each(layers,function(vs,i){if(_.isUndefined(vs)&&i%nodeRankFactor!==0){--delta}else if(delta){_.each(vs,function(v){g.node(v).rank+=delta})}})}function addBorderNode(g,prefix,rank,order){var node={width:0,height:0};if(arguments.length>=4){node.rank=rank;node.order=order}return addDummyNode(g,"border",node,prefix)}function maxRank(g){return _.max(_.map(g.nodes(),function(v){var rank=g.node(v).rank;if(!_.isUndefined(rank)){return rank}}))}function partition(collection,fn){var result={lhs:[],rhs:[]};_.each(collection,function(value){if(fn(value)){result.lhs.push(value)}else{result.rhs.push(value)}});return result}function time(name,fn){var start=_.now();try{return fn()}finally{console.log(name+" time: "+(_.now()-start)+"ms")}}function notime(name,fn){return fn()}},{"./graphlib":33,"./lodash":36}],56:[function(require,module,exports){module.exports="0.7.1"},{}],57:[function(require,module,exports){var lib=require("./lib");module.exports={Graph:lib.Graph,json:require("./lib/json"),alg:require("./lib/alg"),version:lib.version}},{"./lib":73,"./lib/alg":64,"./lib/json":74}],58:[function(require,module,exports){var _=require("../lodash");module.exports=components;function components(g){var visited={},cmpts=[],cmpt;function dfs(v){if(_.has(visited,v))return;visited[v]=true;cmpt.push(v);_.each(g.successors(v),dfs);_.each(g.predecessors(v),dfs)}_.each(g.nodes(),function(v){cmpt=[];dfs(v);if(cmpt.length){cmpts.push(cmpt)}});return cmpts}},{"../lodash":75}],59:[function(require,module,exports){var _=require("../lodash");module.exports=dfs;function dfs(g,vs,order){if(!_.isArray(vs)){vs=[vs]}var acc=[],visited={};_.each(vs,function(v){if(!g.hasNode(v)){throw new Error("Graph does not have node: "+v)}doDfs(g,v,order==="post",visited,acc)});return acc}function doDfs(g,v,postorder,visited,acc){if(!_.has(visited,v)){visited[v]=true;if(!postorder){acc.push(v)}_.each(g.neighbors(v),function(w){doDfs(g,w,postorder,visited,acc)});if(postorder){acc.push(v)}}}},{"../lodash":75}],60:[function(require,module,exports){var dijkstra=require("./dijkstra"),_=require("../lodash");module.exports=dijkstraAll;function dijkstraAll(g,weightFunc,edgeFunc){return _.transform(g.nodes(),function(acc,v){acc[v]=dijkstra(g,v,weightFunc,edgeFunc)},{})}},{"../lodash":75,"./dijkstra":61}],61:[function(require,module,exports){var _=require("../lodash"),PriorityQueue=require("../data/priority-queue");module.exports=dijkstra;var DEFAULT_WEIGHT_FUNC=_.constant(1);function dijkstra(g,source,weightFn,edgeFn){return runDijkstra(g,String(source),weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runDijkstra(g,source,weightFn,edgeFn){var results={},pq=new PriorityQueue,v,vEntry;var updateNeighbors=function(edge){var w=edge.v!==v?edge.v:edge.w,wEntry=results[w],weight=weightFn(edge),distance=vEntry.distance+weight;if(weight<0){throw new Error("dijkstra does not allow negative edge weights. "+"Bad edge: "+edge+" Weight: "+weight)}if(distance0){v=pq.removeMin();vEntry=results[v];if(vEntry.distance===Number.POSITIVE_INFINITY){break}edgeFn(v).forEach(updateNeighbors)}return results}},{"../data/priority-queue":71,"../lodash":75}],62:[function(require,module,exports){var _=require("../lodash"),tarjan=require("./tarjan");module.exports=findCycles;function findCycles(g){return _.filter(tarjan(g),function(cmpt){return cmpt.length>1})}},{"../lodash":75,"./tarjan":69}],63:[function(require,module,exports){var _=require("../lodash");module.exports=floydWarshall;var DEFAULT_WEIGHT_FUNC=_.constant(1);function floydWarshall(g,weightFn,edgeFn){return runFloydWarshall(g,weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runFloydWarshall(g,weightFn,edgeFn){var results={},nodes=g.nodes();nodes.forEach(function(v){results[v]={};results[v][v]={distance:0};nodes.forEach(function(w){if(v!==w){results[v][w]={distance:Number.POSITIVE_INFINITY}}});edgeFn(v).forEach(function(edge){var w=edge.v===v?edge.w:edge.v,d=weightFn(edge);results[v][w]={distance:d,predecessor:v}})});nodes.forEach(function(k){var rowK=results[k];nodes.forEach(function(i){var rowI=results[i];nodes.forEach(function(j){var ik=rowI[k];var kj=rowK[j];var ij=rowI[j];var altDistance=ik.distance+kj.distance;if(altDistance0){v=pq.removeMin();if(_.has(parents,v)){result.setEdge(v,parents[v])}else if(init){throw new Error("Input graph is not connected: "+g)}else{init=true}g.nodeEdges(v).forEach(updateNeighbors)}return result}},{"../data/priority-queue":71,"../graph":72,"../lodash":75}],69:[function(require,module,exports){var _=require("../lodash");module.exports=tarjan;function tarjan(g){var index=0,stack=[],visited={},results=[];function dfs(v){var entry=visited[v]={onStack:true,lowlink:index,index:index++};stack.push(v);g.successors(v).forEach(function(w){if(!_.has(visited,w)){dfs(w);entry.lowlink=Math.min(entry.lowlink,visited[w].lowlink)}else if(visited[w].onStack){entry.lowlink=Math.min(entry.lowlink,visited[w].index)}});if(entry.lowlink===entry.index){var cmpt=[],w;do{w=stack.pop();visited[w].onStack=false;cmpt.push(w)}while(v!==w);results.push(cmpt)}}g.nodes().forEach(function(v){if(!_.has(visited,v)){dfs(v)}});return results}},{"../lodash":75}],70:[function(require,module,exports){var _=require("../lodash");module.exports=topsort;topsort.CycleException=CycleException;function topsort(g){var visited={},stack={},results=[];function visit(node){if(_.has(stack,node)){throw new CycleException}if(!_.has(visited,node)){stack[node]=true;visited[node]=true;_.each(g.predecessors(node),visit);delete stack[node];results.push(node)}}_.each(g.sinks(),visit);if(_.size(visited)!==g.nodeCount()){throw new CycleException}return results}function CycleException(){}},{"../lodash":75}],71:[function(require,module,exports){var _=require("../lodash");module.exports=PriorityQueue;function PriorityQueue(){this._arr=[];this._keyIndices={}}PriorityQueue.prototype.size=function(){return this._arr.length};PriorityQueue.prototype.keys=function(){return this._arr.map(function(x){return x.key})};PriorityQueue.prototype.has=function(key){return _.has(this._keyIndices,key)};PriorityQueue.prototype.priority=function(key){var index=this._keyIndices[key];if(index!==undefined){return this._arr[index].priority}};PriorityQueue.prototype.min=function(){if(this.size()===0){throw new Error("Queue underflow")}return this._arr[0].key};PriorityQueue.prototype.add=function(key,priority){var keyIndices=this._keyIndices;key=String(key);if(!_.has(keyIndices,key)){var arr=this._arr;var index=arr.length;keyIndices[key]=index;arr.push({key:key,priority:priority});this._decrease(index);return true}return false};PriorityQueue.prototype.removeMin=function(){this._swap(0,this._arr.length-1);var min=this._arr.pop();delete this._keyIndices[min.key];this._heapify(0);return min.key};PriorityQueue.prototype.decrease=function(key,priority){var index=this._keyIndices[key];if(priority>this._arr[index].priority){throw new Error("New priority is greater than current priority. "+"Key: "+key+" Old: "+this._arr[index].priority+" New: "+priority)}this._arr[index].priority=priority;this._decrease(index)};PriorityQueue.prototype._heapify=function(i){var arr=this._arr;var l=2*i,r=l+1,largest=i;if(l>1;if(arr[parent].priority1){this.setNode(v,value)}else{this.setNode(v)}},this);return this};Graph.prototype.setNode=function(v,value){if(_.has(this._nodes,v)){if(arguments.length>1){this._nodes[v]=value}return this}this._nodes[v]=arguments.length>1?value:this._defaultNodeLabelFn(v);if(this._isCompound){this._parent[v]=GRAPH_NODE;this._children[v]={};this._children[GRAPH_NODE][v]=true}this._in[v]={};this._preds[v]={};this._out[v]={};this._sucs[v]={};++this._nodeCount;return this};Graph.prototype.node=function(v){return this._nodes[v]};Graph.prototype.hasNode=function(v){return _.has(this._nodes,v)};Graph.prototype.removeNode=function(v){var self=this;if(_.has(this._nodes,v)){var removeEdge=function(e){self.removeEdge(self._edgeObjs[e])};delete this._nodes[v];if(this._isCompound){this._removeFromParentsChildList(v);delete this._parent[v];_.each(this.children(v),function(child){this.setParent(child)},this);delete this._children[v]}_.each(_.keys(this._in[v]),removeEdge);delete this._in[v];delete this._preds[v];_.each(_.keys(this._out[v]),removeEdge);delete this._out[v];delete this._sucs[v];--this._nodeCount}return this};Graph.prototype.setParent=function(v,parent){if(!this._isCompound){throw new Error("Cannot set parent in a non-compound graph")}if(_.isUndefined(parent)){parent=GRAPH_NODE}else{for(var ancestor=parent;!_.isUndefined(ancestor);ancestor=this.parent(ancestor)){if(ancestor===v){throw new Error("Setting "+parent+" as parent of "+v+" would create create a cycle")}}this.setNode(parent)}this.setNode(v);this._removeFromParentsChildList(v);this._parent[v]=parent;this._children[parent][v]=true;return this};Graph.prototype._removeFromParentsChildList=function(v){delete this._children[this._parent[v]][v]};Graph.prototype.parent=function(v){if(this._isCompound){var parent=this._parent[v];if(parent!==GRAPH_NODE){return parent}}};Graph.prototype.children=function(v){if(_.isUndefined(v)){v=GRAPH_NODE}if(this._isCompound){var children=this._children[v];if(children){return _.keys(children)}}else if(v===GRAPH_NODE){return this.nodes()}else if(this.hasNode(v)){return[]}};Graph.prototype.predecessors=function(v){var predsV=this._preds[v];if(predsV){return _.keys(predsV)}};Graph.prototype.successors=function(v){var sucsV=this._sucs[v];if(sucsV){return _.keys(sucsV)}};Graph.prototype.neighbors=function(v){var preds=this.predecessors(v);if(preds){return _.union(preds,this.successors(v))}};Graph.prototype.setDefaultEdgeLabel=function(newDefault){if(!_.isFunction(newDefault)){newDefault=_.constant(newDefault)}this._defaultEdgeLabelFn=newDefault;return this};Graph.prototype.edgeCount=function(){return this._edgeCount};Graph.prototype.edges=function(){return _.values(this._edgeObjs)};Graph.prototype.setPath=function(vs,value){var self=this,args=arguments;_.reduce(vs,function(v,w){if(args.length>1){self.setEdge(v,w,value)}else{self.setEdge(v,w)}return w});return this};Graph.prototype.setEdge=function(){var v,w,name,value,valueSpecified=false;if(_.isPlainObject(arguments[0])){v=arguments[0].v;w=arguments[0].w;name=arguments[0].name;if(arguments.length===2){value=arguments[1];valueSpecified=true}}else{v=arguments[0];w=arguments[1];name=arguments[3];if(arguments.length>2){value=arguments[2];valueSpecified=true}}v=""+v;w=""+w;if(!_.isUndefined(name)){name=""+name}var e=edgeArgsToId(this._isDirected,v,w,name);if(_.has(this._edgeLabels,e)){if(valueSpecified){this._edgeLabels[e]=value}return this}if(!_.isUndefined(name)&&!this._isMultigraph){throw new Error("Cannot set a named edge when isMultigraph = false")}this.setNode(v);this.setNode(w);this._edgeLabels[e]=valueSpecified?value:this._defaultEdgeLabelFn(v,w,name);var edgeObj=edgeArgsToObj(this._isDirected,v,w,name);v=edgeObj.v;w=edgeObj.w;Object.freeze(edgeObj);this._edgeObjs[e]=edgeObj;incrementOrInitEntry(this._preds[w],v);incrementOrInitEntry(this._sucs[v],w);this._in[w][e]=edgeObj;this._out[v][e]=edgeObj;this._edgeCount++;return this};Graph.prototype.edge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return this._edgeLabels[e]};Graph.prototype.hasEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return _.has(this._edgeLabels,e)};Graph.prototype.removeEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name),edge=this._edgeObjs[e];if(edge){v=edge.v;w=edge.w;delete this._edgeLabels[e];delete this._edgeObjs[e];decrementOrRemoveEntry(this._preds[w],v);decrementOrRemoveEntry(this._sucs[v],w);delete this._in[w][e];delete this._out[v][e];this._edgeCount--}return this};Graph.prototype.inEdges=function(v,u){var inV=this._in[v];if(inV){var edges=_.values(inV);if(!u){return edges}return _.filter(edges,function(edge){return edge.v===u})}};Graph.prototype.outEdges=function(v,w){var outV=this._out[v];if(outV){var edges=_.values(outV);if(!w){return edges}return _.filter(edges,function(edge){return edge.w===w})}};Graph.prototype.nodeEdges=function(v,w){var inEdges=this.inEdges(v,w);if(inEdges){return inEdges.concat(this.outEdges(v,w))}};function incrementOrInitEntry(map,k){if(_.has(map,k)){map[k]++}else{map[k]=1}}function decrementOrRemoveEntry(map,k){if(!--map[k]){delete map[k]}}function edgeArgsToId(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}return v+EDGE_KEY_DELIM+w+EDGE_KEY_DELIM+(_.isUndefined(name)?DEFAULT_EDGE_NAME:name)}function edgeArgsToObj(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}var edgeObj={v:v,w:w};if(name){edgeObj.name=name}return edgeObj}function edgeObjToId(isDirected,edgeObj){return edgeArgsToId(isDirected,edgeObj.v,edgeObj.w,edgeObj.name)}},{"./lodash":75}],73:[function(require,module,exports){module.exports={Graph:require("./graph"),version:require("./version")}},{"./graph":72,"./version":76}],74:[function(require,module,exports){var _=require("./lodash"),Graph=require("./graph");module.exports={write:write,read:read};function write(g){var json={options:{directed:g.isDirected(),multigraph:g.isMultigraph(),compound:g.isCompound()},nodes:writeNodes(g),edges:writeEdges(g)};if(!_.isUndefined(g.graph())){json.value=_.clone(g.graph())}return json}function writeNodes(g){return _.map(g.nodes(),function(v){var nodeValue=g.node(v),parent=g.parent(v),node={v:v};if(!_.isUndefined(nodeValue)){node.value=nodeValue}if(!_.isUndefined(parent)){node.parent=parent}return node})}function writeEdges(g){return _.map(g.edges(),function(e){var edgeValue=g.edge(e),edge={v:e.v,w:e.w};if(!_.isUndefined(e.name)){edge.name=e.name}if(!_.isUndefined(edgeValue)){edge.value=edgeValue}return edge})}function read(json){var g=new Graph(json.options).setGraph(json.value);_.each(json.nodes,function(entry){g.setNode(entry.v,entry.value);if(entry.parent){g.setParent(entry.v,entry.parent)}});_.each(json.edges,function(entry){g.setEdge({v:entry.v,w:entry.w,name:entry.name},entry.value)});return g}},{"./graph":72,"./lodash":75}],75:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],76:[function(require,module,exports){module.exports="1.0.1"},{}],77:[function(require,module,exports){(function(global){(function(){var undefined;var arrayPool=[],objectPool=[];var idCounter=0;var keyPrefix=+new Date+"";var largeArraySize=75;var maxPoolSize=40;var whitespace=" \f \ufeff"+"\n\r\u2028\u2029"+" ";var reEmptyStringLeading=/\b__p \+= '';/g,reEmptyStringMiddle=/\b(__p \+=) '' \+/g,reEmptyStringTrailing=/(__e\(.*?\)|\b__t\)) \+\n'';/g;var reEsTemplate=/\$\{([^\\}]*(?:\\.[^\\}]*)*)\}/g;var reFlags=/\w*$/;var reFuncName=/^\s*function[ \n\r\t]+\w/;var reInterpolate=/<%=([\s\S]+?)%>/g;var reLeadingSpacesAndZeros=RegExp("^["+whitespace+"]*0+(?=.$)");var reNoMatch=/($^)/;var reThis=/\bthis\b/;var reUnescapedString=/['\n\r\t\u2028\u2029\\]/g;var contextProps=["Array","Boolean","Date","Function","Math","Number","Object","RegExp","String","_","attachEvent","clearTimeout","isFinite","isNaN","parseInt","setTimeout"];var templateCounter=0;var argsClass="[object Arguments]",arrayClass="[object Array]",boolClass="[object Boolean]",dateClass="[object Date]",funcClass="[object Function]",numberClass="[object Number]",objectClass="[object Object]",regexpClass="[object RegExp]",stringClass="[object String]";var cloneableClasses={};cloneableClasses[funcClass]=false;cloneableClasses[argsClass]=cloneableClasses[arrayClass]=cloneableClasses[boolClass]=cloneableClasses[dateClass]=cloneableClasses[numberClass]=cloneableClasses[objectClass]=cloneableClasses[regexpClass]=cloneableClasses[stringClass]=true;var debounceOptions={leading:false,maxWait:0,trailing:false};var descriptor={configurable:false,enumerable:false,value:null,writable:false};var objectTypes={"boolean":false,"function":true,object:true,number:false,string:false,undefined:false};var stringEscapes={"\\":"\\","'":"'","\n":"n","\r":"r"," ":"t","\u2028":"u2028","\u2029":"u2029"};var root=objectTypes[typeof window]&&window||this;var freeExports=objectTypes[typeof exports]&&exports&&!exports.nodeType&&exports;var freeModule=objectTypes[typeof module]&&module&&!module.nodeType&&module;var moduleExports=freeModule&&freeModule.exports===freeExports&&freeExports;var freeGlobal=objectTypes[typeof global]&&global;if(freeGlobal&&(freeGlobal.global===freeGlobal||freeGlobal.window===freeGlobal)){root=freeGlobal}function baseIndexOf(array,value,fromIndex){var index=(fromIndex||0)-1,length=array?array.length:0;while(++index-1?0:-1:cache?0:-1}function cachePush(value){var cache=this.cache,type=typeof value;if(type=="boolean"||value==null){cache[value]=true}else{if(type!="number"&&type!="string"){type="object"}var key=type=="number"?value:keyPrefix+value,typeCache=cache[type]||(cache[type]={});if(type=="object"){(typeCache[key]||(typeCache[key]=[])).push(value)}else{typeCache[key]=true}}}function charAtCallback(value){return value.charCodeAt(0)}function compareAscending(a,b){var ac=a.criteria,bc=b.criteria,index=-1,length=ac.length;while(++indexother||typeof value=="undefined"){return 1}if(value/g,evaluate:/<%([\s\S]+?)%>/g,interpolate:reInterpolate,variable:"",imports:{_:lodash}};function baseBind(bindData){var func=bindData[0],partialArgs=bindData[2],thisArg=bindData[4];function bound(){if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(this instanceof bound){var thisBinding=baseCreate(func.prototype),result=func.apply(thisBinding,args||arguments);return isObject(result)?result:thisBinding}return func.apply(thisArg,args||arguments)}setBindData(bound,bindData);return bound}function baseClone(value,isDeep,callback,stackA,stackB){if(callback){var result=callback(value);if(typeof result!="undefined"){return result}}var isObj=isObject(value);if(isObj){var className=toString.call(value);if(!cloneableClasses[className]){return value}var ctor=ctorByClass[className];switch(className){case boolClass:case dateClass:return new ctor(+value);case numberClass:case stringClass:return new ctor(value);case regexpClass:result=ctor(value.source,reFlags.exec(value));result.lastIndex=value.lastIndex;return result}}else{return value}var isArr=isArray(value);if(isDeep){var initedStack=!stackA;stackA||(stackA=getArray());stackB||(stackB=getArray());var length=stackA.length;while(length--){if(stackA[length]==value){return stackB[length]}}result=isArr?ctor(value.length):{}}else{result=isArr?slice(value):assign({},value)}if(isArr){if(hasOwnProperty.call(value,"index")){result.index=value.index}if(hasOwnProperty.call(value,"input")){result.input=value.input}}if(!isDeep){return result}stackA.push(value);stackB.push(result);(isArr?forEach:forOwn)(value,function(objValue,key){result[key]=baseClone(objValue,isDeep,callback,stackA,stackB)});if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseCreate(prototype,properties){
+return isObject(prototype)?nativeCreate(prototype):{}}if(!nativeCreate){baseCreate=function(){function Object(){}return function(prototype){if(isObject(prototype)){Object.prototype=prototype;var result=new Object;Object.prototype=null}return result||context.Object()}}()}function baseCreateCallback(func,thisArg,argCount){if(typeof func!="function"){return identity}if(typeof thisArg=="undefined"||!("prototype"in func)){return func}var bindData=func.__bindData__;if(typeof bindData=="undefined"){if(support.funcNames){bindData=!func.name}bindData=bindData||!support.funcDecomp;if(!bindData){var source=fnToString.call(func);if(!support.funcNames){bindData=!reFuncName.test(source)}if(!bindData){bindData=reThis.test(source);setBindData(func,bindData)}}}if(bindData===false||bindData!==true&&bindData[1]&1){return func}switch(argCount){case 1:return function(value){return func.call(thisArg,value)};case 2:return function(a,b){return func.call(thisArg,a,b)};case 3:return function(value,index,collection){return func.call(thisArg,value,index,collection)};case 4:return function(accumulator,value,index,collection){return func.call(thisArg,accumulator,value,index,collection)}}return bind(func,thisArg)}function baseCreateWrapper(bindData){var func=bindData[0],bitmask=bindData[1],partialArgs=bindData[2],partialRightArgs=bindData[3],thisArg=bindData[4],arity=bindData[5];var isBind=bitmask&1,isBindKey=bitmask&2,isCurry=bitmask&4,isCurryBound=bitmask&8,key=func;function bound(){var thisBinding=isBind?thisArg:this;if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(partialRightArgs||isCurry){args||(args=slice(arguments));if(partialRightArgs){push.apply(args,partialRightArgs)}if(isCurry&&args.length=largeArraySize&&indexOf===baseIndexOf,result=[];if(isLarge){var cache=createCache(values);if(cache){indexOf=cacheIndexOf;values=cache}else{isLarge=false}}while(++index-1}})}}stackA.pop();stackB.pop();if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseMerge(object,source,callback,stackA,stackB){(isArray(source)?forEach:forOwn)(source,function(source,key){var found,isArr,result=source,value=object[key];if(source&&((isArr=isArray(source))||isPlainObject(source))){var stackLength=stackA.length;while(stackLength--){if(found=stackA[stackLength]==source){value=stackB[stackLength];break}}if(!found){var isShallow;if(callback){result=callback(value,source);if(isShallow=typeof result!="undefined"){value=result}}if(!isShallow){value=isArr?isArray(value)?value:[]:isPlainObject(value)?value:{}}stackA.push(source);stackB.push(value);if(!isShallow){baseMerge(value,source,callback,stackA,stackB)}}}else{if(callback){result=callback(value,source);if(typeof result=="undefined"){result=source}}if(typeof result!="undefined"){value=result}}object[key]=value})}function baseRandom(min,max){return min+floor(nativeRandom()*(max-min+1))}function baseUniq(array,isSorted,callback){var index=-1,indexOf=getIndexOf(),length=array?array.length:0,result=[];var isLarge=!isSorted&&length>=largeArraySize&&indexOf===baseIndexOf,seen=callback||isLarge?getArray():result;if(isLarge){var cache=createCache(seen);indexOf=cacheIndexOf;seen=cache}while(++index":">",'"':""","'":"'"};var htmlUnescapes=invert(htmlEscapes);var reEscapedHtml=RegExp("("+keys(htmlUnescapes).join("|")+")","g"),reUnescapedHtml=RegExp("["+keys(htmlEscapes).join("")+"]","g");var assign=function(object,source,guard){var index,iterable=object,result=iterable;if(!iterable)return result;var args=arguments,argsIndex=0,argsLength=typeof guard=="number"?2:args.length;if(argsLength>3&&typeof args[argsLength-2]=="function"){var callback=baseCreateCallback(args[--argsLength-1],args[argsLength--],2)}else if(argsLength>2&&typeof args[argsLength-1]=="function"){callback=args[--argsLength]}while(++argsIndex3&&typeof args[length-2]=="function"){var callback=baseCreateCallback(args[--length-1],args[length--],2)}else if(length>2&&typeof args[length-1]=="function"){callback=args[--length]}var sources=slice(arguments,1,length),index=-1,stackA=getArray(),stackB=getArray();while(++index-1}else if(typeof length=="number"){result=(isString(collection)?collection.indexOf(target,fromIndex):indexOf(collection,target,fromIndex))>-1}else{forOwn(collection,function(value){if(++index>=fromIndex){return!(result=value===target)}})}return result}var countBy=createAggregator(function(result,value,key){hasOwnProperty.call(result,key)?result[key]++:result[key]=1});function every(collection,callback,thisArg){var result=true;callback=lodash.createCallback(callback,thisArg,3);var index=-1,length=collection?collection.length:0;if(typeof length=="number"){while(++indexresult){result=value}}}else{callback=callback==null&&isString(collection)?charAtCallback:lodash.createCallback(callback,thisArg,3);forEach(collection,function(value,index,collection){var current=callback(value,index,collection);if(current>computed){computed=current;result=value}})}return result}function min(collection,callback,thisArg){var computed=Infinity,result=computed;if(typeof callback!="function"&&thisArg&&thisArg[callback]===collection){callback=null}if(callback==null&&isArray(collection)){var index=-1,length=collection.length;while(++index=largeArraySize&&createCache(argsIndex?args[argsIndex]:seen))}}var array=args[0],index=-1,length=array?array.length:0,result=[];outer:while(++index>>1;callback(array[mid])=largeArraySize&&indexOf===baseIndexOf,result=[];if(isLarge){var cache=createCache(values);if(cache){indexOf=cacheIndexOf;values=cache}else{isLarge=false}}while(++index-1}})}}stackA.pop();stackB.pop();if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseMerge(object,source,callback,stackA,stackB){(isArray(source)?forEach:forOwn)(source,function(source,key){var found,isArr,result=source,value=object[key];if(source&&((isArr=isArray(source))||isPlainObject(source))){var stackLength=stackA.length;while(stackLength--){if(found=stackA[stackLength]==source){value=stackB[stackLength];break}}if(!found){var isShallow;if(callback){result=callback(value,source);if(isShallow=typeof result!="undefined"){value=result}}if(!isShallow){value=isArr?isArray(value)?value:[]:isPlainObject(value)?value:{}}stackA.push(source);stackB.push(value);if(!isShallow){baseMerge(value,source,callback,stackA,stackB)}}}else{if(callback){result=callback(value,source);if(typeof result=="undefined"){result=source}}if(typeof result!="undefined"){value=result}}object[key]=value})}function baseRandom(min,max){return min+floor(nativeRandom()*(max-min+1))}function baseUniq(array,isSorted,callback){var index=-1,indexOf=getIndexOf(),length=array?array.length:0,result=[];var isLarge=!isSorted&&length>=largeArraySize&&indexOf===baseIndexOf,seen=callback||isLarge?getArray():result;if(isLarge){var cache=createCache(seen);indexOf=cacheIndexOf;seen=cache}while(++index":">",'"':""","'":"'"};var htmlUnescapes=invert(htmlEscapes);var reEscapedHtml=RegExp("("+keys(htmlUnescapes).join("|")+")","g"),reUnescapedHtml=RegExp("["+keys(htmlEscapes).join("")+"]","g");var assign=function(object,source,guard){var index,iterable=object,result=iterable;if(!iterable)return result;var args=arguments,argsIndex=0,argsLength=typeof guard=="number"?2:args.length;if(argsLength>3&&typeof args[argsLength-2]=="function"){var callback=baseCreateCallback(args[--argsLength-1],args[argsLength--],2)}else if(argsLength>2&&typeof args[argsLength-1]=="function"){callback=args[--argsLength]}while(++argsIndex3&&typeof args[length-2]=="function"){var callback=baseCreateCallback(args[--length-1],args[length--],2)}else if(length>2&&typeof args[length-1]=="function"){callback=args[--length]}var sources=slice(arguments,1,length),index=-1,stackA=getArray(),stackB=getArray();while(++index-1}else if(typeof length=="number"){result=(isString(collection)?collection.indexOf(target,fromIndex):indexOf(collection,target,fromIndex))>-1}else{forOwn(collection,function(value){if(++index>=fromIndex){return!(result=value===target)}})}return result}var countBy=createAggregator(function(result,value,key){hasOwnProperty.call(result,key)?result[key]++:result[key]=1});function every(collection,callback,thisArg){var result=true;callback=lodash.createCallback(callback,thisArg,3);var index=-1,length=collection?collection.length:0;if(typeof length=="number"){while(++indexresult){result=value}}}else{callback=callback==null&&isString(collection)?charAtCallback:lodash.createCallback(callback,thisArg,3);forEach(collection,function(value,index,collection){var current=callback(value,index,collection);if(current>computed){computed=current;result=value}})}return result}function min(collection,callback,thisArg){var computed=Infinity,result=computed;if(typeof callback!="function"&&thisArg&&thisArg[callback]===collection){callback=null}if(callback==null&&isArray(collection)){var index=-1,length=collection.length;while(++index=largeArraySize&&createCache(argsIndex?args[argsIndex]:seen))}}var array=args[0],index=-1,length=array?array.length:0,result=[];outer:while(++index>>1;callback(array[mid])1?arguments:arguments[0],index=-1,length=array?max(pluck(array,"length")):0,result=Array(length<0?0:length);while(++index2?createWrapper(func,17,slice(arguments,2),null,thisArg):createWrapper(func,1,null,null,thisArg)}function bindAll(object){var funcs=arguments.length>1?baseFlatten(arguments,true,false,1):functions(object),index=-1,length=funcs.length;while(++index2?createWrapper(key,19,slice(arguments,2),null,object):createWrapper(key,3,null,null,object)}function compose(){var funcs=arguments,length=funcs.length;while(length--){if(!isFunction(funcs[length])){throw new TypeError}}return function(){var args=arguments,length=funcs.length;while(length--){args=[funcs[length].apply(this,args)]}return args[0]}}function curry(func,arity){arity=typeof arity=="number"?arity:+arity||func.length;return createWrapper(func,4,null,null,null,arity)}function debounce(func,wait,options){var args,maxTimeoutId,result,stamp,thisArg,timeoutId,trailingCall,lastCalled=0,maxWait=false,trailing=true;if(!isFunction(func)){throw new TypeError}wait=nativeMax(0,wait)||0;if(options===true){var leading=true;trailing=false}else if(isObject(options)){leading=options.leading;maxWait="maxWait"in options&&(nativeMax(wait,options.maxWait)||0);trailing="trailing"in options?options.trailing:trailing}var delayed=function(){var remaining=wait-(now()-stamp);if(remaining<=0){if(maxTimeoutId){clearTimeout(maxTimeoutId)}var isCalled=trailingCall;maxTimeoutId=timeoutId=trailingCall=undefined;if(isCalled){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}}else{timeoutId=setTimeout(delayed,remaining)}};var maxDelayed=function(){if(timeoutId){clearTimeout(timeoutId)}maxTimeoutId=timeoutId=trailingCall=undefined;if(trailing||maxWait!==wait){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}};return function(){args=arguments;stamp=now();thisArg=this;trailingCall=trailing&&(timeoutId||!leading);if(maxWait===false){var leadingCall=leading&&!timeoutId}else{if(!maxTimeoutId&&!leading){lastCalled=stamp}var remaining=maxWait-(stamp-lastCalled),isCalled=remaining<=0;if(isCalled){if(maxTimeoutId){maxTimeoutId=clearTimeout(maxTimeoutId)}lastCalled=stamp;result=func.apply(thisArg,args)}else if(!maxTimeoutId){maxTimeoutId=setTimeout(maxDelayed,remaining)}}if(isCalled&&timeoutId){timeoutId=clearTimeout(timeoutId)}else if(!timeoutId&&wait!==maxWait){timeoutId=setTimeout(delayed,wait)}if(leadingCall){isCalled=true;result=func.apply(thisArg,args)}if(isCalled&&!timeoutId&&!maxTimeoutId){args=thisArg=null}return result}}function defer(func){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,1);return setTimeout(function(){func.apply(undefined,args)},1)}function delay(func,wait){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,2);return setTimeout(function(){func.apply(undefined,args)},wait)}function memoize(func,resolver){if(!isFunction(func)){throw new TypeError}var memoized=function(){var cache=memoized.cache,key=resolver?resolver.apply(this,arguments):keyPrefix+arguments[0];return hasOwnProperty.call(cache,key)?cache[key]:cache[key]=func.apply(this,arguments)};memoized.cache={};return memoized}function once(func){var ran,result;if(!isFunction(func)){throw new TypeError}return function(){if(ran){return result}ran=true;result=func.apply(this,arguments);func=null;return result}}function partial(func){return createWrapper(func,16,slice(arguments,1))}function partialRight(func){return createWrapper(func,32,null,slice(arguments,1))}function throttle(func,wait,options){var leading=true,trailing=true;if(!isFunction(func)){throw new TypeError}if(options===false){leading=false}else if(isObject(options)){leading="leading"in options?options.leading:leading;trailing="trailing"in options?options.trailing:trailing}debounceOptions.leading=leading;debounceOptions.maxWait=wait;debounceOptions.trailing=trailing;return debounce(func,wait,debounceOptions)}function wrap(value,wrapper){return createWrapper(wrapper,16,[value])}function constant(value){return function(){return value}}function createCallback(func,thisArg,argCount){var type=typeof func;if(func==null||type=="function"){return baseCreateCallback(func,thisArg,argCount)}if(type!="object"){return property(func)}var props=keys(func),key=props[0],a=func[key];if(props.length==1&&a===a&&!isObject(a)){return function(object){var b=object[key];return a===b&&(a!==0||1/a==1/b)}}return function(object){var length=props.length,result=false;while(length--){if(!(result=baseIsEqual(object[props[length]],func[props[length]],null,true))){break}}return result}}function escape(string){return string==null?"":String(string).replace(reUnescapedHtml,escapeHtmlChar)}function identity(value){return value}function mixin(object,source,options){var chain=true,methodNames=source&&functions(source);if(!source||!options&&!methodNames.length){if(options==null){options=source}ctor=lodashWrapper;source=object;object=lodash;methodNames=functions(source)}if(options===false){chain=false}else if(isObject(options)&&"chain"in options){chain=options.chain}var ctor=object,isFunc=isFunction(ctor);forEach(methodNames,function(methodName){var func=object[methodName]=source[methodName];if(isFunc){ctor.prototype[methodName]=function(){var chainAll=this.__chain__,value=this.__wrapped__,args=[value];push.apply(args,arguments);var result=func.apply(object,args);if(chain||chainAll){if(value===result&&isObject(result)){return this}result=new ctor(result);result.__chain__=chainAll}return result}}})}function noConflict(){context._=oldDash;return this}function noop(){}var now=isNative(now=Date.now)&&now||function(){return(new Date).getTime()};var parseInt=nativeParseInt(whitespace+"08")==8?nativeParseInt:function(value,radix){return nativeParseInt(isString(value)?value.replace(reLeadingSpacesAndZeros,""):value,radix||0)};function property(key){return function(object){return object[key]}}function random(min,max,floating){var noMin=min==null,noMax=max==null;if(floating==null){if(typeof min=="boolean"&&noMax){floating=min;min=1}else if(!noMax&&typeof max=="boolean"){floating=max;noMax=true}}if(noMin&&noMax){max=1}min=+min||0;if(noMax){max=min;min=0}else{max=+max||0}if(floating||min%1||max%1){var rand=nativeRandom();return nativeMin(min+rand*(max-min+parseFloat("1e-"+((rand+"").length-1))),max)}return baseRandom(min,max)}function result(object,key){if(object){var value=object[key];return isFunction(value)?object[key]():value}}function template(text,data,options){var settings=lodash.templateSettings;text=String(text||"");options=defaults({},options,settings);var imports=defaults({},options.imports,settings.imports),importsKeys=keys(imports),importsValues=values(imports);var isEvaluating,index=0,interpolate=options.interpolate||reNoMatch,source="__p += '";var reDelimiters=RegExp((options.escape||reNoMatch).source+"|"+interpolate.source+"|"+(interpolate===reInterpolate?reEsTemplate:reNoMatch).source+"|"+(options.evaluate||reNoMatch).source+"|$","g");text.replace(reDelimiters,function(match,escapeValue,interpolateValue,esTemplateValue,evaluateValue,offset){interpolateValue||(interpolateValue=esTemplateValue);source+=text.slice(index,offset).replace(reUnescapedString,escapeStringChar);if(escapeValue){source+="' +\n__e("+escapeValue+") +\n'"}if(evaluateValue){isEvaluating=true;source+="';\n"+evaluateValue+";\n__p += '"}if(interpolateValue){source+="' +\n((__t = ("+interpolateValue+")) == null ? '' : __t) +\n'"}index=offset+match.length;return match});source+="';\n";var variable=options.variable,hasVariable=variable;if(!hasVariable){variable="obj";source="with ("+variable+") {\n"+source+"\n}\n"}source=(isEvaluating?source.replace(reEmptyStringLeading,""):source).replace(reEmptyStringMiddle,"$1").replace(reEmptyStringTrailing,"$1;");source="function("+variable+") {\n"+(hasVariable?"":variable+" || ("+variable+" = {});\n")+"var __t, __p = '', __e = _.escape"+(isEvaluating?", __j = Array.prototype.join;\n"+"function print() { __p += __j.call(arguments, '') }\n":";\n")+source+"return __p\n}";var sourceURL="\n/*\n//# sourceURL="+(options.sourceURL||"/lodash/template/source["+templateCounter++ +"]")+"\n*/";try{var result=Function(importsKeys,"return "+source+sourceURL).apply(undefined,importsValues)}catch(e){e.source=source;throw e}if(data){return result(data)}result.source=source;return result}function times(n,callback,thisArg){n=(n=+n)>-1?n:0;var index=-1,result=Array(n);callback=baseCreateCallback(callback,thisArg,1);while(++index1?arguments:arguments[0],index=-1,length=array?max(pluck(array,"length")):0,result=Array(length<0?0:length);while(++index2?createWrapper(func,17,slice(arguments,2),null,thisArg):createWrapper(func,1,null,null,thisArg)}function bindAll(object){var funcs=arguments.length>1?baseFlatten(arguments,true,false,1):functions(object),index=-1,length=funcs.length;while(++index2?createWrapper(key,19,slice(arguments,2),null,object):createWrapper(key,3,null,null,object)}function compose(){var funcs=arguments,length=funcs.length;while(length--){if(!isFunction(funcs[length])){throw new TypeError}}return function(){var args=arguments,length=funcs.length;while(length--){args=[funcs[length].apply(this,args)]}return args[0]}}function curry(func,arity){arity=typeof arity=="number"?arity:+arity||func.length;return createWrapper(func,4,null,null,null,arity)}function debounce(func,wait,options){var args,maxTimeoutId,result,stamp,thisArg,timeoutId,trailingCall,lastCalled=0,maxWait=false,trailing=true;if(!isFunction(func)){throw new TypeError}wait=nativeMax(0,wait)||0;if(options===true){var leading=true;trailing=false}else if(isObject(options)){leading=options.leading;maxWait="maxWait"in options&&(nativeMax(wait,options.maxWait)||0);trailing="trailing"in options?options.trailing:trailing}var delayed=function(){var remaining=wait-(now()-stamp);if(remaining<=0){if(maxTimeoutId){clearTimeout(maxTimeoutId)}var isCalled=trailingCall;maxTimeoutId=timeoutId=trailingCall=undefined;if(isCalled){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}}else{timeoutId=setTimeout(delayed,remaining)}};var maxDelayed=function(){if(timeoutId){clearTimeout(timeoutId)}maxTimeoutId=timeoutId=trailingCall=undefined;if(trailing||maxWait!==wait){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}};return function(){args=arguments;stamp=now();thisArg=this;trailingCall=trailing&&(timeoutId||!leading);if(maxWait===false){var leadingCall=leading&&!timeoutId}else{if(!maxTimeoutId&&!leading){lastCalled=stamp}var remaining=maxWait-(stamp-lastCalled),isCalled=remaining<=0;if(isCalled){if(maxTimeoutId){maxTimeoutId=clearTimeout(maxTimeoutId)}lastCalled=stamp;result=func.apply(thisArg,args)}else if(!maxTimeoutId){maxTimeoutId=setTimeout(maxDelayed,remaining)}}if(isCalled&&timeoutId){timeoutId=clearTimeout(timeoutId)}else if(!timeoutId&&wait!==maxWait){timeoutId=setTimeout(delayed,wait)}if(leadingCall){isCalled=true;result=func.apply(thisArg,args)}if(isCalled&&!timeoutId&&!maxTimeoutId){args=thisArg=null}return result}}function defer(func){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,1);return setTimeout(function(){func.apply(undefined,args)},1)}function delay(func,wait){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,2);return setTimeout(function(){func.apply(undefined,args)},wait)}function memoize(func,resolver){if(!isFunction(func)){throw new TypeError}var memoized=function(){var cache=memoized.cache,key=resolver?resolver.apply(this,arguments):keyPrefix+arguments[0];return hasOwnProperty.call(cache,key)?cache[key]:cache[key]=func.apply(this,arguments)};memoized.cache={};return memoized}function once(func){var ran,result;if(!isFunction(func)){throw new TypeError}return function(){if(ran){return result}ran=true;result=func.apply(this,arguments);func=null;return result}}function partial(func){return createWrapper(func,16,slice(arguments,1))}function partialRight(func){return createWrapper(func,32,null,slice(arguments,1))}function throttle(func,wait,options){var leading=true,trailing=true;if(!isFunction(func)){throw new TypeError}if(options===false){leading=false}else if(isObject(options)){leading="leading"in options?options.leading:leading;trailing="trailing"in options?options.trailing:trailing}debounceOptions.leading=leading;debounceOptions.maxWait=wait;debounceOptions.trailing=trailing;return debounce(func,wait,debounceOptions)}function wrap(value,wrapper){return createWrapper(wrapper,16,[value])}function constant(value){return function(){return value}}function createCallback(func,thisArg,argCount){var type=typeof func;if(func==null||type=="function"){return baseCreateCallback(func,thisArg,argCount)}if(type!="object"){return property(func)}var props=keys(func),key=props[0],a=func[key];if(props.length==1&&a===a&&!isObject(a)){return function(object){var b=object[key];return a===b&&(a!==0||1/a==1/b)}}return function(object){var length=props.length,result=false;while(length--){if(!(result=baseIsEqual(object[props[length]],func[props[length]],null,true))){break}}return result}}function escape(string){return string==null?"":String(string).replace(reUnescapedHtml,escapeHtmlChar)}function identity(value){return value}function mixin(object,source,options){var chain=true,methodNames=source&&functions(source);if(!source||!options&&!methodNames.length){if(options==null){options=source}ctor=lodashWrapper;source=object;object=lodash;methodNames=functions(source)}if(options===false){chain=false}else if(isObject(options)&&"chain"in options){chain=options.chain}var ctor=object,isFunc=isFunction(ctor);forEach(methodNames,function(methodName){var func=object[methodName]=source[methodName];if(isFunc){ctor.prototype[methodName]=function(){var chainAll=this.__chain__,value=this.__wrapped__,args=[value];push.apply(args,arguments);var result=func.apply(object,args);if(chain||chainAll){if(value===result&&isObject(result)){return this}result=new ctor(result);result.__chain__=chainAll}return result}}})}function noConflict(){context._=oldDash;return this}function noop(){}var now=isNative(now=Date.now)&&now||function(){return(new Date).getTime()};var parseInt=nativeParseInt(whitespace+"08")==8?nativeParseInt:function(value,radix){return nativeParseInt(isString(value)?value.replace(reLeadingSpacesAndZeros,""):value,radix||0)};function property(key){return function(object){return object[key]}}function random(min,max,floating){var noMin=min==null,noMax=max==null;if(floating==null){if(typeof min=="boolean"&&noMax){floating=min;min=1}else if(!noMax&&typeof max=="boolean"){floating=max;noMax=true}}if(noMin&&noMax){max=1}min=+min||0;if(noMax){max=min;min=0}else{max=+max||0}if(floating||min%1||max%1){var rand=nativeRandom();return nativeMin(min+rand*(max-min+parseFloat("1e-"+((rand+"").length-1))),max)}return baseRandom(min,max)}function result(object,key){if(object){var value=object[key];return isFunction(value)?object[key]():value}}function template(text,data,options){var settings=lodash.templateSettings;text=String(text||"");options=defaults({},options,settings);var imports=defaults({},options.imports,settings.imports),importsKeys=keys(imports),importsValues=values(imports);var isEvaluating,index=0,interpolate=options.interpolate||reNoMatch,source="__p += '";var reDelimiters=RegExp((options.escape||reNoMatch).source+"|"+interpolate.source+"|"+(interpolate===reInterpolate?reEsTemplate:reNoMatch).source+"|"+(options.evaluate||reNoMatch).source+"|$","g");text.replace(reDelimiters,function(match,escapeValue,interpolateValue,esTemplateValue,evaluateValue,offset){interpolateValue||(interpolateValue=esTemplateValue);source+=text.slice(index,offset).replace(reUnescapedString,escapeStringChar);if(escapeValue){source+="' +\n__e("+escapeValue+") +\n'"}if(evaluateValue){isEvaluating=true;source+="';\n"+evaluateValue+";\n__p += '"}if(interpolateValue){source+="' +\n((__t = ("+interpolateValue+")) == null ? '' : __t) +\n'"}index=offset+match.length;return match});source+="';\n";var variable=options.variable,hasVariable=variable;if(!hasVariable){variable="obj";source="with ("+variable+") {\n"+source+"\n}\n"}source=(isEvaluating?source.replace(reEmptyStringLeading,""):source).replace(reEmptyStringMiddle,"$1").replace(reEmptyStringTrailing,"$1;");source="function("+variable+") {\n"+(hasVariable?"":variable+" || ("+variable+" = {});\n")+"var __t, __p = '', __e = _.escape"+(isEvaluating?", __j = Array.prototype.join;\n"+"function print() { __p += __j.call(arguments, '') }\n":";\n")+source+"return __p\n}";var sourceURL="\n/*\n//# sourceURL="+(options.sourceURL||"/lodash/template/source["+templateCounter++ +"]")+"\n*/";try{var result=Function(importsKeys,"return "+source+sourceURL).apply(undefined,importsValues)}catch(e){e.source=source;throw e}if(data){return result(data)}result.source=source;return result}function times(n,callback,thisArg){n=(n=+n)>-1?n:0;var index=-1,result=Array(n);callback=baseCreateCallback(callback,thisArg,1);while(++index tr > th {
+ padding-left: 18px;
+ padding-right: 18px;
+}
+
+table.dataTable th:active {
+ outline: none;
+}
+
+/* Scrolling */
+div.dataTables_scrollHead table {
+ margin-bottom: 0 !important;
+ border-bottom-left-radius: 0;
+ border-bottom-right-radius: 0;
+}
+
+div.dataTables_scrollHead table thead tr:last-child th:first-child,
+div.dataTables_scrollHead table thead tr:last-child td:first-child {
+ border-bottom-left-radius: 0 !important;
+ border-bottom-right-radius: 0 !important;
+}
+
+div.dataTables_scrollBody table {
+ border-top: none;
+ margin-top: 0 !important;
+ margin-bottom: 0 !important;
+}
+
+div.dataTables_scrollBody tbody tr:first-child th,
+div.dataTables_scrollBody tbody tr:first-child td {
+ border-top: none;
+}
+
+div.dataTables_scrollFoot table {
+ margin-top: 0 !important;
+ border-top: none;
+}
+
+/* Frustratingly the border-collapse:collapse used by Bootstrap makes the column
+ width calculations when using scrolling impossible to align columns. We have
+ to use separate
+ */
+table.table-bordered.dataTable {
+ border-collapse: separate !important;
+}
+table.table-bordered thead th,
+table.table-bordered thead td {
+ border-left-width: 0;
+ border-top-width: 0;
+}
+table.table-bordered tbody th,
+table.table-bordered tbody td {
+ border-left-width: 0;
+ border-bottom-width: 0;
+}
+table.table-bordered th:last-child,
+table.table-bordered td:last-child {
+ border-right-width: 0;
+}
+div.dataTables_scrollHead table.table-bordered {
+ border-bottom-width: 0;
+}
+
+
+
+
+/*
+ * TableTools styles
+ */
+.table.dataTable tbody tr.active td,
+.table.dataTable tbody tr.active th {
+ background-color: #08C;
+ color: white;
+}
+
+.table.dataTable tbody tr.active:hover td,
+.table.dataTable tbody tr.active:hover th {
+ background-color: #0075b0 !important;
+}
+
+.table.dataTable tbody tr.active th > a,
+.table.dataTable tbody tr.active td > a {
+ color: white;
+}
+
+.table-striped.dataTable tbody tr.active:nth-child(odd) td,
+.table-striped.dataTable tbody tr.active:nth-child(odd) th {
+ background-color: #017ebc;
+}
+
+table.DTTT_selectable tbody tr {
+ cursor: pointer;
+}
+
+div.DTTT .btn {
+ color: #333 !important;
+ font-size: 12px;
+}
+
+div.DTTT .btn:hover {
+ text-decoration: none !important;
+}
+
+ul.DTTT_dropdown.dropdown-menu {
+ z-index: 2003;
+}
+
+ul.DTTT_dropdown.dropdown-menu a {
+ color: #333 !important; /* needed only when demo_page.css is included */
+}
+
+ul.DTTT_dropdown.dropdown-menu li {
+ position: relative;
+}
+
+ul.DTTT_dropdown.dropdown-menu li:hover a {
+ background-color: #0088cc;
+ color: white !important;
+}
+
+div.DTTT_collection_background {
+ z-index: 2002;
+}
+
+/* TableTools information display */
+div.DTTT_print_info {
+ position: fixed;
+ top: 50%;
+ left: 50%;
+ width: 400px;
+ height: 150px;
+ margin-left: -200px;
+ margin-top: -75px;
+ text-align: center;
+ color: #333;
+ padding: 10px 30px;
+ opacity: 0.95;
+
+ background-color: white;
+ border: 1px solid rgba(0, 0, 0, 0.2);
+ border-radius: 6px;
+
+ -webkit-box-shadow: 0 3px 7px rgba(0, 0, 0, 0.5);
+ box-shadow: 0 3px 7px rgba(0, 0, 0, 0.5);
+}
+
+div.DTTT_print_info h6 {
+ font-weight: normal;
+ font-size: 28px;
+ line-height: 28px;
+ margin: 1em;
+}
+
+div.DTTT_print_info p {
+ font-size: 14px;
+ line-height: 20px;
+}
+
+div.dataTables_processing {
+ position: absolute;
+ top: 50%;
+ left: 50%;
+ width: 100%;
+ height: 60px;
+ margin-left: -50%;
+ margin-top: -25px;
+ padding-top: 20px;
+ padding-bottom: 20px;
+ text-align: center;
+ font-size: 1.2em;
+ background-color: white;
+ background: -webkit-gradient(linear, left top, right top, color-stop(0%, rgba(255,255,255,0)), color-stop(25%, rgba(255,255,255,0.9)), color-stop(75%, rgba(255,255,255,0.9)), color-stop(100%, rgba(255,255,255,0)));
+ background: -webkit-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);
+ background: -moz-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);
+ background: -ms-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);
+ background: -o-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);
+ background: linear-gradient(to right, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);
+}
+
+
+
+/*
+ * FixedColumns styles
+ */
+div.DTFC_LeftHeadWrapper table,
+div.DTFC_LeftFootWrapper table,
+div.DTFC_RightHeadWrapper table,
+div.DTFC_RightFootWrapper table,
+table.DTFC_Cloned tr.even {
+ background-color: white;
+ margin-bottom: 0;
+}
+
+div.DTFC_RightHeadWrapper table ,
+div.DTFC_LeftHeadWrapper table {
+ border-bottom: none !important;
+ margin-bottom: 0 !important;
+ border-top-right-radius: 0 !important;
+ border-bottom-left-radius: 0 !important;
+ border-bottom-right-radius: 0 !important;
+}
+
+div.DTFC_RightHeadWrapper table thead tr:last-child th:first-child,
+div.DTFC_RightHeadWrapper table thead tr:last-child td:first-child,
+div.DTFC_LeftHeadWrapper table thead tr:last-child th:first-child,
+div.DTFC_LeftHeadWrapper table thead tr:last-child td:first-child {
+ border-bottom-left-radius: 0 !important;
+ border-bottom-right-radius: 0 !important;
+}
+
+div.DTFC_RightBodyWrapper table,
+div.DTFC_LeftBodyWrapper table {
+ border-top: none;
+ margin: 0 !important;
+}
+
+div.DTFC_RightBodyWrapper tbody tr:first-child th,
+div.DTFC_RightBodyWrapper tbody tr:first-child td,
+div.DTFC_LeftBodyWrapper tbody tr:first-child th,
+div.DTFC_LeftBodyWrapper tbody tr:first-child td {
+ border-top: none;
+}
+
+div.DTFC_RightFootWrapper table,
+div.DTFC_LeftFootWrapper table {
+ border-top: none;
+ margin-top: 0 !important;
+}
+
+
+/*
+ * FixedHeader styles
+ */
+div.FixedHeader_Cloned table {
+ margin: 0 !important
+}
+
diff --git a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.min.js b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.min.js
new file mode 100644
index 000000000000..f0d09b9d5266
--- /dev/null
+++ b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.min.js
@@ -0,0 +1,8 @@
+/*!
+ DataTables Bootstrap 3 integration
+ ©2011-2014 SpryMedia Ltd - datatables.net/license
+*/
+(function(){var f=function(c,b){c.extend(!0,b.defaults,{dom:"<'row'<'col-sm-6'l><'col-sm-6'f>><'row'<'col-sm-12'tr>><'row'<'col-sm-6'i><'col-sm-6'p>>",renderer:"bootstrap"});c.extend(b.ext.classes,{sWrapper:"dataTables_wrapper form-inline dt-bootstrap",sFilterInput:"form-control input-sm",sLengthSelect:"form-control input-sm"});b.ext.renderer.pageButton.bootstrap=function(g,f,p,k,h,l){var q=new b.Api(g),r=g.oClasses,i=g.oLanguage.oPaginate,d,e,o=function(b,f){var j,m,n,a,k=function(a){a.preventDefault();
+c(a.currentTarget).hasClass("disabled")||q.page(a.data.action).draw(!1)};j=0;for(m=f.length;j",{"class":r.sPageButton+" "+
+e,"aria-controls":g.sTableId,tabindex:g.iTabIndex,id:0===p&&"string"===typeof a?g.sTableId+"_"+a:null}).append(c("",{href:"#"}).html(d)).appendTo(b),g.oApi._fnBindAction(n,{action:a},k))}};o(c(f).empty().html('
').children("ul"),k)};b.TableTools&&(c.extend(!0,b.TableTools.classes,{container:"DTTT btn-group",buttons:{normal:"btn btn-default",disabled:"disabled"},collection:{container:"DTTT_dropdown dropdown-menu",buttons:{normal:"",disabled:"disabled"}},print:{info:"DTTT_print_info"},
+select:{row:"active"}}),c.extend(!0,b.TableTools.DEFAULTS.oTags,{collection:{container:"ul",button:"li",liner:"a"}}))};"function"===typeof define&&define.amd?define(["jquery","datatables"],f):"object"===typeof exports?f(require("jquery"),require("datatables")):jQuery&&f(jQuery,jQuery.fn.dataTable)})(window,document);
diff --git a/core/src/main/resources/org/apache/spark/ui/static/dataTables.rowsGroup.js b/core/src/main/resources/org/apache/spark/ui/static/dataTables.rowsGroup.js
new file mode 100644
index 000000000000..983c3a564fb1
--- /dev/null
+++ b/core/src/main/resources/org/apache/spark/ui/static/dataTables.rowsGroup.js
@@ -0,0 +1,224 @@
+/*! RowsGroup for DataTables v1.0.0
+ * 2015 Alexey Shildyakov ashl1future@gmail.com
+ */
+
+/**
+ * @summary RowsGroup
+ * @description Group rows by specified columns
+ * @version 1.0.0
+ * @file dataTables.rowsGroup.js
+ * @author Alexey Shildyakov (ashl1future@gmail.com)
+ * @contact ashl1future@gmail.com
+ * @copyright Alexey Shildyakov
+ *
+ * License MIT - http://datatables.net/license/mit
+ *
+ * This feature plug-in for DataTables automatically merges columns cells
+ * based on it's values equality. It supports multi-column row grouping
+ * in according to the requested order with dependency from each previous
+ * requested columns. Now it supports ordering and searching.
+ * Please see the example.html for details.
+ *
+ * Rows grouping in DataTables can be enabled by using any one of the following
+ * options:
+ *
+ * * Setting the `rowsGroup` parameter in the DataTables initialisation
+ * to array which contains columns selectors
+ * (https://datatables.net/reference/type/column-selector) used for grouping. i.e.
+ * rowsGroup = [1, 'columnName:name', ]
+ * * Setting the `rowsGroup` parameter in the DataTables defaults
+ * (thus causing all tables to have this feature) - i.e.
+ * `$.fn.dataTable.defaults.RowsGroup = [0]`.
+ * * Creating a new instance: `new $.fn.dataTable.RowsGroup( table, columnsForGrouping );`
+ * where `table` is a DataTable's API instance and `columnsForGrouping` is the array
+ * described above.
+ *
+ * For more detailed information please see:
+ *
+ */
+
+(function($){
+
+ShowedDataSelectorModifier = {
+ order: 'current',
+ page: 'current',
+ search: 'applied',
+}
+
+GroupedColumnsOrderDir = 'desc'; // change
+
+
+/*
+ * columnsForGrouping: array of DTAPI:cell-selector for columns for which rows grouping is applied
+ */
+var RowsGroup = function ( dt, columnsForGrouping )
+{
+ this.table = dt.table();
+ this.columnsForGrouping = columnsForGrouping;
+ // set to True when new reorder is applied by RowsGroup to prevent order() looping
+ this.orderOverrideNow = false;
+ this.order = []
+
+ self = this;
+ $(document).on('order.dt', function ( e, settings) {
+ if (!self.orderOverrideNow) {
+ self._updateOrderAndDraw()
+ }
+ self.orderOverrideNow = false;
+ })
+
+ $(document).on('draw.dt', function ( e, settings) {
+ self._mergeCells()
+ })
+
+ this._updateOrderAndDraw();
+};
+
+
+RowsGroup.prototype = {
+ _getOrderWithGroupColumns: function (order, groupedColumnsOrderDir)
+ {
+ if (groupedColumnsOrderDir === undefined)
+ groupedColumnsOrderDir = GroupedColumnsOrderDir
+
+ var self = this;
+ var groupedColumnsIndexes = this.columnsForGrouping.map(function(columnSelector){
+ return self.table.column(columnSelector).index()
+ })
+ var groupedColumnsKnownOrder = order.filter(function(columnOrder){
+ return groupedColumnsIndexes.indexOf(columnOrder[0]) >= 0
+ })
+ var nongroupedColumnsOrder = order.filter(function(columnOrder){
+ return groupedColumnsIndexes.indexOf(columnOrder[0]) < 0
+ })
+ var groupedColumnsKnownOrderIndexes = groupedColumnsKnownOrder.map(function(columnOrder){
+ return columnOrder[0]
+ })
+ var groupedColumnsOrder = groupedColumnsIndexes.map(function(iColumn){
+ var iInOrderIndexes = groupedColumnsKnownOrderIndexes.indexOf(iColumn)
+ if (iInOrderIndexes >= 0)
+ return [iColumn, groupedColumnsKnownOrder[iInOrderIndexes][1]]
+ else
+ return [iColumn, groupedColumnsOrderDir]
+ })
+
+ groupedColumnsOrder.push.apply(groupedColumnsOrder, nongroupedColumnsOrder)
+ return groupedColumnsOrder;
+ },
+
+ // Workaround: the DT reset ordering to 'desc' from multi-ordering if user order on one column (without shift)
+ // but because we always has multi-ordering due to grouped rows this happens every time
+ _getInjectedMonoSelectWorkaround: function(order)
+ {
+ if (order.length === 1) {
+ // got mono order - workaround here
+ var orderingColumn = order[0][0]
+ var previousOrder = this.order.map(function(val){
+ return val[0]
+ })
+ var iColumn = previousOrder.indexOf(orderingColumn);
+ if (iColumn >= 0) {
+ // assume change the direction, because we already has that in previous order
+ return [[orderingColumn, this._toogleDirection(this.order[iColumn][1])]]
+ } // else This is the new ordering column. Proceed as is.
+ } // else got multi order - work normal
+ return order;
+ },
+
+ _mergeCells: function()
+ {
+ var columnsIndexes = this.table.columns(this.columnsForGrouping, ShowedDataSelectorModifier).indexes().toArray()
+ var showedRowsCount = this.table.rows(ShowedDataSelectorModifier)[0].length
+ this._mergeColumn(0, showedRowsCount - 1, columnsIndexes)
+ },
+
+ // the index is relative to the showed data
+ // (selector-modifier = {order: 'current', page: 'current', search: 'applied'}) index
+ _mergeColumn: function(iStartRow, iFinishRow, columnsIndexes)
+ {
+ var columnsIndexesCopy = columnsIndexes.slice()
+ currentColumn = columnsIndexesCopy.shift()
+ currentColumn = this.table.column(currentColumn, ShowedDataSelectorModifier)
+
+ var columnNodes = currentColumn.nodes()
+ var columnValues = currentColumn.data()
+
+ var newSequenceRow = iStartRow,
+ iRow;
+ for (iRow = iStartRow + 1; iRow <= iFinishRow; ++iRow) {
+
+ if (columnValues[iRow] === columnValues[newSequenceRow]) {
+ $(columnNodes[iRow]).hide()
+ } else {
+ $(columnNodes[newSequenceRow]).show()
+ $(columnNodes[newSequenceRow]).attr('rowspan', (iRow-1) - newSequenceRow + 1)
+
+ if (columnsIndexesCopy.length > 0)
+ this._mergeColumn(newSequenceRow, (iRow-1), columnsIndexesCopy)
+
+ newSequenceRow = iRow;
+ }
+
+ }
+ $(columnNodes[newSequenceRow]).show()
+ $(columnNodes[newSequenceRow]).attr('rowspan', (iRow-1)- newSequenceRow + 1)
+ if (columnsIndexesCopy.length > 0)
+ this._mergeColumn(newSequenceRow, (iRow-1), columnsIndexesCopy)
+ },
+
+ _toogleDirection: function(dir)
+ {
+ return dir == 'asc'? 'desc': 'asc';
+ },
+
+ _updateOrderAndDraw: function()
+ {
+ this.orderOverrideNow = true;
+
+ var currentOrder = this.table.order();
+ currentOrder = this._getInjectedMonoSelectWorkaround(currentOrder);
+ this.order = this._getOrderWithGroupColumns(currentOrder)
+ // this.table.order($.extend(true, Array(), this.order)) // disable this line in order to support sorting on non-grouped columns
+ this.table.draw(false)
+ },
+};
+
+
+$.fn.dataTable.RowsGroup = RowsGroup;
+$.fn.DataTable.RowsGroup = RowsGroup;
+
+// Automatic initialisation listener
+$(document).on( 'init.dt', function ( e, settings ) {
+ if ( e.namespace !== 'dt' ) {
+ return;
+ }
+
+ var api = new $.fn.dataTable.Api( settings );
+
+ if ( settings.oInit.rowsGroup ||
+ $.fn.dataTable.defaults.rowsGroup )
+ {
+ options = settings.oInit.rowsGroup?
+ settings.oInit.rowsGroup:
+ $.fn.dataTable.defaults.rowsGroup;
+ new RowsGroup( api, options );
+ }
+} );
+
+}(jQuery));
+
+/*
+
+TODO: Provide function which determines the all s and s with "rowspan" html-attribute is parent (groupped) for the specified or . To use in selections, editing or hover styles.
+
+TODO: Feature
+Use saved order direction for grouped columns
+ Split the columns into grouped and ungrouped.
+
+ user = grouped+ungrouped
+ grouped = grouped
+ saved = grouped+ungrouped
+
+ For grouped uses following order: user -> saved (because 'saved' include 'grouped' after first initialisation). This should be done with saving order like for 'groupedColumns'
+ For ungrouped: uses only 'user' input ordering
+*/
\ No newline at end of file
diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html b/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html
new file mode 100644
index 000000000000..5c91304e49fd
--- /dev/null
+++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html
@@ -0,0 +1,126 @@
+
+
+
diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js
new file mode 100644
index 000000000000..d430d8c5fb35
--- /dev/null
+++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js
@@ -0,0 +1,603 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+var threadDumpEnabled = false;
+
+function setThreadDumpEnabled(val) {
+ threadDumpEnabled = val;
+}
+
+function getThreadDumpEnabled() {
+ return threadDumpEnabled;
+}
+
+function formatStatus(status, type) {
+ if (status) {
+ return "Active"
+ } else {
+ return "Dead"
+ }
+}
+
+jQuery.extend(jQuery.fn.dataTableExt.oSort, {
+ "title-numeric-pre": function (a) {
+ var x = a.match(/title="*(-?[0-9\.]+)/)[1];
+ return parseFloat(x);
+ },
+
+ "title-numeric-asc": function (a, b) {
+ return ((a < b) ? -1 : ((a > b) ? 1 : 0));
+ },
+
+ "title-numeric-desc": function (a, b) {
+ return ((a < b) ? 1 : ((a > b) ? -1 : 0));
+ }
+});
+
+$(document).ajaxStop($.unblockUI);
+$(document).ajaxStart(function () {
+ $.blockUI({message: 'Loading Executors Page...
'});
+});
+
+function createTemplateURI(appId) {
+ var words = document.baseURI.split('/');
+ var ind = words.indexOf("proxy");
+ if (ind > 0) {
+ var baseURI = words.slice(0, ind + 1).join('/') + '/' + appId + '/static/executorspage-template.html';
+ return baseURI;
+ }
+ ind = words.indexOf("history");
+ if(ind > 0) {
+ var baseURI = words.slice(0, ind).join('/') + '/static/executorspage-template.html';
+ return baseURI;
+ }
+ return location.origin + "/static/executorspage-template.html";
+}
+
+function getStandAloneppId(cb) {
+ var words = document.baseURI.split('/');
+ var ind = words.indexOf("proxy");
+ if (ind > 0) {
+ var appId = words[ind + 1];
+ cb(appId);
+ return;
+ }
+ ind = words.indexOf("history");
+ if (ind > 0) {
+ var appId = words[ind + 1];
+ cb(appId);
+ return;
+ }
+ //Looks like Web UI is running in standalone mode
+ //Let's get application-id using REST End Point
+ $.getJSON(location.origin + "/api/v1/applications", function(response, status, jqXHR) {
+ if (response && response.length > 0) {
+ var appId = response[0].id
+ cb(appId);
+ return;
+ }
+ });
+}
+
+function createRESTEndPoint(appId) {
+ var words = document.baseURI.split('/');
+ var ind = words.indexOf("proxy");
+ if (ind > 0) {
+ var appId = words[ind + 1];
+ var newBaseURI = words.slice(0, ind + 2).join('/');
+ return newBaseURI + "/api/v1/applications/" + appId + "/allexecutors"
+ }
+ ind = words.indexOf("history");
+ if (ind > 0) {
+ var appId = words[ind + 1];
+ var attemptId = words[ind + 2];
+ var newBaseURI = words.slice(0, ind).join('/');
+ if (isNaN(attemptId)) {
+ return newBaseURI + "/api/v1/applications/" + appId + "/allexecutors";
+ } else {
+ return newBaseURI + "/api/v1/applications/" + appId + "/" + attemptId + "/allexecutors";
+ }
+ }
+ return location.origin + "/api/v1/applications/" + appId + "/allexecutors";
+}
+
+function formatLogsCells(execLogs, type) {
+ if (type !== 'display') return Object.keys(execLogs);
+ if (!execLogs) return;
+ var result = '';
+ $.each(execLogs, function (logName, logUrl) {
+ result += ''
+ });
+ return result;
+}
+
+function logsExist(execs) {
+ return execs.some(function(exec) {
+ return !($.isEmptyObject(exec["executorLogs"]));
+ });
+}
+
+// Determine Color Opacity from 0.5-1
+// activeTasks range from 0 to maxTasks
+function activeTasksAlpha(activeTasks, maxTasks) {
+ return maxTasks > 0 ? ((activeTasks / maxTasks) * 0.5 + 0.5) : 1;
+}
+
+function activeTasksStyle(activeTasks, maxTasks) {
+ return activeTasks > 0 ? ("hsla(240, 100%, 50%, " + activeTasksAlpha(activeTasks, maxTasks) + ")") : "";
+}
+
+// failedTasks range max at 10% failure, alpha max = 1
+function failedTasksAlpha(failedTasks, totalTasks) {
+ return totalTasks > 0 ?
+ (Math.min(10 * failedTasks / totalTasks, 1) * 0.5 + 0.5) : 1;
+}
+
+function failedTasksStyle(failedTasks, totalTasks) {
+ return failedTasks > 0 ?
+ ("hsla(0, 100%, 50%, " + failedTasksAlpha(failedTasks, totalTasks) + ")") : "";
+}
+
+// totalDuration range from 0 to 50% GC time, alpha max = 1
+function totalDurationAlpha(totalGCTime, totalDuration) {
+ return totalDuration > 0 ?
+ (Math.min(totalGCTime / totalDuration + 0.5, 1)) : 1;
+}
+
+// When GCTimePercent is edited change ToolTips.TASK_TIME to match
+var GCTimePercent = 0.1;
+
+function totalDurationStyle(totalGCTime, totalDuration) {
+ // Red if GC time over GCTimePercent of total time
+ return (totalGCTime > GCTimePercent * totalDuration) ?
+ ("hsla(0, 100%, 50%, " + totalDurationAlpha(totalGCTime, totalDuration) + ")") : "";
+}
+
+function totalDurationColor(totalGCTime, totalDuration) {
+ return (totalGCTime > GCTimePercent * totalDuration) ? "white" : "black";
+}
+
+$(document).ready(function () {
+ $.extend($.fn.dataTable.defaults, {
+ stateSave: true,
+ lengthMenu: [[20, 40, 60, 100, -1], [20, 40, 60, 100, "All"]],
+ pageLength: 20
+ });
+
+ executorsSummary = $("#active-executors");
+
+ getStandAloneppId(function (appId) {
+
+ var endPoint = createRESTEndPoint(appId);
+ $.getJSON(endPoint, function (response, status, jqXHR) {
+ var summary = [];
+ var allExecCnt = 0;
+ var allRDDBlocks = 0;
+ var allMemoryUsed = 0;
+ var allMaxMemory = 0;
+ var allOnHeapMemoryUsed = 0;
+ var allOnHeapMaxMemory = 0;
+ var allOffHeapMemoryUsed = 0;
+ var allOffHeapMaxMemory = 0;
+ var allDiskUsed = 0;
+ var allTotalCores = 0;
+ var allMaxTasks = 0;
+ var allActiveTasks = 0;
+ var allFailedTasks = 0;
+ var allCompletedTasks = 0;
+ var allTotalTasks = 0;
+ var allTotalDuration = 0;
+ var allTotalGCTime = 0;
+ var allTotalInputBytes = 0;
+ var allTotalShuffleRead = 0;
+ var allTotalShuffleWrite = 0;
+ var allTotalBlacklisted = 0;
+
+ var activeExecCnt = 0;
+ var activeRDDBlocks = 0;
+ var activeMemoryUsed = 0;
+ var activeMaxMemory = 0;
+ var activeOnHeapMemoryUsed = 0;
+ var activeOnHeapMaxMemory = 0;
+ var activeOffHeapMemoryUsed = 0;
+ var activeOffHeapMaxMemory = 0;
+ var activeDiskUsed = 0;
+ var activeTotalCores = 0;
+ var activeMaxTasks = 0;
+ var activeActiveTasks = 0;
+ var activeFailedTasks = 0;
+ var activeCompletedTasks = 0;
+ var activeTotalTasks = 0;
+ var activeTotalDuration = 0;
+ var activeTotalGCTime = 0;
+ var activeTotalInputBytes = 0;
+ var activeTotalShuffleRead = 0;
+ var activeTotalShuffleWrite = 0;
+ var activeTotalBlacklisted = 0;
+
+ var deadExecCnt = 0;
+ var deadRDDBlocks = 0;
+ var deadMemoryUsed = 0;
+ var deadMaxMemory = 0;
+ var deadOnHeapMemoryUsed = 0;
+ var deadOnHeapMaxMemory = 0;
+ var deadOffHeapMemoryUsed = 0;
+ var deadOffHeapMaxMemory = 0;
+ var deadDiskUsed = 0;
+ var deadTotalCores = 0;
+ var deadMaxTasks = 0;
+ var deadActiveTasks = 0;
+ var deadFailedTasks = 0;
+ var deadCompletedTasks = 0;
+ var deadTotalTasks = 0;
+ var deadTotalDuration = 0;
+ var deadTotalGCTime = 0;
+ var deadTotalInputBytes = 0;
+ var deadTotalShuffleRead = 0;
+ var deadTotalShuffleWrite = 0;
+ var deadTotalBlacklisted = 0;
+
+ response.forEach(function (exec) {
+ var memoryMetrics = {
+ usedOnHeapStorageMemory: 0,
+ usedOffHeapStorageMemory: 0,
+ totalOnHeapStorageMemory: 0,
+ totalOffHeapStorageMemory: 0
+ };
+
+ exec.memoryMetrics = exec.hasOwnProperty('memoryMetrics') ? exec.memoryMetrics : memoryMetrics;
+ });
+
+ response.forEach(function (exec) {
+ allExecCnt += 1;
+ allRDDBlocks += exec.rddBlocks;
+ allMemoryUsed += exec.memoryUsed;
+ allMaxMemory += exec.maxMemory;
+ allOnHeapMemoryUsed += exec.memoryMetrics.usedOnHeapStorageMemory;
+ allOnHeapMaxMemory += exec.memoryMetrics.totalOnHeapStorageMemory;
+ allOffHeapMemoryUsed += exec.memoryMetrics.usedOffHeapStorageMemory;
+ allOffHeapMaxMemory += exec.memoryMetrics.totalOffHeapStorageMemory;
+ allDiskUsed += exec.diskUsed;
+ allTotalCores += exec.totalCores;
+ allMaxTasks += exec.maxTasks;
+ allActiveTasks += exec.activeTasks;
+ allFailedTasks += exec.failedTasks;
+ allCompletedTasks += exec.completedTasks;
+ allTotalTasks += exec.totalTasks;
+ allTotalDuration += exec.totalDuration;
+ allTotalGCTime += exec.totalGCTime;
+ allTotalInputBytes += exec.totalInputBytes;
+ allTotalShuffleRead += exec.totalShuffleRead;
+ allTotalShuffleWrite += exec.totalShuffleWrite;
+ allTotalBlacklisted += exec.isBlacklisted ? 1 : 0;
+ if (exec.isActive) {
+ activeExecCnt += 1;
+ activeRDDBlocks += exec.rddBlocks;
+ activeMemoryUsed += exec.memoryUsed;
+ activeMaxMemory += exec.maxMemory;
+ activeOnHeapMemoryUsed += exec.memoryMetrics.usedOnHeapStorageMemory;
+ activeOnHeapMaxMemory += exec.memoryMetrics.totalOnHeapStorageMemory;
+ activeOffHeapMemoryUsed += exec.memoryMetrics.usedOffHeapStorageMemory;
+ activeOffHeapMaxMemory += exec.memoryMetrics.totalOffHeapStorageMemory;
+ activeDiskUsed += exec.diskUsed;
+ activeTotalCores += exec.totalCores;
+ activeMaxTasks += exec.maxTasks;
+ activeActiveTasks += exec.activeTasks;
+ activeFailedTasks += exec.failedTasks;
+ activeCompletedTasks += exec.completedTasks;
+ activeTotalTasks += exec.totalTasks;
+ activeTotalDuration += exec.totalDuration;
+ activeTotalGCTime += exec.totalGCTime;
+ activeTotalInputBytes += exec.totalInputBytes;
+ activeTotalShuffleRead += exec.totalShuffleRead;
+ activeTotalShuffleWrite += exec.totalShuffleWrite;
+ activeTotalBlacklisted += exec.isBlacklisted ? 1 : 0;
+ } else {
+ deadExecCnt += 1;
+ deadRDDBlocks += exec.rddBlocks;
+ deadMemoryUsed += exec.memoryUsed;
+ deadMaxMemory += exec.maxMemory;
+ deadOnHeapMemoryUsed += exec.memoryMetrics.usedOnHeapStorageMemory;
+ deadOnHeapMaxMemory += exec.memoryMetrics.totalOnHeapStorageMemory;
+ deadOffHeapMemoryUsed += exec.memoryMetrics.usedOffHeapStorageMemory;
+ deadOffHeapMaxMemory += exec.memoryMetrics.totalOffHeapStorageMemory;
+ deadDiskUsed += exec.diskUsed;
+ deadTotalCores += exec.totalCores;
+ deadMaxTasks += exec.maxTasks;
+ deadActiveTasks += exec.activeTasks;
+ deadFailedTasks += exec.failedTasks;
+ deadCompletedTasks += exec.completedTasks;
+ deadTotalTasks += exec.totalTasks;
+ deadTotalDuration += exec.totalDuration;
+ deadTotalGCTime += exec.totalGCTime;
+ deadTotalInputBytes += exec.totalInputBytes;
+ deadTotalShuffleRead += exec.totalShuffleRead;
+ deadTotalShuffleWrite += exec.totalShuffleWrite;
+ deadTotalBlacklisted += exec.isBlacklisted ? 1 : 0;
+ }
+ });
+
+ var totalSummary = {
+ "execCnt": ( "Total(" + allExecCnt + ")"),
+ "allRDDBlocks": allRDDBlocks,
+ "allMemoryUsed": allMemoryUsed,
+ "allMaxMemory": allMaxMemory,
+ "allOnHeapMemoryUsed": allOnHeapMemoryUsed,
+ "allOnHeapMaxMemory": allOnHeapMaxMemory,
+ "allOffHeapMemoryUsed": allOffHeapMemoryUsed,
+ "allOffHeapMaxMemory": allOffHeapMaxMemory,
+ "allDiskUsed": allDiskUsed,
+ "allTotalCores": allTotalCores,
+ "allMaxTasks": allMaxTasks,
+ "allActiveTasks": allActiveTasks,
+ "allFailedTasks": allFailedTasks,
+ "allCompletedTasks": allCompletedTasks,
+ "allTotalTasks": allTotalTasks,
+ "allTotalDuration": allTotalDuration,
+ "allTotalGCTime": allTotalGCTime,
+ "allTotalInputBytes": allTotalInputBytes,
+ "allTotalShuffleRead": allTotalShuffleRead,
+ "allTotalShuffleWrite": allTotalShuffleWrite,
+ "allTotalBlacklisted": allTotalBlacklisted
+ };
+ var activeSummary = {
+ "execCnt": ( "Active(" + activeExecCnt + ")"),
+ "allRDDBlocks": activeRDDBlocks,
+ "allMemoryUsed": activeMemoryUsed,
+ "allMaxMemory": activeMaxMemory,
+ "allOnHeapMemoryUsed": activeOnHeapMemoryUsed,
+ "allOnHeapMaxMemory": activeOnHeapMaxMemory,
+ "allOffHeapMemoryUsed": activeOffHeapMemoryUsed,
+ "allOffHeapMaxMemory": activeOffHeapMaxMemory,
+ "allDiskUsed": activeDiskUsed,
+ "allTotalCores": activeTotalCores,
+ "allMaxTasks": activeMaxTasks,
+ "allActiveTasks": activeActiveTasks,
+ "allFailedTasks": activeFailedTasks,
+ "allCompletedTasks": activeCompletedTasks,
+ "allTotalTasks": activeTotalTasks,
+ "allTotalDuration": activeTotalDuration,
+ "allTotalGCTime": activeTotalGCTime,
+ "allTotalInputBytes": activeTotalInputBytes,
+ "allTotalShuffleRead": activeTotalShuffleRead,
+ "allTotalShuffleWrite": activeTotalShuffleWrite,
+ "allTotalBlacklisted": activeTotalBlacklisted
+ };
+ var deadSummary = {
+ "execCnt": ( "Dead(" + deadExecCnt + ")" ),
+ "allRDDBlocks": deadRDDBlocks,
+ "allMemoryUsed": deadMemoryUsed,
+ "allMaxMemory": deadMaxMemory,
+ "allOnHeapMemoryUsed": deadOnHeapMemoryUsed,
+ "allOnHeapMaxMemory": deadOnHeapMaxMemory,
+ "allOffHeapMemoryUsed": deadOffHeapMemoryUsed,
+ "allOffHeapMaxMemory": deadOffHeapMaxMemory,
+ "allDiskUsed": deadDiskUsed,
+ "allTotalCores": deadTotalCores,
+ "allMaxTasks": deadMaxTasks,
+ "allActiveTasks": deadActiveTasks,
+ "allFailedTasks": deadFailedTasks,
+ "allCompletedTasks": deadCompletedTasks,
+ "allTotalTasks": deadTotalTasks,
+ "allTotalDuration": deadTotalDuration,
+ "allTotalGCTime": deadTotalGCTime,
+ "allTotalInputBytes": deadTotalInputBytes,
+ "allTotalShuffleRead": deadTotalShuffleRead,
+ "allTotalShuffleWrite": deadTotalShuffleWrite,
+ "allTotalBlacklisted": deadTotalBlacklisted
+ };
+
+ var data = {executors: response, "execSummary": [activeSummary, deadSummary, totalSummary]};
+ $.get(createTemplateURI(appId), function (template) {
+
+ executorsSummary.append(Mustache.render($(template).filter("#executors-summary-template").html(), data));
+ var selector = "#active-executors-table";
+ var conf = {
+ "data": response,
+ "columns": [
+ {
+ data: function (row, type) {
+ return type !== 'display' ? (isNaN(row.id) ? 0 : row.id ) : row.id;
+ }
+ },
+ {data: 'hostPort'},
+ {data: 'isActive', render: function (data, type, row) {
+ if (row.isBlacklisted) return "Blacklisted";
+ else return formatStatus (data, type);
+ }
+ },
+ {data: 'rddBlocks'},
+ {
+ data: function (row, type) {
+ if (type !== 'display')
+ return row.memoryUsed;
+ else
+ return (formatBytes(row.memoryUsed, type) + ' / ' +
+ formatBytes(row.maxMemory, type));
+ }
+ },
+ {
+ data: function (row, type) {
+ if (type !== 'display')
+ return row.memoryMetrics.usedOnHeapStorageMemory;
+ else
+ return (formatBytes(row.memoryMetrics.usedOnHeapStorageMemory, type) + ' / ' +
+ formatBytes(row.memoryMetrics.totalOnHeapStorageMemory, type));
+ },
+ "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) {
+ $(nTd).addClass('on_heap_memory')
+ }
+ },
+ {
+ data: function (row, type) {
+ if (type !== 'display')
+ return row.memoryMetrics.usedOffHeapStorageMemory;
+ else
+ return (formatBytes(row.memoryMetrics.usedOffHeapStorageMemory, type) + ' / ' +
+ formatBytes(row.memoryMetrics.totalOffHeapStorageMemory, type));
+ },
+ "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) {
+ $(nTd).addClass('off_heap_memory')
+ }
+ },
+ {data: 'diskUsed', render: formatBytes},
+ {data: 'totalCores'},
+ {
+ data: 'activeTasks',
+ "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) {
+ if (sData > 0) {
+ $(nTd).css('color', 'white');
+ $(nTd).css('background', activeTasksStyle(oData.activeTasks, oData.maxTasks));
+ }
+ }
+ },
+ {
+ data: 'failedTasks',
+ "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) {
+ if (sData > 0) {
+ $(nTd).css('color', 'white');
+ $(nTd).css('background', failedTasksStyle(oData.failedTasks, oData.totalTasks));
+ }
+ }
+ },
+ {data: 'completedTasks'},
+ {data: 'totalTasks'},
+ {
+ data: function (row, type) {
+ return type === 'display' ? (formatDuration(row.totalDuration) + ' (' + formatDuration(row.totalGCTime) + ')') : row.totalDuration
+ },
+ "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) {
+ if (oData.totalDuration > 0) {
+ $(nTd).css('color', totalDurationColor(oData.totalGCTime, oData.totalDuration));
+ $(nTd).css('background', totalDurationStyle(oData.totalGCTime, oData.totalDuration));
+ }
+ }
+ },
+ {data: 'totalInputBytes', render: formatBytes},
+ {data: 'totalShuffleRead', render: formatBytes},
+ {data: 'totalShuffleWrite', render: formatBytes},
+ {name: 'executorLogsCol', data: 'executorLogs', render: formatLogsCells},
+ {
+ name: 'threadDumpCol',
+ data: 'id', render: function (data, type) {
+ return type === 'display' ? ("Thread Dump" ) : data;
+ }
+ }
+ ],
+ "order": [[0, "asc"]]
+ };
+
+ var dt = $(selector).DataTable(conf);
+ dt.column('executorLogsCol:name').visible(logsExist(response));
+ dt.column('threadDumpCol:name').visible(getThreadDumpEnabled());
+ $('#active-executors [data-toggle="tooltip"]').tooltip();
+
+ var sumSelector = "#summary-execs-table";
+ var sumConf = {
+ "data": [activeSummary, deadSummary, totalSummary],
+ "columns": [
+ {
+ data: 'execCnt',
+ "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) {
+ $(nTd).css('font-weight', 'bold');
+ }
+ },
+ {data: 'allRDDBlocks'},
+ {
+ data: function (row, type) {
+ if (type !== 'display')
+ return row.allMemoryUsed
+ else
+ return (formatBytes(row.allMemoryUsed, type) + ' / ' +
+ formatBytes(row.allMaxMemory, type));
+ }
+ },
+ {
+ data: function (row, type) {
+ if (type !== 'display')
+ return row.allOnHeapMemoryUsed;
+ else
+ return (formatBytes(row.allOnHeapMemoryUsed, type) + ' / ' +
+ formatBytes(row.allOnHeapMaxMemory, type));
+ },
+ "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) {
+ $(nTd).addClass('on_heap_memory')
+ }
+ },
+ {
+ data: function (row, type) {
+ if (type !== 'display')
+ return row.allOffHeapMemoryUsed;
+ else
+ return (formatBytes(row.allOffHeapMemoryUsed, type) + ' / ' +
+ formatBytes(row.allOffHeapMaxMemory, type));
+ },
+ "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) {
+ $(nTd).addClass('off_heap_memory')
+ }
+ },
+ {data: 'allDiskUsed', render: formatBytes},
+ {data: 'allTotalCores'},
+ {
+ data: 'allActiveTasks',
+ "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) {
+ if (sData > 0) {
+ $(nTd).css('color', 'white');
+ $(nTd).css('background', activeTasksStyle(oData.allActiveTasks, oData.allMaxTasks));
+ }
+ }
+ },
+ {
+ data: 'allFailedTasks',
+ "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) {
+ if (sData > 0) {
+ $(nTd).css('color', 'white');
+ $(nTd).css('background', failedTasksStyle(oData.allFailedTasks, oData.allTotalTasks));
+ }
+ }
+ },
+ {data: 'allCompletedTasks'},
+ {data: 'allTotalTasks'},
+ {
+ data: function (row, type) {
+ return type === 'display' ? (formatDuration(row.allTotalDuration, type) + ' (' + formatDuration(row.allTotalGCTime, type) + ')') : row.allTotalDuration
+ },
+ "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) {
+ if (oData.allTotalDuration > 0) {
+ $(nTd).css('color', totalDurationColor(oData.allTotalGCTime, oData.allTotalDuration));
+ $(nTd).css('background', totalDurationStyle(oData.allTotalGCTime, oData.allTotalDuration));
+ }
+ }
+ },
+ {data: 'allTotalInputBytes', render: formatBytes},
+ {data: 'allTotalShuffleRead', render: formatBytes},
+ {data: 'allTotalShuffleWrite', render: formatBytes},
+ {data: 'allTotalBlacklisted'}
+ ],
+ "paging": false,
+ "searching": false,
+ "info": false
+
+ };
+
+ $(sumSelector).DataTable(sumConf);
+ $('#execSummary [data-toggle="tooltip"]').tooltip();
+
+ });
+ });
+ });
+});
diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js b/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js
new file mode 100644
index 000000000000..55d540d8317a
--- /dev/null
+++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+$(document).ready(function() {
+ if ($('#last-updated').length) {
+ var lastUpdatedMillis = Number($('#last-updated').text());
+ var updatedDate = new Date(lastUpdatedMillis);
+ $('#last-updated').text(updatedDate.toLocaleDateString()+", "+updatedDate.toLocaleTimeString())
+ }
+});
diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html
new file mode 100644
index 000000000000..20cd7bfdb223
--- /dev/null
+++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html
@@ -0,0 +1,94 @@
+
+
+
diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js
new file mode 100644
index 000000000000..3e2bba8a8941
--- /dev/null
+++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+var appLimit = -1;
+
+function setAppLimit(val) {
+ appLimit = val;
+}
+
+function makeIdNumeric(id) {
+ var strs = id.split("_");
+ if (strs.length < 3) {
+ return id;
+ }
+ var appSeqNum = strs[2];
+ var resl = strs[0] + "_" + strs[1] + "_";
+ var diff = 10 - appSeqNum.length;
+ while (diff > 0) {
+ resl += "0"; // padding 0 before the app sequence number to make sure it has 10 characters
+ diff--;
+ }
+ resl += appSeqNum;
+ return resl;
+}
+
+function formatDate(date) {
+ if (date <= 0) return "-";
+ else return date.split(".")[0].replace("T", " ");
+}
+
+function getParameterByName(name, searchString) {
+ var regex = new RegExp("[\\?&]" + name + "=([^]*)"),
+ results = regex.exec(searchString);
+ return results === null ? "" : decodeURIComponent(results[1].replace(/\+/g, " "));
+}
+
+function removeColumnByName(columns, columnName) {
+ return columns.filter(function(col) {return col.name != columnName})
+}
+
+function getColumnIndex(columns, columnName) {
+ for(var i = 0; i < columns.length; i++) {
+ if (columns[i].name == columnName)
+ return i;
+ }
+ return -1;
+}
+
+jQuery.extend( jQuery.fn.dataTableExt.oSort, {
+ "title-numeric-pre": function ( a ) {
+ var x = a.match(/title="*(-?[0-9\.]+)/)[1];
+ return parseFloat( x );
+ },
+
+ "title-numeric-asc": function ( a, b ) {
+ return ((a < b) ? -1 : ((a > b) ? 1 : 0));
+ },
+
+ "title-numeric-desc": function ( a, b ) {
+ return ((a < b) ? 1 : ((a > b) ? -1 : 0));
+ }
+} );
+
+jQuery.extend( jQuery.fn.dataTableExt.oSort, {
+ "appid-numeric-pre": function ( a ) {
+ var x = a.match(/title="*(-?[0-9a-zA-Z\-\_]+)/)[1];
+ return makeIdNumeric(x);
+ },
+
+ "appid-numeric-asc": function ( a, b ) {
+ return ((a < b) ? -1 : ((a > b) ? 1 : 0));
+ },
+
+ "appid-numeric-desc": function ( a, b ) {
+ return ((a < b) ? 1 : ((a > b) ? -1 : 0));
+ }
+} );
+
+jQuery.extend( jQuery.fn.dataTableExt.ofnSearch, {
+ "appid-numeric": function ( a ) {
+ return a.replace(/[\r\n]/g, " ").replace(/<.*?>/g, "");
+ }
+} );
+
+$(document).ajaxStop($.unblockUI);
+$(document).ajaxStart(function(){
+ $.blockUI({ message: 'Loading history summary...
'});
+});
+
+$(document).ready(function() {
+ $.extend( $.fn.dataTable.defaults, {
+ stateSave: true,
+ lengthMenu: [[20,40,60,100,-1], [20, 40, 60, 100, "All"]],
+ pageLength: 20
+ });
+
+ historySummary = $("#history-summary");
+ searchString = historySummary["context"]["location"]["search"];
+ requestedIncomplete = getParameterByName("showIncomplete", searchString);
+ requestedIncomplete = (requestedIncomplete == "true" ? true : false);
+
+ $.getJSON("api/v1/applications?limit=" + appLimit, function(response,status,jqXHR) {
+ var array = [];
+ var hasMultipleAttempts = false;
+ for (i in response) {
+ var app = response[i];
+ if (app["attempts"][0]["completed"] == requestedIncomplete) {
+ continue; // if we want to show for Incomplete, we skip the completed apps; otherwise skip incomplete ones.
+ }
+ var id = app["id"];
+ var name = app["name"];
+ if (app["attempts"].length > 1) {
+ hasMultipleAttempts = true;
+ }
+ var num = app["attempts"].length;
+ for (j in app["attempts"]) {
+ var attempt = app["attempts"][j];
+ attempt["startTime"] = formatDate(attempt["startTime"]);
+ attempt["endTime"] = formatDate(attempt["endTime"]);
+ attempt["lastUpdated"] = formatDate(attempt["lastUpdated"]);
+ attempt["log"] = uiRoot + "/api/v1/applications/" + id + "/" +
+ (attempt.hasOwnProperty("attemptId") ? attempt["attemptId"] + "/" : "") + "logs";
+ attempt["durationMillisec"] = attempt["duration"];
+ attempt["duration"] = formatDuration(attempt["duration"]);
+ var app_clone = {"id" : id, "name" : name, "num" : num, "attempts" : [attempt]};
+ array.push(app_clone);
+ }
+ }
+ if(array.length < 20) {
+ $.fn.dataTable.defaults.paging = false;
+ }
+
+ var data = {
+ "uiroot": uiRoot,
+ "applications": array,
+ "hasMultipleAttempts": hasMultipleAttempts,
+ "showCompletedColumn": !requestedIncomplete,
+ }
+
+ $.get("static/historypage-template.html", function(template) {
+ var sibling = historySummary.prev();
+ historySummary.detach();
+ var apps = $(Mustache.render($(template).filter("#history-summary-template").html(),data));
+ var attemptIdColumnName = 'attemptId';
+ var startedColumnName = 'started';
+ var defaultSortColumn = completedColumnName = 'completed';
+ var durationColumnName = 'duration';
+ var conf = {
+ "columns": [
+ {name: 'appId', type: "appid-numeric"},
+ {name: 'appName'},
+ {name: attemptIdColumnName},
+ {name: startedColumnName},
+ {name: completedColumnName},
+ {name: durationColumnName, type: "title-numeric"},
+ {name: 'user'},
+ {name: 'lastUpdated'},
+ {name: 'eventLog'},
+ ],
+ "autoWidth": false,
+ };
+
+ if (hasMultipleAttempts) {
+ conf.rowsGroup = [
+ 'appId:name',
+ 'appName:name'
+ ];
+ } else {
+ conf.columns = removeColumnByName(conf.columns, attemptIdColumnName);
+ }
+
+ var defaultSortColumn = completedColumnName;
+ if (requestedIncomplete) {
+ defaultSortColumn = startedColumnName;
+ conf.columns = removeColumnByName(conf.columns, completedColumnName);
+ }
+ conf.order = [[ getColumnIndex(conf.columns, defaultSortColumn), "desc" ]];
+ conf.columnDefs = [
+ {"searchable": false, "targets": [getColumnIndex(conf.columns, durationColumnName)]}
+ ];
+ historySummary.append(apps);
+ apps.DataTable(conf);
+ sibling.after(historySummary);
+ $('#history-summary [data-toggle="tooltip"]').tooltip();
+ });
+ });
+});
diff --git a/core/src/main/resources/org/apache/spark/ui/static/jquery.blockUI.min.js b/core/src/main/resources/org/apache/spark/ui/static/jquery.blockUI.min.js
new file mode 100644
index 000000000000..1e84b3ec21c4
--- /dev/null
+++ b/core/src/main/resources/org/apache/spark/ui/static/jquery.blockUI.min.js
@@ -0,0 +1,6 @@
+/*
+* jQuery BlockUI; v20131009
+* http://jquery.malsup.com/block/
+* Copyright (c) 2013 M. Alsup; Dual licensed: MIT/GPL
+*/
+(function(){"use strict";function e(e){function o(o,i){var s,h,k=o==window,v=i&&void 0!==i.message?i.message:void 0;if(i=e.extend({},e.blockUI.defaults,i||{}),!i.ignoreIfBlocked||!e(o).data("blockUI.isBlocked")){if(i.overlayCSS=e.extend({},e.blockUI.defaults.overlayCSS,i.overlayCSS||{}),s=e.extend({},e.blockUI.defaults.css,i.css||{}),i.onOverlayClick&&(i.overlayCSS.cursor="pointer"),h=e.extend({},e.blockUI.defaults.themedCSS,i.themedCSS||{}),v=void 0===v?i.message:v,k&&b&&t(window,{fadeOut:0}),v&&"string"!=typeof v&&(v.parentNode||v.jquery)){var y=v.jquery?v[0]:v,m={};e(o).data("blockUI.history",m),m.el=y,m.parent=y.parentNode,m.display=y.style.display,m.position=y.style.position,m.parent&&m.parent.removeChild(y)}e(o).data("blockUI.onUnblock",i.onUnblock);var g,I,w,U,x=i.baseZ;g=r||i.forceIframe?e(''):e(''),I=i.theme?e(''):e(''),i.theme&&k?(U='"):i.theme?(U='"):U=k?'':'',w=e(U),v&&(i.theme?(w.css(h),w.addClass("ui-widget-content")):w.css(s)),i.theme||I.css(i.overlayCSS),I.css("position",k?"fixed":"absolute"),(r||i.forceIframe)&&g.css("opacity",0);var C=[g,I,w],S=k?e("body"):e(o);e.each(C,function(){this.appendTo(S)}),i.theme&&i.draggable&&e.fn.draggable&&w.draggable({handle:".ui-dialog-titlebar",cancel:"li"});var O=f&&(!e.support.boxModel||e("object,embed",k?null:o).length>0);if(u||O){if(k&&i.allowBodyStretch&&e.support.boxModel&&e("html,body").css("height","100%"),(u||!e.support.boxModel)&&!k)var E=d(o,"borderTopWidth"),T=d(o,"borderLeftWidth"),M=E?"(0 - "+E+")":0,B=T?"(0 - "+T+")":0;e.each(C,function(e,o){var t=o[0].style;if(t.position="absolute",2>e)k?t.setExpression("height","Math.max(document.body.scrollHeight, document.body.offsetHeight) - (jQuery.support.boxModel?0:"+i.quirksmodeOffsetHack+') + "px"'):t.setExpression("height",'this.parentNode.offsetHeight + "px"'),k?t.setExpression("width",'jQuery.support.boxModel && document.documentElement.clientWidth || document.body.clientWidth + "px"'):t.setExpression("width",'this.parentNode.offsetWidth + "px"'),B&&t.setExpression("left",B),M&&t.setExpression("top",M);else if(i.centerY)k&&t.setExpression("top",'(document.documentElement.clientHeight || document.body.clientHeight) / 2 - (this.offsetHeight / 2) + (blah = document.documentElement.scrollTop ? document.documentElement.scrollTop : document.body.scrollTop) + "px"'),t.marginTop=0;else if(!i.centerY&&k){var n=i.css&&i.css.top?parseInt(i.css.top,10):0,s="((document.documentElement.scrollTop ? document.documentElement.scrollTop : document.body.scrollTop) + "+n+') + "px"';t.setExpression("top",s)}})}if(v&&(i.theme?w.find(".ui-widget-content").append(v):w.append(v),(v.jquery||v.nodeType)&&e(v).show()),(r||i.forceIframe)&&i.showOverlay&&g.show(),i.fadeIn){var j=i.onBlock?i.onBlock:c,H=i.showOverlay&&!v?j:c,z=v?j:c;i.showOverlay&&I._fadeIn(i.fadeIn,H),v&&w._fadeIn(i.fadeIn,z)}else i.showOverlay&&I.show(),v&&w.show(),i.onBlock&&i.onBlock();if(n(1,o,i),k?(b=w[0],p=e(i.focusableElements,b),i.focusInput&&setTimeout(l,20)):a(w[0],i.centerX,i.centerY),i.timeout){var W=setTimeout(function(){k?e.unblockUI(i):e(o).unblock(i)},i.timeout);e(o).data("blockUI.timeout",W)}}}function t(o,t){var s,l=o==window,a=e(o),d=a.data("blockUI.history"),c=a.data("blockUI.timeout");c&&(clearTimeout(c),a.removeData("blockUI.timeout")),t=e.extend({},e.blockUI.defaults,t||{}),n(0,o,t),null===t.onUnblock&&(t.onUnblock=a.data("blockUI.onUnblock"),a.removeData("blockUI.onUnblock"));var r;r=l?e("body").children().filter(".blockUI").add("body > .blockUI"):a.find(">.blockUI"),t.cursorReset&&(r.length>1&&(r[1].style.cursor=t.cursorReset),r.length>2&&(r[2].style.cursor=t.cursorReset)),l&&(b=p=null),t.fadeOut?(s=r.length,r.stop().fadeOut(t.fadeOut,function(){0===--s&&i(r,d,t,o)})):i(r,d,t,o)}function i(o,t,i,n){var s=e(n);if(!s.data("blockUI.isBlocked")){o.each(function(){this.parentNode&&this.parentNode.removeChild(this)}),t&&t.el&&(t.el.style.display=t.display,t.el.style.position=t.position,t.parent&&t.parent.appendChild(t.el),s.removeData("blockUI.history")),s.data("blockUI.static")&&s.css("position","static"),"function"==typeof i.onUnblock&&i.onUnblock(n,i);var l=e(document.body),a=l.width(),d=l[0].style.width;l.width(a-1).width(a),l[0].style.width=d}}function n(o,t,i){var n=t==window,l=e(t);if((o||(!n||b)&&(n||l.data("blockUI.isBlocked")))&&(l.data("blockUI.isBlocked",o),n&&i.bindEvents&&(!o||i.showOverlay))){var a="mousedown mouseup keydown keypress keyup touchstart touchend touchmove";o?e(document).bind(a,i,s):e(document).unbind(a,s)}}function s(o){if("keydown"===o.type&&o.keyCode&&9==o.keyCode&&b&&o.data.constrainTabKey){var t=p,i=!o.shiftKey&&o.target===t[t.length-1],n=o.shiftKey&&o.target===t[0];if(i||n)return setTimeout(function(){l(n)},10),!1}var s=o.data,a=e(o.target);return a.hasClass("blockOverlay")&&s.onOverlayClick&&s.onOverlayClick(o),a.parents("div."+s.blockMsgClass).length>0?!0:0===a.parents().children().filter("div.blockUI").length}function l(e){if(p){var o=p[e===!0?p.length-1:0];o&&o.focus()}}function a(e,o,t){var i=e.parentNode,n=e.style,s=(i.offsetWidth-e.offsetWidth)/2-d(i,"borderLeftWidth"),l=(i.offsetHeight-e.offsetHeight)/2-d(i,"borderTopWidth");o&&(n.left=s>0?s+"px":"0"),t&&(n.top=l>0?l+"px":"0")}function d(o,t){return parseInt(e.css(o,t),10)||0}e.fn._fadeIn=e.fn.fadeIn;var c=e.noop||function(){},r=/MSIE/.test(navigator.userAgent),u=/MSIE 6.0/.test(navigator.userAgent)&&!/MSIE 8.0/.test(navigator.userAgent);document.documentMode||0;var f=e.isFunction(document.createElement("div").style.setExpression);e.blockUI=function(e){o(window,e)},e.unblockUI=function(e){t(window,e)},e.growlUI=function(o,t,i,n){var s=e('');o&&s.append(""+o+"
"),t&&s.append(""+t+"
"),void 0===i&&(i=3e3);var l=function(o){o=o||{},e.blockUI({message:s,fadeIn:o.fadeIn!==void 0?o.fadeIn:700,fadeOut:o.fadeOut!==void 0?o.fadeOut:1e3,timeout:o.timeout!==void 0?o.timeout:i,centerY:!1,showOverlay:!1,onUnblock:n,css:e.blockUI.defaults.growlCSS})};l(),s.css("opacity"),s.mouseover(function(){l({fadeIn:0,timeout:3e4});var o=e(".blockMsg");o.stop(),o.fadeTo(300,1)}).mouseout(function(){e(".blockMsg").fadeOut(1e3)})},e.fn.block=function(t){if(this[0]===window)return e.blockUI(t),this;var i=e.extend({},e.blockUI.defaults,t||{});return this.each(function(){var o=e(this);i.ignoreIfBlocked&&o.data("blockUI.isBlocked")||o.unblock({fadeOut:0})}),this.each(function(){"static"==e.css(this,"position")&&(this.style.position="relative",e(this).data("blockUI.static",!0)),this.style.zoom=1,o(this,t)})},e.fn.unblock=function(o){return this[0]===window?(e.unblockUI(o),this):this.each(function(){t(this,o)})},e.blockUI.version=2.66,e.blockUI.defaults={message:"Please wait...
",title:null,draggable:!0,theme:!1,css:{padding:0,margin:0,width:"30%",top:"40%",left:"35%",textAlign:"center",color:"#000",border:"3px solid #aaa",backgroundColor:"#fff",cursor:"wait"},themedCSS:{width:"30%",top:"40%",left:"35%"},overlayCSS:{backgroundColor:"#000",opacity:.6,cursor:"wait"},cursorReset:"default",growlCSS:{width:"350px",top:"10px",left:"",right:"10px",border:"none",padding:"5px",opacity:.6,cursor:"default",color:"#fff",backgroundColor:"#000","-webkit-border-radius":"10px","-moz-border-radius":"10px","border-radius":"10px"},iframeSrc:/^https/i.test(window.location.href||"")?"javascript:false":"about:blank",forceIframe:!1,baseZ:1e3,centerX:!0,centerY:!0,allowBodyStretch:!0,bindEvents:!0,constrainTabKey:!0,fadeIn:200,fadeOut:400,timeout:0,showOverlay:!0,focusInput:!0,focusableElements:":input:enabled:visible",onBlock:null,onUnblock:null,onOverlayClick:null,quirksmodeOffsetHack:4,blockMsgClass:"blockMsg",ignoreIfBlocked:!1};var b=null,p=[]}"function"==typeof define&&define.amd&&define.amd.jQuery?define(["jquery"],e):e(jQuery)})();
\ No newline at end of file
diff --git a/core/src/main/resources/org/apache/spark/ui/static/jquery.cookies.2.2.0.min.js b/core/src/main/resources/org/apache/spark/ui/static/jquery.cookies.2.2.0.min.js
new file mode 100644
index 000000000000..bd2dacb4eeeb
--- /dev/null
+++ b/core/src/main/resources/org/apache/spark/ui/static/jquery.cookies.2.2.0.min.js
@@ -0,0 +1,18 @@
+/**
+ * Copyright (c) 2005 - 2010, James Auldridge
+ * All rights reserved.
+ *
+ * Licensed under the BSD, MIT, and GPL (your choice!) Licenses:
+ * http://code.google.com/p/cookies/wiki/License
+ *
+ */
+var jaaulde=window.jaaulde||{};jaaulde.utils=jaaulde.utils||{};jaaulde.utils.cookies=(function(){var resolveOptions,assembleOptionsString,parseCookies,constructor,defaultOptions={expiresAt:null,path:'/',domain:null,secure:false};resolveOptions=function(options){var returnValue,expireDate;if(typeof options!=='object'||options===null){returnValue=defaultOptions;}else
+{returnValue={expiresAt:defaultOptions.expiresAt,path:defaultOptions.path,domain:defaultOptions.domain,secure:defaultOptions.secure};if(typeof options.expiresAt==='object'&&options.expiresAt instanceof Date){returnValue.expiresAt=options.expiresAt;}else if(typeof options.hoursToLive==='number'&&options.hoursToLive!==0){expireDate=new Date();expireDate.setTime(expireDate.getTime()+(options.hoursToLive*60*60*1000));returnValue.expiresAt=expireDate;}if(typeof options.path==='string'&&options.path!==''){returnValue.path=options.path;}if(typeof options.domain==='string'&&options.domain!==''){returnValue.domain=options.domain;}if(options.secure===true){returnValue.secure=options.secure;}}return returnValue;};assembleOptionsString=function(options){options=resolveOptions(options);return((typeof options.expiresAt==='object'&&options.expiresAt instanceof Date?'; expires='+options.expiresAt.toGMTString():'')+'; path='+options.path+(typeof options.domain==='string'?'; domain='+options.domain:'')+(options.secure===true?'; secure':''));};parseCookies=function(){var cookies={},i,pair,name,value,separated=document.cookie.split(';'),unparsedValue;for(i=0;i.sorting_1,table.dataTable.order-column tbody tr>.sorting_2,table.dataTable.order-column tbody tr>.sorting_3,table.dataTable.display tbody tr>.sorting_1,table.dataTable.display tbody tr>.sorting_2,table.dataTable.display tbody tr>.sorting_3{background-color:#f9f9f9}table.dataTable.order-column tbody tr.selected>.sorting_1,table.dataTable.order-column tbody tr.selected>.sorting_2,table.dataTable.order-column tbody tr.selected>.sorting_3,table.dataTable.display tbody tr.selected>.sorting_1,table.dataTable.display tbody tr.selected>.sorting_2,table.dataTable.display tbody tr.selected>.sorting_3{background-color:#acbad4}table.dataTable.display tbody tr.odd>.sorting_1,table.dataTable.order-column.stripe tbody tr.odd>.sorting_1{background-color:#f1f1f1}table.dataTable.display tbody tr.odd>.sorting_2,table.dataTable.order-column.stripe tbody tr.odd>.sorting_2{background-color:#f3f3f3}table.dataTable.display tbody tr.odd>.sorting_3,table.dataTable.order-column.stripe tbody tr.odd>.sorting_3{background-color:#f5f5f5}table.dataTable.display tbody tr.odd.selected>.sorting_1,table.dataTable.order-column.stripe tbody tr.odd.selected>.sorting_1{background-color:#a6b3cd}table.dataTable.display tbody tr.odd.selected>.sorting_2,table.dataTable.order-column.stripe tbody tr.odd.selected>.sorting_2{background-color:#a7b5ce}table.dataTable.display tbody tr.odd.selected>.sorting_3,table.dataTable.order-column.stripe tbody tr.odd.selected>.sorting_3{background-color:#a9b6d0}table.dataTable.display tbody tr.even>.sorting_1,table.dataTable.order-column.stripe tbody tr.even>.sorting_1{background-color:#f9f9f9}table.dataTable.display tbody tr.even>.sorting_2,table.dataTable.order-column.stripe tbody tr.even>.sorting_2{background-color:#fbfbfb}table.dataTable.display tbody tr.even>.sorting_3,table.dataTable.order-column.stripe tbody tr.even>.sorting_3{background-color:#fdfdfd}table.dataTable.display tbody tr.even.selected>.sorting_1,table.dataTable.order-column.stripe tbody tr.even.selected>.sorting_1{background-color:#acbad4}table.dataTable.display tbody tr.even.selected>.sorting_2,table.dataTable.order-column.stripe tbody tr.even.selected>.sorting_2{background-color:#adbbd6}table.dataTable.display tbody tr.even.selected>.sorting_3,table.dataTable.order-column.stripe tbody tr.even.selected>.sorting_3{background-color:#afbdd8}table.dataTable.display tbody tr:hover>.sorting_1,table.dataTable.display tbody tr.odd:hover>.sorting_1,table.dataTable.display tbody tr.even:hover>.sorting_1,table.dataTable.order-column.hover tbody tr:hover>.sorting_1,table.dataTable.order-column.hover tbody tr.odd:hover>.sorting_1,table.dataTable.order-column.hover tbody tr.even:hover>.sorting_1{background-color:#eaeaea}table.dataTable.display tbody tr:hover>.sorting_2,table.dataTable.display tbody tr.odd:hover>.sorting_2,table.dataTable.display tbody tr.even:hover>.sorting_2,table.dataTable.order-column.hover tbody tr:hover>.sorting_2,table.dataTable.order-column.hover tbody tr.odd:hover>.sorting_2,table.dataTable.order-column.hover tbody tr.even:hover>.sorting_2{background-color:#ebebeb}table.dataTable.display tbody tr:hover>.sorting_3,table.dataTable.display tbody tr.odd:hover>.sorting_3,table.dataTable.display tbody tr.even:hover>.sorting_3,table.dataTable.order-column.hover tbody tr:hover>.sorting_3,table.dataTable.order-column.hover tbody tr.odd:hover>.sorting_3,table.dataTable.order-column.hover tbody tr.even:hover>.sorting_3{background-color:#eee}table.dataTable.display tbody tr:hover.selected>.sorting_1,table.dataTable.display tbody tr.odd:hover.selected>.sorting_1,table.dataTable.display tbody tr.even:hover.selected>.sorting_1,table.dataTable.order-column.hover tbody tr:hover.selected>.sorting_1,table.dataTable.order-column.hover tbody tr.odd:hover.selected>.sorting_1,table.dataTable.order-column.hover tbody tr.even:hover.selected>.sorting_1{background-color:#a1aec7}table.dataTable.display tbody tr:hover.selected>.sorting_2,table.dataTable.display tbody tr.odd:hover.selected>.sorting_2,table.dataTable.display tbody tr.even:hover.selected>.sorting_2,table.dataTable.order-column.hover tbody tr:hover.selected>.sorting_2,table.dataTable.order-column.hover tbody tr.odd:hover.selected>.sorting_2,table.dataTable.order-column.hover tbody tr.even:hover.selected>.sorting_2{background-color:#a2afc8}table.dataTable.display tbody tr:hover.selected>.sorting_3,table.dataTable.display tbody tr.odd:hover.selected>.sorting_3,table.dataTable.display tbody tr.even:hover.selected>.sorting_3,table.dataTable.order-column.hover tbody tr:hover.selected>.sorting_3,table.dataTable.order-column.hover tbody tr.odd:hover.selected>.sorting_3,table.dataTable.order-column.hover tbody tr.even:hover.selected>.sorting_3{background-color:#a4b2cb}table.dataTable.no-footer{border-bottom:1px solid #111}table.dataTable.nowrap th,table.dataTable.nowrap td{white-space:nowrap}table.dataTable.compact thead th,table.dataTable.compact thead td{padding:5px 9px}table.dataTable.compact tfoot th,table.dataTable.compact tfoot td{padding:5px 9px 3px 9px}table.dataTable.compact tbody th,table.dataTable.compact tbody td{padding:4px 5px}table.dataTable th.dt-left,table.dataTable td.dt-left{text-align:left}table.dataTable th.dt-center,table.dataTable td.dt-center,table.dataTable td.dataTables_empty{text-align:center}table.dataTable th.dt-right,table.dataTable td.dt-right{text-align:right}table.dataTable th.dt-justify,table.dataTable td.dt-justify{text-align:justify}table.dataTable th.dt-nowrap,table.dataTable td.dt-nowrap{white-space:nowrap}table.dataTable thead th.dt-head-left,table.dataTable thead td.dt-head-left,table.dataTable tfoot th.dt-head-left,table.dataTable tfoot td.dt-head-left{text-align:left}table.dataTable thead th.dt-head-center,table.dataTable thead td.dt-head-center,table.dataTable tfoot th.dt-head-center,table.dataTable tfoot td.dt-head-center{text-align:center}table.dataTable thead th.dt-head-right,table.dataTable thead td.dt-head-right,table.dataTable tfoot th.dt-head-right,table.dataTable tfoot td.dt-head-right{text-align:right}table.dataTable thead th.dt-head-justify,table.dataTable thead td.dt-head-justify,table.dataTable tfoot th.dt-head-justify,table.dataTable tfoot td.dt-head-justify{text-align:justify}table.dataTable thead th.dt-head-nowrap,table.dataTable thead td.dt-head-nowrap,table.dataTable tfoot th.dt-head-nowrap,table.dataTable tfoot td.dt-head-nowrap{white-space:nowrap}table.dataTable tbody th.dt-body-left,table.dataTable tbody td.dt-body-left{text-align:left}table.dataTable tbody th.dt-body-center,table.dataTable tbody td.dt-body-center{text-align:center}table.dataTable tbody th.dt-body-right,table.dataTable tbody td.dt-body-right{text-align:right}table.dataTable tbody th.dt-body-justify,table.dataTable tbody td.dt-body-justify{text-align:justify}table.dataTable tbody th.dt-body-nowrap,table.dataTable tbody td.dt-body-nowrap{white-space:nowrap}table.dataTable,table.dataTable th,table.dataTable td{-webkit-box-sizing:content-box;-moz-box-sizing:content-box;box-sizing:content-box}.dataTables_wrapper{position:relative;clear:both;*zoom:1;zoom:1}.dataTables_wrapper .dataTables_length{float:left}.dataTables_wrapper .dataTables_filter{float:right;text-align:right}.dataTables_wrapper .dataTables_filter input{margin-left:0.5em}.dataTables_wrapper .dataTables_info{clear:both;float:left;padding-top:0.755em}.dataTables_wrapper .dataTables_paginate{float:right;text-align:right;padding-top:0.25em}.dataTables_wrapper .dataTables_paginate .paginate_button{box-sizing:border-box;display:inline-block;min-width:1.5em;padding:0.5em 1em;margin-left:2px;text-align:center;text-decoration:none !important;cursor:pointer;*cursor:hand;color:#333 !important;border:1px solid transparent}.dataTables_wrapper .dataTables_paginate .paginate_button.current,.dataTables_wrapper .dataTables_paginate .paginate_button.current:hover{color:#333 !important;border:1px solid #cacaca;background-color:#fff;background:-webkit-gradient(linear, left top, left bottom, color-stop(0%, #fff), color-stop(100%, #dcdcdc));background:-webkit-linear-gradient(top, #fff 0%, #dcdcdc 100%);background:-moz-linear-gradient(top, #fff 0%, #dcdcdc 100%);background:-ms-linear-gradient(top, #fff 0%, #dcdcdc 100%);background:-o-linear-gradient(top, #fff 0%, #dcdcdc 100%);background:linear-gradient(to bottom, #fff 0%, #dcdcdc 100%)}.dataTables_wrapper .dataTables_paginate .paginate_button.disabled,.dataTables_wrapper .dataTables_paginate .paginate_button.disabled:hover,.dataTables_wrapper .dataTables_paginate .paginate_button.disabled:active{cursor:default;color:#666 !important;border:1px solid transparent;background:transparent;box-shadow:none}.dataTables_wrapper .dataTables_paginate .paginate_button:hover{color:white !important;border:1px solid #111;background-color:#585858;background:-webkit-gradient(linear, left top, left bottom, color-stop(0%, #585858), color-stop(100%, #111));background:-webkit-linear-gradient(top, #585858 0%, #111 100%);background:-moz-linear-gradient(top, #585858 0%, #111 100%);background:-ms-linear-gradient(top, #585858 0%, #111 100%);background:-o-linear-gradient(top, #585858 0%, #111 100%);background:linear-gradient(to bottom, #585858 0%, #111 100%)}.dataTables_wrapper .dataTables_paginate .paginate_button:active{outline:none;background-color:#2b2b2b;background:-webkit-gradient(linear, left top, left bottom, color-stop(0%, #2b2b2b), color-stop(100%, #0c0c0c));background:-webkit-linear-gradient(top, #2b2b2b 0%, #0c0c0c 100%);background:-moz-linear-gradient(top, #2b2b2b 0%, #0c0c0c 100%);background:-ms-linear-gradient(top, #2b2b2b 0%, #0c0c0c 100%);background:-o-linear-gradient(top, #2b2b2b 0%, #0c0c0c 100%);background:linear-gradient(to bottom, #2b2b2b 0%, #0c0c0c 100%);box-shadow:inset 0 0 3px #111}.dataTables_wrapper .dataTables_processing{position:absolute;top:50%;left:50%;width:100%;height:40px;margin-left:-50%;margin-top:-25px;padding-top:20px;text-align:center;font-size:1.2em;background-color:white;background:-webkit-gradient(linear, left top, right top, color-stop(0%, rgba(255,255,255,0)), color-stop(25%, rgba(255,255,255,0.9)), color-stop(75%, rgba(255,255,255,0.9)), color-stop(100%, rgba(255,255,255,0)));background:-webkit-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);background:-moz-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);background:-ms-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);background:-o-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);background:linear-gradient(to right, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%)}.dataTables_wrapper .dataTables_length,.dataTables_wrapper .dataTables_filter,.dataTables_wrapper .dataTables_info,.dataTables_wrapper .dataTables_processing,.dataTables_wrapper .dataTables_paginate{color:#333}.dataTables_wrapper .dataTables_scroll{clear:both}.dataTables_wrapper .dataTables_scroll div.dataTables_scrollBody{*margin-top:-1px;-webkit-overflow-scrolling:touch}.dataTables_wrapper .dataTables_scroll div.dataTables_scrollBody th>div.dataTables_sizing,.dataTables_wrapper .dataTables_scroll div.dataTables_scrollBody td>div.dataTables_sizing{height:0;overflow:hidden;margin:0 !important;padding:0 !important}.dataTables_wrapper.no-footer .dataTables_scrollBody{border-bottom:1px solid #111}.dataTables_wrapper.no-footer div.dataTables_scrollHead table,.dataTables_wrapper.no-footer div.dataTables_scrollBody table{border-bottom:none}.dataTables_wrapper:after{visibility:hidden;display:block;content:"";clear:both;height:0}@media screen and (max-width: 767px){.dataTables_wrapper .dataTables_info,.dataTables_wrapper .dataTables_paginate{float:none;text-align:center}.dataTables_wrapper .dataTables_paginate{margin-top:0.5em}}@media screen and (max-width: 640px){.dataTables_wrapper .dataTables_length,.dataTables_wrapper .dataTables_filter{float:none;text-align:center}.dataTables_wrapper .dataTables_filter{margin-top:0.5em}}
diff --git a/core/src/main/resources/org/apache/spark/ui/static/jquery.dataTables.1.10.4.min.js b/core/src/main/resources/org/apache/spark/ui/static/jquery.dataTables.1.10.4.min.js
new file mode 100644
index 000000000000..8885017c35d0
--- /dev/null
+++ b/core/src/main/resources/org/apache/spark/ui/static/jquery.dataTables.1.10.4.min.js
@@ -0,0 +1,157 @@
+/*! DataTables 1.10.4
+ * ©2008-2014 SpryMedia Ltd - datatables.net/license
+ */
+(function(Da,P,l){var O=function(g){function V(a){var b,c,e={};g.each(a,function(d){if((b=d.match(/^([^A-Z]+?)([A-Z])/))&&-1!=="a aa ai ao as b fn i m o s ".indexOf(b[1]+" "))c=d.replace(b[0],b[2].toLowerCase()),e[c]=d,"o"===b[1]&&V(a[d])});a._hungarianMap=e}function G(a,b,c){a._hungarianMap||V(a);var e;g.each(b,function(d){e=a._hungarianMap[d];if(e!==l&&(c||b[e]===l))"o"===e.charAt(0)?(b[e]||(b[e]={}),g.extend(!0,b[e],b[d]),G(a[e],b[e],c)):b[e]=b[d]})}function O(a){var b=p.defaults.oLanguage,c=a.sZeroRecords;
+!a.sEmptyTable&&(c&&"No data available in table"===b.sEmptyTable)&&D(a,a,"sZeroRecords","sEmptyTable");!a.sLoadingRecords&&(c&&"Loading..."===b.sLoadingRecords)&&D(a,a,"sZeroRecords","sLoadingRecords");a.sInfoThousands&&(a.sThousands=a.sInfoThousands);(a=a.sDecimal)&&cb(a)}function db(a){z(a,"ordering","bSort");z(a,"orderMulti","bSortMulti");z(a,"orderClasses","bSortClasses");z(a,"orderCellsTop","bSortCellsTop");z(a,"order","aaSorting");z(a,"orderFixed","aaSortingFixed");z(a,"paging","bPaginate");
+z(a,"pagingType","sPaginationType");z(a,"pageLength","iDisplayLength");z(a,"searching","bFilter");if(a=a.aoSearchCols)for(var b=0,c=a.length;b").css({position:"absolute",top:0,left:0,height:1,width:1,overflow:"hidden"}).append(g("").css({position:"absolute",top:1,left:1,width:100,
+overflow:"scroll"}).append(g('').css({width:"100%",height:10}))).appendTo("body"),c=b.find(".test");a.bScrollOversize=100===c[0].offsetWidth;a.bScrollbarLeft=1!==c.offset().left;b.remove()}function gb(a,b,c,e,d,f){var h,i=!1;c!==l&&(h=c,i=!0);for(;e!==d;)a.hasOwnProperty(e)&&(h=i?b(h,a[e],e,a):a[e],i=!0,e+=f);return h}function Ea(a,b){var c=p.defaults.column,e=a.aoColumns.length,c=g.extend({},p.models.oColumn,c,{nTh:b?b:P.createElement("th"),sTitle:c.sTitle?c.sTitle:b?b.innerHTML:
+"",aDataSort:c.aDataSort?c.aDataSort:[e],mData:c.mData?c.mData:e,idx:e});a.aoColumns.push(c);c=a.aoPreSearchCols;c[e]=g.extend({},p.models.oSearch,c[e]);ja(a,e,null)}function ja(a,b,c){var b=a.aoColumns[b],e=a.oClasses,d=g(b.nTh);if(!b.sWidthOrig){b.sWidthOrig=d.attr("width")||null;var f=(d.attr("style")||"").match(/width:\s*(\d+[pxem%]+)/);f&&(b.sWidthOrig=f[1])}c!==l&&null!==c&&(eb(c),G(p.defaults.column,c),c.mDataProp!==l&&!c.mData&&(c.mData=c.mDataProp),c.sType&&(b._sManualType=c.sType),c.className&&
+!c.sClass&&(c.sClass=c.className),g.extend(b,c),D(b,c,"sWidth","sWidthOrig"),"number"===typeof c.iDataSort&&(b.aDataSort=[c.iDataSort]),D(b,c,"aDataSort"));var h=b.mData,i=W(h),j=b.mRender?W(b.mRender):null,c=function(a){return"string"===typeof a&&-1!==a.indexOf("@")};b._bAttrSrc=g.isPlainObject(h)&&(c(h.sort)||c(h.type)||c(h.filter));b.fnGetData=function(a,b,c){var e=i(a,b,l,c);return j&&b?j(e,b,a,c):e};b.fnSetData=function(a,b,c){return Q(h)(a,b,c)};"number"!==typeof h&&(a._rowReadObject=!0);a.oFeatures.bSort||
+(b.bSortable=!1,d.addClass(e.sSortableNone));a=-1!==g.inArray("asc",b.asSorting);c=-1!==g.inArray("desc",b.asSorting);!b.bSortable||!a&&!c?(b.sSortingClass=e.sSortableNone,b.sSortingClassJUI=""):a&&!c?(b.sSortingClass=e.sSortableAsc,b.sSortingClassJUI=e.sSortJUIAscAllowed):!a&&c?(b.sSortingClass=e.sSortableDesc,b.sSortingClassJUI=e.sSortJUIDescAllowed):(b.sSortingClass=e.sSortable,b.sSortingClassJUI=e.sSortJUI)}function X(a){if(!1!==a.oFeatures.bAutoWidth){var b=a.aoColumns;Fa(a);for(var c=0,e=b.length;c<
+e;c++)b[c].nTh.style.width=b[c].sWidth}b=a.oScroll;(""!==b.sY||""!==b.sX)&&Y(a);u(a,null,"column-sizing",[a])}function ka(a,b){var c=Z(a,"bVisible");return"number"===typeof c[b]?c[b]:null}function $(a,b){var c=Z(a,"bVisible"),c=g.inArray(b,c);return-1!==c?c:null}function aa(a){return Z(a,"bVisible").length}function Z(a,b){var c=[];g.map(a.aoColumns,function(a,d){a[b]&&c.push(d)});return c}function Ga(a){var b=a.aoColumns,c=a.aoData,e=p.ext.type.detect,d,f,h,i,j,g,m,o,k;d=0;for(f=b.length;do[f])e(m.length+o[f],n);else if("string"===typeof o[f]){i=0;for(j=m.length;ib&&a[d]--; -1!=e&&c===l&&
+a.splice(e,1)}function ca(a,b,c,e){var d=a.aoData[b],f,h=function(c,f){for(;c.childNodes.length;)c.removeChild(c.firstChild);c.innerHTML=v(a,b,f,"display")};if("dom"===c||(!c||"auto"===c)&&"dom"===d.src)d._aData=ma(a,d,e,e===l?l:d._aData).data;else{var i=d.anCells;if(i)if(e!==l)h(i[e],e);else{c=0;for(f=i.length;c ").appendTo(h));b=0;for(c=
+m.length;btr").attr("role","row");g(h).find(">tr>th, >tr>td").addClass(n.sHeaderTH);g(i).find(">tr>th, >tr>td").addClass(n.sFooterTH);if(null!==i){a=a.aoFooter[0];b=0;for(c=a.length;b=a.fnRecordsDisplay()?0:h,a.iInitDisplayStart=-1);var h=a._iDisplayStart,n=a.fnDisplayEnd();if(a.bDeferLoading)a.bDeferLoading=
+!1,a.iDraw++,B(a,!1);else if(i){if(!a.bDestroying&&!jb(a))return}else a.iDraw++;if(0!==j.length){f=i?a.aoData.length:n;for(i=i?0:h;i",{"class":d?
+e[0]:""}).append(g(" ",{valign:"top",colSpan:aa(a),"class":a.oClasses.sRowEmpty}).html(c))[0];u(a,"aoHeaderCallback","header",[g(a.nTHead).children("tr")[0],Ka(a),h,n,j]);u(a,"aoFooterCallback","footer",[g(a.nTFoot).children("tr")[0],Ka(a),h,n,j]);e=g(a.nTBody);e.children().detach();e.append(g(b));u(a,"aoDrawCallback","draw",[a]);a.bSorted=!1;a.bFiltered=!1;a.bDrawing=!1}}function M(a,b){var c=a.oFeatures,e=c.bFilter;c.bSort&&kb(a);e?fa(a,a.oPreviousSearch):a.aiDisplay=a.aiDisplayMaster.slice();
+!0!==b&&(a._iDisplayStart=0);a._drawHold=b;L(a);a._drawHold=!1}function lb(a){var b=a.oClasses,c=g(a.nTable),c=g("").insertBefore(c),e=a.oFeatures,d=g("",{id:a.sTableId+"_wrapper","class":b.sWrapper+(a.nTFoot?"":" "+b.sNoFooter)});a.nHolding=c[0];a.nTableWrapper=d[0];a.nTableReinsertBefore=a.nTable.nextSibling;for(var f=a.sDom.split(""),h,i,j,n,m,o,k=0;k ")[0];n=f[k+1];if("'"==n||'"'==n){m="";for(o=2;f[k+o]!=n;)m+=f[k+o],o++;"H"==m?m=b.sJUIHeader:
+"F"==m&&(m=b.sJUIFooter);-1!=m.indexOf(".")?(n=m.split("."),j.id=n[0].substr(1,n[0].length-1),j.className=n[1]):"#"==m.charAt(0)?j.id=m.substr(1,m.length-1):j.className=m;k+=o}d.append(j);d=g(j)}else if(">"==i)d=d.parent();else if("l"==i&&e.bPaginate&&e.bLengthChange)h=mb(a);else if("f"==i&&e.bFilter)h=nb(a);else if("r"==i&&e.bProcessing)h=ob(a);else if("t"==i)h=pb(a);else if("i"==i&&e.bInfo)h=qb(a);else if("p"==i&&e.bPaginate)h=rb(a);else if(0!==p.ext.feature.length){j=p.ext.feature;o=0;for(n=j.length;o<
+n;o++)if(i==j[o].cFeature){h=j[o].fnInit(a);break}}h&&(j=a.aanFeatures,j[i]||(j[i]=[]),j[i].push(h),d.append(h))}c.replaceWith(d)}function da(a,b){var c=g(b).children("tr"),e,d,f,h,i,j,n,m,o,k;a.splice(0,a.length);f=0;for(j=c.length;f ',i=e.sSearch,i=i.match(/_INPUT_/)?i.replace("_INPUT_",h):i+h,b=g("",{id:!f.f?c+"_filter":null,"class":b.sFilter}).append(g("").append(i)),f=function(){var b=!this.value?"":this.value;b!=d.sSearch&&(fa(a,{sSearch:b,bRegex:d.bRegex,bSmart:d.bSmart,bCaseInsensitive:d.bCaseInsensitive}),a._iDisplayStart=0,L(a))},h=null!==a.searchDelay?a.searchDelay:
+"ssp"===A(a)?400:0,j=g("input",b).val(d.sSearch).attr("placeholder",e.sSearchPlaceholder).bind("keyup.DT search.DT input.DT paste.DT cut.DT",h?ta(f,h):f).bind("keypress.DT",function(a){if(13==a.keyCode)return!1}).attr("aria-controls",c);g(a.nTable).on("search.dt.DT",function(b,c){if(a===c)try{j[0]!==P.activeElement&&j.val(d.sSearch)}catch(f){}});return b[0]}function fa(a,b,c){var e=a.oPreviousSearch,d=a.aoPreSearchCols,f=function(a){e.sSearch=a.sSearch;e.bRegex=a.bRegex;e.bSmart=a.bSmart;e.bCaseInsensitive=
+a.bCaseInsensitive};Ga(a);if("ssp"!=A(a)){ub(a,b.sSearch,c,b.bEscapeRegex!==l?!b.bEscapeRegex:b.bRegex,b.bSmart,b.bCaseInsensitive);f(b);for(b=0;b=b.length)a.aiDisplay=f.slice();else{if(h||c||d.length>b.length||0!==b.indexOf(d)||a.bSorted)a.aiDisplay=f.slice();b=a.aiDisplay;for(c=b.length-1;0<=c;c--)e.test(a.aoData[b[c]]._sFilterRow)||
+b.splice(c,1)}}function Pa(a,b,c,e){a=b?a:ua(a);c&&(a="^(?=.*?"+g.map(a.match(/"[^"]+"|[^ ]+/g)||"",function(a){if('"'===a.charAt(0))var b=a.match(/^"(.*)"$/),a=b?b[1]:a;return a.replace('"',"")}).join(")(?=.*?")+").*$");return RegExp(a,e?"i":"")}function ua(a){return a.replace(Xb,"\\$1")}function xb(a){var b=a.aoColumns,c,e,d,f,h,i,g,n,m=p.ext.type.search;c=!1;e=0;for(f=a.aoData.length;e",{"class":a.oClasses.sInfo,id:!c?b+"_info":null});c||(a.aoDrawCallback.push({fn:Ab,sName:"information"}),e.attr("role","status").attr("aria-live","polite"),g(a.nTable).attr("aria-describedby",b+"_info"));return e[0]}function Ab(a){var b=a.aanFeatures.i;if(0!==b.length){var c=a.oLanguage,e=a._iDisplayStart+1,d=a.fnDisplayEnd(),f=a.fnRecordsTotal(),h=a.fnRecordsDisplay(),i=h?c.sInfo:c.sInfoEmpty;h!==f&&(i+=" "+c.sInfoFiltered);i+=c.sInfoPostFix;
+i=Bb(a,i);c=c.fnInfoCallback;null!==c&&(i=c.call(a.oInstance,a,e,d,f,h,i));g(b).html(i)}}function Bb(a,b){var c=a.fnFormatNumber,e=a._iDisplayStart+1,d=a._iDisplayLength,f=a.fnRecordsDisplay(),h=-1===d;return b.replace(/_START_/g,c.call(a,e)).replace(/_END_/g,c.call(a,a.fnDisplayEnd())).replace(/_MAX_/g,c.call(a,a.fnRecordsTotal())).replace(/_TOTAL_/g,c.call(a,f)).replace(/_PAGE_/g,c.call(a,h?1:Math.ceil(e/d))).replace(/_PAGES_/g,c.call(a,h?1:Math.ceil(f/d)))}function ga(a){var b,c,e=a.iInitDisplayStart,
+d=a.aoColumns,f;c=a.oFeatures;if(a.bInitialised){lb(a);ib(a);ea(a,a.aoHeader);ea(a,a.aoFooter);B(a,!0);c.bAutoWidth&&Fa(a);b=0;for(c=d.length;b",{name:c+"_length","aria-controls":c,"class":b.sLengthSelect}),h=0,i=f.length;h").addClass(b.sLength);a.aanFeatures.l||(j[0].id=c+"_length");j.children().append(a.oLanguage.sLengthMenu.replace("_MENU_",d[0].outerHTML));g("select",j).val(a._iDisplayLength).bind("change.DT",
+function(){Qa(a,g(this).val());L(a)});g(a.nTable).bind("length.dt.DT",function(b,c,f){a===c&&g("select",j).val(f)});return j[0]}function rb(a){var b=a.sPaginationType,c=p.ext.pager[b],e="function"===typeof c,d=function(a){L(a)},b=g("").addClass(a.oClasses.sPaging+b)[0],f=a.aanFeatures;e||c.fnInit(a,b,d);f.p||(b.id=a.sTableId+"_paginate",a.aoDrawCallback.push({fn:function(a){if(e){var b=a._iDisplayStart,g=a._iDisplayLength,n=a.fnRecordsDisplay(),m=-1===g,b=m?0:Math.ceil(b/g),g=m?1:Math.ceil(n/
+g),n=c(b,g),o,m=0;for(o=f.p.length;mf&&(e=0)):"first"==b?e=0:"previous"==b?(e=0<=d?e-d:0,0>e&&(e=0)):"next"==b?e+d ",{id:!a.aanFeatures.r?a.sTableId+"_processing":null,"class":a.oClasses.sProcessing}).html(a.oLanguage.sProcessing).insertBefore(a.nTable)[0]}function B(a,b){a.oFeatures.bProcessing&&g(a.aanFeatures.r).css("display",b?"block":"none");u(a,null,"processing",[a,b])}function pb(a){var b=g(a.nTable);b.attr("role","grid");var c=a.oScroll;if(""===c.sX&&""===c.sY)return a.nTable;var e=c.sX,d=c.sY,f=a.oClasses,h=b.children("caption"),i=h.length?h[0]._captionSide:null,
+j=g(b[0].cloneNode(!1)),n=g(b[0].cloneNode(!1)),m=b.children("tfoot");c.sX&&"100%"===b.attr("width")&&b.removeAttr("width");m.length||(m=null);c=g("",{"class":f.sScrollWrapper}).append(g("",{"class":f.sScrollHead}).css({overflow:"hidden",position:"relative",border:0,width:e?!e?null:s(e):"100%"}).append(g("",{"class":f.sScrollHeadInner}).css({"box-sizing":"content-box",width:c.sXInner||"100%"}).append(j.removeAttr("id").css("margin-left",0).append("top"===i?h:null).append(b.children("thead"))))).append(g("",
+{"class":f.sScrollBody}).css({overflow:"auto",height:!d?null:s(d),width:!e?null:s(e)}).append(b));m&&c.append(g("",{"class":f.sScrollFoot}).css({overflow:"hidden",border:0,width:e?!e?null:s(e):"100%"}).append(g("",{"class":f.sScrollFootInner}).append(n.removeAttr("id").css("margin-left",0).append("bottom"===i?h:null).append(b.children("tfoot")))));var b=c.children(),o=b[0],f=b[1],k=m?b[2]:null;e&&g(f).scroll(function(){var a=this.scrollLeft;o.scrollLeft=a;m&&(k.scrollLeft=a)});a.nScrollHead=
+o;a.nScrollBody=f;a.nScrollFoot=k;a.aoDrawCallback.push({fn:Y,sName:"scrolling"});return c[0]}function Y(a){var b=a.oScroll,c=b.sX,e=b.sXInner,d=b.sY,f=b.iBarWidth,h=g(a.nScrollHead),i=h[0].style,j=h.children("div"),n=j[0].style,m=j.children("table"),j=a.nScrollBody,o=g(j),k=j.style,l=g(a.nScrollFoot).children("div"),p=l.children("table"),r=g(a.nTHead),q=g(a.nTable),t=q[0],N=t.style,J=a.nTFoot?g(a.nTFoot):null,u=a.oBrowser,w=u.bScrollOversize,y,v,x,K,z,A=[],B=[],C=[],D,E=function(a){a=a.style;a.paddingTop=
+"0";a.paddingBottom="0";a.borderTopWidth="0";a.borderBottomWidth="0";a.height=0};q.children("thead, tfoot").remove();z=r.clone().prependTo(q);y=r.find("tr");x=z.find("tr");z.find("th, td").removeAttr("tabindex");J&&(K=J.clone().prependTo(q),v=J.find("tr"),K=K.find("tr"));c||(k.width="100%",h[0].style.width="100%");g.each(pa(a,z),function(b,c){D=ka(a,b);c.style.width=a.aoColumns[D].sWidth});J&&F(function(a){a.style.width=""},K);b.bCollapse&&""!==d&&(k.height=o[0].offsetHeight+r[0].offsetHeight+"px");
+h=q.outerWidth();if(""===c){if(N.width="100%",w&&(q.find("tbody").height()>j.offsetHeight||"scroll"==o.css("overflow-y")))N.width=s(q.outerWidth()-f)}else""!==e?N.width=s(e):h==o.width()&&o.height()h-f&&(N.width=s(h))):N.width=s(h);h=q.outerWidth();F(E,x);F(function(a){C.push(a.innerHTML);A.push(s(g(a).css("width")))},x);F(function(a,b){a.style.width=A[b]},y);g(x).height(0);J&&(F(E,K),F(function(a){B.push(s(g(a).css("width")))},K),F(function(a,b){a.style.width=
+B[b]},v),g(K).height(0));F(function(a,b){a.innerHTML='";a.style.width=A[b]},x);J&&F(function(a,b){a.innerHTML="";a.style.width=B[b]},K);if(q.outerWidth()j.offsetHeight||"scroll"==o.css("overflow-y")?h+f:h;if(w&&(j.scrollHeight>j.offsetHeight||"scroll"==o.css("overflow-y")))N.width=s(v-f);(""===c||""!==e)&&R(a,1,"Possible column misalignment",6)}else v="100%";k.width=s(v);i.width=s(v);J&&(a.nScrollFoot.style.width=
+s(v));!d&&w&&(k.height=s(t.offsetHeight+f));d&&b.bCollapse&&(k.height=s(d),b=c&&t.offsetWidth>j.offsetWidth?f:0,t.offsetHeightj.clientHeight||"scroll"==o.css("overflow-y");u="padding"+(u.bScrollbarLeft?"Left":"Right");n[u]=m?f+"px":"0px";J&&(p[0].style.width=s(b),l[0].style.width=s(b),l[0].style[u]=m?f+"px":"0px");o.scroll();if((a.bSorted||a.bFiltered)&&!a._drawHold)j.scrollTop=0}function F(a,
+b,c){for(var e=0,d=0,f=b.length,h,g;d "));i.find("tfoot th, tfoot td").css("width","");var p=i.find("tbody tr"),j=pa(a,i.find("thead")[0]);for(k=0;k").css("width",s(a)).appendTo(b||P.body),e=c[0].offsetWidth;c.remove();return e}function Eb(a,b){var c=a.oScroll;if(c.sX||c.sY)c=!c.sX?c.iBarWidth:0,b.style.width=s(g(b).outerWidth()-c)}function Db(a,b){var c=Fb(a,b);if(0>c)return null;
+var e=a.aoData[c];return!e.nTr?g(" ").html(v(a,c,b,"display"))[0]:e.anCells[b]}function Fb(a,b){for(var c,e=-1,d=-1,f=0,h=a.aoData.length;fe&&(e=c.length,d=f);return d}function s(a){return null===a?"0px":"number"==typeof a?0>a?"0px":a+"px":a.match(/\d$/)?a+"px":a}function Gb(){if(!p.__scrollbarWidth){var a=g("").css({width:"100%",height:200,padding:0})[0],b=g("").css({position:"absolute",top:0,left:0,width:200,height:150,padding:0,
+overflow:"hidden",visibility:"hidden"}).append(a).appendTo("body"),c=a.offsetWidth;b.css("overflow","scroll");a=a.offsetWidth;c===a&&(a=b[0].clientWidth);b.remove();p.__scrollbarWidth=c-a}return p.__scrollbarWidth}function T(a){var b,c,e=[],d=a.aoColumns,f,h,i,j;b=a.aaSortingFixed;c=g.isPlainObject(b);var n=[];f=function(a){a.length&&!g.isArray(a[0])?n.push(a):n.push.apply(n,a)};g.isArray(b)&&f(b);c&&b.pre&&f(b.pre);f(a.aaSorting);c&&b.post&&f(b.post);for(a=0;ad?1:0,0!==c)return"asc"===g.dir?c:-c;c=e[a];d=e[b];return cd?1:0}):j.sort(function(a,b){var c,h,g,i,j=n.length,l=f[a]._aSortData,p=f[b]._aSortData;for(g=0;gh?1:0})}a.bSorted=!0}function Ib(a){for(var b,c,e=a.aoColumns,d=T(a),a=a.oLanguage.oAria,f=0,h=e.length;f/g,"");var j=c.nTh;j.removeAttribute("aria-sort");c.bSortable&&(0d?d+1:3));d=0;for(f=e.length;dd?d+1:3))}a.aLastSort=e}function Hb(a,b){var c=a.aoColumns[b],e=p.ext.order[c.sSortDataType],d;e&&(d=e.call(a.oInstance,a,b,$(a,b)));for(var f,h=p.ext.type.order[c.sType+"-pre"],g=0,j=a.aoData.length;g<
+j;g++)if(c=a.aoData[g],c._aSortData||(c._aSortData=[]),!c._aSortData[b]||e)f=e?d[g]:v(a,g,b,"sort"),c._aSortData[b]=h?h(f):f}function xa(a){if(a.oFeatures.bStateSave&&!a.bDestroying){var b={time:+new Date,start:a._iDisplayStart,length:a._iDisplayLength,order:g.extend(!0,[],a.aaSorting),search:yb(a.oPreviousSearch),columns:g.map(a.aoColumns,function(b,e){return{visible:b.bVisible,search:yb(a.aoPreSearchCols[e])}})};u(a,"aoStateSaveParams","stateSaveParams",[a,b]);a.oSavedState=b;a.fnStateSaveCallback.call(a.oInstance,
+a,b)}}function Jb(a){var b,c,e=a.aoColumns;if(a.oFeatures.bStateSave){var d=a.fnStateLoadCallback.call(a.oInstance,a);if(d&&d.time&&(b=u(a,"aoStateLoadParams","stateLoadParams",[a,d]),-1===g.inArray(!1,b)&&(b=a.iStateDuration,!(0=e.length?[0,c[1]]:c)});g.extend(a.oPreviousSearch,
+zb(d.search));b=0;for(c=d.columns.length;b=c&&(b=c-e);b-=b%e;if(-1===e||0>b)b=0;a._iDisplayStart=b}function Oa(a,b){var c=a.renderer,e=p.ext.renderer[b];return g.isPlainObject(c)&&
+c[b]?e[c[b]]||e._:"string"===typeof c?e[c]||e._:e._}function A(a){return a.oFeatures.bServerSide?"ssp":a.ajax||a.sAjaxSource?"ajax":"dom"}function Va(a,b){var c=[],c=Lb.numbers_length,e=Math.floor(c/2);b<=c?c=U(0,b):a<=e?(c=U(0,c-2),c.push("ellipsis"),c.push(b-1)):(a>=b-1-e?c=U(b-(c-2),b):(c=U(a-1,a+2),c.push("ellipsis"),c.push(b-1)),c.splice(0,0,"ellipsis"),c.splice(0,0,0));c.DT_el="span";return c}function cb(a){g.each({num:function(b){return za(b,a)},"num-fmt":function(b){return za(b,a,Wa)},"html-num":function(b){return za(b,
+a,Aa)},"html-num-fmt":function(b){return za(b,a,Aa,Wa)}},function(b,c){w.type.order[b+a+"-pre"]=c;b.match(/^html\-/)&&(w.type.search[b+a]=w.type.search.html)})}function Mb(a){return function(){var b=[ya(this[p.ext.iApiIndex])].concat(Array.prototype.slice.call(arguments));return p.ext.internal[a].apply(this,b)}}var p,w,q,r,t,Xa={},Nb=/[\r\n]/g,Aa=/<.*?>/g,$b=/^[\w\+\-]/,ac=/[\w\+\-]$/,Xb=RegExp("(\\/|\\.|\\*|\\+|\\?|\\||\\(|\\)|\\[|\\]|\\{|\\}|\\\\|\\$|\\^|\\-)","g"),Wa=/[',$\u00a3\u20ac\u00a5%\u2009\u202F]/g,
+H=function(a){return!a||!0===a||"-"===a?!0:!1},Ob=function(a){var b=parseInt(a,10);return!isNaN(b)&&isFinite(a)?b:null},Pb=function(a,b){Xa[b]||(Xa[b]=RegExp(ua(b),"g"));return"string"===typeof a&&"."!==b?a.replace(/\./g,"").replace(Xa[b],"."):a},Ya=function(a,b,c){var e="string"===typeof a;b&&e&&(a=Pb(a,b));c&&e&&(a=a.replace(Wa,""));return H(a)||!isNaN(parseFloat(a))&&isFinite(a)},Qb=function(a,b,c){return H(a)?!0:!(H(a)||"string"===typeof a)?null:Ya(a.replace(Aa,""),b,c)?!0:null},C=function(a,
+b,c){var e=[],d=0,f=a.length;if(c!==l)for(;d")[0],Yb=va.textContent!==l,Zb=/<.*?>/g;p=function(a){this.$=function(a,b){return this.api(!0).$(a,b)};this._=function(a,b){return this.api(!0).rows(a,b).data()};this.api=function(a){return a?new q(ya(this[w.iApiIndex])):new q(this)};this.fnAddData=function(a,b){var c=this.api(!0),e=g.isArray(a)&&(g.isArray(a[0])||g.isPlainObject(a[0]))?
+c.rows.add(a):c.row.add(a);(b===l||b)&&c.draw();return e.flatten().toArray()};this.fnAdjustColumnSizing=function(a){var b=this.api(!0).columns.adjust(),c=b.settings()[0],e=c.oScroll;a===l||a?b.draw(!1):(""!==e.sX||""!==e.sY)&&Y(c)};this.fnClearTable=function(a){var b=this.api(!0).clear();(a===l||a)&&b.draw()};this.fnClose=function(a){this.api(!0).row(a).child.hide()};this.fnDeleteRow=function(a,b,c){var e=this.api(!0),a=e.rows(a),d=a.settings()[0],g=d.aoData[a[0][0]];a.remove();b&&b.call(this,d,g);
+(c===l||c)&&e.draw();return g};this.fnDestroy=function(a){this.api(!0).destroy(a)};this.fnDraw=function(a){this.api(!0).draw(!a)};this.fnFilter=function(a,b,c,e,d,g){d=this.api(!0);null===b||b===l?d.search(a,c,e,g):d.column(b).search(a,c,e,g);d.draw()};this.fnGetData=function(a,b){var c=this.api(!0);if(a!==l){var e=a.nodeName?a.nodeName.toLowerCase():"";return b!==l||"td"==e||"th"==e?c.cell(a,b).data():c.row(a).data()||null}return c.data().toArray()};this.fnGetNodes=function(a){var b=this.api(!0);
+return a!==l?b.row(a).node():b.rows().nodes().flatten().toArray()};this.fnGetPosition=function(a){var b=this.api(!0),c=a.nodeName.toUpperCase();return"TR"==c?b.row(a).index():"TD"==c||"TH"==c?(a=b.cell(a).index(),[a.row,a.columnVisible,a.column]):null};this.fnIsOpen=function(a){return this.api(!0).row(a).child.isShown()};this.fnOpen=function(a,b,c){return this.api(!0).row(a).child(b,c).show().child()[0]};this.fnPageChange=function(a,b){var c=this.api(!0).page(a);(b===l||b)&&c.draw(!1)};this.fnSetColumnVis=
+function(a,b,c){a=this.api(!0).column(a).visible(b);(c===l||c)&&a.columns.adjust().draw()};this.fnSettings=function(){return ya(this[w.iApiIndex])};this.fnSort=function(a){this.api(!0).order(a).draw()};this.fnSortListener=function(a,b,c){this.api(!0).order.listener(a,b,c)};this.fnUpdate=function(a,b,c,e,d){var g=this.api(!0);c===l||null===c?g.row(b).data(a):g.cell(b,c).data(a);(d===l||d)&&g.columns.adjust();(e===l||e)&&g.draw();return 0};this.fnVersionCheck=w.fnVersionCheck;var b=this,c=a===l,e=this.length;
+c&&(a={});this.oApi=this.internal=w.internal;for(var d in p.ext.internal)d&&(this[d]=Mb(d));this.each(function(){var d={},d=1t<"F"ip>'),k.renderer)?
+g.isPlainObject(k.renderer)&&!k.renderer.header&&(k.renderer.header="jqueryui"):k.renderer="jqueryui":g.extend(j,p.ext.classes,d.oClasses);g(this).addClass(j.sTable);if(""!==k.oScroll.sX||""!==k.oScroll.sY)k.oScroll.iBarWidth=Gb();!0===k.oScroll.sX&&(k.oScroll.sX="100%");k.iInitDisplayStart===l&&(k.iInitDisplayStart=d.iDisplayStart,k._iDisplayStart=d.iDisplayStart);null!==d.iDeferLoading&&(k.bDeferLoading=!0,h=g.isArray(d.iDeferLoading),k._iRecordsDisplay=h?d.iDeferLoading[0]:d.iDeferLoading,k._iRecordsTotal=
+h?d.iDeferLoading[1]:d.iDeferLoading);var r=k.oLanguage;g.extend(!0,r,d.oLanguage);""!==r.sUrl&&(g.ajax({dataType:"json",url:r.sUrl,success:function(a){O(a);G(m.oLanguage,a);g.extend(true,r,a);ga(k)},error:function(){ga(k)}}),n=!0);null===d.asStripeClasses&&(k.asStripeClasses=[j.sStripeOdd,j.sStripeEven]);var h=k.asStripeClasses,q=g("tbody tr:eq(0)",this);-1!==g.inArray(!0,g.map(h,function(a){return q.hasClass(a)}))&&(g("tbody tr",this).removeClass(h.join(" ")),k.asDestroyStripes=h.slice());var o=
+[],s,h=this.getElementsByTagName("thead");0!==h.length&&(da(k.aoHeader,h[0]),o=pa(k));if(null===d.aoColumns){s=[];h=0;for(i=o.length;h").appendTo(this));k.nTHead=i[0];i=g(this).children("tbody");0===i.length&&(i=g("").appendTo(this));k.nTBody=i[0];i=g(this).children("tfoot");if(0===i.length&&0 ").appendTo(this);0===i.length||0===i.children().length?g(this).addClass(j.sNoFooter):
+0a?new q(b[a],this[a]):null},filter:function(a){var b=[];if(y.filter)b=y.filter.call(this,a,this);else for(var c=0,e=this.length;c ").addClass(b);g("td",c).addClass(b).html(a)[0].colSpan=aa(e);d.push(c[0])}};if(g.isArray(a)||a instanceof g)for(var h=0,i=a.length;h=0?b:h.length+b];if(typeof a==="function"){var d=Ba(c,f);return g.map(h,function(b,f){return a(f,Vb(c,f,0,0,d),j[f])?f:null})}var k=typeof a==="string"?a.match(cc):"";if(k)switch(k[2]){case "visIdx":case "visible":b=
+parseInt(k[1],10);if(b<0){var l=g.map(h,function(a,b){return a.bVisible?b:null});return[l[l.length+b]]}return[ka(c,b)];case "name":return g.map(i,function(a,b){return a===k[1]?b:null})}else return g(j).filter(a).map(function(){return g.inArray(this,j)}).toArray()})},1);c.selector.cols=a;c.selector.opts=b;return c});t("columns().header()","column().header()",function(){return this.iterator("column",function(a,b){return a.aoColumns[b].nTh},1)});t("columns().footer()","column().footer()",function(){return this.iterator("column",
+function(a,b){return a.aoColumns[b].nTf},1)});t("columns().data()","column().data()",function(){return this.iterator("column-rows",Vb,1)});t("columns().dataSrc()","column().dataSrc()",function(){return this.iterator("column",function(a,b){return a.aoColumns[b].mData},1)});t("columns().cache()","column().cache()",function(a){return this.iterator("column-rows",function(b,c,e,d,f){return ha(b.aoData,f,"search"===a?"_aFilterData":"_aSortData",c)},1)});t("columns().nodes()","column().nodes()",function(){return this.iterator("column-rows",
+function(a,b,c,e,d){return ha(a.aoData,d,"anCells",b)},1)});t("columns().visible()","column().visible()",function(a,b){return this.iterator("column",function(c,e){if(a===l)return c.aoColumns[e].bVisible;var d=c.aoColumns,f=d[e],h=c.aoData,i,j,n;if(a!==l&&f.bVisible!==a){if(a){var m=g.inArray(!0,C(d,"bVisible"),e+1);i=0;for(j=h.length;ie;return!0};p.isDataTable=p.fnIsDataTable=function(a){var b=g(a).get(0),c=!1;g.each(p.settings,function(a,d){if(d.nTable===b||d.nScrollHead===b||d.nScrollFoot===b)c=!0});return c};p.tables=p.fnTables=function(a){return g.map(p.settings,function(b){if(!a||a&&g(b.nTable).is(":visible"))return b.nTable})};p.util={throttle:ta,escapeRegex:ua};
+p.camelToHungarian=G;r("$()",function(a,b){var c=this.rows(b).nodes(),c=g(c);return g([].concat(c.filter(a).toArray(),c.find(a).toArray()))});g.each(["on","one","off"],function(a,b){r(b+"()",function(){var a=Array.prototype.slice.call(arguments);a[0].match(/\.dt\b/)||(a[0]+=".dt");var e=g(this.tables().nodes());e[b].apply(e,a);return this})});r("clear()",function(){return this.iterator("table",function(a){na(a)})});r("settings()",function(){return new q(this.context,this.context)});r("data()",function(){return this.iterator("table",
+function(a){return C(a.aoData,"_aData")}).flatten()});r("destroy()",function(a){a=a||!1;return this.iterator("table",function(b){var c=b.nTableWrapper.parentNode,e=b.oClasses,d=b.nTable,f=b.nTBody,h=b.nTHead,i=b.nTFoot,j=g(d),f=g(f),l=g(b.nTableWrapper),m=g.map(b.aoData,function(a){return a.nTr}),o;b.bDestroying=!0;u(b,"aoDestroyCallback","destroy",[b]);a||(new q(b)).columns().visible(!0);l.unbind(".DT").find(":not(tbody *)").unbind(".DT");g(Da).unbind(".DT-"+b.sInstance);d!=h.parentNode&&(j.children("thead").detach(),
+j.append(h));i&&d!=i.parentNode&&(j.children("tfoot").detach(),j.append(i));j.detach();l.detach();b.aaSorting=[];b.aaSortingFixed=[];wa(b);g(m).removeClass(b.asStripeClasses.join(" "));g("th, td",h).removeClass(e.sSortable+" "+e.sSortableAsc+" "+e.sSortableDesc+" "+e.sSortableNone);b.bJUI&&(g("th span."+e.sSortIcon+", td span."+e.sSortIcon,h).detach(),g("th, td",h).each(function(){var a=g("div."+e.sSortJUIWrapper,this);g(this).append(a.contents());a.detach()}));!a&&c&&c.insertBefore(d,b.nTableReinsertBefore);
+f.children().detach();f.append(m);j.css("width",b.sDestroyWidth).removeClass(e.sTable);(o=b.asDestroyStripes.length)&&f.children().each(function(a){g(this).addClass(b.asDestroyStripes[a%o])});c=g.inArray(b,p.settings);-1!==c&&p.settings.splice(c,1)})});p.version="1.10.4";p.settings=[];p.models={};p.models.oSearch={bCaseInsensitive:!0,sSearch:"",bRegex:!1,bSmart:!0};p.models.oRow={nTr:null,anCells:null,_aData:[],_aSortData:null,_aFilterData:null,_sFilterRow:null,_sRowStripe:"",src:null};p.models.oColumn=
+{idx:null,aDataSort:null,asSorting:null,bSearchable:null,bSortable:null,bVisible:null,_sManualType:null,_bAttrSrc:!1,fnCreatedCell:null,fnGetData:null,fnSetData:null,mData:null,mRender:null,nTh:null,nTf:null,sClass:null,sContentPadding:null,sDefaultContent:null,sName:null,sSortDataType:"std",sSortingClass:null,sSortingClassJUI:null,sTitle:null,sType:null,sWidth:null,sWidthOrig:null};p.defaults={aaData:null,aaSorting:[[0,"asc"]],aaSortingFixed:[],ajax:null,aLengthMenu:[10,25,50,100],aoColumns:null,
+aoColumnDefs:null,aoSearchCols:[],asStripeClasses:null,bAutoWidth:!0,bDeferRender:!1,bDestroy:!1,bFilter:!0,bInfo:!0,bJQueryUI:!1,bLengthChange:!0,bPaginate:!0,bProcessing:!1,bRetrieve:!1,bScrollCollapse:!1,bServerSide:!1,bSort:!0,bSortMulti:!0,bSortCellsTop:!1,bSortClasses:!0,bStateSave:!1,fnCreatedRow:null,fnDrawCallback:null,fnFooterCallback:null,fnFormatNumber:function(a){return a.toString().replace(/\B(?=(\d{3})+(?!\d))/g,this.oLanguage.sThousands)},fnHeaderCallback:null,fnInfoCallback:null,
+fnInitComplete:null,fnPreDrawCallback:null,fnRowCallback:null,fnServerData:null,fnServerParams:null,fnStateLoadCallback:function(a){try{return JSON.parse((-1===a.iStateDuration?sessionStorage:localStorage).getItem("DataTables_"+a.sInstance+"_"+location.pathname))}catch(b){}},fnStateLoadParams:null,fnStateLoaded:null,fnStateSaveCallback:function(a,b){try{(-1===a.iStateDuration?sessionStorage:localStorage).setItem("DataTables_"+a.sInstance+"_"+location.pathname,JSON.stringify(b))}catch(c){}},fnStateSaveParams:null,
+iStateDuration:7200,iDeferLoading:null,iDisplayLength:10,iDisplayStart:0,iTabIndex:0,oClasses:{},oLanguage:{oAria:{sSortAscending:": activate to sort column ascending",sSortDescending:": activate to sort column descending"},oPaginate:{sFirst:"First",sLast:"Last",sNext:"Next",sPrevious:"Previous"},sEmptyTable:"No data available in table",sInfo:"Showing _START_ to _END_ of _TOTAL_ entries",sInfoEmpty:"Showing 0 to 0 of 0 entries",sInfoFiltered:"(filtered from _MAX_ total entries)",sInfoPostFix:"",sDecimal:"",
+sThousands:",",sLengthMenu:"Show _MENU_ entries",sLoadingRecords:"Loading...",sProcessing:"Processing...",sSearch:"Search:",sSearchPlaceholder:"",sUrl:"",sZeroRecords:"No matching records found"},oSearch:g.extend({},p.models.oSearch),sAjaxDataProp:"data",sAjaxSource:null,sDom:"lfrtip",searchDelay:null,sPaginationType:"simple_numbers",sScrollX:"",sScrollXInner:"",sScrollY:"",sServerMethod:"GET",renderer:null};V(p.defaults);p.defaults.column={aDataSort:null,iDataSort:-1,asSorting:["asc","desc"],bSearchable:!0,
+bSortable:!0,bVisible:!0,fnCreatedCell:null,mData:null,mRender:null,sCellType:"td",sClass:"",sContentPadding:"",sDefaultContent:null,sName:"",sSortDataType:"std",sTitle:null,sType:null,sWidth:null};V(p.defaults.column);p.models.oSettings={oFeatures:{bAutoWidth:null,bDeferRender:null,bFilter:null,bInfo:null,bLengthChange:null,bPaginate:null,bProcessing:null,bServerSide:null,bSort:null,bSortMulti:null,bSortClasses:null,bStateSave:null},oScroll:{bCollapse:null,iBarWidth:0,sX:null,sXInner:null,sY:null},
+oLanguage:{fnInfoCallback:null},oBrowser:{bScrollOversize:!1,bScrollbarLeft:!1},ajax:null,aanFeatures:[],aoData:[],aiDisplay:[],aiDisplayMaster:[],aoColumns:[],aoHeader:[],aoFooter:[],oPreviousSearch:{},aoPreSearchCols:[],aaSorting:null,aaSortingFixed:[],asStripeClasses:null,asDestroyStripes:[],sDestroyWidth:0,aoRowCallback:[],aoHeaderCallback:[],aoFooterCallback:[],aoDrawCallback:[],aoRowCreatedCallback:[],aoPreDrawCallback:[],aoInitComplete:[],aoStateSaveParams:[],aoStateLoadParams:[],aoStateLoaded:[],
+sTableId:"",nTable:null,nTHead:null,nTFoot:null,nTBody:null,nTableWrapper:null,bDeferLoading:!1,bInitialised:!1,aoOpenRows:[],sDom:null,searchDelay:null,sPaginationType:"two_button",iStateDuration:0,aoStateSave:[],aoStateLoad:[],oSavedState:null,oLoadedState:null,sAjaxSource:null,sAjaxDataProp:null,bAjaxDataGet:!0,jqXHR:null,json:l,oAjaxData:l,fnServerData:null,aoServerParams:[],sServerMethod:null,fnFormatNumber:null,aLengthMenu:null,iDraw:0,bDrawing:!1,iDrawError:-1,_iDisplayLength:10,_iDisplayStart:0,
+_iRecordsTotal:0,_iRecordsDisplay:0,bJUI:null,oClasses:{},bFiltered:!1,bSorted:!1,bSortCellsTop:null,oInit:null,aoDestroyCallback:[],fnRecordsTotal:function(){return"ssp"==A(this)?1*this._iRecordsTotal:this.aiDisplayMaster.length},fnRecordsDisplay:function(){return"ssp"==A(this)?1*this._iRecordsDisplay:this.aiDisplay.length},fnDisplayEnd:function(){var a=this._iDisplayLength,b=this._iDisplayStart,c=b+a,e=this.aiDisplay.length,d=this.oFeatures,f=d.bPaginate;return d.bServerSide?!1===f||-1===a?b+e:
+Math.min(b+a,this._iRecordsDisplay):!f||c>e||-1===a?e:c},oInstance:null,sInstance:null,iTabIndex:0,nScrollHead:null,nScrollFoot:null,aLastSort:[],oPlugins:{}};p.ext=w={classes:{},errMode:"alert",feature:[],search:[],internal:{},legacy:{ajax:null},pager:{},renderer:{pageButton:{},header:{}},order:{},type:{detect:[],search:{},order:{}},_unique:0,fnVersionCheck:p.fnVersionCheck,iApiIndex:0,oJUIClasses:{},sVersion:p.version};g.extend(w,{afnFiltering:w.search,aTypes:w.type.detect,ofnSearch:w.type.search,
+oSort:w.type.order,afnSortData:w.order,aoFeatures:w.feature,oApi:w.internal,oStdClasses:w.classes,oPagination:w.pager});g.extend(p.ext.classes,{sTable:"dataTable",sNoFooter:"no-footer",sPageButton:"paginate_button",sPageButtonActive:"current",sPageButtonDisabled:"disabled",sStripeOdd:"odd",sStripeEven:"even",sRowEmpty:"dataTables_empty",sWrapper:"dataTables_wrapper",sFilter:"dataTables_filter",sInfo:"dataTables_info",sPaging:"dataTables_paginate paging_",sLength:"dataTables_length",sProcessing:"dataTables_processing",
+sSortAsc:"sorting_asc",sSortDesc:"sorting_desc",sSortable:"sorting",sSortableAsc:"sorting_asc_disabled",sSortableDesc:"sorting_desc_disabled",sSortableNone:"sorting_disabled",sSortColumn:"sorting_",sFilterInput:"",sLengthSelect:"",sScrollWrapper:"dataTables_scroll",sScrollHead:"dataTables_scrollHead",sScrollHeadInner:"dataTables_scrollHeadInner",sScrollBody:"dataTables_scrollBody",sScrollFoot:"dataTables_scrollFoot",sScrollFootInner:"dataTables_scrollFootInner",sHeaderTH:"",sFooterTH:"",sSortJUIAsc:"",
+sSortJUIDesc:"",sSortJUI:"",sSortJUIAscAllowed:"",sSortJUIDescAllowed:"",sSortJUIWrapper:"",sSortIcon:"",sJUIHeader:"",sJUIFooter:""});var Ca="",Ca="",E=Ca+"ui-state-default",ia=Ca+"css_right ui-icon ui-icon-",Wb=Ca+"fg-toolbar ui-toolbar ui-widget-header ui-helper-clearfix";g.extend(p.ext.oJUIClasses,p.ext.classes,{sPageButton:"fg-button ui-button "+E,sPageButtonActive:"ui-state-disabled",sPageButtonDisabled:"ui-state-disabled",sPaging:"dataTables_paginate fg-buttonset ui-buttonset fg-buttonset-multi ui-buttonset-multi paging_",
+sSortAsc:E+" sorting_asc",sSortDesc:E+" sorting_desc",sSortable:E+" sorting",sSortableAsc:E+" sorting_asc_disabled",sSortableDesc:E+" sorting_desc_disabled",sSortableNone:E+" sorting_disabled",sSortJUIAsc:ia+"triangle-1-n",sSortJUIDesc:ia+"triangle-1-s",sSortJUI:ia+"carat-2-n-s",sSortJUIAscAllowed:ia+"carat-1-n",sSortJUIDescAllowed:ia+"carat-1-s",sSortJUIWrapper:"DataTables_sort_wrapper",sSortIcon:"DataTables_sort_icon",sScrollHead:"dataTables_scrollHead "+E,sScrollFoot:"dataTables_scrollFoot "+E,
+sHeaderTH:E,sFooterTH:E,sJUIHeader:Wb+" ui-corner-tl ui-corner-tr",sJUIFooter:Wb+" ui-corner-bl ui-corner-br"});var Lb=p.ext.pager;g.extend(Lb,{simple:function(){return["previous","next"]},full:function(){return["first","previous","next","last"]},simple_numbers:function(a,b){return["previous",Va(a,b),"next"]},full_numbers:function(a,b){return["first","previous",Va(a,b),"next","last"]},_numbers:Va,numbers_length:7});g.extend(!0,p.ext.renderer,{pageButton:{_:function(a,b,c,e,d,f){var h=a.oClasses,i=
+a.oLanguage.oPaginate,j,l,m=0,o=function(b,e){var k,p,r,q,s=function(b){Sa(a,b.data.action,true)};k=0;for(p=e.length;k").appendTo(b);o(r,q)}else{l=j="";switch(q){case "ellipsis":b.append("…");break;case "first":j=i.sFirst;l=q+(d>0?"":" "+h.sPageButtonDisabled);break;case "previous":j=i.sPrevious;l=q+(d>0?"":" "+h.sPageButtonDisabled);break;case "next":j=i.sNext;l=q+(d",{"class":h.sPageButton+" "+l,"aria-controls":a.sTableId,"data-dt-idx":m,tabindex:a.iTabIndex,id:c===0&&typeof q==="string"?a.sTableId+"_"+q:null}).html(j).appendTo(b);Ua(r,{action:q},s);m++}}}};try{var k=g(P.activeElement).data("dt-idx");o(g(b).empty(),e);k!==null&&g(b).find("[data-dt-idx="+k+"]").focus()}catch(p){}}}});g.extend(p.ext.type.detect,[function(a,b){var c=b.oLanguage.sDecimal;
+return Ya(a,c)?"num"+c:null},function(a){if(a&&!(a instanceof Date)&&(!$b.test(a)||!ac.test(a)))return null;var b=Date.parse(a);return null!==b&&!isNaN(b)||H(a)?"date":null},function(a,b){var c=b.oLanguage.sDecimal;return Ya(a,c,!0)?"num-fmt"+c:null},function(a,b){var c=b.oLanguage.sDecimal;return Qb(a,c)?"html-num"+c:null},function(a,b){var c=b.oLanguage.sDecimal;return Qb(a,c,!0)?"html-num-fmt"+c:null},function(a){return H(a)||"string"===typeof a&&-1!==a.indexOf("<")?"html":null}]);g.extend(p.ext.type.search,
+{html:function(a){return H(a)?a:"string"===typeof a?a.replace(Nb," ").replace(Aa,""):""},string:function(a){return H(a)?a:"string"===typeof a?a.replace(Nb," "):a}});var za=function(a,b,c,e){if(0!==a&&(!a||"-"===a))return-Infinity;b&&(a=Pb(a,b));a.replace&&(c&&(a=a.replace(c,"")),e&&(a=a.replace(e,"")));return 1*a};g.extend(w.type.order,{"date-pre":function(a){return Date.parse(a)||0},"html-pre":function(a){return H(a)?"":a.replace?a.replace(/<.*?>/g,"").toLowerCase():a+""},"string-pre":function(a){return H(a)?
+"":"string"===typeof a?a.toLowerCase():!a.toString?"":a.toString()},"string-asc":function(a,b){return ab?1:0},"string-desc":function(a,b){return ab?-1:0}});cb("");g.extend(!0,p.ext.renderer,{header:{_:function(a,b,c,e){g(a.nTable).on("order.dt.DT",function(d,f,h,g){if(a===f){d=c.idx;b.removeClass(c.sSortingClass+" "+e.sSortAsc+" "+e.sSortDesc).addClass(g[d]=="asc"?e.sSortAsc:g[d]=="desc"?e.sSortDesc:c.sSortingClass)}})},jqueryui:function(a,b,c,e){g("").addClass(e.sSortJUIWrapper).append(b.contents()).append(g("").addClass(e.sSortIcon+
+" "+c.sSortingClassJUI)).appendTo(b);g(a.nTable).on("order.dt.DT",function(d,f,g,i){if(a===f){d=c.idx;b.removeClass(e.sSortAsc+" "+e.sSortDesc).addClass(i[d]=="asc"?e.sSortAsc:i[d]=="desc"?e.sSortDesc:c.sSortingClass);b.find("span."+e.sSortIcon).removeClass(e.sSortJUIAsc+" "+e.sSortJUIDesc+" "+e.sSortJUI+" "+e.sSortJUIAscAllowed+" "+e.sSortJUIDescAllowed).addClass(i[d]=="asc"?e.sSortJUIAsc:i[d]=="desc"?e.sSortJUIDesc:c.sSortingClassJUI)}})}}});p.render={number:function(a,b,c,e){return{display:function(d){var f=
+0>d?"-":"",d=Math.abs(parseFloat(d)),g=parseInt(d,10),d=c?b+(d-g).toFixed(c).substring(2):"";return f+(e||"")+g.toString().replace(/\B(?=(\d{3})+(?!\d))/g,a)+d}}}};g.extend(p.ext.internal,{_fnExternApiFunc:Mb,_fnBuildAjax:qa,_fnAjaxUpdate:jb,_fnAjaxParameters:sb,_fnAjaxUpdateDraw:tb,_fnAjaxDataSrc:ra,_fnAddColumn:Ea,_fnColumnOptions:ja,_fnAdjustColumnSizing:X,_fnVisibleToColumnIndex:ka,_fnColumnIndexToVisible:$,_fnVisbleColumns:aa,_fnGetColumns:Z,_fnColumnTypes:Ga,_fnApplyColumnDefs:hb,_fnHungarianMap:V,
+_fnCamelToHungarian:G,_fnLanguageCompat:O,_fnBrowserDetect:fb,_fnAddData:I,_fnAddTr:la,_fnNodeToDataIndex:function(a,b){return b._DT_RowIndex!==l?b._DT_RowIndex:null},_fnNodeToColumnIndex:function(a,b,c){return g.inArray(c,a.aoData[b].anCells)},_fnGetCellData:v,_fnSetCellData:Ha,_fnSplitObjNotation:Ja,_fnGetObjectDataFn:W,_fnSetObjectDataFn:Q,_fnGetDataMaster:Ka,_fnClearTable:na,_fnDeleteIndex:oa,_fnInvalidate:ca,_fnGetRowElements:ma,_fnCreateTr:Ia,_fnBuildHead:ib,_fnDrawHead:ea,_fnDraw:L,_fnReDraw:M,
+_fnAddOptionsHtml:lb,_fnDetectHeader:da,_fnGetUniqueThs:pa,_fnFeatureHtmlFilter:nb,_fnFilterComplete:fa,_fnFilterCustom:wb,_fnFilterColumn:vb,_fnFilter:ub,_fnFilterCreateSearch:Pa,_fnEscapeRegex:ua,_fnFilterData:xb,_fnFeatureHtmlInfo:qb,_fnUpdateInfo:Ab,_fnInfoMacros:Bb,_fnInitialise:ga,_fnInitComplete:sa,_fnLengthChange:Qa,_fnFeatureHtmlLength:mb,_fnFeatureHtmlPaginate:rb,_fnPageChange:Sa,_fnFeatureHtmlProcessing:ob,_fnProcessingDisplay:B,_fnFeatureHtmlTable:pb,_fnScrollDraw:Y,_fnApplyToChildren:F,
+_fnCalculateColumnWidths:Fa,_fnThrottle:ta,_fnConvertToWidth:Cb,_fnScrollingWidthAdjust:Eb,_fnGetWidestNode:Db,_fnGetMaxLenString:Fb,_fnStringToCss:s,_fnScrollBarWidth:Gb,_fnSortFlatten:T,_fnSort:kb,_fnSortAria:Ib,_fnSortListener:Ta,_fnSortAttachListener:Na,_fnSortingClasses:wa,_fnSortData:Hb,_fnSaveState:xa,_fnLoadState:Jb,_fnSettingsFromNode:ya,_fnLog:R,_fnMap:D,_fnBindAction:Ua,_fnCallbackReg:x,_fnCallbackFire:u,_fnLengthOverflow:Ra,_fnRenderer:Oa,_fnDataSource:A,_fnRowAttributes:La,_fnCalculateEnd:function(){}});
+g.fn.dataTable=p;g.fn.dataTableSettings=p.settings;g.fn.dataTableExt=p.ext;g.fn.DataTable=function(a){return g(this).dataTable(a).api()};g.each(p,function(a,b){g.fn.DataTable[a]=b});return g.fn.dataTable};"function"===typeof define&&define.amd?define("datatables",["jquery"],O):"object"===typeof exports?O(require("jquery")):jQuery&&!jQuery.fn.dataTable&&O(jQuery)})(window,document);
diff --git a/core/src/main/resources/org/apache/spark/ui/static/jquery.mustache.js b/core/src/main/resources/org/apache/spark/ui/static/jquery.mustache.js
new file mode 100644
index 000000000000..14925bf93d0f
--- /dev/null
+++ b/core/src/main/resources/org/apache/spark/ui/static/jquery.mustache.js
@@ -0,0 +1,592 @@
+/*
+Shameless port of a shameless port
+@defunkt => @janl => @aq
+
+See http://github.com/defunkt/mustache for more info.
+*/
+
+;(function($) {
+
+/*!
+ * mustache.js - Logic-less {{mustache}} templates with JavaScript
+ * http://github.com/janl/mustache.js
+ */
+
+/*global define: false*/
+
+(function (root, factory) {
+ if (typeof exports === "object" && exports) {
+ factory(exports); // CommonJS
+ } else {
+ var mustache = {};
+ factory(mustache);
+ if (typeof define === "function" && define.amd) {
+ define(mustache); // AMD
+ } else {
+ root.Mustache = mustache; //
+
+
+
+ {providerConfig.map { case (k, v) => - {k}: {v}
}}
+
+ {
+ if (eventLogsUnderProcessCount > 0) {
+ There are {eventLogsUnderProcessCount} event log(s) currently being
+ processed which may result in additional applications getting listed on this page.
+ Refresh the page to view updates.
+ }
+ }
+
+ {
+ if (lastUpdatedTime > 0) {
+ Last updated: {lastUpdatedTime}
+ }
+ }
-
- Showing {actualFirst + 1}-{last + 1} of {allAppsSize}
- {if (requestedIncomplete) "(Incomplete applications)"}
-
- {
- if (actualPage > 1) {
- <
- 1
- }
- }
- {if (actualPage - plusOrMinus > secondPageFromLeft) " ... "}
- {leftSideIndices}
- {actualPage}
- {rightSideIndices}
- {if (actualPage + plusOrMinus < secondPageFromRight) " ... "}
- {
- if (actualPage < pageCount) {
- {pageCount}
- >
- }
- }
-
-
++
- appTable
+ {
+ if (allAppsSize > 0) {
+ ++
+ ++
+ ++
+ ++
+
} else if (requestedIncomplete) {
No incomplete applications found!
+ } else if (eventLogsUnderProcessCount > 0) {
+ No completed applications found!
} else {
- No completed applications found!
++
- Did you specify the correct logging directory?
- Please verify your setting of
- spark.history.fs.logDirectory and whether you have the permissions to
- access it.
It is also possible that your application did not run to
- completion or did not stop the SparkContext.
-
+ No completed applications found!
++ parent.emptyListingHtml
}
- }
-
- {
+ }
+
+
+ {
if (requestedIncomplete) {
"Back to completed applications"
} else {
"Show incomplete applications"
}
- }
-
-
+ }
+
+
- UIUtils.basicSparkPage(content, "History Server")
- }
-
- private val appHeader = Seq(
- "App ID",
- "App Name",
- "Started",
- "Completed",
- "Duration",
- "Spark User",
- "Last Updated")
-
- private val appWithAttemptHeader = Seq(
- "App ID",
- "App Name",
- "Attempt ID",
- "Started",
- "Completed",
- "Duration",
- "Spark User",
- "Last Updated")
-
- private def rangeIndices(
- range: Seq[Int],
- condition: Int => Boolean,
- showIncomplete: Boolean): Seq[Node] = {
- range.filter(condition).map(nextPage =>
- {nextPage} )
- }
-
- private def attemptRow(
- renderAttemptIdColumn: Boolean,
- info: ApplicationHistoryInfo,
- attempt: ApplicationAttemptInfo,
- isFirst: Boolean): Seq[Node] = {
- val uiAddress = UIUtils.prependBaseUri(HistoryServer.getAttemptURI(info.id, attempt.attemptId))
- val startTime = UIUtils.formatDate(attempt.startTime)
- val endTime = if (attempt.endTime > 0) UIUtils.formatDate(attempt.endTime) else "-"
- val duration =
- if (attempt.endTime > 0) {
- UIUtils.formatDuration(attempt.endTime - attempt.startTime)
- } else {
- "-"
- }
- val lastUpdated = UIUtils.formatDate(attempt.lastUpdated)
-
- {
- if (isFirst) {
- if (info.attempts.size > 1 || renderAttemptIdColumn) {
-
- {info.id}
-
- {info.name}
- } else {
- {info.id}
- {info.name}
- }
- } else {
- Nil
- }
- }
- {
- if (renderAttemptIdColumn) {
- if (info.attempts.size > 1 && attempt.attemptId.isDefined) {
- {attempt.attemptId.get}
- } else {
-
- }
- } else {
- Nil
- }
- }
- {startTime}
- {endTime}
-
- {duration}
- {attempt.sparkUser}
- {lastUpdated}
-
- }
-
- private def appRow(info: ApplicationHistoryInfo): Seq[Node] = {
- attemptRow(false, info, info.attempts.head, true)
- }
-
- private def appWithAttemptRow(info: ApplicationHistoryInfo): Seq[Node] = {
- attemptRow(true, info, info.attempts.head, true) ++
- info.attempts.drop(1).flatMap(attemptRow(true, info, _, false))
+ UIUtils.basicSparkPage(content, "History Server", true)
}
- private def makePageLink(linkPage: Int, showIncomplete: Boolean): String = {
- UIUtils.prependBaseUri("/?" + Array(
- "page=" + linkPage,
- "showIncomplete=" + showIncomplete
- ).mkString("&"))
+ private def makePageLink(showIncomplete: Boolean): String = {
+ UIUtils.prependBaseUri("/?" + "showIncomplete=" + showIncomplete)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
index d4f327cc588f..967cf14ad353 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
@@ -21,16 +21,19 @@ import java.util.NoSuchElementException
import java.util.zip.ZipOutputStream
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}
-import com.google.common.cache._
+import scala.util.control.NonFatal
+import scala.xml.Node
+
import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder}
-import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, ApplicationsListResource,
- UIRoot}
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
+import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, ApplicationsListResource, UIRoot}
import org.apache.spark.ui.{SparkUI, UIUtils, WebUI}
import org.apache.spark.ui.JettyUtils._
-import org.apache.spark.util.{ShutdownHookManager, SignalLogger, Utils}
+import org.apache.spark.util.{ShutdownHookManager, SystemClock, Utils}
/**
* A web server that renders SparkUIs of completed applications.
@@ -48,31 +51,20 @@ class HistoryServer(
provider: ApplicationHistoryProvider,
securityManager: SecurityManager,
port: Int)
- extends WebUI(securityManager, port, conf) with Logging with UIRoot {
+ extends WebUI(securityManager, securityManager.getSSLOptions("historyServer"), port, conf)
+ with Logging with UIRoot with ApplicationCacheOperations {
// How many applications to retain
private val retainedApplications = conf.getInt("spark.history.retainedApplications", 50)
- private val appLoader = new CacheLoader[String, SparkUI] {
- override def load(key: String): SparkUI = {
- val parts = key.split("/")
- require(parts.length == 1 || parts.length == 2, s"Invalid app key $key")
- val ui = provider
- .getAppUI(parts(0), if (parts.length > 1) Some(parts(1)) else None)
- .getOrElse(throw new NoSuchElementException(s"no app with key $key"))
- attachSparkUI(ui)
- ui
- }
- }
+ // How many applications the summary ui displays
+ private[history] val maxApplications = conf.get(HISTORY_UI_MAX_APPS);
- private val appCache = CacheBuilder.newBuilder()
- .maximumSize(retainedApplications)
- .removalListener(new RemovalListener[String, SparkUI] {
- override def onRemoval(rm: RemovalNotification[String, SparkUI]): Unit = {
- detachSparkUI(rm.getValue())
- }
- })
- .build(appLoader)
+ // application
+ private val appCache = new ApplicationCache(this, retainedApplications, new SystemClock())
+
+ // and its metrics, for testing as well as monitoring
+ val cacheMetrics = appCache.metrics
private val loaderServlet = new HttpServlet {
protected override def doGet(req: HttpServletRequest, res: HttpServletResponse): Unit = {
@@ -103,7 +95,9 @@ class HistoryServer(
// Note we don't use the UI retrieved from the cache; the cache loader above will register
// the app's UI, and all we need to do is redirect the user to the same URI that was
// requested, and the proper data should be served at that point.
- res.sendRedirect(res.encodeRedirectURL(req.getRequestURI()))
+ // Also, make sure that the redirect url contains the query string present in the request.
+ val requestURI = req.getRequestURI + Option(req.getQueryString).map("?" + _).getOrElse("")
+ res.sendRedirect(res.encodeRedirectURL(requestURI))
}
// SPARK-5983 ensure TRACE is not supported
@@ -113,7 +107,7 @@ class HistoryServer(
}
def getSparkUI(appKey: String): Option[SparkUI] = {
- Option(appCache.get(appKey))
+ appCache.getSparkUI(appKey)
}
initialize()
@@ -146,32 +140,58 @@ class HistoryServer(
override def stop() {
super.stop()
provider.stop()
+ appCache.stop()
}
/** Attach a reconstructed UI to this server. Only valid after bind(). */
- private def attachSparkUI(ui: SparkUI) {
+ override def attachSparkUI(
+ appId: String,
+ attemptId: Option[String],
+ ui: SparkUI,
+ completed: Boolean) {
assert(serverInfo.isDefined, "HistoryServer must be bound before attaching SparkUIs")
ui.getHandlers.foreach(attachHandler)
- addFilters(ui.getHandlers, conf)
}
/** Detach a reconstructed UI from this server. Only valid after bind(). */
- private def detachSparkUI(ui: SparkUI) {
+ override def detachSparkUI(appId: String, attemptId: Option[String], ui: SparkUI): Unit = {
assert(serverInfo.isDefined, "HistoryServer must be bound before detaching SparkUIs")
ui.getHandlers.foreach(detachHandler)
}
+ /**
+ * Get the application UI and whether or not it is completed
+ * @param appId application ID
+ * @param attemptId attempt ID
+ * @return If found, the Spark UI and any history information to be used in the cache
+ */
+ override def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] = {
+ provider.getAppUI(appId, attemptId)
+ }
+
/**
* Returns a list of available applications, in descending order according to their end time.
*
* @return List of all known applications.
*/
- def getApplicationList(): Iterable[ApplicationHistoryInfo] = {
+ def getApplicationList(): Iterator[ApplicationHistoryInfo] = {
provider.getListing()
}
+ def getEventLogsUnderProcess(): Int = {
+ provider.getEventLogsUnderProcess()
+ }
+
+ def getLastUpdatedTime(): Long = {
+ provider.getLastUpdatedTime()
+ }
+
def getApplicationInfoList: Iterator[ApplicationInfo] = {
- getApplicationList().iterator.map(ApplicationsListResource.appHistoryInfoToPublicAppInfo)
+ getApplicationList().map(ApplicationsListResource.appHistoryInfoToPublicAppInfo)
+ }
+
+ def getApplicationInfo(appId: String): Option[ApplicationInfo] = {
+ provider.getApplicationInfo(appId).map(ApplicationsListResource.appHistoryInfoToPublicAppInfo)
}
override def writeEventLogs(
@@ -181,6 +201,13 @@ class HistoryServer(
provider.writeEventLogs(appId, attemptId, zipStream)
}
+ /**
+ * @return html text to display when the application list is empty
+ */
+ def emptyListingHtml(): Seq[Node] = {
+ provider.getEmptyListingHtml()
+ }
+
/**
* Returns the provider configuration to show in the listing page.
*
@@ -188,12 +215,18 @@ class HistoryServer(
*/
def getProviderConfig(): Map[String, String] = provider.getConfig()
+ /**
+ * Load an application UI and attach it to the web server.
+ * @param appId application ID
+ * @param attemptId optional attempt ID
+ * @return true if the application was found and loaded.
+ */
private def loadAppUi(appId: String, attemptId: Option[String]): Boolean = {
try {
- appCache.get(appId + attemptId.map { id => s"/$id" }.getOrElse(""))
+ appCache.get(appId, attemptId)
true
} catch {
- case e: Exception => e.getCause() match {
+ case NonFatal(e) => e.getCause() match {
case nsee: NoSuchElementException =>
false
@@ -202,6 +235,17 @@ class HistoryServer(
}
}
+ /**
+ * String value for diagnostics.
+ * @return a multi-line description of the server state.
+ */
+ override def toString: String = {
+ s"""
+ | History Server;
+ | provider = $provider
+ | cache = $appCache
+ """.stripMargin
+ }
}
/**
@@ -220,11 +264,11 @@ object HistoryServer extends Logging {
val UI_PATH_PREFIX = "/history"
- def main(argStrings: Array[String]) {
- SignalLogger.register(log)
+ def main(argStrings: Array[String]): Unit = {
+ Utils.initDaemon(log)
new HistoryServerArguments(conf, argStrings)
initSecurity()
- val securityManager = new SecurityManager(conf)
+ val securityManager = createSecurityManager(conf)
val providerName = conf.getOption("spark.history.provider")
.getOrElse(classOf[FsHistoryProvider].getName())
@@ -244,6 +288,29 @@ object HistoryServer extends Logging {
while(true) { Thread.sleep(Int.MaxValue) }
}
+ /**
+ * Create a security manager.
+ * This turns off security in the SecurityManager, so that the History Server can start
+ * in a Spark cluster where security is enabled.
+ * @param config configuration for the SecurityManager constructor
+ * @return the security manager for use in constructing the History Server.
+ */
+ private[history] def createSecurityManager(config: SparkConf): SecurityManager = {
+ if (config.getBoolean(SecurityManager.SPARK_AUTH_CONF, false)) {
+ logDebug(s"Clearing ${SecurityManager.SPARK_AUTH_CONF}")
+ config.set(SecurityManager.SPARK_AUTH_CONF, "false")
+ }
+
+ if (config.getBoolean("spark.acls.enable", config.getBoolean("spark.ui.acls.enable", false))) {
+ logInfo("Either spark.acls.enable or spark.ui.acls.enable is configured, clearing it and " +
+ "only using spark.history.ui.acl.enable")
+ config.set("spark.acls.enable", "false")
+ config.set("spark.ui.acls.enable", "false")
+ }
+
+ new SecurityManager(config)
+ }
+
def initSecurity() {
// If we are accessing HDFS and it has security enabled (Kerberos), we have to login
// from a keytab file so that we can access HDFS beyond the kerberos ticket expiration.
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
index d03bab3820bb..080ba12c2f0d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
@@ -17,11 +17,14 @@
package org.apache.spark.deploy.history
-import org.apache.spark.{Logging, SparkConf}
+import scala.annotation.tailrec
+
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils
/**
- * Command-line parser for the master.
+ * Command-line parser for the [[HistoryServer]].
*/
private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String])
extends Logging {
@@ -29,6 +32,7 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin
parse(args.toList)
+ @tailrec
private def parse(args: List[String]): Unit = {
if (args.length == 1) {
setLogDirectory(args.head)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
index ac553b71115d..53564d0e9515 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
@@ -41,7 +41,6 @@ private[spark] class ApplicationInfo(
@transient var coresGranted: Int = _
@transient var endTime: Long = _
@transient var appSource: ApplicationSource = _
- @transient @volatile var appUIUrlAtHistoryServer: Option[String] = None
// A cap on the number of executors this application can have at any given time.
// By default, this is infinite. Only after the first allocation request is issued by the
@@ -65,7 +64,7 @@ private[spark] class ApplicationInfo(
appSource = new ApplicationSource(this)
nextExecutorId = 0
removedExecutors = new ArrayBuffer[ExecutorDesc]
- executorLimit = Integer.MAX_VALUE
+ executorLimit = desc.initialExecutorLimit.getOrElse(Integer.MAX_VALUE)
}
private def newExecutorId(useID: Option[Int] = None): Int = {
@@ -135,11 +134,4 @@ private[spark] class ApplicationInfo(
System.currentTimeMillis() - startTime
}
}
-
- /**
- * Returns the original application UI url unless there is its address at history server
- * is defined
- */
- def curAppUIUrl: String = appUIUrlAtHistoryServer.getOrElse(desc.appUiUrl)
-
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
index 37bfcdfdf477..097728c82157 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
@@ -22,6 +22,4 @@ private[master] object ApplicationState extends Enumeration {
type ApplicationState = Value
val WAITING, RUNNING, FINISHED, FAILED, KILLED, UNKNOWN = Value
-
- val MAX_NUM_RETRY = 10
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala
index b197dbcbfe29..8d5edae0501e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala
@@ -19,7 +19,6 @@ package org.apache.spark.deploy.master
import java.util.Date
-import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.deploy.DriverDescription
import org.apache.spark.util.Utils
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
index 1aa8cd5013b4..f2b5ea7e23ec 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
@@ -21,7 +21,7 @@ import java.io._
import scala.reflect.ClassTag
-import org.apache.spark.Logging
+import org.apache.spark.internal.Logging
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer}
import org.apache.spark.util.Utils
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
index 70f21fbe0de8..52e2854961ed 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
@@ -32,8 +32,8 @@ trait LeaderElectionAgent {
@DeveloperApi
trait LeaderElectable {
- def electedLeader()
- def revokedLeadership()
+ def electedLeader(): Unit
+ def revokedLeadership(): Unit
}
/** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index b25a487806c7..cbc5aae0b334 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -17,33 +17,26 @@
package org.apache.spark.deploy.master
-import java.io.FileNotFoundException
-import java.net.URLEncoder
import java.text.SimpleDateFormat
-import java.util.Date
+import java.util.{Date, Locale}
import java.util.concurrent.{ScheduledFuture, TimeUnit}
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
-import scala.language.postfixOps
import scala.util.Random
-import org.apache.hadoop.fs.Path
-
-import org.apache.spark.rpc._
-import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
+import org.apache.spark.{SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.{ApplicationDescription, DriverDescription,
ExecutorState, SparkHadoopUtil}
import org.apache.spark.deploy.DeployMessages._
-import org.apache.spark.deploy.history.HistoryServer
import org.apache.spark.deploy.master.DriverState.DriverState
import org.apache.spark.deploy.master.MasterMessages._
import org.apache.spark.deploy.master.ui.MasterWebUI
import org.apache.spark.deploy.rest.StandaloneRestServer
+import org.apache.spark.internal.Logging
import org.apache.spark.metrics.MetricsSystem
-import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus}
+import org.apache.spark.rpc._
import org.apache.spark.serializer.{JavaSerializer, Serializer}
-import org.apache.spark.ui.SparkUI
-import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils}
+import org.apache.spark.util.{ThreadUtils, Utils}
private[deploy] class Master(
override val rpcEnv: RpcEnv,
@@ -58,17 +51,19 @@ private[deploy] class Master(
private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
- private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
+ // For application IDs
+ private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US)
private val WORKER_TIMEOUT_MS = conf.getLong("spark.worker.timeout", 60) * 1000
private val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200)
private val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200)
private val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15)
private val RECOVERY_MODE = conf.get("spark.deploy.recoveryMode", "NONE")
+ private val MAX_EXECUTOR_RETRIES = conf.getInt("spark.deploy.maxExecutorRetries", 10)
val workers = new HashSet[WorkerInfo]
val idToApp = new HashMap[String, ApplicationInfo]
- val waitingApps = new ArrayBuffer[ApplicationInfo]
+ private val waitingApps = new ArrayBuffer[ApplicationInfo]
val apps = new HashSet[ApplicationInfo]
private val idToWorker = new HashMap[String, WorkerInfo]
@@ -78,7 +73,6 @@ private[deploy] class Master(
private val addressToApp = new HashMap[RpcAddress, ApplicationInfo]
private val completedApps = new ArrayBuffer[ApplicationInfo]
private var nextAppNumber = 0
- private val appIdToUI = new HashMap[String, SparkUI]
private val drivers = new HashSet[DriverInfo]
private val completedDrivers = new ArrayBuffer[DriverInfo]
@@ -121,6 +115,7 @@ private[deploy] class Master(
// Default maxCores for applications that don't specify it (i.e. pass Int.MaxValue)
private val defaultCores = conf.getInt("spark.deploy.defaultCores", Int.MaxValue)
+ val reverseProxy = conf.getBoolean("spark.ui.reverseProxy", false)
if (defaultCores < 1) {
throw new SparkException("spark.deploy.defaultCores must be positive")
}
@@ -136,6 +131,11 @@ private[deploy] class Master(
webUi = new MasterWebUI(this, webUiPort)
webUi.bind()
masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort
+ if (reverseProxy) {
+ masterWebUiUrl = conf.get("spark.ui.reverseProxyUrl", masterWebUiUrl)
+ logInfo(s"Spark Master is acting as a reverse proxy. Master, Workers and " +
+ s"Applications UIs are available at $masterWebUiUrl")
+ }
checkForWorkerTimeOutTask = forwardMessageThread.scheduleAtFixedRate(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
self.send(CheckForWorkerTimeOut)
@@ -208,7 +208,7 @@ private[deploy] class Master(
}
override def receive: PartialFunction[Any, Unit] = {
- case ElectedLeader => {
+ case ElectedLeader =>
val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv)
state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) {
RecoveryState.ALIVE
@@ -224,16 +224,38 @@ private[deploy] class Master(
}
}, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS)
}
- }
case CompleteRecovery => completeRecovery()
- case RevokedLeadership => {
+ case RevokedLeadership =>
logError("Leadership has been revoked -- master shutting down.")
System.exit(0)
- }
- case RegisterApplication(description, driver) => {
+ case RegisterWorker(
+ id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl, masterAddress) =>
+ logInfo("Registering worker %s:%d with %d cores, %s RAM".format(
+ workerHost, workerPort, cores, Utils.megabytesToString(memory)))
+ if (state == RecoveryState.STANDBY) {
+ workerRef.send(MasterInStandby)
+ } else if (idToWorker.contains(id)) {
+ workerRef.send(RegisterWorkerFailed("Duplicate worker ID"))
+ } else {
+ val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory,
+ workerRef, workerWebUiUrl)
+ if (registerWorker(worker)) {
+ persistenceEngine.addWorker(worker)
+ workerRef.send(RegisteredWorker(self, masterWebUiUrl, masterAddress))
+ schedule()
+ } else {
+ val workerAddress = worker.endpoint.address
+ logWarning("Worker registration failed. Attempted to re-register worker at same " +
+ "address: " + workerAddress)
+ workerRef.send(RegisterWorkerFailed("Attempted to re-register worker at same address: "
+ + workerAddress))
+ }
+ }
+
+ case RegisterApplication(description, driver) =>
// TODO Prevent repeated registrations from some driver
if (state == RecoveryState.STANDBY) {
// ignore, don't send response
@@ -246,16 +268,23 @@ private[deploy] class Master(
driver.send(RegisteredApplication(app.id, self))
schedule()
}
- }
- case ExecutorStateChanged(appId, execId, state, message, exitStatus) => {
+ case ExecutorStateChanged(appId, execId, state, message, exitStatus) =>
val execOption = idToApp.get(appId).flatMap(app => app.executors.get(execId))
execOption match {
- case Some(exec) => {
+ case Some(exec) =>
val appInfo = idToApp(appId)
+ val oldState = exec.state
exec.state = state
- if (state == ExecutorState.RUNNING) { appInfo.resetRetryCount() }
- exec.application.driver.send(ExecutorUpdated(execId, state, message, exitStatus))
+
+ if (state == ExecutorState.RUNNING) {
+ assert(oldState == ExecutorState.LAUNCHING,
+ s"executor $execId state transfer from $oldState to RUNNING is illegal")
+ appInfo.resetRetryCount()
+ }
+
+ exec.application.driver.send(ExecutorUpdated(execId, state, message, exitStatus, false))
+
if (ExecutorState.isFinished(state)) {
// Remove this executor from the worker and app
logInfo(s"Removing executor ${exec.fullId} because it is $state")
@@ -268,35 +297,33 @@ private[deploy] class Master(
val normalExit = exitStatus == Some(0)
// Only retry certain number of times so we don't go into an infinite loop.
- if (!normalExit) {
- if (appInfo.incrementRetryCount() < ApplicationState.MAX_NUM_RETRY) {
- schedule()
- } else {
- val execs = appInfo.executors.values
- if (!execs.exists(_.state == ExecutorState.RUNNING)) {
- logError(s"Application ${appInfo.desc.name} with ID ${appInfo.id} failed " +
- s"${appInfo.retryCount} times; removing it")
- removeApplication(appInfo, ApplicationState.FAILED)
- }
+ // Important note: this code path is not exercised by tests, so be very careful when
+ // changing this `if` condition.
+ if (!normalExit
+ && appInfo.incrementRetryCount() >= MAX_EXECUTOR_RETRIES
+ && MAX_EXECUTOR_RETRIES >= 0) { // < 0 disables this application-killing path
+ val execs = appInfo.executors.values
+ if (!execs.exists(_.state == ExecutorState.RUNNING)) {
+ logError(s"Application ${appInfo.desc.name} with ID ${appInfo.id} failed " +
+ s"${appInfo.retryCount} times; removing it")
+ removeApplication(appInfo, ApplicationState.FAILED)
}
}
}
- }
+ schedule()
case None =>
logWarning(s"Got status update for unknown executor $appId/$execId")
}
- }
- case DriverStateChanged(driverId, state, exception) => {
+ case DriverStateChanged(driverId, state, exception) =>
state match {
case DriverState.ERROR | DriverState.FINISHED | DriverState.KILLED | DriverState.FAILED =>
removeDriver(driverId, state, exception)
case _ =>
throw new Exception(s"Received unexpected state update for driver $driverId: $state")
}
- }
- case Heartbeat(workerId, worker) => {
+ case Heartbeat(workerId, worker) =>
idToWorker.get(workerId) match {
case Some(workerInfo) =>
workerInfo.lastHeartbeat = System.currentTimeMillis()
@@ -310,9 +337,8 @@ private[deploy] class Master(
" This worker was never registered, so ignoring the heartbeat.")
}
}
- }
- case MasterChangeAcknowledged(appId) => {
+ case MasterChangeAcknowledged(appId) =>
idToApp.get(appId) match {
case Some(app) =>
logInfo("Application has been re-registered: " + appId)
@@ -322,9 +348,8 @@ private[deploy] class Master(
}
if (canCompleteRecovery) { completeRecovery() }
- }
- case WorkerSchedulerStateResponse(workerId, executors, driverIds) => {
+ case WorkerSchedulerStateResponse(workerId, executors, driverIds) =>
idToWorker.get(workerId) match {
case Some(worker) =>
logInfo("Worker has been re-registered: " + workerId)
@@ -342,7 +367,7 @@ private[deploy] class Master(
drivers.find(_.id == driverId).foreach { driver =>
driver.worker = Some(worker)
driver.state = DriverState.RUNNING
- worker.drivers(driverId) = driver
+ worker.addDriver(driver)
}
}
case None =>
@@ -350,44 +375,42 @@ private[deploy] class Master(
}
if (canCompleteRecovery) { completeRecovery() }
- }
+
+ case WorkerLatestState(workerId, executors, driverIds) =>
+ idToWorker.get(workerId) match {
+ case Some(worker) =>
+ for (exec <- executors) {
+ val executorMatches = worker.executors.exists {
+ case (_, e) => e.application.id == exec.appId && e.id == exec.execId
+ }
+ if (!executorMatches) {
+ // master doesn't recognize this executor. So just tell worker to kill it.
+ worker.endpoint.send(KillExecutor(masterUrl, exec.appId, exec.execId))
+ }
+ }
+
+ for (driverId <- driverIds) {
+ val driverMatches = worker.drivers.exists { case (id, _) => id == driverId }
+ if (!driverMatches) {
+ // master doesn't recognize this driver. So just tell worker to kill it.
+ worker.endpoint.send(KillDriver(driverId))
+ }
+ }
+ case None =>
+ logWarning("Worker state from unknown worker: " + workerId)
+ }
case UnregisterApplication(applicationId) =>
logInfo(s"Received unregister request from application $applicationId")
idToApp.get(applicationId).foreach(finishApplication)
- case CheckForWorkerTimeOut => {
+ case CheckForWorkerTimeOut =>
timeOutDeadWorkers()
- }
+
}
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
- case RegisterWorker(
- id, workerHost, workerPort, workerRef, cores, memory, workerUiPort, publicAddress) => {
- logInfo("Registering worker %s:%d with %d cores, %s RAM".format(
- workerHost, workerPort, cores, Utils.megabytesToString(memory)))
- if (state == RecoveryState.STANDBY) {
- context.reply(MasterInStandby)
- } else if (idToWorker.contains(id)) {
- context.reply(RegisterWorkerFailed("Duplicate worker ID"))
- } else {
- val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory,
- workerRef, workerUiPort, publicAddress)
- if (registerWorker(worker)) {
- persistenceEngine.addWorker(worker)
- context.reply(RegisteredWorker(self, masterWebUiUrl))
- schedule()
- } else {
- val workerAddress = worker.endpoint.address
- logWarning("Worker registration failed. Attempted to re-register worker at same " +
- "address: " + workerAddress)
- context.reply(RegisterWorkerFailed("Attempted to re-register worker at same address: "
- + workerAddress))
- }
- }
- }
-
- case RequestSubmitDriver(description) => {
+ case RequestSubmitDriver(description) =>
if (state != RecoveryState.ALIVE) {
val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " +
"Can only accept driver submissions in ALIVE state."
@@ -406,9 +429,8 @@ private[deploy] class Master(
context.reply(SubmitDriverResponse(self, true, Some(driver.id),
s"Driver successfully submitted as ${driver.id}"))
}
- }
- case RequestKillDriver(driverId) => {
+ case RequestKillDriver(driverId) =>
if (state != RecoveryState.ALIVE) {
val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " +
s"Can only kill drivers in ALIVE state."
@@ -439,9 +461,8 @@ private[deploy] class Master(
context.reply(KillDriverResponse(self, driverId, success = false, msg))
}
}
- }
- case RequestDriverStatus(driverId) => {
+ case RequestDriverStatus(driverId) =>
if (state != RecoveryState.ALIVE) {
val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " +
"Can only request driver status in ALIVE state."
@@ -456,18 +477,15 @@ private[deploy] class Master(
context.reply(DriverStatusResponse(found = false, None, None, None, None))
}
}
- }
- case RequestMasterState => {
+ case RequestMasterState =>
context.reply(MasterStateResponse(
address.host, address.port, restServerBoundPort,
workers.toArray, apps.toArray, completedApps.toArray,
drivers.toArray, completedDrivers.toArray, state))
- }
- case BoundPortsRequest => {
+ case BoundPortsRequest =>
context.reply(BoundPortsResponse(address.port, webUi.boundPort, restServerBoundPort))
- }
case RequestExecutors(appId, requestedTotal) =>
context.reply(handleRequestExecutors(appId, requestedTotal))
@@ -529,6 +547,9 @@ private[deploy] class Master(
workers.filter(_.state == WorkerState.UNKNOWN).foreach(removeWorker)
apps.filter(_.state == ApplicationState.UNKNOWN).foreach(finishApplication)
+ // Update the state of recovered apps to RUNNING
+ apps.filter(_.state == ApplicationState.WAITING).foreach(_.state = ApplicationState.RUNNING)
+
// Reschedule drivers which were not claimed by any workers
drivers.filter(_.worker.isEmpty).foreach { d =>
logWarning(s"Driver ${d.id} was not found after master recovery")
@@ -683,15 +704,28 @@ private[deploy] class Master(
* every time a new app joins or resource availability changes.
*/
private def schedule(): Unit = {
- if (state != RecoveryState.ALIVE) { return }
+ if (state != RecoveryState.ALIVE) {
+ return
+ }
// Drivers take strict precedence over executors
- val shuffledWorkers = Random.shuffle(workers) // Randomization helps balance drivers
- for (worker <- shuffledWorkers if worker.state == WorkerState.ALIVE) {
- for (driver <- waitingDrivers) {
+ val shuffledAliveWorkers = Random.shuffle(workers.toSeq.filter(_.state == WorkerState.ALIVE))
+ val numWorkersAlive = shuffledAliveWorkers.size
+ var curPos = 0
+ for (driver <- waitingDrivers.toList) { // iterate over a copy of waitingDrivers
+ // We assign workers to each waiting driver in a round-robin fashion. For each driver, we
+ // start from the last worker that was assigned a driver, and continue onwards until we have
+ // explored all alive workers.
+ var launched = false
+ var numWorkersVisited = 0
+ while (numWorkersVisited < numWorkersAlive && !launched) {
+ val worker = shuffledAliveWorkers(curPos)
+ numWorkersVisited += 1
if (worker.memoryFree >= driver.desc.mem && worker.coresFree >= driver.desc.cores) {
launchDriver(worker, driver)
waitingDrivers -= driver
+ launched = true
}
+ curPos = (curPos + 1) % numWorkersAlive
}
}
startExecutorsOnWorkers()
@@ -702,8 +736,8 @@ private[deploy] class Master(
worker.addExecutor(exec)
worker.endpoint.send(LaunchExecutor(masterUrl,
exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory))
- exec.application.driver.send(ExecutorAdded(
- exec.id, worker.id, worker.hostPort, exec.cores, exec.memory))
+ exec.application.driver.send(
+ ExecutorAdded(exec.id, worker.id, worker.hostPort, exec.cores, exec.memory))
}
private def registerWorker(worker: WorkerInfo): Boolean = {
@@ -731,6 +765,9 @@ private[deploy] class Master(
workers += worker
idToWorker(worker.id) = worker
addressToWorker(workerAddress) = worker
+ if (reverseProxy) {
+ webUi.addProxyTargets(worker.id, worker.webUiAddress)
+ }
true
}
@@ -739,10 +776,14 @@ private[deploy] class Master(
worker.setState(WorkerState.DEAD)
idToWorker -= worker.id
addressToWorker -= worker.endpoint.address
+ if (reverseProxy) {
+ webUi.removeProxyTargets(worker.id)
+ }
for (exec <- worker.executors.values) {
logInfo("Telling app of lost executor: " + exec.id)
exec.application.driver.send(ExecutorUpdated(
- exec.id, ExecutorState.LOST, Some("worker lost"), None))
+ exec.id, ExecutorState.LOST, Some("worker lost"), None, workerLost = true))
+ exec.state = ExecutorState.LOST
exec.application.removeExecutor(exec)
}
for (driver <- worker.drivers.values) {
@@ -785,6 +826,9 @@ private[deploy] class Master(
endpointToApp(app.driver) = app
addressToApp(appAddress) = app
waitingApps += app
+ if (reverseProxy) {
+ webUi.addProxyTargets(app.id, app.desc.appUiUrl)
+ }
}
private def finishApplication(app: ApplicationInfo) {
@@ -798,20 +842,19 @@ private[deploy] class Master(
idToApp -= app.id
endpointToApp -= app.driver
addressToApp -= app.driver.address
+ if (reverseProxy) {
+ webUi.removeProxyTargets(app.id)
+ }
if (completedApps.size >= RETAINED_APPLICATIONS) {
val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1)
- completedApps.take(toRemove).foreach( a => {
- appIdToUI.remove(a.id).foreach { ui => webUi.detachSparkUI(ui) }
+ completedApps.take(toRemove).foreach { a =>
applicationMetricsSystem.removeSource(a.appSource)
- })
+ }
completedApps.trimStart(toRemove)
}
completedApps += app // Remember it in our history
waitingApps -= app
- // If application events are logged, use them to rebuild the UI
- rebuildSparkUI(app)
-
for (exec <- app.executors.values) {
killExecutor(exec)
}
@@ -910,77 +953,7 @@ private[deploy] class Master(
exec.state = ExecutorState.KILLED
}
- /**
- * Rebuild a new SparkUI from the given application's event logs.
- * Return the UI if successful, else None
- */
- private[master] def rebuildSparkUI(app: ApplicationInfo): Option[SparkUI] = {
- val appName = app.desc.name
- val notFoundBasePath = HistoryServer.UI_PATH_PREFIX + "/not-found"
- try {
- val eventLogDir = app.desc.eventLogDir
- .getOrElse {
- // Event logging is not enabled for this application
- app.appUIUrlAtHistoryServer = Some(notFoundBasePath)
- return None
- }
-
- val eventLogFilePrefix = EventLoggingListener.getLogPath(
- eventLogDir, app.id, app.desc.eventLogCodec)
- val fs = Utils.getHadoopFileSystem(eventLogDir, hadoopConf)
- val inProgressExists = fs.exists(new Path(eventLogFilePrefix +
- EventLoggingListener.IN_PROGRESS))
-
- if (inProgressExists) {
- // Event logging is enabled for this application, but the application is still in progress
- logWarning(s"Application $appName is still in progress, it may be terminated abnormally.")
- }
-
- val (eventLogFile, status) = if (inProgressExists) {
- (eventLogFilePrefix + EventLoggingListener.IN_PROGRESS, " (in progress)")
- } else {
- (eventLogFilePrefix, " (completed)")
- }
-
- val logInput = EventLoggingListener.openEventLog(new Path(eventLogFile), fs)
- val replayBus = new ReplayListenerBus()
- val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf),
- appName, HistoryServer.UI_PATH_PREFIX + s"/${app.id}", app.startTime)
- val maybeTruncated = eventLogFile.endsWith(EventLoggingListener.IN_PROGRESS)
- try {
- replayBus.replay(logInput, eventLogFile, maybeTruncated)
- } finally {
- logInput.close()
- }
- appIdToUI(app.id) = ui
- webUi.attachSparkUI(ui)
- // Application UI is successfully rebuilt, so link the Master UI to it
- app.appUIUrlAtHistoryServer = Some(ui.basePath)
- Some(ui)
- } catch {
- case fnf: FileNotFoundException =>
- // Event logging is enabled for this application, but no event logs are found
- val title = s"Application history not found (${app.id})"
- var msg = s"No event logs found for application $appName in ${app.desc.eventLogDir.get}."
- logWarning(msg)
- msg += " Did you specify the correct logging directory?"
- msg = URLEncoder.encode(msg, "UTF-8")
- app.appUIUrlAtHistoryServer = Some(notFoundBasePath + s"?msg=$msg&title=$title")
- None
- case e: Exception =>
- // Relay exception message to application UI page
- val title = s"Application history load error (${app.id})"
- val exception = URLEncoder.encode(Utils.exceptionString(e), "UTF-8")
- var msg = s"Exception in replaying log for application $appName!"
- logError(msg, e)
- msg = URLEncoder.encode(msg, "UTF-8")
- app.appUIUrlAtHistoryServer =
- Some(notFoundBasePath + s"?msg=$msg&exception=$exception&title=$title")
- None
- }
- }
-
- /** Generate a new app ID given a app's submission date */
+ /** Generate a new app ID given an app's submission date */
private def newApplicationId(submitDate: Date): String = {
val appId = "app-%s-%04d".format(createDateFormat.format(submitDate), nextAppNumber)
nextAppNumber += 1
@@ -1054,7 +1027,7 @@ private[deploy] object Master extends Logging {
val ENDPOINT_NAME = "Master"
def main(argStrings: Array[String]) {
- SignalLogger.register(log)
+ Utils.initDaemon(log)
val conf = new SparkConf
val args = new MasterArguments(argStrings, conf)
val (rpcEnv, _, _) = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, conf)
@@ -1076,7 +1049,7 @@ private[deploy] object Master extends Logging {
val rpcEnv = RpcEnv.create(SYSTEM_NAME, host, port, conf, securityMgr)
val masterEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME,
new Master(rpcEnv, rpcEnv.address, webUiPort, securityMgr, conf))
- val portsResponse = masterEndpoint.askWithRetry[BoundPortsResponse](BoundPortsRequest)
+ val portsResponse = masterEndpoint.askSync[BoundPortsResponse](BoundPortsRequest)
(rpcEnv, portsResponse.webUIPort, portsResponse.restPort)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
index 44cefbc77f08..c63793c16dce 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
@@ -17,19 +17,27 @@
package org.apache.spark.deploy.master
+import scala.annotation.tailrec
+
import org.apache.spark.SparkConf
+import org.apache.spark.internal.Logging
import org.apache.spark.util.{IntParam, Utils}
/**
* Command-line parser for the master.
*/
-private[master] class MasterArguments(args: Array[String], conf: SparkConf) {
+private[master] class MasterArguments(args: Array[String], conf: SparkConf) extends Logging {
var host = Utils.localHostName()
var port = 7077
var webUiPort = 8080
var propertiesFile: String = null
// Check for settings in environment variables
+ if (System.getenv("SPARK_MASTER_IP") != null) {
+ logWarning("SPARK_MASTER_IP is deprecated, please use SPARK_MASTER_HOST")
+ host = System.getenv("SPARK_MASTER_IP")
+ }
+
if (System.getenv("SPARK_MASTER_HOST") != null) {
host = System.getenv("SPARK_MASTER_HOST")
}
@@ -49,6 +57,7 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) {
webUiPort = conf.get("spark.master.ui.port").toInt
}
+ @tailrec
private def parse(args: List[String]): Unit = args match {
case ("--ip" | "-i") :: value :: tail =>
Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
@@ -75,7 +84,7 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) {
case ("--help") :: tail =>
printUsageAndExit(0)
- case Nil => {}
+ case Nil => // No-op
case _ =>
printUsageAndExit(1)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala
index 66a9ff38678c..fb07c39dd02e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala
@@ -32,7 +32,7 @@ private[spark] class MasterSource(val master: Master) extends Source {
// Gauge for alive worker numbers in cluster
metricRegistry.register(MetricRegistry.name("aliveWorkers"), new Gauge[Int]{
- override def getValue: Int = master.workers.filter(_.state == WorkerState.ALIVE).size
+ override def getValue: Int = master.workers.count(_.state == WorkerState.ALIVE)
})
// Gauge for application numbers in cluster
@@ -42,6 +42,6 @@ private[spark] class MasterSource(val master: Master) extends Source {
// Gauge for waiting application numbers in cluster
metricRegistry.register(MetricRegistry.name("waitingApps"), new Gauge[Int] {
- override def getValue: Int = master.waitingApps.size
+ override def getValue: Int = master.apps.count(_.state == ApplicationState.WAITING)
})
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
index 58a00bceee6a..b30bc821b732 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
@@ -17,11 +17,11 @@
package org.apache.spark.deploy.master
+import scala.reflect.ClassTag
+
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rpc.RpcEnv
-import scala.reflect.ClassTag
-
/**
* Allows Master to persist any state that is necessary in order to recover from a failure.
* The following semantics are required:
@@ -40,12 +40,12 @@ abstract class PersistenceEngine {
* Defines how the object is serialized and persisted. Implementation will
* depend on the store used.
*/
- def persist(name: String, obj: Object)
+ def persist(name: String, obj: Object): Unit
/**
* Defines how the object referred by its name is removed from the store.
*/
- def unpersist(name: String)
+ def unpersist(name: String): Unit
/**
* Gives all objects, matching a prefix. This defines how objects are
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala
index c4c3283fb73f..ffdd635be4f5 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala
@@ -17,8 +17,9 @@
package org.apache.spark.deploy.master
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.internal.Logging
import org.apache.spark.serializer.Serializer
/**
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
index f75196660520..4e20c10fd142 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
@@ -29,8 +29,7 @@ private[spark] class WorkerInfo(
val cores: Int,
val memory: Int,
val endpoint: RpcEndpointRef,
- val webUiPort: Int,
- val publicAddress: String)
+ val webUiAddress: String)
extends Serializable {
Utils.checkHost(host, "Expected hostname")
@@ -98,10 +97,6 @@ private[spark] class WorkerInfo(
coresUsed -= driver.desc.cores
}
- def webUiAddress : String = {
- "http://" + this.publicAddress + ":" + this.webUiPort
- }
-
def setState(state: WorkerState.Value): Unit = {
this.state = state
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
index d317206a614f..1e8dabfbe6b0 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
@@ -17,10 +17,12 @@
package org.apache.spark.deploy.master
-import org.apache.spark.{Logging, SparkConf}
import org.apache.curator.framework.CuratorFramework
-import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch}
+import org.apache.curator.framework.recipes.leader.{LeaderLatch, LeaderLatchListener}
+
+import org.apache.spark.SparkConf
import org.apache.spark.deploy.SparkCuratorUtil
+import org.apache.spark.internal.Logging
private[master] class ZooKeeperLeaderElectionAgent(val masterInstance: LeaderElectable,
conf: SparkConf) extends LeaderLatchListener with LeaderElectionAgent with Logging {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
index 540e802420ce..af850e4871e5 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
@@ -25,8 +25,9 @@ import scala.reflect.ClassTag
import org.apache.curator.framework.CuratorFramework
import org.apache.zookeeper.CreateMode
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.SparkConf
import org.apache.spark.deploy.SparkCuratorUtil
+import org.apache.spark.internal.Logging
import org.apache.spark.serializer.Serializer
@@ -50,7 +51,7 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer
override def read[T: ClassTag](prefix: String): Seq[T] = {
zk.getChildren.forPath(WORKING_DIR).asScala
- .filter(_.startsWith(prefix)).map(deserializeFromFile[T]).flatten
+ .filter(_.startsWith(prefix)).flatMap(deserializeFromFile[T])
}
override def close() {
@@ -69,11 +70,10 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer
try {
Some(serializer.newInstance().deserialize[T](ByteBuffer.wrap(fileData)))
} catch {
- case e: Exception => {
+ case e: Exception =>
logWarning("Exception while reading persisted file, deleting", e)
zk.delete().forPath(WORKING_DIR + "/" + filename)
None
- }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
index f405aa2bdc8b..94ff81c1a68e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
@@ -21,10 +21,10 @@ import javax.servlet.http.HttpServletRequest
import scala.xml.Node
-import org.apache.spark.deploy.ExecutorState
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
+import org.apache.spark.deploy.ExecutorState
import org.apache.spark.deploy.master.ExecutorDesc
-import org.apache.spark.ui.{UIUtils, WebUIPage}
+import org.apache.spark.ui.{ToolTips, UIUtils, WebUIPage}
import org.apache.spark.util.Utils
private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") {
@@ -33,11 +33,11 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")
/** Executor details for a particular application */
def render(request: HttpServletRequest): Seq[Node] = {
- val appId = request.getParameter("appId")
- val state = master.askWithRetry[MasterStateResponse](RequestMasterState)
- val app = state.activeApps.find(_.id == appId).getOrElse({
- state.completedApps.find(_.id == appId).getOrElse(null)
- })
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val appId = UIUtils.stripXSS(request.getParameter("appId"))
+ val state = master.askSync[MasterStateResponse](RequestMasterState)
+ val app = state.activeApps.find(_.id == appId)
+ .getOrElse(state.completedApps.find(_.id == appId).orNull)
if (app == null) {
val msg = No running application with ID {appId}
return UIUtils.basicSparkPage(msg, "Not Found")
@@ -70,13 +70,30 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")
}
}
+
+
+ Executor Limit:
+ {
+ if (app.executorLimit == Int.MaxValue) "Unlimited" else app.executorLimit
+ }
+ ({app.executors.size} granted)
+
+
Executor Memory:
{Utils.megabytesToString(app.desc.memoryPerExecutorMB)}
- Submit Date: {app.submitDate}
+ Submit Date: {UIUtils.formatDate(app.submitDate)}
State: {app.state}
- Application Detail UI
+ {
+ if (!app.isFinished) {
+
+ Application Detail UI
+
+ }
+ }
@@ -97,19 +114,21 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")
}
private def executorRow(executor: ExecutorDesc): Seq[Node] = {
+ val workerUrlRef = UIUtils.makeHref(parent.master.reverseProxy,
+ executor.worker.id, executor.worker.webUiAddress)
{executor.id}
- {executor.worker.id}
+ {executor.worker.id}
{executor.cores}
{executor.memory}
{executor.state}
stdout
+ .format(workerUrlRef, executor.application.id, executor.id)}>stdout
stderr
+ .format(workerUrlRef, executor.application.id, executor.id)}>stderr
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/HistoryNotFoundPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/HistoryNotFoundPage.scala
deleted file mode 100644
index e021f1eef794..000000000000
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/HistoryNotFoundPage.scala
+++ /dev/null
@@ -1,73 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.deploy.master.ui
-
-import java.net.URLDecoder
-import javax.servlet.http.HttpServletRequest
-
-import scala.xml.Node
-
-import org.apache.spark.ui.{UIUtils, WebUIPage}
-
-private[ui] class HistoryNotFoundPage(parent: MasterWebUI)
- extends WebUIPage("history/not-found") {
-
- /**
- * Render a page that conveys failure in loading application history.
- *
- * This accepts 3 HTTP parameters:
- * msg = message to display to the user
- * title = title of the page
- * exception = detailed description of the exception in loading application history (if any)
- *
- * Parameters "msg" and "exception" are assumed to be UTF-8 encoded.
- */
- def render(request: HttpServletRequest): Seq[Node] = {
- val titleParam = request.getParameter("title")
- val msgParam = request.getParameter("msg")
- val exceptionParam = request.getParameter("exception")
-
- // If no parameters are specified, assume the user did not enable event logging
- val defaultTitle = "Event logging is not enabled"
- val defaultContent =
-
-
- No event logs were found for this application! To
- enable event logging,
- set spark.eventLog.enabled to true and
- spark.eventLog.dir to the directory to which your
- event logs are written.
-
-
-
- val title = Option(titleParam).getOrElse(defaultTitle)
- val content = Option(msgParam)
- .map { msg => URLDecoder.decode(msg, "UTF-8") }
- .map { msg =>
-
- {msg}
- ++
- Option(exceptionParam)
- .map { e => URLDecoder.decode(e, "UTF-8") }
- .map { e => {e} }
- .getOrElse(Seq.empty)
- }.getOrElse(defaultContent)
-
- UIUtils.basicSparkPage(content, title)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
index ee539dd1f511..ce71300e9097 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
@@ -23,17 +23,17 @@ import scala.xml.Node
import org.json4s.JValue
+import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, MasterStateResponse, RequestKillDriver, RequestMasterState}
import org.apache.spark.deploy.JsonProtocol
-import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver, MasterStateResponse, RequestMasterState}
import org.apache.spark.deploy.master._
-import org.apache.spark.ui.{WebUIPage, UIUtils}
+import org.apache.spark.ui.{UIUtils, WebUIPage}
import org.apache.spark.util.Utils
private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
private val master = parent.masterEndpointRef
def getMasterState: MasterStateResponse = {
- master.askWithRetry[MasterStateResponse](RequestMasterState)
+ master.askSync[MasterStateResponse](RequestMasterState)
}
override def renderJson(request: HttpServletRequest): JValue = {
@@ -57,8 +57,10 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = {
if (parent.killEnabled &&
parent.master.securityMgr.checkModifyPermissions(request.getRemoteUser)) {
- val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean
- val id = Option(request.getParameter("id"))
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val killFlag =
+ Option(UIUtils.stripXSS(request.getParameter("terminate"))).getOrElse("false").toBoolean
+ val id = Option(UIUtils.stripXSS(request.getParameter("id")))
if (id.isDefined && killFlag) {
action(id.get)
}
@@ -76,7 +78,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
val aliveWorkers = state.workers.filter(_.state == WorkerState.ALIVE)
val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers)
- val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time",
+ val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Executor", "Submitted Time",
"User", "State", "Duration")
val activeApps = state.activeApps.sortBy(_.startTime).reverse
val activeAppsTable = UIUtils.listingTable(appHeaders, appRow, activeApps)
@@ -107,18 +109,18 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
}.getOrElse { Seq.empty }
}
- Alive Workers: {aliveWorkers.size}
+ Alive Workers: {aliveWorkers.length}
Cores in use: {aliveWorkers.map(_.cores).sum} Total,
{aliveWorkers.map(_.coresUsed).sum} Used
Memory in use:
{Utils.megabytesToString(aliveWorkers.map(_.memory).sum)} Total,
{Utils.megabytesToString(aliveWorkers.map(_.memoryUsed).sum)} Used
Applications:
- {state.activeApps.size} Running,
- {state.completedApps.size} Completed
+ {state.activeApps.length} Running,
+ {state.completedApps.length} Completed
Drivers:
- {state.activeDrivers.size} Running,
- {state.completedDrivers.size} Completed
+ {state.activeDrivers.length} Running,
+ {state.completedDrivers.length} Completed
Status: {state.status}
@@ -133,7 +135,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
- Running Applications
+ Running Applications
{activeAppsTable}
@@ -152,7 +154,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
- Completed Applications
+ Completed Applications
{completedAppsTable}
@@ -176,7 +178,15 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
private def workerRow(worker: WorkerInfo): Seq[Node] = {
- {worker.id}
+ {
+ if (worker.isAlive()) {
+
+ {worker.id}
+
+ } else {
+ worker.id
+ }
+ }
{worker.host}:{worker.port}
{worker.state}
@@ -206,7 +216,14 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
{killLink}
- {app.desc.name}
+ {
+ if (app.isFinished) {
+ app.desc.name
+ } else {
+ {app.desc.name}
+ }
+ }
{app.coresGranted}
@@ -237,8 +254,15 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
}
{driver.id} {killLink}
- {driver.submitDate}
- {driver.worker.map(w => {w.id.toString}).getOrElse("None")}
+ {UIUtils.formatDate(driver.submitDate)}
+ {driver.worker.map(w =>
+ if (w.isAlive()) {
+
+ {w.id.toString}
+
+ } else {
+ w.id.toString
+ }).getOrElse("None")}
{driver.state}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
index 6174fc11f83d..8cfd0f682932 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
@@ -17,10 +17,12 @@
package org.apache.spark.deploy.master.ui
-import org.apache.spark.Logging
+import scala.collection.mutable.HashMap
+
+import org.eclipse.jetty.servlet.ServletContextHandler
+
import org.apache.spark.deploy.master.Master
-import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationsListResource, ApplicationInfo,
- UIRoot}
+import org.apache.spark.internal.Logging
import org.apache.spark.ui.{SparkUI, WebUI}
import org.apache.spark.ui.JettyUtils._
@@ -28,14 +30,15 @@ import org.apache.spark.ui.JettyUtils._
* Web UI server for the standalone master.
*/
private[master]
-class MasterWebUI(val master: Master, requestedPort: Int)
- extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging
- with UIRoot {
+class MasterWebUI(
+ val master: Master,
+ requestedPort: Int)
+ extends WebUI(master.securityMgr, master.securityMgr.getSSLOptions("standalone"),
+ requestedPort, master.conf, name = "MasterUI") with Logging {
val masterEndpointRef = master.self
val killEnabled = master.conf.getBoolean("spark.ui.killEnabled", true)
-
- val masterPage = new MasterPage(this)
+ private val proxyHandlers = new HashMap[String, ServletContextHandler]
initialize()
@@ -43,43 +46,23 @@ class MasterWebUI(val master: Master, requestedPort: Int)
def initialize() {
val masterPage = new MasterPage(this)
attachPage(new ApplicationPage(this))
- attachPage(new HistoryNotFoundPage(this))
attachPage(masterPage)
attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static"))
- attachHandler(ApiRootResource.getServletHandler(this))
attachHandler(createRedirectHandler(
"/app/kill", "/", masterPage.handleAppKillRequest, httpMethods = Set("POST")))
attachHandler(createRedirectHandler(
"/driver/kill", "/", masterPage.handleDriverKillRequest, httpMethods = Set("POST")))
}
- /** Attach a reconstructed UI to this Master UI. Only valid after bind(). */
- def attachSparkUI(ui: SparkUI) {
- assert(serverInfo.isDefined, "Master UI must be bound to a server before attaching SparkUIs")
- ui.getHandlers.foreach(attachHandler)
- }
-
- /** Detach a reconstructed UI from this Master UI. Only valid after bind(). */
- def detachSparkUI(ui: SparkUI) {
- assert(serverInfo.isDefined, "Master UI must be bound to a server before detaching SparkUIs")
- ui.getHandlers.foreach(detachHandler)
- }
-
- def getApplicationInfoList: Iterator[ApplicationInfo] = {
- val state = masterPage.getMasterState
- val activeApps = state.activeApps.sortBy(_.startTime).reverse
- val completedApps = state.completedApps.sortBy(_.endTime).reverse
- activeApps.iterator.map { ApplicationsListResource.convertApplicationInfo(_, false) } ++
- completedApps.iterator.map { ApplicationsListResource.convertApplicationInfo(_, true) }
+ def addProxyTargets(id: String, target: String): Unit = {
+ var endTarget = target.stripSuffix("/")
+ val handler = createProxyHandler("/proxy/" + id, endTarget)
+ attachHandler(handler)
+ proxyHandlers(id) = handler
}
- def getSparkUI(appId: String): Option[SparkUI] = {
- val state = masterPage.getMasterState
- val activeApps = state.activeApps.sortBy(_.startTime).reverse
- val completedApps = state.completedApps.sortBy(_.endTime).reverse
- (activeApps ++ completedApps).find { _.id == appId }.flatMap {
- master.rebuildSparkUI
- }
+ def removeProxyTargets(id: String): Unit = {
+ proxyHandlers.remove(id).foreach(detachHandler)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala
deleted file mode 100644
index 5accaf78d0a5..000000000000
--- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala
+++ /dev/null
@@ -1,107 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.deploy.mesos
-
-import org.apache.spark.SparkConf
-import org.apache.spark.util.{IntParam, Utils}
-
-
-private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: SparkConf) {
- var host = Utils.localHostName()
- var port = 7077
- var name = "Spark Cluster"
- var webUiPort = 8081
- var masterUrl: String = _
- var zookeeperUrl: Option[String] = None
- var propertiesFile: String = _
-
- parse(args.toList)
-
- propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile)
-
- private def parse(args: List[String]): Unit = args match {
- case ("--host" | "-h") :: value :: tail =>
- Utils.checkHost(value, "Please use hostname " + value)
- host = value
- parse(tail)
-
- case ("--port" | "-p") :: IntParam(value) :: tail =>
- port = value
- parse(tail)
-
- case ("--webui-port" | "-p") :: IntParam(value) :: tail =>
- webUiPort = value
- parse(tail)
-
- case ("--zk" | "-z") :: value :: tail =>
- zookeeperUrl = Some(value)
- parse(tail)
-
- case ("--master" | "-m") :: value :: tail =>
- if (!value.startsWith("mesos://")) {
- // scalastyle:off println
- System.err.println("Cluster dispatcher only supports mesos (uri begins with mesos://)")
- // scalastyle:on println
- System.exit(1)
- }
- masterUrl = value.stripPrefix("mesos://")
- parse(tail)
-
- case ("--name") :: value :: tail =>
- name = value
- parse(tail)
-
- case ("--properties-file") :: value :: tail =>
- propertiesFile = value
- parse(tail)
-
- case ("--help") :: tail =>
- printUsageAndExit(0)
-
- case Nil => {
- if (masterUrl == null) {
- // scalastyle:off println
- System.err.println("--master is required")
- // scalastyle:on println
- printUsageAndExit(1)
- }
- }
-
- case _ =>
- printUsageAndExit(1)
- }
-
- private def printUsageAndExit(exitCode: Int): Unit = {
- // scalastyle:off println
- System.err.println(
- "Usage: MesosClusterDispatcher [options]\n" +
- "\n" +
- "Options:\n" +
- " -h HOST, --host HOST Hostname to listen on\n" +
- " -p PORT, --port PORT Port to listen on (default: 7077)\n" +
- " --webui-port WEBUI_PORT WebUI Port to listen on (default: 8081)\n" +
- " --name NAME Framework name to show in Mesos UI\n" +
- " -m --master MASTER URI for connecting to Mesos master\n" +
- " -z --zk ZOOKEEPER Comma delimited URLs for connecting to \n" +
- " Zookeeper for persistence\n" +
- " --properties-file FILE Path to a custom Spark properties file.\n" +
- " Default is conf/spark-defaults.conf.")
- // scalastyle:on println
- System.exit(exitCode)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala
deleted file mode 100644
index 12337a940a41..000000000000
--- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala
+++ /dev/null
@@ -1,107 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.deploy.mesos
-
-import java.net.SocketAddress
-
-import scala.collection.mutable
-
-import org.apache.spark.{Logging, SecurityManager, SparkConf}
-import org.apache.spark.deploy.ExternalShuffleService
-import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
-import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler
-import org.apache.spark.network.shuffle.protocol.BlockTransferMessage
-import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver
-import org.apache.spark.network.util.TransportConf
-
-/**
- * An RPC endpoint that receives registration requests from Spark drivers running on Mesos.
- * It detects driver termination and calls the cleanup callback to [[ExternalShuffleService]].
- */
-private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportConf)
- extends ExternalShuffleBlockHandler(transportConf, null) with Logging {
-
- // Stores a map of driver socket addresses to app ids
- private val connectedApps = new mutable.HashMap[SocketAddress, String]
-
- protected override def handleMessage(
- message: BlockTransferMessage,
- client: TransportClient,
- callback: RpcResponseCallback): Unit = {
- message match {
- case RegisterDriverParam(appId) =>
- val address = client.getSocketAddress
- logDebug(s"Received registration request from app $appId (remote address $address).")
- if (connectedApps.contains(address)) {
- val existingAppId = connectedApps(address)
- if (!existingAppId.equals(appId)) {
- logError(s"A new app '$appId' has connected to existing address $address, " +
- s"removing previously registered app '$existingAppId'.")
- applicationRemoved(existingAppId, true)
- }
- }
- connectedApps(address) = appId
- callback.onSuccess(new Array[Byte](0))
- case _ => super.handleMessage(message, client, callback)
- }
- }
-
- /**
- * On connection termination, clean up shuffle files written by the associated application.
- */
- override def connectionTerminated(client: TransportClient): Unit = {
- val address = client.getSocketAddress
- if (connectedApps.contains(address)) {
- val appId = connectedApps(address)
- logInfo(s"Application $appId disconnected (address was $address).")
- applicationRemoved(appId, true /* cleanupLocalDirs */)
- connectedApps.remove(address)
- } else {
- logWarning(s"Unknown $address disconnected.")
- }
- }
-
- /** An extractor object for matching [[RegisterDriver]] message. */
- private object RegisterDriverParam {
- def unapply(r: RegisterDriver): Option[String] = Some(r.getAppId)
- }
-}
-
-/**
- * A wrapper of [[ExternalShuffleService]] that provides an additional endpoint for drivers
- * to associate with. This allows the shuffle service to detect when a driver is terminated
- * and can clean up the associated shuffle files.
- */
-private[mesos] class MesosExternalShuffleService(conf: SparkConf, securityManager: SecurityManager)
- extends ExternalShuffleService(conf, securityManager) {
-
- protected override def newShuffleBlockHandler(
- conf: TransportConf): ExternalShuffleBlockHandler = {
- new MesosExternalShuffleBlockHandler(conf)
- }
-}
-
-private[spark] object MesosExternalShuffleService extends Logging {
-
- def main(args: Array[String]): Unit = {
- ExternalShuffleService.main(args,
- (conf: SparkConf, sm: SecurityManager) => new MesosExternalShuffleService(conf, sm))
- }
-}
-
-
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala
index 957a928bc402..21cb94142b15 100644
--- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala
@@ -19,15 +19,20 @@ package org.apache.spark.deploy.rest
import java.io.{DataOutputStream, FileNotFoundException}
import java.net.{ConnectException, HttpURLConnection, SocketException, URL}
+import java.nio.charset.StandardCharsets
+import java.util.concurrent.TimeoutException
import javax.servlet.http.HttpServletResponse
import scala.collection.mutable
+import scala.concurrent.{Await, Future}
+import scala.concurrent.duration._
import scala.io.Source
+import scala.util.control.NonFatal
import com.fasterxml.jackson.core.JsonProcessingException
-import com.google.common.base.Charsets
-import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion}
+import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf, SparkException}
+import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils
/**
@@ -208,7 +213,7 @@ private[spark] class RestSubmissionClient(master: String) extends Logging {
try {
val out = new DataOutputStream(conn.getOutputStream)
Utils.tryWithSafeFinally {
- out.write(json.getBytes(Charsets.UTF_8))
+ out.write(json.getBytes(StandardCharsets.UTF_8))
} {
out.close()
}
@@ -225,7 +230,8 @@ private[spark] class RestSubmissionClient(master: String) extends Logging {
* Exposed for testing.
*/
private[rest] def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = {
- try {
+ import scala.concurrent.ExecutionContext.Implicits.global
+ val responseFuture = Future {
val dataStream =
if (connection.getResponseCode == HttpServletResponse.SC_OK) {
connection.getInputStream
@@ -251,11 +257,19 @@ private[spark] class RestSubmissionClient(master: String) extends Logging {
throw new SubmitRestProtocolException(
s"Message received from server was not a response:\n${unexpected.toJson}")
}
- } catch {
+ }
+
+ // scalastyle:off awaitresult
+ try { Await.result(responseFuture, 10.seconds) } catch {
+ // scalastyle:on awaitresult
case unreachable @ (_: FileNotFoundException | _: SocketException) =>
throw new SubmitRestConnectionException("Unable to connect to server", unreachable)
case malformed @ (_: JsonProcessingException | _: SubmitRestProtocolException) =>
throw new SubmitRestProtocolException("Malformed response received from server", malformed)
+ case timeout: TimeoutException =>
+ throw new SubmitRestConnectionException("No response from server", timeout)
+ case NonFatal(t) =>
+ throw new SparkException("Exception while waiting for response", t)
}
}
@@ -374,7 +388,7 @@ private[spark] class RestSubmissionClient(master: String) extends Logging {
logWarning(s"Unable to connect to server ${masterUrl}.")
lostMasters += masterUrl
}
- lostMasters.size >= masters.size
+ lostMasters.size >= masters.length
}
}
@@ -404,13 +418,13 @@ private[spark] object RestSubmissionClient {
}
def main(args: Array[String]): Unit = {
- if (args.size < 2) {
+ if (args.length < 2) {
sys.error("Usage: RestSubmissionClient [app resource] [main class] [app args*]")
sys.exit(1)
}
val appResource = args(0)
val mainClass = args(1)
- val appArgs = args.slice(2, args.size)
+ val appArgs = args.slice(2, args.length)
val conf = new SparkConf
val env = filterSystemEnvironment(sys.env)
run(appResource, mainClass, appArgs, conf, env)
@@ -420,8 +434,10 @@ private[spark] object RestSubmissionClient {
* Filter non-spark environment variables from any environment.
*/
private[rest] def filterSystemEnvironment(env: Map[String, String]): Map[String, String] = {
- env.filter { case (k, _) =>
- (k.startsWith("SPARK_") && k != "SPARK_ENV_LOADED") || k.startsWith("MESOS_")
+ env.filterKeys { k =>
+ // SPARK_HOME is filtered out because it is usually wrong on the remote machine (SPARK-12345)
+ (k.startsWith("SPARK_") && k != "SPARK_ENV_LOADED" && k != "SPARK_HOME") ||
+ k.startsWith("MESOS_")
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala
index 2e78d03e5c0c..b30c980e95a9 100644
--- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala
@@ -17,18 +17,19 @@
package org.apache.spark.deploy.rest
-import java.net.InetSocketAddress
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}
import scala.io.Source
+
import com.fasterxml.jackson.core.JsonProcessingException
-import org.eclipse.jetty.server.Server
-import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler}
-import org.eclipse.jetty.util.thread.QueuedThreadPool
+import org.eclipse.jetty.server.{HttpConnectionFactory, Server, ServerConnector}
+import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder}
+import org.eclipse.jetty.util.thread.{QueuedThreadPool, ScheduledExecutorScheduler}
import org.json4s._
import org.json4s.jackson.JsonMethods._
-import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion}
+import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf}
+import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils
/**
@@ -78,18 +79,32 @@ private[spark] abstract class RestSubmissionServer(
* Return a 2-tuple of the started server and the bound port.
*/
private def doStart(startPort: Int): (Server, Int) = {
- val server = new Server(new InetSocketAddress(host, startPort))
val threadPool = new QueuedThreadPool
threadPool.setDaemon(true)
- server.setThreadPool(threadPool)
+ val server = new Server(threadPool)
+
+ val connector = new ServerConnector(
+ server,
+ null,
+ // Call this full constructor to set this, which forces daemon threads:
+ new ScheduledExecutorScheduler("RestSubmissionServer-JettyScheduler", true),
+ null,
+ -1,
+ -1,
+ new HttpConnectionFactory())
+ connector.setHost(host)
+ connector.setPort(startPort)
+ server.addConnector(connector)
+
val mainHandler = new ServletContextHandler
+ mainHandler.setServer(server)
mainHandler.setContextPath("/")
contextToServlet.foreach { case (prefix, servlet) =>
mainHandler.addServlet(new ServletHolder(servlet), prefix)
}
server.setHandler(mainHandler)
server.start()
- val boundPort = server.getConnectors()(0).getLocalPort
+ val boundPort = connector.getLocalPort
(server, boundPort)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
index d5b9bcab1423..56620064c57f 100644
--- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
@@ -20,11 +20,11 @@ package org.apache.spark.deploy.rest
import java.io.File
import javax.servlet.http.HttpServletResponse
-import org.apache.spark.deploy.ClientArguments._
+import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf}
import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription}
+import org.apache.spark.deploy.ClientArguments._
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils
-import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf}
/**
* A server that responds to requests submitted by the [[RestSubmissionClient]].
@@ -71,7 +71,7 @@ private[rest] class StandaloneKillRequestServlet(masterEndpoint: RpcEndpointRef,
extends KillRequestServlet {
protected def handleKill(submissionId: String): KillSubmissionResponse = {
- val response = masterEndpoint.askWithRetry[DeployMessages.KillDriverResponse](
+ val response = masterEndpoint.askSync[DeployMessages.KillDriverResponse](
DeployMessages.RequestKillDriver(submissionId))
val k = new KillSubmissionResponse
k.serverSparkVersion = sparkVersion
@@ -89,7 +89,7 @@ private[rest] class StandaloneStatusRequestServlet(masterEndpoint: RpcEndpointRe
extends StatusRequestServlet {
protected def handleStatus(submissionId: String): SubmissionStatusResponse = {
- val response = masterEndpoint.askWithRetry[DeployMessages.DriverStatusResponse](
+ val response = masterEndpoint.askSync[DeployMessages.DriverStatusResponse](
DeployMessages.RequestDriverStatus(submissionId))
val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) }
val d = new SubmissionStatusResponse
@@ -174,7 +174,7 @@ private[rest] class StandaloneSubmitRequestServlet(
requestMessage match {
case submitRequest: CreateSubmissionRequest =>
val driverDescription = buildDriverDescription(submitRequest)
- val response = masterEndpoint.askWithRetry[DeployMessages.SubmitDriverResponse](
+ val response = masterEndpoint.askSync[DeployMessages.SubmitDriverResponse](
DeployMessages.RequestSubmitDriver(driverDescription))
val submitResponse = new CreateSubmissionResponse
submitResponse.serverSparkVersion = sparkVersion
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala
index 0d50a768942e..86ddf954ca12 100644
--- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala
@@ -46,6 +46,8 @@ private[rest] class CreateSubmissionRequest extends SubmitRestProtocolRequest {
super.doValidate()
assert(sparkProperties != null, "No Spark properties set!")
assertFieldIsSet(appResource, "appResource")
+ assertFieldIsSet(appArgs, "appArgs")
+ assertFieldIsSet(environmentVariables, "environmentVariables")
assertPropertyIsSet("spark.app.name")
assertPropertyIsBoolean("spark.driver.supervise")
assertPropertyIsNumeric("spark.driver.cores")
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
index ce02ee203a4b..cba4aaffe2ca 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
@@ -22,14 +22,14 @@ import java.io.{File, FileOutputStream, InputStream, IOException}
import scala.collection.JavaConverters._
import scala.collection.Map
-import org.apache.spark.Logging
import org.apache.spark.SecurityManager
import org.apache.spark.deploy.Command
+import org.apache.spark.internal.Logging
import org.apache.spark.launcher.WorkerCommandBuilder
import org.apache.spark.util.Utils
/**
- ** Utilities for running commands with the spark classpath.
+ * Utilities for running commands with the spark classpath.
*/
private[deploy]
object CommandUtils extends Logging {
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
index 89159ff5e2b3..58a181128eb4 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
@@ -18,20 +18,21 @@
package org.apache.spark.deploy.worker
import java.io._
+import java.net.URI
+import java.nio.charset.StandardCharsets
import scala.collection.JavaConverters._
-import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files
-import org.apache.hadoop.fs.Path
-import org.apache.spark.{Logging, SparkConf, SecurityManager}
+import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.deploy.{DriverDescription, SparkHadoopUtil}
import org.apache.spark.deploy.DeployMessages.DriverStateChanged
import org.apache.spark.deploy.master.DriverState
import org.apache.spark.deploy.master.DriverState.DriverState
+import org.apache.spark.internal.Logging
import org.apache.spark.rpc.RpcEndpointRef
-import org.apache.spark.util.{Utils, Clock, SystemClock}
+import org.apache.spark.util.{Clock, ShutdownHookManager, SystemClock, Utils}
/**
* Manages the execution of one driver, including automatically restarting the driver on failure.
@@ -52,9 +53,12 @@ private[deploy] class DriverRunner(
@volatile private var killed = false
// Populated once finished
- private[worker] var finalState: Option[DriverState] = None
- private[worker] var finalException: Option[Exception] = None
- private var finalExitCode: Option[Int] = None
+ @volatile private[worker] var finalState: Option[DriverState] = None
+ @volatile private[worker] var finalException: Option[Exception] = None
+
+ // Timeout to wait for when trying to terminate a driver.
+ private val DRIVER_TERMINATE_TIMEOUT_MS =
+ conf.getTimeAsMs("spark.worker.driverTerminateTimeout", "10s")
// Decoupled for testing
def setClock(_clock: Clock): Unit = {
@@ -67,56 +71,63 @@ private[deploy] class DriverRunner(
private var clock: Clock = new SystemClock()
private var sleeper = new Sleeper {
- def sleep(seconds: Int): Unit = (0 until seconds).takeWhile(f => {Thread.sleep(1000); !killed})
+ def sleep(seconds: Int): Unit = (0 until seconds).takeWhile { _ =>
+ Thread.sleep(1000)
+ !killed
+ }
}
/** Starts a thread to run and manage the driver. */
private[worker] def start() = {
new Thread("DriverRunner for " + driverId) {
override def run() {
+ var shutdownHook: AnyRef = null
try {
- val driverDir = createWorkingDirectory()
- val localJarFilename = downloadUserJar(driverDir)
-
- def substituteVariables(argument: String): String = argument match {
- case "{{WORKER_URL}}" => workerUrl
- case "{{USER_JAR}}" => localJarFilename
- case other => other
+ shutdownHook = ShutdownHookManager.addShutdownHook { () =>
+ logInfo(s"Worker shutting down, killing driver $driverId")
+ kill()
}
- // TODO: If we add ability to submit multiple jars they should also be added here
- val builder = CommandUtils.buildProcessBuilder(driverDesc.command, securityManager,
- driverDesc.mem, sparkHome.getAbsolutePath, substituteVariables)
- launchDriver(builder, driverDir, driverDesc.supervise)
- }
- catch {
- case e: Exception => finalException = Some(e)
- }
+ // prepare driver jars and run driver
+ val exitCode = prepareAndRunDriver()
- val state =
- if (killed) {
- DriverState.KILLED
- } else if (finalException.isDefined) {
- DriverState.ERROR
+ // set final state depending on if forcibly killed and process exit code
+ finalState = if (exitCode == 0) {
+ Some(DriverState.FINISHED)
+ } else if (killed) {
+ Some(DriverState.KILLED)
} else {
- finalExitCode match {
- case Some(0) => DriverState.FINISHED
- case _ => DriverState.FAILED
- }
+ Some(DriverState.FAILED)
}
+ } catch {
+ case e: Exception =>
+ kill()
+ finalState = Some(DriverState.ERROR)
+ finalException = Some(e)
+ } finally {
+ if (shutdownHook != null) {
+ ShutdownHookManager.removeShutdownHook(shutdownHook)
+ }
+ }
- finalState = Some(state)
-
- worker.send(DriverStateChanged(driverId, state, finalException))
+ // notify worker of final driver state, possible exception
+ worker.send(DriverStateChanged(driverId, finalState.get, finalException))
}
}.start()
}
/** Terminate this driver (or prevent it from ever starting if not yet started) */
- private[worker] def kill() {
+ private[worker] def kill(): Unit = {
+ logInfo("Killing driver process!")
+ killed = true
synchronized {
- process.foreach(p => p.destroy())
- killed = true
+ process.foreach { p =>
+ val exitCode = Utils.terminateProcess(p, DRIVER_TERMINATE_TIMEOUT_MS)
+ if (exitCode.isEmpty) {
+ logWarning("Failed to terminate driver process: " + p +
+ ". This process will likely be orphaned.")
+ }
+ }
}
}
@@ -137,34 +148,44 @@ private[deploy] class DriverRunner(
* Will throw an exception if there are errors downloading the jar.
*/
private def downloadUserJar(driverDir: File): String = {
- val jarPath = new Path(driverDesc.jarUrl)
-
- val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
- val destPath = new File(driverDir.getAbsolutePath, jarPath.getName)
- val jarFileName = jarPath.getName
+ val jarFileName = new URI(driverDesc.jarUrl).getPath.split("/").last
val localJarFile = new File(driverDir, jarFileName)
- val localJarFilename = localJarFile.getAbsolutePath
-
if (!localJarFile.exists()) { // May already exist if running multiple workers on one node
- logInfo(s"Copying user jar $jarPath to $destPath")
+ logInfo(s"Copying user jar ${driverDesc.jarUrl} to $localJarFile")
Utils.fetchFile(
driverDesc.jarUrl,
driverDir,
conf,
securityManager,
- hadoopConf,
+ SparkHadoopUtil.get.newConfiguration(conf),
System.currentTimeMillis(),
useCache = false)
+ if (!localJarFile.exists()) { // Verify copy succeeded
+ throw new IOException(
+ s"Can not find expected jar $jarFileName which should have been loaded in $driverDir")
+ }
}
+ localJarFile.getAbsolutePath
+ }
+
+ private[worker] def prepareAndRunDriver(): Int = {
+ val driverDir = createWorkingDirectory()
+ val localJarFilename = downloadUserJar(driverDir)
- if (!localJarFile.exists()) { // Verify copy succeeded
- throw new Exception(s"Did not see expected jar $jarFileName in $driverDir")
+ def substituteVariables(argument: String): String = argument match {
+ case "{{WORKER_URL}}" => workerUrl
+ case "{{USER_JAR}}" => localJarFilename
+ case other => other
}
- localJarFilename
+ // TODO: If we add ability to submit multiple jars they should also be added here
+ val builder = CommandUtils.buildProcessBuilder(driverDesc.command, securityManager,
+ driverDesc.mem, sparkHome.getAbsolutePath, substituteVariables)
+
+ runDriver(builder, driverDir, driverDesc.supervise)
}
- private def launchDriver(builder: ProcessBuilder, baseDir: File, supervise: Boolean) {
+ private def runDriver(builder: ProcessBuilder, baseDir: File, supervise: Boolean): Int = {
builder.directory(baseDir)
def initialize(process: Process): Unit = {
// Redirect stdout and stderr to files
@@ -174,50 +195,51 @@ private[deploy] class DriverRunner(
val stderr = new File(baseDir, "stderr")
val formattedCommand = builder.command.asScala.mkString("\"", "\" \"", "\"")
val header = "Launch Command: %s\n%s\n\n".format(formattedCommand, "=" * 40)
- Files.append(header, stderr, UTF_8)
+ Files.append(header, stderr, StandardCharsets.UTF_8)
CommandUtils.redirectStream(process.getErrorStream, stderr)
}
runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise)
}
- def runCommandWithRetry(
- command: ProcessBuilderLike, initialize: Process => Unit, supervise: Boolean): Unit = {
+ private[worker] def runCommandWithRetry(
+ command: ProcessBuilderLike, initialize: Process => Unit, supervise: Boolean): Int = {
+ var exitCode = -1
// Time to wait between submission retries.
var waitSeconds = 1
// A run of this many seconds resets the exponential back-off.
val successfulRunDuration = 5
-
var keepTrying = !killed
while (keepTrying) {
logInfo("Launch Command: " + command.command.mkString("\"", "\" \"", "\""))
synchronized {
- if (killed) { return }
+ if (killed) { return exitCode }
process = Some(command.start())
initialize(process.get)
}
val processStart = clock.getTimeMillis()
- val exitCode = process.get.waitFor()
- if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000) {
- waitSeconds = 1
- }
+ exitCode = process.get.waitFor()
- if (supervise && exitCode != 0 && !killed) {
+ // check if attempting another run
+ keepTrying = supervise && exitCode != 0 && !killed
+ if (keepTrying) {
+ if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000) {
+ waitSeconds = 1
+ }
logInfo(s"Command exited with status $exitCode, re-launching after $waitSeconds s.")
sleeper.sleep(waitSeconds)
waitSeconds = waitSeconds * 2 // exponential back-off
}
-
- keepTrying = supervise && exitCode != 0 && !killed
- finalExitCode = Some(exitCode)
}
+
+ exitCode
}
}
private[deploy] trait Sleeper {
- def sleep(seconds: Int)
+ def sleep(seconds: Int): Unit
}
// Needed because ProcessBuilder is a final class and cannot be mocked
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index 3aef0515cbf6..d4d8521cc820 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -18,16 +18,17 @@
package org.apache.spark.deploy.worker
import java.io._
+import java.nio.charset.StandardCharsets
import scala.collection.JavaConverters._
-import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files
-import org.apache.spark.rpc.RpcEndpointRef
-import org.apache.spark.{SecurityManager, SparkConf, Logging}
+import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.{ShutdownHookManager, Utils}
import org.apache.spark.util.logging.FileAppender
@@ -60,6 +61,9 @@ private[deploy] class ExecutorRunner(
private var stdoutAppender: FileAppender = null
private var stderrAppender: FileAppender = null
+ // Timeout to wait for when trying to terminate an executor.
+ private val EXECUTOR_TERMINATE_TIMEOUT_MS = 10 * 1000
+
// NOTE: This is now redundant with the automated shut-down enforced by the Executor. It might
// make sense to remove this in the future.
private var shutdownHook: AnyRef = null
@@ -71,6 +75,11 @@ private[deploy] class ExecutorRunner(
workerThread.start()
// Shutdown hook that kills actors on shutdown.
shutdownHook = ShutdownHookManager.addShutdownHook { () =>
+ // It's possible that we arrive here before calling `fetchAndRunExecutor`, then `state` will
+ // be `ExecutorState.RUNNING`. In this case, we should set `state` to `FAILED`.
+ if (state == ExecutorState.RUNNING) {
+ state = ExecutorState.FAILED
+ }
killProcess(Some("Worker shutting down")) }
}
@@ -89,10 +98,17 @@ private[deploy] class ExecutorRunner(
if (stderrAppender != null) {
stderrAppender.stop()
}
- process.destroy()
- exitCode = Some(process.waitFor())
+ exitCode = Utils.terminateProcess(process, EXECUTOR_TERMINATE_TIMEOUT_MS)
+ if (exitCode.isEmpty) {
+ logWarning("Failed to terminate process: " + process +
+ ". This process will likely be orphaned.")
+ }
+ }
+ try {
+ worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode))
+ } catch {
+ case e: IllegalStateException => logWarning(e.getMessage(), e)
}
- worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode))
}
/** Stop this executor runner, including killing the process it launched */
@@ -140,7 +156,11 @@ private[deploy] class ExecutorRunner(
// Add webUI log urls
val baseUrl =
- s"http://$publicAddress:$webUiPort/logPage/?appId=$appId&executorId=$execId&logType="
+ if (conf.getBoolean("spark.ui.reverseProxy", false)) {
+ s"/proxy/$workerId/logPage/?appId=$appId&executorId=$execId&logType="
+ } else {
+ s"http://$publicAddress:$webUiPort/logPage/?appId=$appId&executorId=$execId&logType="
+ }
builder.environment.put("SPARK_LOG_URL_STDERR", s"${baseUrl}stderr")
builder.environment.put("SPARK_LOG_URL_STDOUT", s"${baseUrl}stdout")
@@ -153,7 +173,7 @@ private[deploy] class ExecutorRunner(
stdoutAppender = FileAppender(process.getInputStream, stdout, conf)
val stderr = new File(executorDir, "stderr")
- Files.write(header, stderr, UTF_8)
+ Files.write(header, stderr, StandardCharsets.UTF_8)
stderrAppender = FileAppender(process.getErrorStream, stderr, conf)
// Wait for it to exit; executor may exit with code 0 (when driver instructs it to shutdown)
@@ -163,16 +183,14 @@ private[deploy] class ExecutorRunner(
val message = "Command exited with code " + exitCode
worker.send(ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode)))
} catch {
- case interrupted: InterruptedException => {
+ case interrupted: InterruptedException =>
logInfo("Runner thread for executor " + fullId + " interrupted")
state = ExecutorState.KILLED
killProcess(None)
- }
- case e: Exception => {
+ case e: Exception =>
logError("Error running executor", e)
state = ExecutorState.FAILED
killProcess(Some(e.toString))
- }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index a45867e7680e..ca9243e39c0a 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -20,24 +20,25 @@ package org.apache.spark.deploy.worker
import java.io.File
import java.io.IOException
import java.text.SimpleDateFormat
-import java.util.{UUID, Date}
+import java.util.{Date, Locale, UUID}
import java.util.concurrent._
import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture}
import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap}
import scala.concurrent.ExecutionContext
-import scala.util.{Failure, Random, Success}
+import scala.util.Random
import scala.util.control.NonFatal
-import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.deploy.{Command, ExecutorDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.ExternalShuffleService
import org.apache.spark.deploy.master.{DriverState, Master}
import org.apache.spark.deploy.worker.ui.WorkerWebUI
+import org.apache.spark.internal.Logging
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.rpc._
-import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils}
+import org.apache.spark.util.{ThreadUtils, Utils}
private[deploy] class Worker(
override val rpcEnv: RpcEnv,
@@ -45,7 +46,6 @@ private[deploy] class Worker(
cores: Int,
memory: Int,
masterRpcAddresses: Array[RpcAddress],
- systemName: String,
endpointName: String,
workDirPath: String = null,
val conf: SparkConf,
@@ -62,13 +62,13 @@ private[deploy] class Worker(
private val forwordMessageScheduler =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("worker-forward-message-scheduler")
- // A separated thread to clean up the workDir. Used to provide the implicit parameter of `Future`
- // methods.
+ // A separated thread to clean up the workDir and the directories of finished applications.
+ // Used to provide the implicit parameter of `Future` methods.
private val cleanupThreadExecutor = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonSingleThreadExecutor("worker-cleanup-thread"))
// For worker and executor IDs
- private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
+ private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US)
// Send a heartbeat every (heartbeat timeout) / 4 milliseconds
private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4
@@ -99,9 +99,24 @@ private[deploy] class Worker(
private val testing: Boolean = sys.props.contains("spark.testing")
private var master: Option[RpcEndpointRef] = None
+
+ /**
+ * Whether to use the master address in `masterRpcAddresses` if possible. If it's disabled, Worker
+ * will just use the address received from Master.
+ */
+ private val preferConfiguredMasterAddress =
+ conf.getBoolean("spark.worker.preferConfiguredMasterAddress", false)
+ /**
+ * The master address to connect in case of failure. When the connection is broken, worker will
+ * use this address to connect. This is usually just one of `masterRpcAddresses`. However, when
+ * a master is restarted or takes over leadership, it will be an address sent from master, which
+ * may not be in `masterRpcAddresses`.
+ */
+ private var masterAddressToConnect: Option[RpcAddress] = None
private var activeMasterUrl: String = ""
private[worker] var activeMasterWebUiUrl : String = ""
- private val workerUri = rpcEnv.uriOf(systemName, rpcEnv.address, endpointName)
+ private var workerWebUiUrl: String = ""
+ private val workerUri = RpcEndpointAddress(rpcEnv.address, endpointName).toString
private var registered = false
private var connected = false
private val workerId = generateWorkerId()
@@ -146,12 +161,10 @@ private[deploy] class Worker(
// A thread pool for registering with masters. Because registering with a master is a blocking
// action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same
// time so that we can register with all masters.
- private val registerMasterThreadPool = new ThreadPoolExecutor(
- 0,
- masterRpcAddresses.size, // Make sure we can register with all masters at the same time
- 60L, TimeUnit.SECONDS,
- new SynchronousQueue[Runnable](),
- ThreadUtils.namedThreadFactory("worker-register-master-threadpool"))
+ private val registerMasterThreadPool = ThreadUtils.newDaemonCachedThreadPool(
+ "worker-register-master-threadpool",
+ masterRpcAddresses.length // Make sure we can register with all masters at the same time
+ )
var coresUsed = 0
var memoryUsed = 0
@@ -187,6 +200,8 @@ private[deploy] class Worker(
shuffleService.startIfEnabled()
webUi = new WorkerWebUI(this, workDir, webUiPort)
webUi.bind()
+
+ workerWebUiUrl = s"http://$publicAddress:${webUi.boundPort}"
registerWithMaster()
metricsSystem.registerSource(workerSource)
@@ -195,12 +210,24 @@ private[deploy] class Worker(
metricsSystem.getServletHandlers.foreach(webUi.attachHandler)
}
- private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String) {
+ /**
+ * Change to use the new master.
+ *
+ * @param masterRef the new master ref
+ * @param uiUrl the new master Web UI address
+ * @param masterAddress the new master address which the worker should use to connect in case of
+ * failure
+ */
+ private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String, masterAddress: RpcAddress) {
// activeMasterUrl it's a valid Spark url since we receive it from master.
activeMasterUrl = masterRef.address.toSparkURL
activeMasterWebUiUrl = uiUrl
+ masterAddressToConnect = Some(masterAddress)
master = Some(masterRef)
connected = true
+ if (conf.getBoolean("spark.ui.reverseProxy", false)) {
+ logInfo(s"WorkerWebUI is available at $activeMasterWebUiUrl/proxy/$workerId")
+ }
// Cancel any outstanding re-registration attempts because we found a new master
cancelLastRegistrationRetry()
}
@@ -211,9 +238,8 @@ private[deploy] class Worker(
override def run(): Unit = {
try {
logInfo("Connecting to master " + masterAddress + "...")
- val masterEndpoint =
- rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME)
- registerWithMaster(masterEndpoint)
+ val masterEndpoint = rpcEnv.setupEndpointRef(masterAddress, Master.ENDPOINT_NAME)
+ sendRegisterMessageToMaster(masterEndpoint)
} catch {
case ie: InterruptedException => // Cancelled
case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e)
@@ -263,14 +289,14 @@ private[deploy] class Worker(
if (registerMasterFutures != null) {
registerMasterFutures.foreach(_.cancel(true))
}
- val masterAddress = masterRef.address
+ val masterAddress =
+ if (preferConfiguredMasterAddress) masterAddressToConnect.get else masterRef.address
registerMasterFutures = Array(registerMasterThreadPool.submit(new Runnable {
override def run(): Unit = {
try {
logInfo("Connecting to master " + masterAddress + "...")
- val masterEndpoint =
- rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME)
- registerWithMaster(masterEndpoint)
+ val masterEndpoint = rpcEnv.setupEndpointRef(masterAddress, Master.ENDPOINT_NAME)
+ sendRegisterMessageToMaster(masterEndpoint)
} catch {
case ie: InterruptedException => // Cancelled
case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e)
@@ -339,27 +365,28 @@ private[deploy] class Worker(
}
}
- private def registerWithMaster(masterEndpoint: RpcEndpointRef): Unit = {
- masterEndpoint.ask[RegisterWorkerResponse](RegisterWorker(
- workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress))
- .onComplete {
- // This is a very fast action so we can use "ThreadUtils.sameThread"
- case Success(msg) =>
- Utils.tryLogNonFatalError {
- handleRegisterResponse(msg)
- }
- case Failure(e) =>
- logError(s"Cannot register with master: ${masterEndpoint.address}", e)
- System.exit(1)
- }(ThreadUtils.sameThread)
+ private def sendRegisterMessageToMaster(masterEndpoint: RpcEndpointRef): Unit = {
+ masterEndpoint.send(RegisterWorker(
+ workerId,
+ host,
+ port,
+ self,
+ cores,
+ memory,
+ workerWebUiUrl,
+ masterEndpoint.address))
}
private def handleRegisterResponse(msg: RegisterWorkerResponse): Unit = synchronized {
msg match {
- case RegisteredWorker(masterRef, masterWebUiUrl) =>
- logInfo("Successfully registered with master " + masterRef.address.toSparkURL)
+ case RegisteredWorker(masterRef, masterWebUiUrl, masterAddress) =>
+ if (preferConfiguredMasterAddress) {
+ logInfo("Successfully registered with master " + masterAddress.toSparkURL)
+ } else {
+ logInfo("Successfully registered with master " + masterRef.address.toSparkURL)
+ }
registered = true
- changeMaster(masterRef, masterWebUiUrl)
+ changeMaster(masterRef, masterWebUiUrl, masterAddress)
forwordMessageScheduler.scheduleAtFixedRate(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
self.send(SendHeartbeat)
@@ -375,6 +402,11 @@ private[deploy] class Worker(
}, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS)
}
+ val execs = executors.values.map { e =>
+ new ExecutorDescription(e.appId, e.execId, e.cores, e.state)
+ }
+ masterRef.send(WorkerLatestState(workerId, execs.toList, drivers.keys.toSeq))
+
case RegisterWorkerFailed(message) =>
if (!registered) {
logError("Worker registration failed: " + message)
@@ -387,6 +419,9 @@ private[deploy] class Worker(
}
override def receive: PartialFunction[Any, Unit] = synchronized {
+ case msg: RegisterWorkerResponse =>
+ handleRegisterResponse(msg)
+
case SendHeartbeat =>
if (connected) { sendToMaster(Heartbeat(workerId, self)) }
@@ -395,7 +430,7 @@ private[deploy] class Worker(
// rpcEndpoint.
// Copy ids so that it can be used in the cleanup thread.
val appIds = executors.values.map(_.appId).toSet
- val cleanupFuture = concurrent.future {
+ val cleanupFuture = concurrent.Future {
val appDirs = workDir.listFiles()
if (appDirs == null) {
throw new IOException("ERROR: Failed to list files in " + appDirs)
@@ -420,7 +455,7 @@ private[deploy] class Worker(
case MasterChanged(masterRef, masterWebUiUrl) =>
logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL)
- changeMaster(masterRef, masterWebUiUrl)
+ changeMaster(masterRef, masterWebUiUrl, masterRef.address)
val execs = executors.values.
map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state))
@@ -446,13 +481,25 @@ private[deploy] class Worker(
// Create local dirs for the executor. These are passed to the executor via the
// SPARK_EXECUTOR_DIRS environment variable, and deleted by the Worker when the
// application finishes.
- val appLocalDirs = appDirectories.get(appId).getOrElse {
- Utils.getOrCreateLocalRootDirs(conf).map { dir =>
- val appDir = Utils.createDirectory(dir, namePrefix = "executor")
- Utils.chmod700(appDir)
- appDir.getAbsolutePath()
+ val appLocalDirs = appDirectories.getOrElse(appId, {
+ val localRootDirs = Utils.getOrCreateLocalRootDirs(conf)
+ val dirs = localRootDirs.flatMap { dir =>
+ try {
+ val appDir = Utils.createDirectory(dir, namePrefix = "executor")
+ Utils.chmod700(appDir)
+ Some(appDir.getAbsolutePath())
+ } catch {
+ case e: IOException =>
+ logWarning(s"${e.getMessage}. Ignoring this directory.")
+ None
+ }
}.toSeq
- }
+ if (dirs.isEmpty) {
+ throw new IOException("No subfolder can be created in " +
+ s"${localRootDirs.mkString(",")}.")
+ }
+ dirs
+ })
appDirectories(appId) = appLocalDirs
val manager = new ExecutorRunner(
appId,
@@ -469,14 +516,14 @@ private[deploy] class Worker(
executorDir,
workerUri,
conf,
- appLocalDirs, ExecutorState.LOADING)
+ appLocalDirs, ExecutorState.RUNNING)
executors(appId + "/" + execId) = manager
manager.start()
coresUsed += cores_
memoryUsed += memory_
sendToMaster(ExecutorStateChanged(appId, execId, manager.state, None, None))
} catch {
- case e: Exception => {
+ case e: Exception =>
logError(s"Failed to launch executor $appId/$execId for ${appDesc.name}.", e)
if (executors.contains(appId + "/" + execId)) {
executors(appId + "/" + execId).kill()
@@ -484,7 +531,6 @@ private[deploy] class Worker(
}
sendToMaster(ExecutorStateChanged(appId, execId, ExecutorState.FAILED,
Some(e.toString), None))
- }
}
}
@@ -493,7 +539,7 @@ private[deploy] class Worker(
case KillExecutor(masterUrl, appId, execId) =>
if (masterUrl != activeMasterUrl) {
- logWarning("Invalid Master (" + masterUrl + ") attempted to launch executor " + execId)
+ logWarning("Invalid Master (" + masterUrl + ") attempted to kill executor " + execId)
} else {
val fullId = appId + "/" + execId
executors.get(fullId) match {
@@ -505,7 +551,7 @@ private[deploy] class Worker(
}
}
- case LaunchDriver(driverId, driverDesc) => {
+ case LaunchDriver(driverId, driverDesc) =>
logInfo(s"Asked to launch driver $driverId")
val driver = new DriverRunner(
conf,
@@ -521,9 +567,8 @@ private[deploy] class Worker(
coresUsed += driverDesc.cores
memoryUsed += driverDesc.mem
- }
- case KillDriver(driverId) => {
+ case KillDriver(driverId) =>
logInfo(s"Asked to kill driver $driverId")
drivers.get(driverId) match {
case Some(runner) =>
@@ -531,11 +576,9 @@ private[deploy] class Worker(
case None =>
logError(s"Asked to kill unknown driver $driverId")
}
- }
- case driverStateChanged @ DriverStateChanged(driverId, state, exception) => {
+ case driverStateChanged @ DriverStateChanged(driverId, state, exception) =>
handleDriverStateChanged(driverStateChanged)
- }
case ReregisterWithMaster =>
reregisterWithMaster()
@@ -554,7 +597,8 @@ private[deploy] class Worker(
}
override def onDisconnected(remoteAddress: RpcAddress): Unit = {
- if (master.exists(_.address == remoteAddress)) {
+ if (master.exists(_.address == remoteAddress) ||
+ masterAddressToConnect.exists(_ == remoteAddress)) {
logInfo(s"$remoteAddress Disassociated !")
masterDisconnected()
}
@@ -571,10 +615,15 @@ private[deploy] class Worker(
if (shouldCleanup) {
finishedApps -= id
appDirectories.remove(id).foreach { dirList =>
- logInfo(s"Cleaning up local directories for application $id")
- dirList.foreach { dir =>
- Utils.deleteRecursively(new File(dir))
- }
+ concurrent.Future {
+ logInfo(s"Cleaning up local directories for application $id")
+ dirList.foreach { dir =>
+ Utils.deleteRecursively(new File(dir))
+ }
+ }(cleanupThreadExecutor).onFailure {
+ case e: Throwable =>
+ logError(s"Clean up app dir $dirList failed: ${e.getMessage}", e)
+ }(cleanupThreadExecutor)
}
shuffleService.applicationRemoved(id)
}
@@ -688,11 +737,11 @@ private[deploy] object Worker extends Logging {
val ENDPOINT_NAME = "Worker"
def main(argStrings: Array[String]) {
- SignalLogger.register(log)
+ Utils.initDaemon(log)
val conf = new SparkConf
val args = new WorkerArguments(argStrings, conf)
val rpcEnv = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, args.cores,
- args.memory, args.masters, args.workDir)
+ args.memory, args.masters, args.workDir, conf = conf)
rpcEnv.awaitTermination()
}
@@ -713,7 +762,7 @@ private[deploy] object Worker extends Logging {
val rpcEnv = RpcEnv.create(systemName, host, port, conf, securityMgr)
val masterAddresses = masterUrls.map(RpcAddress.fromSparkURL(_))
rpcEnv.setupEndpoint(ENDPOINT_NAME, new Worker(rpcEnv, webUiPort, cores, memory,
- masterAddresses, systemName, ENDPOINT_NAME, workDir, conf, securityMgr))
+ masterAddresses, ENDPOINT_NAME, workDir, conf, securityMgr))
rpcEnv
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
index 5181142c5f80..777020d4d5c8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
@@ -19,6 +19,8 @@ package org.apache.spark.deploy.worker
import java.lang.management.ManagementFactory
+import scala.annotation.tailrec
+
import org.apache.spark.util.{IntParam, MemoryParam, Utils}
import org.apache.spark.SparkConf
@@ -63,6 +65,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) {
checkWorkerMemory()
+ @tailrec
private def parse(args: List[String]): Unit = args match {
case ("--ip" | "-i") :: value :: tail =>
Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
@@ -162,12 +165,11 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) {
}
// scalastyle:on classforname
} catch {
- case e: Exception => {
+ case e: Exception =>
totalMb = 2*1024
// scalastyle:off println
System.out.println("Failed to get total physical memory. Using " + totalMb + " MB")
// scalastyle:on println
- }
}
// Leave out 1 GB for the operating system, but don't return a negative memory size
math.max(totalMb - 1024, Utils.DEFAULT_DRIVER_MEM_MB)
@@ -175,7 +177,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) {
def checkWorkerMemory(): Unit = {
if (memory <= 0) {
- val message = "Memory can't be 0, missing a M or G on the end of the memory specification?"
+ val message = "Memory is below 1MB, or missing a M/G at the end of the memory specification?"
throw new IllegalStateException(message)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
index ab56fde938ba..23efcab6caad 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
@@ -17,11 +17,12 @@
package org.apache.spark.deploy.worker
-import org.apache.spark.Logging
+import org.apache.spark.internal.Logging
import org.apache.spark.rpc._
/**
- * Actor which connects to a worker process and terminates the JVM if the connection is severed.
+ * Endpoint which connects to a worker process and terminates the JVM if the
+ * connection is severed.
* Provides fate sharing between a worker and its associated child processes.
*/
private[spark] class WorkerWatcher(
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
index 5a1d06eb87db..2f5a5642d3ca 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
@@ -18,36 +18,37 @@
package org.apache.spark.deploy.worker.ui
import java.io.File
-import java.net.URI
import javax.servlet.http.HttpServletRequest
-import scala.xml.Node
+import scala.xml.{Node, Unparsed}
-import org.apache.spark.ui.{WebUIPage, UIUtils}
+import org.apache.spark.internal.Logging
+import org.apache.spark.ui.{UIUtils, WebUIPage}
import org.apache.spark.util.Utils
-import org.apache.spark.Logging
import org.apache.spark.util.logging.RollingFileAppender
private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with Logging {
private val worker = parent.worker
- private val workDir = parent.workDir
+ private val workDir = new File(parent.workDir.toURI.normalize().getPath)
private val supportedLogTypes = Set("stderr", "stdout")
+ private val defaultBytes = 100 * 1024
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
def renderLog(request: HttpServletRequest): String = {
- val defaultBytes = 100 * 1024
-
- val appId = Option(request.getParameter("appId"))
- val executorId = Option(request.getParameter("executorId"))
- val driverId = Option(request.getParameter("driverId"))
- val logType = request.getParameter("logType")
- val offset = Option(request.getParameter("offset")).map(_.toLong)
- val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
+ val appId = Option(UIUtils.stripXSS(request.getParameter("appId")))
+ val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId")))
+ val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId")))
+ val logType = UIUtils.stripXSS(request.getParameter("logType"))
+ val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong)
+ val byteLength =
+ Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt)
+ .getOrElse(defaultBytes)
val logDir = (appId, executorId, driverId) match {
case (Some(a), Some(e), None) =>
- s"${workDir.getPath}/$appId/$executorId/"
+ s"${workDir.getPath}/$a/$e/"
case (None, None, Some(d)) =>
- s"${workDir.getPath}/$driverId/"
+ s"${workDir.getPath}/$d/"
case _ =>
throw new Exception("Request must specify either application or driver identifiers")
}
@@ -57,14 +58,16 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with
pre + logText
}
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
def render(request: HttpServletRequest): Seq[Node] = {
- val defaultBytes = 100 * 1024
- val appId = Option(request.getParameter("appId"))
- val executorId = Option(request.getParameter("executorId"))
- val driverId = Option(request.getParameter("driverId"))
- val logType = request.getParameter("logType")
- val offset = Option(request.getParameter("offset")).map(_.toLong)
- val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
+ val appId = Option(UIUtils.stripXSS(request.getParameter("appId")))
+ val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId")))
+ val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId")))
+ val logType = UIUtils.stripXSS(request.getParameter("logType"))
+ val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong)
+ val byteLength =
+ Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt)
+ .getOrElse(defaultBytes)
val (logDir, params, pageName) = (appId, executorId, driverId) match {
case (Some(a), Some(e), None) =>
@@ -77,51 +80,44 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with
val (logText, startByte, endByte, logLength) = getLog(logDir, logType, offset, byteLength)
val linkToMaster =
- val range = Bytes {startByte.toString} - {endByte.toString} of {logLength}
-
- val backButton =
- if (startByte > 0) {
-
-
-
- } else {
-
- }
-
- val nextButton =
- if (endByte < logLength) {
-
-
-
- } else {
-
- }
+ val curLogLength = endByte - startByte
+ val range =
+
+ Showing {curLogLength} Bytes: {startByte.toString} - {endByte.toString} of {logLength}
+
+
+ val moreButton =
+
+
+ val newButton =
+
+
+ val alert =
+
+
+ val logParams = "?%s&logType=%s".format(params, logType)
+ val jsOnload = "window.onload = " +
+ s"initLogPage('$logParams', $curLogLength, $startByte, $endByte, $logLength, $byteLength);"
val content =
-
-
- {linkToMaster}
-
- {backButton}
- {range}
- {nextButton}
-
-
-
- {logText}
-
-
-
+
+ {linkToMaster}
+ {range}
+
+ {moreButton}
+ {logText}
+ {alert}
+ {newButton}
+
+
+
+
UIUtils.basicSparkPage(content, logType + " log page for " + pageName)
}
@@ -138,7 +134,7 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with
}
// Verify that the normalized path of the log directory is in the working directory
- val normalizedUri = new URI(logDirectory).normalize()
+ val normalizedUri = new File(logDirectory).toURI.normalize()
val normalizedLogDir = new File(normalizedUri.getPath)
if (!Utils.isInDirectory(workDir, normalizedLogDir)) {
return ("Error: invalid log directory " + logDirectory, 0, 0, 0)
@@ -148,7 +144,8 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with
val files = RollingFileAppender.getSortedRolledOverFiles(logDirectory, logType)
logDebug(s"Sorted log files of type $logType in $logDirectory:\n${files.mkString("\n")}")
- val totalLength = files.map { _.length }.sum
+ val fileLengths: Seq[Long] = files.map(Utils.getFileLength(_, worker.conf))
+ val totalLength = fileLengths.sum
val offset = offsetOption.getOrElse(totalLength - byteLength)
val startIndex = {
if (offset < 0) {
@@ -161,7 +158,7 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with
}
val endIndex = math.min(startIndex + byteLength, totalLength)
logDebug(s"Getting log from $startIndex to $endIndex")
- val logText = Utils.offsetBytes(files, startIndex, endIndex)
+ val logText = Utils.offsetBytes(files, fileLengths, startIndex, endIndex)
logDebug(s"Got log of length ${logText.length} bytes")
(logText, startIndex, endIndex, totalLength)
} catch {
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala
index fd905feb97e9..1ad973122b60 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala
@@ -17,28 +17,29 @@
package org.apache.spark.deploy.worker.ui
+import javax.servlet.http.HttpServletRequest
+
import scala.xml.Node
-import javax.servlet.http.HttpServletRequest
import org.json4s.JValue
-import org.apache.spark.deploy.JsonProtocol
import org.apache.spark.deploy.DeployMessages.{RequestWorkerState, WorkerStateResponse}
+import org.apache.spark.deploy.JsonProtocol
import org.apache.spark.deploy.master.DriverState
import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner}
-import org.apache.spark.ui.{WebUIPage, UIUtils}
+import org.apache.spark.ui.{UIUtils, WebUIPage}
import org.apache.spark.util.Utils
private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") {
private val workerEndpoint = parent.worker.self
override def renderJson(request: HttpServletRequest): JValue = {
- val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState)
+ val workerState = workerEndpoint.askSync[WorkerStateResponse](RequestWorkerState)
JsonProtocol.writeWorkerState(workerState)
}
def render(request: HttpServletRequest): Seq[Node] = {
- val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState)
+ val workerState = workerEndpoint.askSync[WorkerStateResponse](RequestWorkerState)
val executorHeaders = Seq("ExecutorID", "Cores", "State", "Memory", "Job Details", "Logs")
val runningExecutors = workerState.executors
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
index 1a0598e50dcf..db696b04384b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -20,8 +20,8 @@ package org.apache.spark.deploy.worker.ui
import java.io.File
import javax.servlet.http.HttpServletRequest
-import org.apache.spark.Logging
import org.apache.spark.deploy.worker.Worker
+import org.apache.spark.internal.Logging
import org.apache.spark.ui.{SparkUI, WebUI}
import org.apache.spark.ui.JettyUtils._
import org.apache.spark.util.RpcUtils
@@ -34,7 +34,8 @@ class WorkerWebUI(
val worker: Worker,
val workDir: File,
requestedPort: Int)
- extends WebUI(worker.securityMgr, requestedPort, worker.conf, name = "WorkerUI")
+ extends WebUI(worker.securityMgr, worker.securityMgr.getSSLOptions("standalone"),
+ requestedPort, worker.conf, name = "WorkerUI")
with Logging {
private[ui] val timeout = RpcUtils.askRpcTimeout(worker.conf)
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index c2ebf3059621..b2b26ee107c0 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -19,32 +19,35 @@ package org.apache.spark.executor
import java.net.URL
import java.nio.ByteBuffer
-
-import org.apache.hadoop.conf.Configuration
+import java.util.Locale
+import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.mutable
import scala.util.{Failure, Success}
+import scala.util.control.NonFatal
-import org.apache.spark.rpc._
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.deploy.worker.WorkerWatcher
-import org.apache.spark.scheduler.TaskDescription
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc._
+import org.apache.spark.scheduler.{ExecutorLossReason, TaskDescription}
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.serializer.SerializerInstance
-import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils}
+import org.apache.spark.util.{ThreadUtils, Utils}
private[spark] class CoarseGrainedExecutorBackend(
override val rpcEnv: RpcEnv,
driverUrl: String,
executorId: String,
- hostPort: String,
+ hostname: String,
cores: Int,
userClassPath: Seq[URL],
env: SparkEnv)
extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging {
+ private[this] val stopping = new AtomicBoolean(false)
var executor: Executor = null
@volatile var driver: Option[RpcEndpointRef] = None
@@ -57,70 +60,77 @@ private[spark] class CoarseGrainedExecutorBackend(
rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref =>
// This is a very fast action so we can use "ThreadUtils.sameThread"
driver = Some(ref)
- ref.ask[RegisterExecutorResponse](
- RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls))
+ ref.ask[Boolean](RegisterExecutor(executorId, self, hostname, cores, extractLogUrls))
}(ThreadUtils.sameThread).onComplete {
// This is a very fast action so we can use "ThreadUtils.sameThread"
- case Success(msg) => Utils.tryLogNonFatalError {
- Option(self).foreach(_.send(msg)) // msg must be RegisterExecutorResponse
- }
- case Failure(e) => {
- logError(s"Cannot register with driver: $driverUrl", e)
- System.exit(1)
- }
+ case Success(msg) =>
+ // Always receive `true`. Just ignore it
+ case Failure(e) =>
+ exitExecutor(1, s"Cannot register with driver: $driverUrl", e, notifyDriver = false)
}(ThreadUtils.sameThread)
}
def extractLogUrls: Map[String, String] = {
val prefix = "SPARK_LOG_URL_"
sys.env.filterKeys(_.startsWith(prefix))
- .map(e => (e._1.substring(prefix.length).toLowerCase, e._2))
+ .map(e => (e._1.substring(prefix.length).toLowerCase(Locale.ROOT), e._2))
}
override def receive: PartialFunction[Any, Unit] = {
- case RegisteredExecutor(hostname) =>
+ case RegisteredExecutor =>
logInfo("Successfully registered with driver")
- executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false)
+ try {
+ executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false)
+ } catch {
+ case NonFatal(e) =>
+ exitExecutor(1, "Unable to create executor due to " + e.getMessage, e)
+ }
case RegisterExecutorFailed(message) =>
- logError("Slave registration failed: " + message)
- System.exit(1)
+ exitExecutor(1, "Slave registration failed: " + message)
case LaunchTask(data) =>
if (executor == null) {
- logError("Received LaunchTask command but executor was null")
- System.exit(1)
+ exitExecutor(1, "Received LaunchTask command but executor was null")
} else {
- val taskDesc = ser.deserialize[TaskDescription](data.value)
+ val taskDesc = TaskDescription.decode(data.value)
logInfo("Got assigned task " + taskDesc.taskId)
- executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber,
- taskDesc.name, taskDesc.serializedTask)
+ executor.launchTask(this, taskDesc)
}
- case KillTask(taskId, _, interruptThread) =>
+ case KillTask(taskId, _, interruptThread, reason) =>
if (executor == null) {
- logError("Received KillTask command but executor was null")
- System.exit(1)
+ exitExecutor(1, "Received KillTask command but executor was null")
} else {
- executor.killTask(taskId, interruptThread)
+ executor.killTask(taskId, interruptThread, reason)
}
case StopExecutor =>
+ stopping.set(true)
logInfo("Driver commanded a shutdown")
// Cannot shutdown here because an ack may need to be sent back to the caller. So send
// a message to self to actually do the shutdown.
self.send(Shutdown)
case Shutdown =>
- executor.stop()
- stop()
- rpcEnv.shutdown()
+ stopping.set(true)
+ new Thread("CoarseGrainedExecutorBackend-stop-executor") {
+ override def run(): Unit = {
+ // executor.stop() will call `SparkEnv.stop()` which waits until RpcEnv stops totally.
+ // However, if `executor.stop()` runs in some thread of RpcEnv, RpcEnv won't be able to
+ // stop until `executor.stop()` returns, which becomes a dead-lock (See SPARK-14180).
+ // Therefore, we put this line in a new thread.
+ executor.stop()
+ }
+ }.start()
}
override def onDisconnected(remoteAddress: RpcAddress): Unit = {
- if (driver.exists(_.address == remoteAddress)) {
- logError(s"Driver $remoteAddress disassociated! Shutting down.")
- System.exit(1)
+ if (stopping.get()) {
+ logInfo(s"Driver from $remoteAddress disconnected during shutdown")
+ } else if (driver.exists(_.address == remoteAddress)) {
+ exitExecutor(1, s"Driver $remoteAddress disassociated! Shutting down.", null,
+ notifyDriver = false)
} else {
logWarning(s"An unknown ($remoteAddress) driver disconnected.")
}
@@ -133,6 +143,33 @@ private[spark] class CoarseGrainedExecutorBackend(
case None => logWarning(s"Drop $msg because has not yet connected to driver")
}
}
+
+ /**
+ * This function can be overloaded by other child classes to handle
+ * executor exits differently. For e.g. when an executor goes down,
+ * back-end may not want to take the parent process down.
+ */
+ protected def exitExecutor(code: Int,
+ reason: String,
+ throwable: Throwable = null,
+ notifyDriver: Boolean = true) = {
+ val message = "Executor self-exiting due to : " + reason
+ if (throwable != null) {
+ logError(message, throwable)
+ } else {
+ logError(message)
+ }
+
+ if (notifyDriver && driver.nonEmpty) {
+ driver.get.ask[Boolean](
+ RemoveExecutor(executorId, new ExecutorLossReason(reason))
+ ).onFailure { case e =>
+ logWarning(s"Unable to notify the driver due to " + e.getMessage, e)
+ }(ThreadUtils.sameThread)
+ }
+
+ System.exit(code)
+ }
}
private[spark] object CoarseGrainedExecutorBackend extends Logging {
@@ -146,7 +183,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
workerUrl: Option[String],
userClassPath: Seq[URL]) {
- SignalLogger.register(log)
+ Utils.initDaemon(log)
SparkHadoopUtil.get.runAsSparkUser { () =>
// Debug code
@@ -163,8 +200,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
new SecurityManager(executorConf),
clientMode = true)
val driver = fetcher.setupEndpointRefByURI(driverUrl)
- val props = driver.askWithRetry[Seq[(String, String)]](RetrieveSparkProps) ++
- Seq[(String, String)](("spark.app.id", appId))
+ val cfg = driver.askSync[SparkAppConfig](RetrieveSparkAppConfig)
+ val props = cfg.sparkProperties ++ Seq[(String, String)](("spark.app.id", appId))
fetcher.shutdown()
// Create SparkEnv using properties we fetched from the driver.
@@ -180,25 +217,19 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
if (driverConf.contains("spark.yarn.credentials.file")) {
logInfo("Will periodically update credentials from: " +
driverConf.get("spark.yarn.credentials.file"))
- SparkHadoopUtil.get.startExecutorDelegationTokenRenewer(driverConf)
+ SparkHadoopUtil.get.startCredentialUpdater(driverConf)
}
val env = SparkEnv.createExecutorEnv(
- driverConf, executorId, hostname, port, cores, isLocal = false)
-
- // SparkEnv will set spark.executor.port if the rpc env is listening for incoming
- // connections (e.g., if it's using akka). Otherwise, the executor is running in
- // client mode only, and does not accept incoming connections.
- val sparkHostPort = env.conf.getOption("spark.executor.port").map { port =>
- hostname + ":" + port
- }.orNull
+ driverConf, executorId, hostname, port, cores, cfg.ioEncryptionKey, isLocal = false)
+
env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend(
- env.rpcEnv, driverUrl, executorId, sparkHostPort, cores, userClassPath, env))
+ env.rpcEnv, driverUrl, executorId, hostname, cores, userClassPath, env))
workerUrl.foreach { url =>
env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url))
}
env.rpcEnv.awaitTermination()
- SparkHadoopUtil.get.stopExecutorDelegationTokenRenewer()
+ SparkHadoopUtil.get.stopCredentialUpdater()
}
}
@@ -251,13 +282,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
}
run(driverUrl, executorId, hostname, cores, appId, workerUrl, userClassPath)
+ System.exit(0)
}
private def printUsageAndExit() = {
// scalastyle:off println
System.err.println(
"""
- |"Usage: CoarseGrainedExecutorBackend [options]
+ |Usage: CoarseGrainedExecutorBackend [options]
|
| Options are:
| --driver-url
diff --git a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala
index 7d84889a2def..326e04241977 100644
--- a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala
@@ -17,7 +17,7 @@
package org.apache.spark.executor
-import org.apache.spark.{TaskCommitDenied, TaskEndReason}
+import org.apache.spark.{TaskCommitDenied, TaskFailedReason}
/**
* Exception thrown when a task attempts to commit output to HDFS but is denied by the driver.
@@ -29,5 +29,5 @@ private[spark] class CommitDeniedException(
attemptNumber: Int)
extends Exception(msg) {
- def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, attemptNumber)
+ def toTaskFailedReason: TaskFailedReason = TaskCommitDenied(jobID, splitID, attemptNumber)
}
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 9e88d488c037..47c51c0474f2 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -18,28 +18,36 @@
package org.apache.spark.executor
import java.io.{File, NotSerializableException}
+import java.lang.Thread.UncaughtExceptionHandler
import java.lang.management.ManagementFactory
-import java.net.URL
+import java.net.{URI, URL}
import java.nio.ByteBuffer
-import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
+import java.util.Properties
+import java.util.concurrent._
+import javax.annotation.concurrent.GuardedBy
import scala.collection.JavaConverters._
-import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
import scala.util.control.NonFatal
+import com.google.common.util.concurrent.ThreadFactoryBuilder
+
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.internal.Logging
import org.apache.spark.memory.TaskMemoryManager
-import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task}
+import org.apache.spark.rpc.RpcTimeout
+import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task, TaskDescription}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util._
+import org.apache.spark.util.io.ChunkedByteBuffer
/**
* Spark executor, backed by a threadpool to run tasks.
*
* This can be used with Mesos, YARN, and the standalone scheduler.
- * An internal RPC interface (at the moment Akka) is used for communication with the driver,
+ * An internal RPC interface is used for communication with the driver,
* except in the case of Mesos fine-grained mode.
*/
private[spark] class Executor(
@@ -47,7 +55,8 @@ private[spark] class Executor(
executorHostname: String,
env: SparkEnv,
userClassPath: Seq[URL] = Nil,
- isLocal: Boolean = false)
+ isLocal: Boolean = false,
+ uncaughtExceptionHandler: UncaughtExceptionHandler = SparkUncaughtExceptionHandler)
extends Logging {
logInfo(s"Starting executor ID $executorId on host $executorHostname")
@@ -73,25 +82,47 @@ private[spark] class Executor(
// Setup an uncaught exception handler for non-local mode.
// Make any thread terminations due to uncaught exceptions kill the entire
// executor process to avoid surprising stalls.
- Thread.setDefaultUncaughtExceptionHandler(SparkUncaughtExceptionHandler)
+ Thread.setDefaultUncaughtExceptionHandler(uncaughtExceptionHandler)
}
// Start worker thread pool
- private val threadPool = ThreadUtils.newDaemonCachedThreadPool("Executor task launch worker")
+ private val threadPool = {
+ val threadFactory = new ThreadFactoryBuilder()
+ .setDaemon(true)
+ .setNameFormat("Executor task launch worker-%d")
+ .setThreadFactory(new ThreadFactory {
+ override def newThread(r: Runnable): Thread =
+ // Use UninterruptibleThread to run tasks so that we can allow running codes without being
+ // interrupted by `Thread.interrupt()`. Some issues, such as KAFKA-1894, HADOOP-10622,
+ // will hang forever if some methods are interrupted.
+ new UninterruptibleThread(r, "unused") // thread name will be set by ThreadFactoryBuilder
+ })
+ .build()
+ Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor]
+ }
private val executorSource = new ExecutorSource(threadPool, executorId)
+ // Pool used for threads that supervise task killing / cancellation
+ private val taskReaperPool = ThreadUtils.newDaemonCachedThreadPool("Task reaper")
+ // For tasks which are in the process of being killed, this map holds the most recently created
+ // TaskReaper. All accesses to this map should be synchronized on the map itself (this isn't
+ // a ConcurrentHashMap because we use the synchronization for purposes other than simply guarding
+ // the integrity of the map's internal state). The purpose of this map is to prevent the creation
+ // of a separate TaskReaper for every killTask() of a given task. Instead, this map allows us to
+ // track whether an existing TaskReaper fulfills the role of a TaskReaper that we would otherwise
+ // create. The map key is a task id.
+ private val taskReaperForTask: HashMap[Long, TaskReaper] = HashMap[Long, TaskReaper]()
if (!isLocal) {
env.metricsSystem.registerSource(executorSource)
env.blockManager.initialize(conf.getAppId)
}
- // Create an RpcEndpoint for receiving RPCs from the driver
- private val executorEndpoint = env.rpcEnv.setupEndpoint(
- ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME, new ExecutorEndpoint(env.rpcEnv, executorId))
-
// Whether to load classes in user jars before those in Spark jars
private val userClassPathFirst = conf.getBoolean("spark.executor.userClassPathFirst", false)
+ // Whether to monitor killed / interrupted tasks
+ private val taskReaperEnabled = conf.getBoolean("spark.task.reaper.enabled", false)
+
// Create our ClassLoader
// do this after SparkEnv creation so can access the SecurityManager
private val urlClassLoader = createClassLoader()
@@ -99,10 +130,15 @@ private[spark] class Executor(
// Set the classloader for serializer
env.serializer.setDefaultClassLoader(replClassLoader)
+ // SPARK-21928. SerializerManager's internal instance of Kryo might get used in netty threads
+ // for fetching remote cached RDD blocks, so need to make sure it uses the right classloader too.
+ env.serializerManager.setDefaultClassLoader(replClassLoader)
- // Akka's message frame size. If task result is bigger than this, we use the block manager
+ // Max size of direct result. If task result is bigger than this, we use the block manager
// to send the result back.
- private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
+ private val maxDirectResultSize = Math.min(
+ conf.getSizeAsBytes("spark.task.maxDirectResultSize", 1L << 20),
+ RpcUtils.maxMessageSizeBytes(conf))
// Limit of bytes for total size of results (default is 1GB)
private val maxResultSize = Utils.getMaxResultSize(conf)
@@ -113,30 +149,72 @@ private[spark] class Executor(
// Executor for the heartbeat task.
private val heartbeater = ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-heartbeater")
+ // must be initialized before running startDriverHeartbeat()
+ private val heartbeatReceiverRef =
+ RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv)
+
+ /**
+ * When an executor is unable to send heartbeats to the driver more than `HEARTBEAT_MAX_FAILURES`
+ * times, it should kill itself. The default value is 60. It means we will retry to send
+ * heartbeats about 10 minutes because the heartbeat interval is 10s.
+ */
+ private val HEARTBEAT_MAX_FAILURES = conf.getInt("spark.executor.heartbeat.maxFailures", 60)
+
+ /**
+ * Count the failure times of heartbeat. It should only be accessed in the heartbeat thread. Each
+ * successful heartbeat will reset it to 0.
+ */
+ private var heartbeatFailures = 0
+
startDriverHeartbeater()
- def launchTask(
- context: ExecutorBackend,
- taskId: Long,
- attemptNumber: Int,
- taskName: String,
- serializedTask: ByteBuffer): Unit = {
- val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,
- serializedTask)
- runningTasks.put(taskId, tr)
+ private[executor] def numRunningTasks: Int = runningTasks.size()
+
+ def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
+ val tr = new TaskRunner(context, taskDescription)
+ runningTasks.put(taskDescription.taskId, tr)
threadPool.execute(tr)
}
- def killTask(taskId: Long, interruptThread: Boolean): Unit = {
- val tr = runningTasks.get(taskId)
- if (tr != null) {
- tr.kill(interruptThread)
+ def killTask(taskId: Long, interruptThread: Boolean, reason: String): Unit = {
+ val taskRunner = runningTasks.get(taskId)
+ if (taskRunner != null) {
+ if (taskReaperEnabled) {
+ val maybeNewTaskReaper: Option[TaskReaper] = taskReaperForTask.synchronized {
+ val shouldCreateReaper = taskReaperForTask.get(taskId) match {
+ case None => true
+ case Some(existingReaper) => interruptThread && !existingReaper.interruptThread
+ }
+ if (shouldCreateReaper) {
+ val taskReaper = new TaskReaper(
+ taskRunner, interruptThread = interruptThread, reason = reason)
+ taskReaperForTask(taskId) = taskReaper
+ Some(taskReaper)
+ } else {
+ None
+ }
+ }
+ // Execute the TaskReaper from outside of the synchronized block.
+ maybeNewTaskReaper.foreach(taskReaperPool.execute)
+ } else {
+ taskRunner.kill(interruptThread = interruptThread, reason = reason)
+ }
}
}
+ /**
+ * Function to kill the running tasks in an executor.
+ * This can be called by executor back-ends to kill the
+ * tasks instead of taking the JVM down.
+ * @param interruptThread whether to interrupt the task thread
+ */
+ def killAllTasks(interruptThread: Boolean, reason: String) : Unit = {
+ runningTasks.keys().asScala.foreach(t =>
+ killTask(t, interruptThread = interruptThread, reason = reason))
+ }
+
def stop(): Unit = {
env.metricsSystem.report()
- env.rpcEnv.stop(executorEndpoint)
heartbeater.shutdown()
heartbeater.awaitTermination(10, TimeUnit.SECONDS)
threadPool.shutdown()
@@ -152,14 +230,25 @@ private[spark] class Executor(
class TaskRunner(
execBackend: ExecutorBackend,
- val taskId: Long,
- val attemptNumber: Int,
- taskName: String,
- serializedTask: ByteBuffer)
+ private val taskDescription: TaskDescription)
extends Runnable {
- /** Whether this task has been killed. */
- @volatile private var killed = false
+ val taskId = taskDescription.taskId
+ val threadName = s"Executor task launch worker for task $taskId"
+ private val taskName = taskDescription.name
+
+ /** If specified, this task has been killed and this option contains the reason. */
+ @volatile private var reasonIfKilled: Option[String] = None
+
+ @volatile private var threadId: Long = -1
+
+ def getThreadId: Long = threadId
+
+ /** Whether this task has been finished. */
+ @GuardedBy("TaskRunner.this")
+ private var finished = false
+
+ def isFinished: Boolean = synchronized { finished }
/** How much the JVM process has spent in GC when the task starts to run. */
@volatile var startGCTime: Long = _
@@ -170,89 +259,156 @@ private[spark] class Executor(
*/
@volatile var task: Task[Any] = _
- def kill(interruptThread: Boolean): Unit = {
- logInfo(s"Executor is trying to kill $taskName (TID $taskId)")
- killed = true
+ def kill(interruptThread: Boolean, reason: String): Unit = {
+ logInfo(s"Executor is trying to kill $taskName (TID $taskId), reason: $reason")
+ reasonIfKilled = Some(reason)
if (task != null) {
- task.kill(interruptThread)
+ synchronized {
+ if (!finished) {
+ task.kill(interruptThread, reason)
+ }
+ }
}
}
+ /**
+ * Set the finished flag to true and clear the current thread's interrupt status
+ */
+ private def setTaskFinishedAndClearInterruptStatus(): Unit = synchronized {
+ this.finished = true
+ // SPARK-14234 - Reset the interrupted status of the thread to avoid the
+ // ClosedByInterruptException during execBackend.statusUpdate which causes
+ // Executor to crash
+ Thread.interrupted()
+ // Notify any waiting TaskReapers. Generally there will only be one reaper per task but there
+ // is a rare corner-case where one task can have two reapers in case cancel(interrupt=False)
+ // is followed by cancel(interrupt=True). Thus we use notifyAll() to avoid a lost wakeup:
+ notifyAll()
+ }
+
override def run(): Unit = {
+ threadId = Thread.currentThread.getId
+ Thread.currentThread.setName(threadName)
+ val threadMXBean = ManagementFactory.getThreadMXBean
val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
val deserializeStartTime = System.currentTimeMillis()
+ val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
+ threadMXBean.getCurrentThreadCpuTime
+ } else 0L
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = env.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var taskStart: Long = 0
+ var taskStartCpu: Long = 0
startGCTime = computeTotalGcTime()
try {
- val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
- updateDependencies(taskFiles, taskJars)
- task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
+ // Must be set before updateDependencies() is called, in case fetching dependencies
+ // requires access to properties contained within (e.g. for access control).
+ Executor.taskDeserializationProps.set(taskDescription.properties)
+
+ updateDependencies(taskDescription.addedFiles, taskDescription.addedJars)
+ task = ser.deserialize[Task[Any]](
+ taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
+ task.localProperties = taskDescription.properties
task.setTaskMemoryManager(taskMemoryManager)
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
// continue executing the task.
- if (killed) {
+ val killReason = reasonIfKilled
+ if (killReason.isDefined) {
// Throw an exception rather than returning, because returning within a try{} block
// causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
// exception will be caught by the catch block, leading to an incorrect ExceptionFailure
// for the task.
- throw new TaskKilledException
+ throw new TaskKilledException(killReason.get)
}
- logDebug("Task " + taskId + "'s epoch is " + task.epoch)
- env.mapOutputTracker.updateEpoch(task.epoch)
+ // The purpose of updating the epoch here is to invalidate executor map output status cache
+ // in case FetchFailures have occurred. In local mode `env.mapOutputTracker` will be
+ // MapOutputTrackerMaster and its cache invalidation is not based on epoch numbers so
+ // we don't need to make any special calls here.
+ if (!isLocal) {
+ logDebug("Task " + taskId + "'s epoch is " + task.epoch)
+ env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch)
+ }
// Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
+ taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
+ threadMXBean.getCurrentThreadCpuTime
+ } else 0L
var threwException = true
- val (value, accumUpdates) = try {
+ val value = try {
val res = task.run(
taskAttemptId = taskId,
- attemptNumber = attemptNumber,
+ attemptNumber = taskDescription.attemptNumber,
metricsSystem = env.metricsSystem)
threwException = false
res
} finally {
+ val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
- if (freedMemory > 0) {
+
+ if (freedMemory > 0 && !threwException) {
val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
- if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false) && !threwException) {
+ if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
throw new SparkException(errMsg)
} else {
- logError(errMsg)
+ logWarning(errMsg)
}
}
+
+ if (releasedLocks.nonEmpty && !threwException) {
+ val errMsg =
+ s"${releasedLocks.size} block locks were not released by TID = $taskId:\n" +
+ releasedLocks.mkString("[", ", ", "]")
+ if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false)) {
+ throw new SparkException(errMsg)
+ } else {
+ logInfo(errMsg)
+ }
+ }
+ }
+ task.context.fetchFailed.foreach { fetchFailure =>
+ // uh-oh. it appears the user code has caught the fetch-failure without throwing any
+ // other exceptions. Its *possible* this is what the user meant to do (though highly
+ // unlikely). So we will log an error and keep going.
+ logError(s"TID ${taskId} completed successfully though internally it encountered " +
+ s"unrecoverable fetch failures! Most likely this means user code is incorrectly " +
+ s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure)
}
val taskFinish = System.currentTimeMillis()
+ val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
+ threadMXBean.getCurrentThreadCpuTime
+ } else 0L
// If the task has been killed, let's fail it.
- if (task.killed) {
- throw new TaskKilledException
- }
+ task.context.killTaskIfInterrupted()
val resultSer = env.serializer.newInstance()
val beforeSerialization = System.currentTimeMillis()
val valueBytes = resultSer.serialize(value)
val afterSerialization = System.currentTimeMillis()
- for (m <- task.metrics) {
- // Deserialization happens in two parts: first, we deserialize a Task object, which
- // includes the Partition. Second, Task.run() deserializes the RDD and function to be run.
- m.setExecutorDeserializeTime(
- (taskStart - deserializeStartTime) + task.executorDeserializeTime)
- // We need to subtract Task.run()'s deserialization time to avoid double-counting
- m.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime)
- m.setJvmGCTime(computeTotalGcTime() - startGCTime)
- m.setResultSerializationTime(afterSerialization - beforeSerialization)
- m.updateAccumulators()
- }
-
- val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)
+ // Deserialization happens in two parts: first, we deserialize a Task object, which
+ // includes the Partition. Second, Task.run() deserializes the RDD and function to be run.
+ task.metrics.setExecutorDeserializeTime(
+ (taskStart - deserializeStartTime) + task.executorDeserializeTime)
+ task.metrics.setExecutorDeserializeCpuTime(
+ (taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime)
+ // We need to subtract Task.run()'s deserialization time to avoid double-counting
+ task.metrics.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime)
+ task.metrics.setExecutorCpuTime(
+ (taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime)
+ task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
+ task.metrics.setResultSerializationTime(afterSerialization - beforeSerialization)
+
+ // Note: accumulator updates must be collected after TaskMetrics is updated
+ val accumUpdates = task.collectAccumulatorUpdates()
+ // TODO: do not serialize value twice
+ val directResult = new DirectTaskResult(valueBytes, accumUpdates)
val serializedDirectResult = ser.serialize(directResult)
val resultSize = serializedDirectResult.limit
@@ -263,10 +419,12 @@ private[spark] class Executor(
s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
s"dropping it.")
ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
- } else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
+ } else if (resultSize > maxDirectResultSize) {
val blockId = TaskResultBlockId(taskId)
env.blockManager.putBytes(
- blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
+ blockId,
+ new ChunkedByteBuffer(serializedDirectResult.duplicate()),
+ StorageLevel.MEMORY_AND_DISK_SER)
logInfo(
s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
@@ -276,19 +434,40 @@ private[spark] class Executor(
}
}
+ setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
} catch {
- case ffe: FetchFailedException =>
- val reason = ffe.toTaskEndReason
+ case t: TaskKilledException =>
+ logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}")
+ setTaskFinishedAndClearInterruptStatus()
+ execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason)))
+
+ case _: InterruptedException | NonFatal(_) if
+ task != null && task.reasonIfKilled.isDefined =>
+ val killReason = task.reasonIfKilled.getOrElse("unknown reason")
+ logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason")
+ setTaskFinishedAndClearInterruptStatus()
+ execBackend.statusUpdate(
+ taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason)))
+
+ case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
+ val reason = task.context.fetchFailed.get.toTaskFailedReason
+ if (!t.isInstanceOf[FetchFailedException]) {
+ // there was a fetch failure in the task, but some user code wrapped that exception
+ // and threw something else. Regardless, we treat it as a fetch failure.
+ val fetchFailedCls = classOf[FetchFailedException].getName
+ logWarning(s"TID ${taskId} encountered a ${fetchFailedCls} and " +
+ s"failed, but the ${fetchFailedCls} was hidden by another " +
+ s"exception. Spark is handling this like a fetch failure and ignoring the " +
+ s"other exception: $t")
+ }
+ setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
- case _: TaskKilledException | _: InterruptedException if task.killed =>
- logInfo(s"Executor killed $taskName (TID $taskId)")
- execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
-
- case cDE: CommitDeniedException =>
- val reason = cDE.toTaskEndReason
+ case CausedBy(cDE: CommitDeniedException) =>
+ val reason = cDE.toTaskFailedReason
+ setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
case t: Throwable =>
@@ -297,35 +476,165 @@ private[spark] class Executor(
// the default uncaught exception handler, which will terminate the Executor.
logError(s"Exception in $taskName (TID $taskId)", t)
- val metrics: Option[TaskMetrics] = Option(task).flatMap { task =>
- task.metrics.map { m =>
- m.setExecutorRunTime(System.currentTimeMillis() - taskStart)
- m.setJvmGCTime(computeTotalGcTime() - startGCTime)
- m.updateAccumulators()
- m
- }
- }
- val serializedTaskEndReason = {
- try {
- ser.serialize(new ExceptionFailure(t, metrics))
- } catch {
- case _: NotSerializableException =>
- // t is not serializable so just send the stacktrace
- ser.serialize(new ExceptionFailure(t, metrics, false))
+ // SPARK-20904: Do not report failure to driver if if happened during shut down. Because
+ // libraries may set up shutdown hooks that race with running tasks during shutdown,
+ // spurious failures may occur and can result in improper accounting in the driver (e.g.
+ // the task failure would not be ignored if the shutdown happened because of premption,
+ // instead of an app issue).
+ if (!ShutdownHookManager.inShutdown()) {
+ // Collect latest accumulator values to report back to the driver
+ val accums: Seq[AccumulatorV2[_, _]] =
+ if (task != null) {
+ task.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStart)
+ task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
+ task.collectAccumulatorUpdates(taskFailed = true)
+ } else {
+ Seq.empty
+ }
+
+ val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None))
+
+ val serializedTaskEndReason = {
+ try {
+ ser.serialize(new ExceptionFailure(t, accUpdates).withAccums(accums))
+ } catch {
+ case _: NotSerializableException =>
+ // t is not serializable so just send the stacktrace
+ ser.serialize(new ExceptionFailure(t, accUpdates, false).withAccums(accums))
+ }
}
+ setTaskFinishedAndClearInterruptStatus()
+ execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
+ } else {
+ logInfo("Not reporting error to driver during JVM shutdown.")
}
- execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
// Don't forcibly exit unless the exception was inherently fatal, to avoid
// stopping other tasks unnecessarily.
if (Utils.isFatalError(t)) {
- SparkUncaughtExceptionHandler.uncaughtException(t)
+ uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t)
}
} finally {
runningTasks.remove(taskId)
}
}
+
+ private def hasFetchFailure: Boolean = {
+ task != null && task.context != null && task.context.fetchFailed.isDefined
+ }
+ }
+
+ /**
+ * Supervises the killing / cancellation of a task by sending the interrupted flag, optionally
+ * sending a Thread.interrupt(), and monitoring the task until it finishes.
+ *
+ * Spark's current task cancellation / task killing mechanism is "best effort" because some tasks
+ * may not be interruptable or may not respond to their "killed" flags being set. If a significant
+ * fraction of a cluster's task slots are occupied by tasks that have been marked as killed but
+ * remain running then this can lead to a situation where new jobs and tasks are starved of
+ * resources that are being used by these zombie tasks.
+ *
+ * The TaskReaper was introduced in SPARK-18761 as a mechanism to monitor and clean up zombie
+ * tasks. For backwards-compatibility / backportability this component is disabled by default
+ * and must be explicitly enabled by setting `spark.task.reaper.enabled=true`.
+ *
+ * A TaskReaper is created for a particular task when that task is killed / cancelled. Typically
+ * a task will have only one TaskReaper, but it's possible for a task to have up to two reapers
+ * in case kill is called twice with different values for the `interrupt` parameter.
+ *
+ * Once created, a TaskReaper will run until its supervised task has finished running. If the
+ * TaskReaper has not been configured to kill the JVM after a timeout (i.e. if
+ * `spark.task.reaper.killTimeout < 0`) then this implies that the TaskReaper may run indefinitely
+ * if the supervised task never exits.
+ */
+ private class TaskReaper(
+ taskRunner: TaskRunner,
+ val interruptThread: Boolean,
+ val reason: String)
+ extends Runnable {
+
+ private[this] val taskId: Long = taskRunner.taskId
+
+ private[this] val killPollingIntervalMs: Long =
+ conf.getTimeAsMs("spark.task.reaper.pollingInterval", "10s")
+
+ private[this] val killTimeoutMs: Long = conf.getTimeAsMs("spark.task.reaper.killTimeout", "-1")
+
+ private[this] val takeThreadDump: Boolean =
+ conf.getBoolean("spark.task.reaper.threadDump", true)
+
+ override def run(): Unit = {
+ val startTimeMs = System.currentTimeMillis()
+ def elapsedTimeMs = System.currentTimeMillis() - startTimeMs
+ def timeoutExceeded(): Boolean = killTimeoutMs > 0 && elapsedTimeMs > killTimeoutMs
+ try {
+ // Only attempt to kill the task once. If interruptThread = false then a second kill
+ // attempt would be a no-op and if interruptThread = true then it may not be safe or
+ // effective to interrupt multiple times:
+ taskRunner.kill(interruptThread = interruptThread, reason = reason)
+ // Monitor the killed task until it exits. The synchronization logic here is complicated
+ // because we don't want to synchronize on the taskRunner while possibly taking a thread
+ // dump, but we also need to be careful to avoid races between checking whether the task
+ // has finished and wait()ing for it to finish.
+ var finished: Boolean = false
+ while (!finished && !timeoutExceeded()) {
+ taskRunner.synchronized {
+ // We need to synchronize on the TaskRunner while checking whether the task has
+ // finished in order to avoid a race where the task is marked as finished right after
+ // we check and before we call wait().
+ if (taskRunner.isFinished) {
+ finished = true
+ } else {
+ taskRunner.wait(killPollingIntervalMs)
+ }
+ }
+ if (taskRunner.isFinished) {
+ finished = true
+ } else {
+ logWarning(s"Killed task $taskId is still running after $elapsedTimeMs ms")
+ if (takeThreadDump) {
+ try {
+ Utils.getThreadDumpForThread(taskRunner.getThreadId).foreach { thread =>
+ if (thread.threadName == taskRunner.threadName) {
+ logWarning(s"Thread dump from task $taskId:\n${thread.stackTrace}")
+ }
+ }
+ } catch {
+ case NonFatal(e) =>
+ logWarning("Exception thrown while obtaining thread dump: ", e)
+ }
+ }
+ }
+ }
+
+ if (!taskRunner.isFinished && timeoutExceeded()) {
+ if (isLocal) {
+ logError(s"Killed task $taskId could not be stopped within $killTimeoutMs ms; " +
+ "not killing JVM because we are running in local mode.")
+ } else {
+ // In non-local-mode, the exception thrown here will bubble up to the uncaught exception
+ // handler and cause the executor JVM to exit.
+ throw new SparkException(
+ s"Killing executor JVM because killed task $taskId could not be stopped within " +
+ s"$killTimeoutMs ms.")
+ }
+ }
+ } finally {
+ // Clean up entries in the taskReaperForTask map.
+ taskReaperForTask.synchronized {
+ taskReaperForTask.get(taskId).foreach { taskReaperInMap =>
+ if (taskReaperInMap eq this) {
+ taskReaperForTask.remove(taskId)
+ } else {
+ // This must have been a TaskReaper where interruptThread == false where a subsequent
+ // killTask() call for the same task had interruptThread == true and overwrote the
+ // map entry.
+ }
+ }
+ }
+ }
+ }
}
/**
@@ -365,9 +674,9 @@ private[spark] class Executor(
val _userClassPathFirst: java.lang.Boolean = userClassPathFirst
val klass = Utils.classForName("org.apache.spark.repl.ExecutorClassLoader")
.asInstanceOf[Class[_ <: ClassLoader]]
- val constructor = klass.getConstructor(classOf[SparkConf], classOf[String],
- classOf[ClassLoader], classOf[Boolean])
- constructor.newInstance(conf, classUri, parent, _userClassPathFirst)
+ val constructor = klass.getConstructor(classOf[SparkConf], classOf[SparkEnv],
+ classOf[String], classOf[ClassLoader], classOf[Boolean])
+ constructor.newInstance(conf, env, classUri, parent, _userClassPathFirst)
} catch {
case _: ClassNotFoundException =>
logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!")
@@ -383,7 +692,7 @@ private[spark] class Executor(
* Download any missing dependencies if we receive a new set of files and JARs from the
* SparkContext. Also adds any new JARs we fetched to the class loader.
*/
- private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
+ private def updateDependencies(newFiles: Map[String, Long], newJars: Map[String, Long]) {
lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
synchronized {
// Fetch missing dependencies
@@ -395,7 +704,7 @@ private[spark] class Executor(
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars) {
- val localName = name.split("/").last
+ val localName = new URI(name).getPath.split("/").last
val currentTimeStamp = currentJars.get(name)
.orElse(currentJars.get(localName))
.getOrElse(-1L)
@@ -416,46 +725,38 @@ private[spark] class Executor(
}
}
- private val heartbeatReceiverRef =
- RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv)
-
/** Reports heartbeat and metrics for active tasks to the driver. */
private def reportHeartBeat(): Unit = {
- // list of (task id, metrics) to send back to the driver
- val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]()
+ // list of (task id, accumUpdates) to send back to the driver
+ val accumUpdates = new ArrayBuffer[(Long, Seq[AccumulatorV2[_, _]])]()
val curGCTime = computeTotalGcTime()
for (taskRunner <- runningTasks.values().asScala) {
if (taskRunner.task != null) {
- taskRunner.task.metrics.foreach { metrics =>
- metrics.updateShuffleReadMetrics()
- metrics.updateInputMetrics()
- metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
- metrics.updateAccumulators()
-
- if (isLocal) {
- // JobProgressListener will hold an reference of it during
- // onExecutorMetricsUpdate(), then JobProgressListener can not see
- // the changes of metrics any more, so make a deep copy of it
- val copiedMetrics = Utils.deserialize[TaskMetrics](Utils.serialize(metrics))
- tasksMetrics += ((taskRunner.taskId, copiedMetrics))
- } else {
- // It will be copied by serialization
- tasksMetrics += ((taskRunner.taskId, metrics))
- }
- }
+ taskRunner.task.metrics.mergeShuffleReadMetrics()
+ taskRunner.task.metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
+ accumUpdates += ((taskRunner.taskId, taskRunner.task.metrics.accumulators()))
}
}
- val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId)
+ val message = Heartbeat(executorId, accumUpdates.toArray, env.blockManager.blockManagerId)
try {
- val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](message)
+ val response = heartbeatReceiverRef.askSync[HeartbeatResponse](
+ message, RpcTimeout(conf, "spark.executor.heartbeatInterval", "10s"))
if (response.reregisterBlockManager) {
logInfo("Told to re-register on heartbeat")
env.blockManager.reregister()
}
+ heartbeatFailures = 0
} catch {
- case NonFatal(e) => logWarning("Issue communicating with driver in heartbeater", e)
+ case NonFatal(e) =>
+ logWarning("Issue communicating with driver in heartbeater", e)
+ heartbeatFailures += 1
+ if (heartbeatFailures >= HEARTBEAT_MAX_FAILURES) {
+ logError(s"Exit as unable to send heartbeats to driver " +
+ s"more than $HEARTBEAT_MAX_FAILURES times")
+ System.exit(ExecutorExitCode.HEARTBEAT_FAILURE)
+ }
}
}
@@ -474,3 +775,10 @@ private[spark] class Executor(
heartbeater.scheduleAtFixedRate(heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS)
}
}
+
+private[spark] object Executor {
+ // This is reserved for internal use by components that need to read task properties before a
+ // task is fully deserialized. When possible, the TaskContext.getLocalProperty call should be
+ // used instead.
+ val taskDeserializationProps: ThreadLocal[Properties] = new ThreadLocal[Properties]
+}
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala
index e07cb31cbe4b..7153323d01a0 100644
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala
@@ -25,6 +25,6 @@ import org.apache.spark.TaskState.TaskState
* A pluggable interface used by the Executor to send updates to the cluster scheduler.
*/
private[spark] trait ExecutorBackend {
- def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer)
+ def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala
deleted file mode 100644
index cf362f846473..000000000000
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala
+++ /dev/null
@@ -1,43 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.executor
-
-import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint}
-import org.apache.spark.util.Utils
-
-/**
- * Driver -> Executor message to trigger a thread dump.
- */
-private[spark] case object TriggerThreadDump
-
-/**
- * [[RpcEndpoint]] that runs inside of executors to enable driver -> executor RPC.
- */
-private[spark]
-class ExecutorEndpoint(override val rpcEnv: RpcEnv, executorId: String) extends RpcEndpoint {
-
- override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
- case TriggerThreadDump =>
- context.reply(Utils.getThreadDump())
- }
-
-}
-
-object ExecutorEndpoint {
- val EXECUTOR_ENDPOINT_NAME = "ExecutorEndpoint"
-}
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala
index ea36fb60bd54..99858f785600 100644
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala
@@ -39,6 +39,12 @@ object ExecutorExitCode {
/** ExternalBlockStore failed to create a local temporary directory after many attempts. */
val EXTERNAL_BLOCK_STORE_FAILED_TO_CREATE_DIR = 55
+ /**
+ * Executor is unable to send heartbeats to the driver more than
+ * "spark.executor.heartbeat.maxFailures" times.
+ */
+ val HEARTBEAT_FAILURE = 56
+
def explainExitCode(exitCode: Int): String = {
exitCode match {
case UNCAUGHT_EXCEPTION => "Uncaught exception"
@@ -51,6 +57,8 @@ object ExecutorExitCode {
// TODO: replace external block store with concrete implementation name
case EXTERNAL_BLOCK_STORE_FAILED_TO_CREATE_DIR =>
"ExternalBlockStore failed to create a local temporary directory."
+ case HEARTBEAT_FAILURE =>
+ "Unable to send heartbeats to driver."
case _ =>
"Unknown executor exit code (" + exitCode + ")" + (
if (exitCode > 128) {
diff --git a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
new file mode 100644
index 000000000000..3d15f3a0396e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.util.LongAccumulator
+
+
+/**
+ * :: DeveloperApi ::
+ * Method by which input data was read. Network means that the data was read over the network
+ * from a remote block manager (which may have stored the data on-disk or in-memory).
+ * Operations are not thread-safe.
+ */
+@DeveloperApi
+object DataReadMethod extends Enumeration with Serializable {
+ type DataReadMethod = Value
+ val Memory, Disk, Hadoop, Network = Value
+}
+
+
+/**
+ * :: DeveloperApi ::
+ * A collection of accumulators that represents metrics about reading data from external systems.
+ */
+@DeveloperApi
+class InputMetrics private[spark] () extends Serializable {
+ private[executor] val _bytesRead = new LongAccumulator
+ private[executor] val _recordsRead = new LongAccumulator
+
+ /**
+ * Total number of bytes read.
+ */
+ def bytesRead: Long = _bytesRead.sum
+
+ /**
+ * Total number of records read.
+ */
+ def recordsRead: Long = _recordsRead.sum
+
+ private[spark] def incBytesRead(v: Long): Unit = _bytesRead.add(v)
+ private[spark] def incRecordsRead(v: Long): Unit = _recordsRead.add(v)
+ private[spark] def setBytesRead(v: Long): Unit = _bytesRead.setValue(v)
+}
diff --git a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala
new file mode 100644
index 000000000000..dada9697c1cf
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.util.LongAccumulator
+
+
+/**
+ * :: DeveloperApi ::
+ * Method by which output data was written.
+ * Operations are not thread-safe.
+ */
+@DeveloperApi
+object DataWriteMethod extends Enumeration with Serializable {
+ type DataWriteMethod = Value
+ val Hadoop = Value
+}
+
+
+/**
+ * :: DeveloperApi ::
+ * A collection of accumulators that represents metrics about writing data to external systems.
+ */
+@DeveloperApi
+class OutputMetrics private[spark] () extends Serializable {
+ private[executor] val _bytesWritten = new LongAccumulator
+ private[executor] val _recordsWritten = new LongAccumulator
+
+ /**
+ * Total number of bytes written.
+ */
+ def bytesWritten: Long = _bytesWritten.sum
+
+ /**
+ * Total number of records written.
+ */
+ def recordsWritten: Long = _recordsWritten.sum
+
+ private[spark] def setBytesWritten(v: Long): Unit = _bytesWritten.setValue(v)
+ private[spark] def setRecordsWritten(v: Long): Unit = _recordsWritten.setValue(v)
+}
diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala
new file mode 100644
index 000000000000..8dd1a1ea059b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.util.LongAccumulator
+
+
+/**
+ * :: DeveloperApi ::
+ * A collection of accumulators that represent metrics about reading shuffle data.
+ * Operations are not thread-safe.
+ */
+@DeveloperApi
+class ShuffleReadMetrics private[spark] () extends Serializable {
+ private[executor] val _remoteBlocksFetched = new LongAccumulator
+ private[executor] val _localBlocksFetched = new LongAccumulator
+ private[executor] val _remoteBytesRead = new LongAccumulator
+ private[executor] val _localBytesRead = new LongAccumulator
+ private[executor] val _fetchWaitTime = new LongAccumulator
+ private[executor] val _recordsRead = new LongAccumulator
+
+ /**
+ * Number of remote blocks fetched in this shuffle by this task.
+ */
+ def remoteBlocksFetched: Long = _remoteBlocksFetched.sum
+
+ /**
+ * Number of local blocks fetched in this shuffle by this task.
+ */
+ def localBlocksFetched: Long = _localBlocksFetched.sum
+
+ /**
+ * Total number of remote bytes read from the shuffle by this task.
+ */
+ def remoteBytesRead: Long = _remoteBytesRead.sum
+
+ /**
+ * Shuffle data that was read from the local disk (as opposed to from a remote executor).
+ */
+ def localBytesRead: Long = _localBytesRead.sum
+
+ /**
+ * Time the task spent waiting for remote shuffle blocks. This only includes the time
+ * blocking on shuffle input data. For instance if block B is being fetched while the task is
+ * still not finished processing block A, it is not considered to be blocking on block B.
+ */
+ def fetchWaitTime: Long = _fetchWaitTime.sum
+
+ /**
+ * Total number of records read from the shuffle by this task.
+ */
+ def recordsRead: Long = _recordsRead.sum
+
+ /**
+ * Total bytes fetched in the shuffle by this task (both remote and local).
+ */
+ def totalBytesRead: Long = remoteBytesRead + localBytesRead
+
+ /**
+ * Number of blocks fetched in this shuffle by this task (remote or local).
+ */
+ def totalBlocksFetched: Long = remoteBlocksFetched + localBlocksFetched
+
+ private[spark] def incRemoteBlocksFetched(v: Long): Unit = _remoteBlocksFetched.add(v)
+ private[spark] def incLocalBlocksFetched(v: Long): Unit = _localBlocksFetched.add(v)
+ private[spark] def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead.add(v)
+ private[spark] def incLocalBytesRead(v: Long): Unit = _localBytesRead.add(v)
+ private[spark] def incFetchWaitTime(v: Long): Unit = _fetchWaitTime.add(v)
+ private[spark] def incRecordsRead(v: Long): Unit = _recordsRead.add(v)
+
+ private[spark] def setRemoteBlocksFetched(v: Int): Unit = _remoteBlocksFetched.setValue(v)
+ private[spark] def setLocalBlocksFetched(v: Int): Unit = _localBlocksFetched.setValue(v)
+ private[spark] def setRemoteBytesRead(v: Long): Unit = _remoteBytesRead.setValue(v)
+ private[spark] def setLocalBytesRead(v: Long): Unit = _localBytesRead.setValue(v)
+ private[spark] def setFetchWaitTime(v: Long): Unit = _fetchWaitTime.setValue(v)
+ private[spark] def setRecordsRead(v: Long): Unit = _recordsRead.setValue(v)
+
+ /**
+ * Resets the value of the current metrics (`this`) and merges all the independent
+ * [[TempShuffleReadMetrics]] into `this`.
+ */
+ private[spark] def setMergeValues(metrics: Seq[TempShuffleReadMetrics]): Unit = {
+ _remoteBlocksFetched.setValue(0)
+ _localBlocksFetched.setValue(0)
+ _remoteBytesRead.setValue(0)
+ _localBytesRead.setValue(0)
+ _fetchWaitTime.setValue(0)
+ _recordsRead.setValue(0)
+ metrics.foreach { metric =>
+ _remoteBlocksFetched.add(metric.remoteBlocksFetched)
+ _localBlocksFetched.add(metric.localBlocksFetched)
+ _remoteBytesRead.add(metric.remoteBytesRead)
+ _localBytesRead.add(metric.localBytesRead)
+ _fetchWaitTime.add(metric.fetchWaitTime)
+ _recordsRead.add(metric.recordsRead)
+ }
+ }
+}
+
+/**
+ * A temporary shuffle read metrics holder that is used to collect shuffle read metrics for each
+ * shuffle dependency, and all temporary metrics will be merged into the [[ShuffleReadMetrics]] at
+ * last.
+ */
+private[spark] class TempShuffleReadMetrics {
+ private[this] var _remoteBlocksFetched = 0L
+ private[this] var _localBlocksFetched = 0L
+ private[this] var _remoteBytesRead = 0L
+ private[this] var _localBytesRead = 0L
+ private[this] var _fetchWaitTime = 0L
+ private[this] var _recordsRead = 0L
+
+ def incRemoteBlocksFetched(v: Long): Unit = _remoteBlocksFetched += v
+ def incLocalBlocksFetched(v: Long): Unit = _localBlocksFetched += v
+ def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead += v
+ def incLocalBytesRead(v: Long): Unit = _localBytesRead += v
+ def incFetchWaitTime(v: Long): Unit = _fetchWaitTime += v
+ def incRecordsRead(v: Long): Unit = _recordsRead += v
+
+ def remoteBlocksFetched: Long = _remoteBlocksFetched
+ def localBlocksFetched: Long = _localBlocksFetched
+ def remoteBytesRead: Long = _remoteBytesRead
+ def localBytesRead: Long = _localBytesRead
+ def fetchWaitTime: Long = _fetchWaitTime
+ def recordsRead: Long = _recordsRead
+}
diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala
new file mode 100644
index 000000000000..ada2e1bc0859
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.util.LongAccumulator
+
+
+/**
+ * :: DeveloperApi ::
+ * A collection of accumulators that represent metrics about writing shuffle data.
+ * Operations are not thread-safe.
+ */
+@DeveloperApi
+class ShuffleWriteMetrics private[spark] () extends Serializable {
+ private[executor] val _bytesWritten = new LongAccumulator
+ private[executor] val _recordsWritten = new LongAccumulator
+ private[executor] val _writeTime = new LongAccumulator
+
+ /**
+ * Number of bytes written for the shuffle by this task.
+ */
+ def bytesWritten: Long = _bytesWritten.sum
+
+ /**
+ * Total number of records written to the shuffle by this task.
+ */
+ def recordsWritten: Long = _recordsWritten.sum
+
+ /**
+ * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds.
+ */
+ def writeTime: Long = _writeTime.sum
+
+ private[spark] def incBytesWritten(v: Long): Unit = _bytesWritten.add(v)
+ private[spark] def incRecordsWritten(v: Long): Unit = _recordsWritten.add(v)
+ private[spark] def incWriteTime(v: Long): Unit = _writeTime.add(v)
+ private[spark] def decBytesWritten(v: Long): Unit = {
+ _bytesWritten.setValue(bytesWritten - v)
+ }
+ private[spark] def decRecordsWritten(v: Long): Unit = {
+ _recordsWritten.setValue(recordsWritten - v)
+ }
+
+ // Legacy methods for backward compatibility.
+ // TODO: remove these once we make this class private.
+ @deprecated("use bytesWritten instead", "2.0.0")
+ def shuffleBytesWritten: Long = bytesWritten
+ @deprecated("use writeTime instead", "2.0.0")
+ def shuffleWriteTime: Long = writeTime
+ @deprecated("use recordsWritten instead", "2.0.0")
+ def shuffleRecordsWritten: Long = recordsWritten
+
+}
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 42207a955359..a3ce3d1ccc5e 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -17,421 +17,304 @@
package org.apache.spark.executor
-import java.io.{IOException, ObjectInputStream}
-import java.util.concurrent.ConcurrentHashMap
-
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.JavaConverters._
+import scala.collection.mutable.{ArrayBuffer, LinkedHashMap}
+import org.apache.spark._
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.executor.DataReadMethod.DataReadMethod
+import org.apache.spark.internal.Logging
+import org.apache.spark.scheduler.AccumulableInfo
import org.apache.spark.storage.{BlockId, BlockStatus}
-import org.apache.spark.util.Utils
+import org.apache.spark.util._
+
/**
* :: DeveloperApi ::
* Metrics tracked during the execution of a task.
*
- * This class is used to house metrics both for in-progress and completed tasks. In executors,
- * both the task thread and the heartbeat thread write to the TaskMetrics. The heartbeat thread
- * reads it to send in-progress metrics, and the task thread reads it to send metrics along with
- * the completed task.
+ * This class is wrapper around a collection of internal accumulators that represent metrics
+ * associated with a task. The local values of these accumulators are sent from the executor
+ * to the driver when the task completes. These values are then merged into the corresponding
+ * accumulator previously registered on the driver.
*
- * So, when adding new fields, take into consideration that the whole object can be serialized for
- * shipping off at any time to consumers of the SparkListener interface.
+ * The accumulator updates are also sent to the driver periodically (on executor heartbeat)
+ * and when the task failed with an exception. The [[TaskMetrics]] object itself should never
+ * be sent to the driver.
*/
@DeveloperApi
-class TaskMetrics extends Serializable {
- /**
- * Host's name the task runs on
- */
- private var _hostname: String = _
- def hostname: String = _hostname
- private[spark] def setHostname(value: String) = _hostname = value
+class TaskMetrics private[spark] () extends Serializable {
+ // Each metric is internally represented as an accumulator
+ private val _executorDeserializeTime = new LongAccumulator
+ private val _executorDeserializeCpuTime = new LongAccumulator
+ private val _executorRunTime = new LongAccumulator
+ private val _executorCpuTime = new LongAccumulator
+ private val _resultSize = new LongAccumulator
+ private val _jvmGCTime = new LongAccumulator
+ private val _resultSerializationTime = new LongAccumulator
+ private val _memoryBytesSpilled = new LongAccumulator
+ private val _diskBytesSpilled = new LongAccumulator
+ private val _peakExecutionMemory = new LongAccumulator
+ private val _updatedBlockStatuses = new CollectionAccumulator[(BlockId, BlockStatus)]
/**
- * Time taken on the executor to deserialize this task
+ * Time taken on the executor to deserialize this task.
*/
- private var _executorDeserializeTime: Long = _
- def executorDeserializeTime: Long = _executorDeserializeTime
- private[spark] def setExecutorDeserializeTime(value: Long) = _executorDeserializeTime = value
-
+ def executorDeserializeTime: Long = _executorDeserializeTime.sum
/**
- * Time the executor spends actually running the task (including fetching shuffle data)
+ * CPU Time taken on the executor to deserialize this task in nanoseconds.
*/
- private var _executorRunTime: Long = _
- def executorRunTime: Long = _executorRunTime
- private[spark] def setExecutorRunTime(value: Long) = _executorRunTime = value
+ def executorDeserializeCpuTime: Long = _executorDeserializeCpuTime.sum
/**
- * The number of bytes this task transmitted back to the driver as the TaskResult
+ * Time the executor spends actually running the task (including fetching shuffle data).
*/
- private var _resultSize: Long = _
- def resultSize: Long = _resultSize
- private[spark] def setResultSize(value: Long) = _resultSize = value
-
+ def executorRunTime: Long = _executorRunTime.sum
/**
- * Amount of time the JVM spent in garbage collection while executing this task
+ * CPU Time the executor spends actually running the task
+ * (including fetching shuffle data) in nanoseconds.
*/
- private var _jvmGCTime: Long = _
- def jvmGCTime: Long = _jvmGCTime
- private[spark] def setJvmGCTime(value: Long) = _jvmGCTime = value
+ def executorCpuTime: Long = _executorCpuTime.sum
/**
- * Amount of time spent serializing the task result
+ * The number of bytes this task transmitted back to the driver as the TaskResult.
*/
- private var _resultSerializationTime: Long = _
- def resultSerializationTime: Long = _resultSerializationTime
- private[spark] def setResultSerializationTime(value: Long) = _resultSerializationTime = value
+ def resultSize: Long = _resultSize.sum
/**
- * The number of in-memory bytes spilled by this task
+ * Amount of time the JVM spent in garbage collection while executing this task.
*/
- private var _memoryBytesSpilled: Long = _
- def memoryBytesSpilled: Long = _memoryBytesSpilled
- private[spark] def incMemoryBytesSpilled(value: Long): Unit = _memoryBytesSpilled += value
- private[spark] def decMemoryBytesSpilled(value: Long): Unit = _memoryBytesSpilled -= value
+ def jvmGCTime: Long = _jvmGCTime.sum
/**
- * The number of on-disk bytes spilled by this task
+ * Amount of time spent serializing the task result.
*/
- private var _diskBytesSpilled: Long = _
- def diskBytesSpilled: Long = _diskBytesSpilled
- private[spark] def incDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled += value
- private[spark] def decDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled -= value
+ def resultSerializationTime: Long = _resultSerializationTime.sum
/**
- * If this task reads from a HadoopRDD or from persisted data, metrics on how much data was read
- * are stored here.
+ * The number of in-memory bytes spilled by this task.
*/
- private var _inputMetrics: Option[InputMetrics] = None
-
- def inputMetrics: Option[InputMetrics] = _inputMetrics
+ def memoryBytesSpilled: Long = _memoryBytesSpilled.sum
/**
- * This should only be used when recreating TaskMetrics, not when updating input metrics in
- * executors
+ * The number of on-disk bytes spilled by this task.
*/
- private[spark] def setInputMetrics(inputMetrics: Option[InputMetrics]) {
- _inputMetrics = inputMetrics
- }
+ def diskBytesSpilled: Long = _diskBytesSpilled.sum
/**
- * If this task writes data externally (e.g. to a distributed filesystem), metrics on how much
- * data was written are stored here.
+ * Peak memory used by internal data structures created during shuffles, aggregations and
+ * joins. The value of this accumulator should be approximately the sum of the peak sizes
+ * across all such data structures created in this task. For SQL jobs, this only tracks all
+ * unsafe operators and ExternalSort.
*/
- var outputMetrics: Option[OutputMetrics] = None
+ def peakExecutionMemory: Long = _peakExecutionMemory.sum
/**
- * If this task reads from shuffle output, metrics on getting shuffle data will be collected here.
- * This includes read metrics aggregated over all the task's shuffle dependencies.
+ * Storage statuses of any blocks that have been updated as a result of this task.
*/
- private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None
+ def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = {
+ // This is called on driver. All accumulator updates have a fixed value. So it's safe to use
+ // `asScala` which accesses the internal values using `java.util.Iterator`.
+ _updatedBlockStatuses.value.asScala
+ }
- def shuffleReadMetrics: Option[ShuffleReadMetrics] = _shuffleReadMetrics
+ // Setters and increment-ers
+ private[spark] def setExecutorDeserializeTime(v: Long): Unit =
+ _executorDeserializeTime.setValue(v)
+ private[spark] def setExecutorDeserializeCpuTime(v: Long): Unit =
+ _executorDeserializeCpuTime.setValue(v)
+ private[spark] def setExecutorRunTime(v: Long): Unit = _executorRunTime.setValue(v)
+ private[spark] def setExecutorCpuTime(v: Long): Unit = _executorCpuTime.setValue(v)
+ private[spark] def setResultSize(v: Long): Unit = _resultSize.setValue(v)
+ private[spark] def setJvmGCTime(v: Long): Unit = _jvmGCTime.setValue(v)
+ private[spark] def setResultSerializationTime(v: Long): Unit =
+ _resultSerializationTime.setValue(v)
+ private[spark] def incMemoryBytesSpilled(v: Long): Unit = _memoryBytesSpilled.add(v)
+ private[spark] def incDiskBytesSpilled(v: Long): Unit = _diskBytesSpilled.add(v)
+ private[spark] def incPeakExecutionMemory(v: Long): Unit = _peakExecutionMemory.add(v)
+ private[spark] def incUpdatedBlockStatuses(v: (BlockId, BlockStatus)): Unit =
+ _updatedBlockStatuses.add(v)
+ private[spark] def setUpdatedBlockStatuses(v: java.util.List[(BlockId, BlockStatus)]): Unit =
+ _updatedBlockStatuses.setValue(v)
+ private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit =
+ _updatedBlockStatuses.setValue(v.asJava)
/**
- * This should only be used when recreating TaskMetrics, not when updating read metrics in
- * executors.
+ * Metrics related to reading data from a [[org.apache.spark.rdd.HadoopRDD]] or from persisted
+ * data, defined only in tasks with input.
*/
- private[spark] def setShuffleReadMetrics(shuffleReadMetrics: Option[ShuffleReadMetrics]) {
- _shuffleReadMetrics = shuffleReadMetrics
- }
+ val inputMetrics: InputMetrics = new InputMetrics()
/**
- * ShuffleReadMetrics per dependency for collecting independently while task is in progress.
+ * Metrics related to writing data externally (e.g. to a distributed filesystem),
+ * defined only in tasks with output.
*/
- @transient private lazy val depsShuffleReadMetrics: ArrayBuffer[ShuffleReadMetrics] =
- new ArrayBuffer[ShuffleReadMetrics]()
+ val outputMetrics: OutputMetrics = new OutputMetrics()
/**
- * If this task writes to shuffle output, metrics on the written shuffle data will be collected
- * here
+ * Metrics related to shuffle read aggregated across all shuffle dependencies.
+ * This is defined only if there are shuffle dependencies in this task.
*/
- var shuffleWriteMetrics: Option[ShuffleWriteMetrics] = None
+ val shuffleReadMetrics: ShuffleReadMetrics = new ShuffleReadMetrics()
/**
- * Storage statuses of any blocks that have been updated as a result of this task.
+ * Metrics related to shuffle write, defined only in shuffle map stages.
*/
- var updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = None
+ val shuffleWriteMetrics: ShuffleWriteMetrics = new ShuffleWriteMetrics()
/**
+ * A list of [[TempShuffleReadMetrics]], one per shuffle dependency.
+ *
* A task may have multiple shuffle readers for multiple dependencies. To avoid synchronization
- * issues from readers in different threads, in-progress tasks use a ShuffleReadMetrics for each
- * dependency, and merge these metrics before reporting them to the driver. This method returns
- * a ShuffleReadMetrics for a dependency and registers it for merging later.
+ * issues from readers in different threads, in-progress tasks use a [[TempShuffleReadMetrics]]
+ * for each dependency and merge these metrics before reporting them to the driver.
*/
- private [spark] def createShuffleReadMetricsForDependency(): ShuffleReadMetrics = synchronized {
- val readMetrics = new ShuffleReadMetrics()
- depsShuffleReadMetrics += readMetrics
- readMetrics
- }
+ @transient private lazy val tempShuffleReadMetrics = new ArrayBuffer[TempShuffleReadMetrics]
/**
- * Returns the input metrics object that the task should use. Currently, if
- * there exists an input metric with the same readMethod, we return that one
- * so the caller can accumulate bytes read. If the readMethod is different
- * than previously seen by this task, we return a new InputMetric but don't
- * record it.
+ * Create a [[TempShuffleReadMetrics]] for a particular shuffle dependency.
*
- * Once https://issues.apache.org/jira/browse/SPARK-5225 is addressed,
- * we can store all the different inputMetrics (one per readMethod).
+ * All usages are expected to be followed by a call to [[mergeShuffleReadMetrics]], which
+ * merges the temporary values synchronously. Otherwise, all temporary data collected will
+ * be lost.
*/
- private[spark] def getInputMetricsForReadMethod(readMethod: DataReadMethod): InputMetrics = {
- synchronized {
- _inputMetrics match {
- case None =>
- val metrics = new InputMetrics(readMethod)
- _inputMetrics = Some(metrics)
- metrics
- case Some(metrics @ InputMetrics(method)) if method == readMethod =>
- metrics
- case Some(InputMetrics(method)) =>
- new InputMetrics(readMethod)
- }
- }
+ private[spark] def createTempShuffleReadMetrics(): TempShuffleReadMetrics = synchronized {
+ val readMetrics = new TempShuffleReadMetrics
+ tempShuffleReadMetrics += readMetrics
+ readMetrics
}
/**
- * Aggregates shuffle read metrics for all registered dependencies into shuffleReadMetrics.
+ * Merge values across all temporary [[ShuffleReadMetrics]] into `_shuffleReadMetrics`.
+ * This is expected to be called on executor heartbeat and at the end of a task.
*/
- private[spark] def updateShuffleReadMetrics(): Unit = synchronized {
- if (!depsShuffleReadMetrics.isEmpty) {
- val merged = new ShuffleReadMetrics()
- for (depMetrics <- depsShuffleReadMetrics) {
- merged.incFetchWaitTime(depMetrics.fetchWaitTime)
- merged.incLocalBlocksFetched(depMetrics.localBlocksFetched)
- merged.incRemoteBlocksFetched(depMetrics.remoteBlocksFetched)
- merged.incRemoteBytesRead(depMetrics.remoteBytesRead)
- merged.incLocalBytesRead(depMetrics.localBytesRead)
- merged.incRecordsRead(depMetrics.recordsRead)
- }
- _shuffleReadMetrics = Some(merged)
+ private[spark] def mergeShuffleReadMetrics(): Unit = synchronized {
+ if (tempShuffleReadMetrics.nonEmpty) {
+ shuffleReadMetrics.setMergeValues(tempShuffleReadMetrics)
}
}
- private[spark] def updateInputMetrics(): Unit = synchronized {
- inputMetrics.foreach(_.updateBytesRead())
- }
-
- @throws(classOf[IOException])
- private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
- in.defaultReadObject()
- // Get the hostname from cached data, since hostname is the order of number of nodes in
- // cluster, so using cached hostname will decrease the object number and alleviate the GC
- // overhead.
- _hostname = TaskMetrics.getCachedHostName(_hostname)
- }
-
- private var _accumulatorUpdates: Map[Long, Any] = Map.empty
- @transient private var _accumulatorsUpdater: () => Map[Long, Any] = null
-
- private[spark] def updateAccumulators(): Unit = synchronized {
- _accumulatorUpdates = _accumulatorsUpdater()
+ // Only used for test
+ private[spark] val testAccum = sys.props.get("spark.testing").map(_ => new LongAccumulator)
+
+
+ import InternalAccumulator._
+ @transient private[spark] lazy val nameToAccums = LinkedHashMap(
+ EXECUTOR_DESERIALIZE_TIME -> _executorDeserializeTime,
+ EXECUTOR_DESERIALIZE_CPU_TIME -> _executorDeserializeCpuTime,
+ EXECUTOR_RUN_TIME -> _executorRunTime,
+ EXECUTOR_CPU_TIME -> _executorCpuTime,
+ RESULT_SIZE -> _resultSize,
+ JVM_GC_TIME -> _jvmGCTime,
+ RESULT_SERIALIZATION_TIME -> _resultSerializationTime,
+ MEMORY_BYTES_SPILLED -> _memoryBytesSpilled,
+ DISK_BYTES_SPILLED -> _diskBytesSpilled,
+ PEAK_EXECUTION_MEMORY -> _peakExecutionMemory,
+ UPDATED_BLOCK_STATUSES -> _updatedBlockStatuses,
+ shuffleRead.REMOTE_BLOCKS_FETCHED -> shuffleReadMetrics._remoteBlocksFetched,
+ shuffleRead.LOCAL_BLOCKS_FETCHED -> shuffleReadMetrics._localBlocksFetched,
+ shuffleRead.REMOTE_BYTES_READ -> shuffleReadMetrics._remoteBytesRead,
+ shuffleRead.LOCAL_BYTES_READ -> shuffleReadMetrics._localBytesRead,
+ shuffleRead.FETCH_WAIT_TIME -> shuffleReadMetrics._fetchWaitTime,
+ shuffleRead.RECORDS_READ -> shuffleReadMetrics._recordsRead,
+ shuffleWrite.BYTES_WRITTEN -> shuffleWriteMetrics._bytesWritten,
+ shuffleWrite.RECORDS_WRITTEN -> shuffleWriteMetrics._recordsWritten,
+ shuffleWrite.WRITE_TIME -> shuffleWriteMetrics._writeTime,
+ input.BYTES_READ -> inputMetrics._bytesRead,
+ input.RECORDS_READ -> inputMetrics._recordsRead,
+ output.BYTES_WRITTEN -> outputMetrics._bytesWritten,
+ output.RECORDS_WRITTEN -> outputMetrics._recordsWritten
+ ) ++ testAccum.map(TEST_ACCUM -> _)
+
+ @transient private[spark] lazy val internalAccums: Seq[AccumulatorV2[_, _]] =
+ nameToAccums.values.toIndexedSeq
+
+ /* ========================== *
+ | OTHER THINGS |
+ * ========================== */
+
+ private[spark] def register(sc: SparkContext): Unit = {
+ nameToAccums.foreach {
+ case (name, acc) => acc.register(sc, name = Some(name), countFailedValues = true)
+ }
}
/**
- * Return the latest updates of accumulators in this task.
+ * External accumulators registered with this task.
*/
- def accumulatorUpdates(): Map[Long, Any] = _accumulatorUpdates
+ @transient private[spark] lazy val externalAccums = new ArrayBuffer[AccumulatorV2[_, _]]
- private[spark] def setAccumulatorsUpdater(accumulatorsUpdater: () => Map[Long, Any]): Unit = {
- _accumulatorsUpdater = accumulatorsUpdater
+ private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit = {
+ externalAccums += a
}
-}
-
-private[spark] object TaskMetrics {
- private val hostNameCache = new ConcurrentHashMap[String, String]()
- def empty: TaskMetrics = new TaskMetrics
+ private[spark] def accumulators(): Seq[AccumulatorV2[_, _]] = internalAccums ++ externalAccums
- def getCachedHostName(host: String): String = {
- val canonicalHost = hostNameCache.putIfAbsent(host, host)
- if (canonicalHost != null) canonicalHost else host
+ private[spark] def nonZeroInternalAccums(): Seq[AccumulatorV2[_, _]] = {
+ // RESULT_SIZE accumulator is always zero at executor, we need to send it back as its
+ // value will be updated at driver side.
+ internalAccums.filter(a => !a.isZero || a == _resultSize)
}
}
-/**
- * :: DeveloperApi ::
- * Method by which input data was read. Network means that the data was read over the network
- * from a remote block manager (which may have stored the data on-disk or in-memory).
- */
-@DeveloperApi
-object DataReadMethod extends Enumeration with Serializable {
- type DataReadMethod = Value
- val Memory, Disk, Hadoop, Network = Value
-}
-
-/**
- * :: DeveloperApi ::
- * Method by which output data was written.
- */
-@DeveloperApi
-object DataWriteMethod extends Enumeration with Serializable {
- type DataWriteMethod = Value
- val Hadoop = Value
-}
-
-/**
- * :: DeveloperApi ::
- * Metrics about reading input data.
- */
-@DeveloperApi
-case class InputMetrics(readMethod: DataReadMethod.Value) {
- /**
- * This is volatile so that it is visible to the updater thread.
- */
- @volatile @transient var bytesReadCallback: Option[() => Long] = None
+private[spark] object TaskMetrics extends Logging {
+ import InternalAccumulator._
/**
- * Total bytes read.
+ * Create an empty task metrics that doesn't register its accumulators.
*/
- private var _bytesRead: Long = _
- def bytesRead: Long = _bytesRead
- def incBytesRead(bytes: Long): Unit = _bytesRead += bytes
-
- /**
- * Total records read.
- */
- private var _recordsRead: Long = _
- def recordsRead: Long = _recordsRead
- def incRecordsRead(records: Long): Unit = _recordsRead += records
-
- /**
- * Invoke the bytesReadCallback and mutate bytesRead.
- */
- def updateBytesRead() {
- bytesReadCallback.foreach { c =>
- _bytesRead = c()
+ def empty: TaskMetrics = {
+ val tm = new TaskMetrics
+ tm.nameToAccums.foreach { case (name, acc) =>
+ acc.metadata = AccumulatorMetadata(AccumulatorContext.newId(), Some(name), true)
}
+ tm
}
- /**
- * Register a function that can be called to get up-to-date information on how many bytes the task
- * has read from an input source.
- */
- def setBytesReadCallback(f: Option[() => Long]) {
- bytesReadCallback = f
+ def registered: TaskMetrics = {
+ val tm = empty
+ tm.internalAccums.foreach(AccumulatorContext.register)
+ tm
}
-}
-
-/**
- * :: DeveloperApi ::
- * Metrics about writing output data.
- */
-@DeveloperApi
-case class OutputMetrics(writeMethod: DataWriteMethod.Value) {
- /**
- * Total bytes written
- */
- private var _bytesWritten: Long = _
- def bytesWritten: Long = _bytesWritten
- private[spark] def setBytesWritten(value : Long): Unit = _bytesWritten = value
-
- /**
- * Total records written
- */
- private var _recordsWritten: Long = 0L
- def recordsWritten: Long = _recordsWritten
- private[spark] def setRecordsWritten(value: Long): Unit = _recordsWritten = value
-}
-/**
- * :: DeveloperApi ::
- * Metrics pertaining to shuffle data read in a given task.
- */
-@DeveloperApi
-class ShuffleReadMetrics extends Serializable {
/**
- * Number of remote blocks fetched in this shuffle by this task
- */
- private var _remoteBlocksFetched: Int = _
- def remoteBlocksFetched: Int = _remoteBlocksFetched
- private[spark] def incRemoteBlocksFetched(value: Int) = _remoteBlocksFetched += value
- private[spark] def decRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value
-
- /**
- * Number of local blocks fetched in this shuffle by this task
- */
- private var _localBlocksFetched: Int = _
- def localBlocksFetched: Int = _localBlocksFetched
- private[spark] def incLocalBlocksFetched(value: Int) = _localBlocksFetched += value
- private[spark] def decLocalBlocksFetched(value: Int) = _localBlocksFetched -= value
-
- /**
- * Time the task spent waiting for remote shuffle blocks. This only includes the time
- * blocking on shuffle input data. For instance if block B is being fetched while the task is
- * still not finished processing block A, it is not considered to be blocking on block B.
- */
- private var _fetchWaitTime: Long = _
- def fetchWaitTime: Long = _fetchWaitTime
- private[spark] def incFetchWaitTime(value: Long) = _fetchWaitTime += value
- private[spark] def decFetchWaitTime(value: Long) = _fetchWaitTime -= value
-
- /**
- * Total number of remote bytes read from the shuffle by this task
- */
- private var _remoteBytesRead: Long = _
- def remoteBytesRead: Long = _remoteBytesRead
- private[spark] def incRemoteBytesRead(value: Long) = _remoteBytesRead += value
- private[spark] def decRemoteBytesRead(value: Long) = _remoteBytesRead -= value
-
- /**
- * Shuffle data that was read from the local disk (as opposed to from a remote executor).
- */
- private var _localBytesRead: Long = _
- def localBytesRead: Long = _localBytesRead
- private[spark] def incLocalBytesRead(value: Long) = _localBytesRead += value
-
- /**
- * Total bytes fetched in the shuffle by this task (both remote and local).
- */
- def totalBytesRead: Long = _remoteBytesRead + _localBytesRead
-
- /**
- * Number of blocks fetched in this shuffle by this task (remote or local)
- */
- def totalBlocksFetched: Int = _remoteBlocksFetched + _localBlocksFetched
-
- /**
- * Total number of records read from the shuffle by this task
- */
- private var _recordsRead: Long = _
- def recordsRead: Long = _recordsRead
- private[spark] def incRecordsRead(value: Long) = _recordsRead += value
- private[spark] def decRecordsRead(value: Long) = _recordsRead -= value
-}
-
-/**
- * :: DeveloperApi ::
- * Metrics pertaining to shuffle data written in a given task.
- */
-@DeveloperApi
-class ShuffleWriteMetrics extends Serializable {
- /**
- * Number of bytes written for the shuffle by this task
- */
- @volatile private var _shuffleBytesWritten: Long = _
- def shuffleBytesWritten: Long = _shuffleBytesWritten
- private[spark] def incShuffleBytesWritten(value: Long) = _shuffleBytesWritten += value
- private[spark] def decShuffleBytesWritten(value: Long) = _shuffleBytesWritten -= value
-
- /**
- * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds
- */
- @volatile private var _shuffleWriteTime: Long = _
- def shuffleWriteTime: Long = _shuffleWriteTime
- private[spark] def incShuffleWriteTime(value: Long) = _shuffleWriteTime += value
- private[spark] def decShuffleWriteTime(value: Long) = _shuffleWriteTime -= value
+ * Construct a [[TaskMetrics]] object from a list of [[AccumulableInfo]], called on driver only.
+ * The returned [[TaskMetrics]] is only used to get some internal metrics, we don't need to take
+ * care of external accumulator info passed in.
+ */
+ def fromAccumulatorInfos(infos: Seq[AccumulableInfo]): TaskMetrics = {
+ val tm = new TaskMetrics
+ infos.filter(info => info.name.isDefined && info.update.isDefined).foreach { info =>
+ val name = info.name.get
+ val value = info.update.get
+ if (name == UPDATED_BLOCK_STATUSES) {
+ tm.setUpdatedBlockStatuses(value.asInstanceOf[java.util.List[(BlockId, BlockStatus)]])
+ } else {
+ tm.nameToAccums.get(name).foreach(
+ _.asInstanceOf[LongAccumulator].setValue(value.asInstanceOf[Long])
+ )
+ }
+ }
+ tm
+ }
/**
- * Total number of records written to the shuffle by this task
+ * Construct a [[TaskMetrics]] object from a list of accumulator updates, called on driver only.
*/
- @volatile private var _shuffleRecordsWritten: Long = _
- def shuffleRecordsWritten: Long = _shuffleRecordsWritten
- private[spark] def incShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten += value
- private[spark] def decShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten -= value
- private[spark] def setShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten = value
+ def fromAccumulators(accums: Seq[AccumulatorV2[_, _]]): TaskMetrics = {
+ val tm = new TaskMetrics
+ for (acc <- accums) {
+ val name = acc.name
+ if (name.isDefined && tm.nameToAccums.contains(name.get)) {
+ val tmAcc = tm.nameToAccums(name.get).asInstanceOf[AccumulatorV2[Any, Any]]
+ tmAcc.metadata = acc.metadata
+ tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]])
+ } else {
+ tm.externalAccums += acc
+ }
+ }
+ tm
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/executor/package-info.java b/core/src/main/scala/org/apache/spark/executor/package-info.java
index dd3b6815fb45..fb280964c490 100644
--- a/core/src/main/scala/org/apache/spark/executor/package-info.java
+++ b/core/src/main/scala/org/apache/spark/executor/package-info.java
@@ -18,4 +18,4 @@
/**
* Package for executor components used with various cluster managers.
*/
-package org.apache.spark.executor;
\ No newline at end of file
+package org.apache.spark.executor;
diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
index 532850dd5771..978afaffab30 100644
--- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
+++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
@@ -19,11 +19,10 @@ package org.apache.spark.input
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.{BytesWritable, LongWritable}
-import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
-import org.apache.spark.Logging
-import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.internal.Logging
/**
* Custom Input Format for reading and splitting flat binary files that contain records,
@@ -36,7 +35,7 @@ private[spark] object FixedLengthBinaryInputFormat {
/** Retrieves the record length property from a Hadoop configuration */
def getRecordLength(context: JobContext): Int = {
- SparkHadoopUtil.get.getConfigurationFromJobContext(context).get(RECORD_LENGTH_PROPERTY).toInt
+ context.getConfiguration.get(RECORD_LENGTH_PROPERTY).toInt
}
}
diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
index 67a96925da01..549395314ba6 100644
--- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
+++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
@@ -20,11 +20,10 @@ package org.apache.spark.input
import java.io.IOException
import org.apache.hadoop.fs.FSDataInputStream
-import org.apache.hadoop.io.compress.CompressionCodecFactory
import org.apache.hadoop.io.{BytesWritable, LongWritable}
+import org.apache.hadoop.io.compress.CompressionCodecFactory
import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.input.FileSplit
-import org.apache.spark.deploy.SparkHadoopUtil
/**
* FixedLengthBinaryRecordReader is returned by FixedLengthBinaryInputFormat.
@@ -83,16 +82,16 @@ private[spark] class FixedLengthBinaryRecordReader
// the actual file we will be reading from
val file = fileSplit.getPath
// job configuration
- val job = SparkHadoopUtil.get.getConfigurationFromJobContext(context)
+ val conf = context.getConfiguration
// check compression
- val codec = new CompressionCodecFactory(job).getCodec(file)
+ val codec = new CompressionCodecFactory(conf).getCodec(file)
if (codec != null) {
throw new IOException("FixedLengthRecordReader does not support reading compressed files")
}
// get the record length
recordLength = FixedLengthBinaryInputFormat.getRecordLength(context)
// get the filesystem
- val fs = file.getFileSystem(job)
+ val fs = file.getFileSystem(conf)
// open the File
fileInputStream = fs.open(file)
// seek to the splitStart position
diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
index 280e7a5fe893..9606c4754314 100644
--- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
+++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
@@ -21,13 +21,15 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da
import scala.collection.JavaConverters._
-import com.google.common.io.{Closeables, ByteStreams}
+import com.google.common.io.{ByteStreams, Closeables}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit}
-import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.internal.config
+import org.apache.spark.SparkContext
+import org.apache.spark.annotation.Since
/**
* A general format for reading whole files in as streams, byte arrays,
@@ -42,10 +44,14 @@ private[spark] abstract class StreamFileInputFormat[T]
* Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API
* which is set through setMaxSplitSize
*/
- def setMinPartitions(context: JobContext, minPartitions: Int) {
+ def setMinPartitions(sc: SparkContext, context: JobContext, minPartitions: Int) {
+ val defaultMaxSplitBytes = sc.getConf.get(config.FILES_MAX_PARTITION_BYTES)
+ val openCostInBytes = sc.getConf.get(config.FILES_OPEN_COST_IN_BYTES)
+ val defaultParallelism = sc.defaultParallelism
val files = listStatus(context).asScala
- val totalLen = files.map(file => if (file.isDir) 0L else file.getLen).sum
- val maxSplitSize = Math.ceil(totalLen * 1.0 / files.size).toLong
+ val totalBytes = files.filterNot(_.isDirectory).map(_.getLen + openCostInBytes).sum
+ val bytesPerCore = totalBytes / defaultParallelism
+ val maxSplitSize = Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore))
super.setMaxSplitSize(maxSplitSize)
}
@@ -135,8 +141,7 @@ class PortableDataStream(
private val confBytes = {
val baos = new ByteArrayOutputStream()
- SparkHadoopUtil.get.getConfigurationFromJobContext(context).
- write(new DataOutputStream(baos))
+ context.getConfiguration.write(new DataOutputStream(baos))
baos.toByteArray
}
@@ -171,6 +176,7 @@ class PortableDataStream(
* Create a new DataInputStream from the split and context. The user of this method is responsible
* for closing the stream after usage.
*/
+ @Since("1.2.0")
def open(): DataInputStream = {
val pathp = split.getPath(index)
val fs = pathp.getFileSystem(conf)
@@ -180,6 +186,7 @@ class PortableDataStream(
/**
* Read the file as a byte array
*/
+ @Since("1.2.0")
def toArray(): Array[Byte] = {
val stream = open()
try {
@@ -189,15 +196,10 @@ class PortableDataStream(
}
}
- /**
- * Closing the PortableDataStream is not needed anymore. The user either can use the
- * PortableDataStream to get a DataInputStream (which the user needs to close after usage),
- * or a byte array.
- */
- @deprecated("Closing the PortableDataStream is not needed anymore.", "1.6.0")
- def close(): Unit = {
- }
-
+ @Since("1.2.0")
def getPath(): String = path
+
+ @Since("2.2.0")
+ def getConfiguration: Configuration = conf
}
diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
index 1ba34a11414a..fa34f1e886c7 100644
--- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
+++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
@@ -20,6 +20,7 @@ package org.apache.spark.input
import scala.collection.JavaConverters._
import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.Text
import org.apache.hadoop.mapreduce.InputSplit
import org.apache.hadoop.mapreduce.JobContext
import org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat
@@ -33,14 +34,13 @@ import org.apache.hadoop.mapreduce.TaskAttemptContext
*/
private[spark] class WholeTextFileInputFormat
- extends CombineFileInputFormat[String, String] with Configurable {
+ extends CombineFileInputFormat[Text, Text] with Configurable {
override protected def isSplitable(context: JobContext, file: Path): Boolean = false
override def createRecordReader(
split: InputSplit,
- context: TaskAttemptContext): RecordReader[String, String] = {
-
+ context: TaskAttemptContext): RecordReader[Text, Text] = {
val reader =
new ConfigurableCombineFileRecordReader(split, context, classOf[WholeTextFileRecordReader])
reader.setConf(getConf)
@@ -53,7 +53,7 @@ private[spark] class WholeTextFileInputFormat
*/
def setMinPartitions(context: JobContext, minPartitions: Int) {
val files = listStatus(context).asScala
- val totalLen = files.map(file => if (file.isDir) 0L else file.getLen).sum
+ val totalLen = files.map(file => if (file.isDirectory) 0L else file.getLen).sum
val maxSplitSize = Math.ceil(totalLen * 1.0 /
(if (minPartitions == 0) 1 else minPartitions)).toLong
super.setMaxSplitSize(maxSplitSize)
diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
index 31bde8a78f3c..6b7f086678e9 100644
--- a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
+++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
@@ -17,17 +17,14 @@
package org.apache.spark.input
-import org.apache.hadoop.conf.{Configuration, Configurable => HConfigurable}
import com.google.common.io.{ByteStreams, Closeables}
-
+import org.apache.hadoop.conf.{Configurable => HConfigurable, Configuration}
import org.apache.hadoop.io.Text
import org.apache.hadoop.io.compress.CompressionCodecFactory
import org.apache.hadoop.mapreduce.InputSplit
-import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, CombineFileRecordReader}
import org.apache.hadoop.mapreduce.RecordReader
import org.apache.hadoop.mapreduce.TaskAttemptContext
-import org.apache.spark.deploy.SparkHadoopUtil
-
+import org.apache.hadoop.mapreduce.lib.input.{CombineFileRecordReader, CombineFileSplit}
/**
* A trait to implement [[org.apache.hadoop.conf.Configurable Configurable]] interface.
@@ -49,17 +46,16 @@ private[spark] class WholeTextFileRecordReader(
split: CombineFileSplit,
context: TaskAttemptContext,
index: Integer)
- extends RecordReader[String, String] with Configurable {
+ extends RecordReader[Text, Text] with Configurable {
private[this] val path = split.getPath(index)
- private[this] val fs = path.getFileSystem(
- SparkHadoopUtil.get.getConfigurationFromJobContext(context))
+ private[this] val fs = path.getFileSystem(context.getConfiguration)
// True means the current file has been processed, then skip it.
private[this] var processed = false
- private[this] val key = path.toString
- private[this] var value: String = null
+ private[this] val key: Text = new Text(path.toString)
+ private[this] var value: Text = null
override def initialize(split: InputSplit, context: TaskAttemptContext): Unit = {}
@@ -67,9 +63,9 @@ private[spark] class WholeTextFileRecordReader(
override def getProgress: Float = if (processed) 1.0f else 0.0f
- override def getCurrentKey: String = key
+ override def getCurrentKey: Text = key
- override def getCurrentValue: String = value
+ override def getCurrentValue: Text = value
override def nextKeyValue(): Boolean = {
if (!processed) {
@@ -83,7 +79,7 @@ private[spark] class WholeTextFileRecordReader(
ByteStreams.toByteArray(fileIn)
}
- value = new Text(innerBuffer).toString
+ value = new Text(innerBuffer)
Closeables.close(fileIn, false)
processed = true
true
diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/internal/Logging.scala
similarity index 75%
rename from core/src/main/scala/org/apache/spark/Logging.scala
rename to core/src/main/scala/org/apache/spark/internal/Logging.scala
index 69f6e06ee005..c7f2847731fc 100644
--- a/core/src/main/scala/org/apache/spark/Logging.scala
+++ b/core/src/main/scala/org/apache/spark/internal/Logging.scala
@@ -15,25 +15,21 @@
* limitations under the License.
*/
-package org.apache.spark
+package org.apache.spark.internal
-import org.apache.log4j.{LogManager, PropertyConfigurator}
+import org.apache.log4j.{Level, LogManager, PropertyConfigurator}
import org.slf4j.{Logger, LoggerFactory}
import org.slf4j.impl.StaticLoggerBinder
-import org.apache.spark.annotation.Private
import org.apache.spark.util.Utils
/**
* Utility trait for classes that want to log data. Creates a SLF4J logger for the class and allows
* logging messages at different levels using methods that only evaluate parameters lazily if the
* log level is enabled.
- *
- * NOTE: DO NOT USE this class outside of Spark. It is intended as an internal utility.
- * This will likely be changed or removed in future releases.
*/
-@Private
trait Logging {
+
// Make the log field transient so that objects with Logging can
// be serialized and used on another machine
@transient private var log_ : Logger = null
@@ -47,7 +43,7 @@ trait Logging {
// Method to get or create the logger for this object
protected def log: Logger = {
if (log_ == null) {
- initializeIfNecessary()
+ initializeLogIfNecessary(false)
log_ = LoggerFactory.getLogger(logName)
}
log_
@@ -99,17 +95,17 @@ trait Logging {
log.isTraceEnabled
}
- private def initializeIfNecessary() {
+ protected def initializeLogIfNecessary(isInterpreter: Boolean): Unit = {
if (!Logging.initialized) {
Logging.initLock.synchronized {
if (!Logging.initialized) {
- initializeLogging()
+ initializeLogging(isInterpreter)
}
}
}
}
- private def initializeLogging() {
+ private def initializeLogging(isInterpreter: Boolean): Unit = {
// Don't use a logger in here, as this is itself occurring during initialization of a logger
// If Log4j 1.2 is being used, but is not initialized, load a default properties file
val binderClass = StaticLoggerBinder.getSingleton.getLoggerFactoryClassStr
@@ -119,30 +115,32 @@ trait Logging {
val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass)
if (usingLog4j12) {
val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
+ // scalastyle:off println
if (!log4j12Initialized) {
- // scalastyle:off println
- if (Utils.isInInterpreter) {
- val replDefaultLogProps = "org/apache/spark/log4j-defaults-repl.properties"
- Option(Utils.getSparkClassLoader.getResource(replDefaultLogProps)) match {
- case Some(url) =>
- PropertyConfigurator.configure(url)
- System.err.println(s"Using Spark's repl log4j profile: $replDefaultLogProps")
- System.err.println("To adjust logging level use sc.setLogLevel(\"INFO\")")
- case None =>
- System.err.println(s"Spark was unable to load $replDefaultLogProps")
- }
- } else {
- val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
- Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match {
- case Some(url) =>
- PropertyConfigurator.configure(url)
- System.err.println(s"Using Spark's default log4j profile: $defaultLogProps")
- case None =>
- System.err.println(s"Spark was unable to load $defaultLogProps")
- }
+ val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
+ Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match {
+ case Some(url) =>
+ PropertyConfigurator.configure(url)
+ System.err.println(s"Using Spark's default log4j profile: $defaultLogProps")
+ case None =>
+ System.err.println(s"Spark was unable to load $defaultLogProps")
+ }
+ }
+
+ if (isInterpreter) {
+ // Use the repl's main class to define the default log level when running the shell,
+ // overriding the root logger's config if they're different.
+ val rootLogger = LogManager.getRootLogger()
+ val replLogger = LogManager.getLogger(logName)
+ val replLevel = Option(replLogger.getLevel()).getOrElse(Level.WARN)
+ if (replLevel != rootLogger.getEffectiveLevel()) {
+ System.err.printf("Setting default log level to \"%s\".\n", replLevel)
+ System.err.println("To adjust logging level use sc.setLogLevel(newLevel). " +
+ "For SparkR, use setLogLevel(newLevel).")
+ rootLogger.setLevel(replLevel)
}
- // scalastyle:on println
}
+ // scalastyle:on println
}
Logging.initialized = true
diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala
new file mode 100644
index 000000000000..e5d60a7ef098
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala
@@ -0,0 +1,238 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.internal.config
+
+import java.util.concurrent.TimeUnit
+import java.util.regex.PatternSyntaxException
+
+import scala.util.matching.Regex
+
+import org.apache.spark.network.util.{ByteUnit, JavaUtils}
+
+private object ConfigHelpers {
+
+ def toNumber[T](s: String, converter: String => T, key: String, configType: String): T = {
+ try {
+ converter(s)
+ } catch {
+ case _: NumberFormatException =>
+ throw new IllegalArgumentException(s"$key should be $configType, but was $s")
+ }
+ }
+
+ def toBoolean(s: String, key: String): Boolean = {
+ try {
+ s.toBoolean
+ } catch {
+ case _: IllegalArgumentException =>
+ throw new IllegalArgumentException(s"$key should be boolean, but was $s")
+ }
+ }
+
+ def stringToSeq[T](str: String, converter: String => T): Seq[T] = {
+ str.split(",").map(_.trim()).filter(_.nonEmpty).map(converter)
+ }
+
+ def seqToString[T](v: Seq[T], stringConverter: T => String): String = {
+ v.map(stringConverter).mkString(",")
+ }
+
+ def timeFromString(str: String, unit: TimeUnit): Long = JavaUtils.timeStringAs(str, unit)
+
+ def timeToString(v: Long, unit: TimeUnit): String = TimeUnit.MILLISECONDS.convert(v, unit) + "ms"
+
+ def byteFromString(str: String, unit: ByteUnit): Long = {
+ val (input, multiplier) =
+ if (str.length() > 0 && str.charAt(0) == '-') {
+ (str.substring(1), -1)
+ } else {
+ (str, 1)
+ }
+ multiplier * JavaUtils.byteStringAs(input, unit)
+ }
+
+ def byteToString(v: Long, unit: ByteUnit): String = unit.convertTo(v, ByteUnit.BYTE) + "b"
+
+ def regexFromString(str: String, key: String): Regex = {
+ try str.r catch {
+ case e: PatternSyntaxException =>
+ throw new IllegalArgumentException(s"$key should be a regex, but was $str", e)
+ }
+ }
+
+}
+
+/**
+ * A type-safe config builder. Provides methods for transforming the input data (which can be
+ * used, e.g., for validation) and creating the final config entry.
+ *
+ * One of the methods that return a [[ConfigEntry]] must be called to create a config entry that
+ * can be used with [[SparkConf]].
+ */
+private[spark] class TypedConfigBuilder[T](
+ val parent: ConfigBuilder,
+ val converter: String => T,
+ val stringConverter: T => String) {
+
+ import ConfigHelpers._
+
+ def this(parent: ConfigBuilder, converter: String => T) = {
+ this(parent, converter, Option(_).map(_.toString).orNull)
+ }
+
+ /** Apply a transformation to the user-provided values of the config entry. */
+ def transform(fn: T => T): TypedConfigBuilder[T] = {
+ new TypedConfigBuilder(parent, s => fn(converter(s)), stringConverter)
+ }
+
+ /** Checks if the user-provided value for the config matches the validator. */
+ def checkValue(validator: T => Boolean, errorMsg: String): TypedConfigBuilder[T] = {
+ transform { v =>
+ if (!validator(v)) throw new IllegalArgumentException(errorMsg)
+ v
+ }
+ }
+
+ /** Check that user-provided values for the config match a pre-defined set. */
+ def checkValues(validValues: Set[T]): TypedConfigBuilder[T] = {
+ transform { v =>
+ if (!validValues.contains(v)) {
+ throw new IllegalArgumentException(
+ s"The value of ${parent.key} should be one of ${validValues.mkString(", ")}, but was $v")
+ }
+ v
+ }
+ }
+
+ /** Turns the config entry into a sequence of values of the underlying type. */
+ def toSequence: TypedConfigBuilder[Seq[T]] = {
+ new TypedConfigBuilder(parent, stringToSeq(_, converter), seqToString(_, stringConverter))
+ }
+
+ /** Creates a [[ConfigEntry]] that does not have a default value. */
+ def createOptional: OptionalConfigEntry[T] = {
+ val entry = new OptionalConfigEntry[T](parent.key, converter, stringConverter, parent._doc,
+ parent._public)
+ parent._onCreate.foreach(_(entry))
+ entry
+ }
+
+ /** Creates a [[ConfigEntry]] that has a default value. */
+ def createWithDefault(default: T): ConfigEntry[T] = {
+ // Treat "String" as a special case, so that both createWithDefault and createWithDefaultString
+ // behave the same w.r.t. variable expansion of default values.
+ if (default.isInstanceOf[String]) {
+ createWithDefaultString(default.asInstanceOf[String])
+ } else {
+ val transformedDefault = converter(stringConverter(default))
+ val entry = new ConfigEntryWithDefault[T](parent.key, transformedDefault, converter,
+ stringConverter, parent._doc, parent._public)
+ parent._onCreate.foreach(_(entry))
+ entry
+ }
+ }
+
+ /** Creates a [[ConfigEntry]] with a function to determine the default value */
+ def createWithDefaultFunction(defaultFunc: () => T): ConfigEntry[T] = {
+ val entry = new ConfigEntryWithDefaultFunction[T](parent.key, defaultFunc, converter,
+ stringConverter, parent._doc, parent._public)
+ parent._onCreate.foreach(_ (entry))
+ entry
+ }
+
+ /**
+ * Creates a [[ConfigEntry]] that has a default value. The default value is provided as a
+ * [[String]] and must be a valid value for the entry.
+ */
+ def createWithDefaultString(default: String): ConfigEntry[T] = {
+ val entry = new ConfigEntryWithDefaultString[T](parent.key, default, converter, stringConverter,
+ parent._doc, parent._public)
+ parent._onCreate.foreach(_(entry))
+ entry
+ }
+
+}
+
+/**
+ * Basic builder for Spark configurations. Provides methods for creating type-specific builders.
+ *
+ * @see TypedConfigBuilder
+ */
+private[spark] case class ConfigBuilder(key: String) {
+
+ import ConfigHelpers._
+
+ private[config] var _public = true
+ private[config] var _doc = ""
+ private[config] var _onCreate: Option[ConfigEntry[_] => Unit] = None
+
+ def internal(): ConfigBuilder = {
+ _public = false
+ this
+ }
+
+ def doc(s: String): ConfigBuilder = {
+ _doc = s
+ this
+ }
+
+ /**
+ * Registers a callback for when the config entry is finally instantiated. Currently used by
+ * SQLConf to keep track of SQL configuration entries.
+ */
+ def onCreate(callback: ConfigEntry[_] => Unit): ConfigBuilder = {
+ _onCreate = Option(callback)
+ this
+ }
+
+ def intConf: TypedConfigBuilder[Int] = {
+ new TypedConfigBuilder(this, toNumber(_, _.toInt, key, "int"))
+ }
+
+ def longConf: TypedConfigBuilder[Long] = {
+ new TypedConfigBuilder(this, toNumber(_, _.toLong, key, "long"))
+ }
+
+ def doubleConf: TypedConfigBuilder[Double] = {
+ new TypedConfigBuilder(this, toNumber(_, _.toDouble, key, "double"))
+ }
+
+ def booleanConf: TypedConfigBuilder[Boolean] = {
+ new TypedConfigBuilder(this, toBoolean(_, key))
+ }
+
+ def stringConf: TypedConfigBuilder[String] = {
+ new TypedConfigBuilder(this, v => v)
+ }
+
+ def timeConf(unit: TimeUnit): TypedConfigBuilder[Long] = {
+ new TypedConfigBuilder(this, timeFromString(_, unit), timeToString(_, unit))
+ }
+
+ def bytesConf(unit: ByteUnit): TypedConfigBuilder[Long] = {
+ new TypedConfigBuilder(this, byteFromString(_, unit), byteToString(_, unit))
+ }
+
+ def fallbackConf[T](fallback: ConfigEntry[T]): ConfigEntry[T] = {
+ new FallbackConfigEntry(key, _doc, _public, fallback)
+ }
+
+ def regexConf: TypedConfigBuilder[Regex] = {
+ new TypedConfigBuilder(this, regexFromString(_, this.key), _.toString)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala
new file mode 100644
index 000000000000..e86712e84d6a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.internal.config
+
+/**
+ * An entry contains all meta information for a configuration.
+ *
+ * When applying variable substitution to config values, only references starting with "spark." are
+ * considered in the default namespace. For known Spark configuration keys (i.e. those created using
+ * `ConfigBuilder`), references will also consider the default value when it exists.
+ *
+ * Variable expansion is also applied to the default values of config entries that have a default
+ * value declared as a string.
+ *
+ * @param key the key for the configuration
+ * @param valueConverter how to convert a string to the value. It should throw an exception if the
+ * string does not have the required format.
+ * @param stringConverter how to convert a value to a string that the user can use it as a valid
+ * string value. It's usually `toString`. But sometimes, a custom converter
+ * is necessary. E.g., if T is List[String], `a, b, c` is better than
+ * `List(a, b, c)`.
+ * @param doc the documentation for the configuration
+ * @param isPublic if this configuration is public to the user. If it's `false`, this
+ * configuration is only used internally and we should not expose it to users.
+ * @tparam T the value type
+ */
+private[spark] abstract class ConfigEntry[T] (
+ val key: String,
+ val valueConverter: String => T,
+ val stringConverter: T => String,
+ val doc: String,
+ val isPublic: Boolean) {
+
+ import ConfigEntry._
+
+ registerEntry(this)
+
+ def defaultValueString: String
+
+ def readFrom(reader: ConfigReader): T
+
+ def defaultValue: Option[T] = None
+
+ override def toString: String = {
+ s"ConfigEntry(key=$key, defaultValue=$defaultValueString, doc=$doc, public=$isPublic)"
+ }
+
+}
+
+private class ConfigEntryWithDefault[T] (
+ key: String,
+ _defaultValue: T,
+ valueConverter: String => T,
+ stringConverter: T => String,
+ doc: String,
+ isPublic: Boolean)
+ extends ConfigEntry(key, valueConverter, stringConverter, doc, isPublic) {
+
+ override def defaultValue: Option[T] = Some(_defaultValue)
+
+ override def defaultValueString: String = stringConverter(_defaultValue)
+
+ def readFrom(reader: ConfigReader): T = {
+ reader.get(key).map(valueConverter).getOrElse(_defaultValue)
+ }
+}
+
+private class ConfigEntryWithDefaultFunction[T] (
+ key: String,
+ _defaultFunction: () => T,
+ valueConverter: String => T,
+ stringConverter: T => String,
+ doc: String,
+ isPublic: Boolean)
+ extends ConfigEntry(key, valueConverter, stringConverter, doc, isPublic) {
+
+ override def defaultValue: Option[T] = Some(_defaultFunction())
+
+ override def defaultValueString: String = stringConverter(_defaultFunction())
+
+ def readFrom(reader: ConfigReader): T = {
+ reader.get(key).map(valueConverter).getOrElse(_defaultFunction())
+ }
+}
+
+private class ConfigEntryWithDefaultString[T] (
+ key: String,
+ _defaultValue: String,
+ valueConverter: String => T,
+ stringConverter: T => String,
+ doc: String,
+ isPublic: Boolean)
+ extends ConfigEntry(key, valueConverter, stringConverter, doc, isPublic) {
+
+ override def defaultValue: Option[T] = Some(valueConverter(_defaultValue))
+
+ override def defaultValueString: String = _defaultValue
+
+ def readFrom(reader: ConfigReader): T = {
+ val value = reader.get(key).getOrElse(reader.substitute(_defaultValue))
+ valueConverter(value)
+ }
+
+}
+
+
+/**
+ * A config entry that does not have a default value.
+ */
+private[spark] class OptionalConfigEntry[T](
+ key: String,
+ val rawValueConverter: String => T,
+ val rawStringConverter: T => String,
+ doc: String,
+ isPublic: Boolean)
+ extends ConfigEntry[Option[T]](key, s => Some(rawValueConverter(s)),
+ v => v.map(rawStringConverter).orNull, doc, isPublic) {
+
+ override def defaultValueString: String = ""
+
+ override def readFrom(reader: ConfigReader): Option[T] = {
+ reader.get(key).map(rawValueConverter)
+ }
+
+}
+
+/**
+ * A config entry whose default value is defined by another config entry.
+ */
+private class FallbackConfigEntry[T] (
+ key: String,
+ doc: String,
+ isPublic: Boolean,
+ private[config] val fallback: ConfigEntry[T])
+ extends ConfigEntry[T](key, fallback.valueConverter, fallback.stringConverter, doc, isPublic) {
+
+ override def defaultValueString: String = s""
+
+ override def readFrom(reader: ConfigReader): T = {
+ reader.get(key).map(valueConverter).getOrElse(fallback.readFrom(reader))
+ }
+
+}
+
+private[spark] object ConfigEntry {
+
+ private val knownConfigs = new java.util.concurrent.ConcurrentHashMap[String, ConfigEntry[_]]()
+
+ def registerEntry(entry: ConfigEntry[_]): Unit = {
+ val existing = knownConfigs.putIfAbsent(entry.key, entry)
+ require(existing == null, s"Config entry ${entry.key} already registered!")
+ }
+
+ def findEntry(key: String): ConfigEntry[_] = knownConfigs.get(key)
+
+}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala
new file mode 100644
index 000000000000..97f56a64d600
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.internal.config
+
+import java.util.{Map => JMap}
+
+/**
+ * A source of configuration values.
+ */
+private[spark] trait ConfigProvider {
+
+ def get(key: String): Option[String]
+
+}
+
+private[spark] class EnvProvider extends ConfigProvider {
+
+ override def get(key: String): Option[String] = sys.env.get(key)
+
+}
+
+private[spark] class SystemProvider extends ConfigProvider {
+
+ override def get(key: String): Option[String] = sys.props.get(key)
+
+}
+
+private[spark] class MapProvider(conf: JMap[String, String]) extends ConfigProvider {
+
+ override def get(key: String): Option[String] = Option(conf.get(key))
+
+}
+
+/**
+ * A config provider that only reads Spark config keys, and considers default values for known
+ * configs when fetching configuration values.
+ */
+private[spark] class SparkConfigProvider(conf: JMap[String, String]) extends ConfigProvider {
+
+ import ConfigEntry._
+
+ override def get(key: String): Option[String] = {
+ if (key.startsWith("spark.")) {
+ Option(conf.get(key)).orElse(defaultValueString(key))
+ } else {
+ None
+ }
+ }
+
+ private def defaultValueString(key: String): Option[String] = {
+ findEntry(key) match {
+ case e: ConfigEntryWithDefault[_] => Option(e.defaultValueString)
+ case e: ConfigEntryWithDefaultString[_] => Option(e.defaultValueString)
+ case e: FallbackConfigEntry[_] => get(e.fallback.key)
+ case _ => None
+ }
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigReader.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigReader.scala
new file mode 100644
index 000000000000..c62de9bfd8fc
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigReader.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.internal.config
+
+import java.util.{Map => JMap}
+
+import scala.collection.mutable.HashMap
+import scala.util.matching.Regex
+
+private object ConfigReader {
+
+ private val REF_RE = "\\$\\{(?:(\\w+?):)?(\\S+?)\\}".r
+
+}
+
+/**
+ * A helper class for reading config entries and performing variable substitution.
+ *
+ * If a config value contains variable references of the form "${prefix:variableName}", the
+ * reference will be replaced with the value of the variable depending on the prefix. By default,
+ * the following prefixes are handled:
+ *
+ * - no prefix: use the default config provider
+ * - system: looks for the value in the system properties
+ * - env: looks for the value in the environment
+ *
+ * Different prefixes can be bound to a `ConfigProvider`, which is used to read configuration
+ * values from the data source for the prefix, and both the system and env providers can be
+ * overridden.
+ *
+ * If the reference cannot be resolved, the original string will be retained.
+ *
+ * @param conf The config provider for the default namespace (no prefix).
+ */
+private[spark] class ConfigReader(conf: ConfigProvider) {
+
+ def this(conf: JMap[String, String]) = this(new MapProvider(conf))
+
+ private val bindings = new HashMap[String, ConfigProvider]()
+ bind(null, conf)
+ bindEnv(new EnvProvider())
+ bindSystem(new SystemProvider())
+
+ /**
+ * Binds a prefix to a provider. This method is not thread-safe and should be called
+ * before the instance is used to expand values.
+ */
+ def bind(prefix: String, provider: ConfigProvider): ConfigReader = {
+ bindings(prefix) = provider
+ this
+ }
+
+ def bind(prefix: String, values: JMap[String, String]): ConfigReader = {
+ bind(prefix, new MapProvider(values))
+ }
+
+ def bindEnv(provider: ConfigProvider): ConfigReader = bind("env", provider)
+
+ def bindSystem(provider: ConfigProvider): ConfigReader = bind("system", provider)
+
+ /**
+ * Reads a configuration key from the default provider, and apply variable substitution.
+ */
+ def get(key: String): Option[String] = conf.get(key).map(substitute)
+
+ /**
+ * Perform variable substitution on the given input string.
+ */
+ def substitute(input: String): String = substitute(input, Set())
+
+ private def substitute(input: String, usedRefs: Set[String]): String = {
+ if (input != null) {
+ ConfigReader.REF_RE.replaceAllIn(input, { m =>
+ val prefix = m.group(1)
+ val name = m.group(2)
+ val ref = if (prefix == null) name else s"$prefix:$name"
+ require(!usedRefs.contains(ref), s"Circular reference in $input: $ref")
+
+ val replacement = bindings.get(prefix)
+ .flatMap(_.get(name))
+ .map { v => substitute(v, usedRefs + ref) }
+ .getOrElse(m.matched)
+ Regex.quoteReplacement(replacement)
+ })
+ } else {
+ input
+ }
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
new file mode 100644
index 000000000000..f65a9d750c51
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -0,0 +1,315 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.internal
+
+import java.util.concurrent.TimeUnit
+
+import org.apache.spark.launcher.SparkLauncher
+import org.apache.spark.network.util.ByteUnit
+import org.apache.spark.util.Utils
+
+package object config {
+
+ private[spark] val DRIVER_CLASS_PATH =
+ ConfigBuilder(SparkLauncher.DRIVER_EXTRA_CLASSPATH).stringConf.createOptional
+
+ private[spark] val DRIVER_JAVA_OPTIONS =
+ ConfigBuilder(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS).stringConf.createOptional
+
+ private[spark] val DRIVER_LIBRARY_PATH =
+ ConfigBuilder(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH).stringConf.createOptional
+
+ private[spark] val DRIVER_USER_CLASS_PATH_FIRST =
+ ConfigBuilder("spark.driver.userClassPathFirst").booleanConf.createWithDefault(false)
+
+ private[spark] val DRIVER_MEMORY = ConfigBuilder("spark.driver.memory")
+ .bytesConf(ByteUnit.MiB)
+ .createWithDefaultString("1g")
+
+ private[spark] val EXECUTOR_CLASS_PATH =
+ ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_CLASSPATH).stringConf.createOptional
+
+ private[spark] val EXECUTOR_JAVA_OPTIONS =
+ ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_JAVA_OPTIONS).stringConf.createOptional
+
+ private[spark] val EXECUTOR_LIBRARY_PATH =
+ ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_LIBRARY_PATH).stringConf.createOptional
+
+ private[spark] val EXECUTOR_USER_CLASS_PATH_FIRST =
+ ConfigBuilder("spark.executor.userClassPathFirst").booleanConf.createWithDefault(false)
+
+ private[spark] val EXECUTOR_MEMORY = ConfigBuilder("spark.executor.memory")
+ .bytesConf(ByteUnit.MiB)
+ .createWithDefaultString("1g")
+
+ private[spark] val IS_PYTHON_APP = ConfigBuilder("spark.yarn.isPython").internal()
+ .booleanConf.createWithDefault(false)
+
+ private[spark] val CPUS_PER_TASK = ConfigBuilder("spark.task.cpus").intConf.createWithDefault(1)
+
+ private[spark] val DYN_ALLOCATION_MIN_EXECUTORS =
+ ConfigBuilder("spark.dynamicAllocation.minExecutors").intConf.createWithDefault(0)
+
+ private[spark] val DYN_ALLOCATION_INITIAL_EXECUTORS =
+ ConfigBuilder("spark.dynamicAllocation.initialExecutors")
+ .fallbackConf(DYN_ALLOCATION_MIN_EXECUTORS)
+
+ private[spark] val DYN_ALLOCATION_MAX_EXECUTORS =
+ ConfigBuilder("spark.dynamicAllocation.maxExecutors").intConf.createWithDefault(Int.MaxValue)
+
+ private[spark] val SHUFFLE_SERVICE_ENABLED =
+ ConfigBuilder("spark.shuffle.service.enabled").booleanConf.createWithDefault(false)
+
+ private[spark] val KEYTAB = ConfigBuilder("spark.yarn.keytab")
+ .doc("Location of user's keytab.")
+ .stringConf.createOptional
+
+ private[spark] val PRINCIPAL = ConfigBuilder("spark.yarn.principal")
+ .doc("Name of the Kerberos principal.")
+ .stringConf.createOptional
+
+ private[spark] val EXECUTOR_INSTANCES = ConfigBuilder("spark.executor.instances")
+ .intConf
+ .createOptional
+
+ private[spark] val PY_FILES = ConfigBuilder("spark.yarn.dist.pyFiles")
+ .internal()
+ .stringConf
+ .toSequence
+ .createWithDefault(Nil)
+
+ private[spark] val MAX_TASK_FAILURES =
+ ConfigBuilder("spark.task.maxFailures")
+ .intConf
+ .createWithDefault(4)
+
+ // Blacklist confs
+ private[spark] val BLACKLIST_ENABLED =
+ ConfigBuilder("spark.blacklist.enabled")
+ .booleanConf
+ .createOptional
+
+ private[spark] val MAX_TASK_ATTEMPTS_PER_EXECUTOR =
+ ConfigBuilder("spark.blacklist.task.maxTaskAttemptsPerExecutor")
+ .intConf
+ .createWithDefault(1)
+
+ private[spark] val MAX_TASK_ATTEMPTS_PER_NODE =
+ ConfigBuilder("spark.blacklist.task.maxTaskAttemptsPerNode")
+ .intConf
+ .createWithDefault(2)
+
+ private[spark] val MAX_FAILURES_PER_EXEC =
+ ConfigBuilder("spark.blacklist.application.maxFailedTasksPerExecutor")
+ .intConf
+ .createWithDefault(2)
+
+ private[spark] val MAX_FAILURES_PER_EXEC_STAGE =
+ ConfigBuilder("spark.blacklist.stage.maxFailedTasksPerExecutor")
+ .intConf
+ .createWithDefault(2)
+
+ private[spark] val MAX_FAILED_EXEC_PER_NODE =
+ ConfigBuilder("spark.blacklist.application.maxFailedExecutorsPerNode")
+ .intConf
+ .createWithDefault(2)
+
+ private[spark] val MAX_FAILED_EXEC_PER_NODE_STAGE =
+ ConfigBuilder("spark.blacklist.stage.maxFailedExecutorsPerNode")
+ .intConf
+ .createWithDefault(2)
+
+ private[spark] val BLACKLIST_TIMEOUT_CONF =
+ ConfigBuilder("spark.blacklist.timeout")
+ .timeConf(TimeUnit.MILLISECONDS)
+ .createOptional
+
+ private[spark] val BLACKLIST_KILL_ENABLED =
+ ConfigBuilder("spark.blacklist.killBlacklistedExecutors")
+ .booleanConf
+ .createWithDefault(false)
+
+ private[spark] val BLACKLIST_LEGACY_TIMEOUT_CONF =
+ ConfigBuilder("spark.scheduler.executorTaskBlacklistTime")
+ .internal()
+ .timeConf(TimeUnit.MILLISECONDS)
+ .createOptional
+ // End blacklist confs
+
+ private[spark] val LISTENER_BUS_EVENT_QUEUE_SIZE =
+ ConfigBuilder("spark.scheduler.listenerbus.eventqueue.size")
+ .intConf
+ .createWithDefault(10000)
+
+ // This property sets the root namespace for metrics reporting
+ private[spark] val METRICS_NAMESPACE = ConfigBuilder("spark.metrics.namespace")
+ .stringConf
+ .createOptional
+
+ private[spark] val PYSPARK_DRIVER_PYTHON = ConfigBuilder("spark.pyspark.driver.python")
+ .stringConf
+ .createOptional
+
+ private[spark] val PYSPARK_PYTHON = ConfigBuilder("spark.pyspark.python")
+ .stringConf
+ .createOptional
+
+ // To limit memory usage, we only track information for a fixed number of tasks
+ private[spark] val UI_RETAINED_TASKS = ConfigBuilder("spark.ui.retainedTasks")
+ .intConf
+ .createWithDefault(100000)
+
+ // To limit how many applications are shown in the History Server summary ui
+ private[spark] val HISTORY_UI_MAX_APPS =
+ ConfigBuilder("spark.history.ui.maxApplications").intConf.createWithDefault(Integer.MAX_VALUE)
+
+ private[spark] val IO_ENCRYPTION_ENABLED = ConfigBuilder("spark.io.encryption.enabled")
+ .booleanConf
+ .createWithDefault(false)
+
+ private[spark] val IO_ENCRYPTION_KEYGEN_ALGORITHM =
+ ConfigBuilder("spark.io.encryption.keygen.algorithm")
+ .stringConf
+ .createWithDefault("HmacSHA1")
+
+ private[spark] val IO_ENCRYPTION_KEY_SIZE_BITS = ConfigBuilder("spark.io.encryption.keySizeBits")
+ .intConf
+ .checkValues(Set(128, 192, 256))
+ .createWithDefault(128)
+
+ private[spark] val IO_CRYPTO_CIPHER_TRANSFORMATION =
+ ConfigBuilder("spark.io.crypto.cipher.transformation")
+ .internal()
+ .stringConf
+ .createWithDefaultString("AES/CTR/NoPadding")
+
+ private[spark] val DRIVER_HOST_ADDRESS = ConfigBuilder("spark.driver.host")
+ .doc("Address of driver endpoints.")
+ .stringConf
+ .createWithDefault(Utils.localHostName())
+
+ private[spark] val DRIVER_BIND_ADDRESS = ConfigBuilder("spark.driver.bindAddress")
+ .doc("Address where to bind network listen sockets on the driver.")
+ .fallbackConf(DRIVER_HOST_ADDRESS)
+
+ private[spark] val BLOCK_MANAGER_PORT = ConfigBuilder("spark.blockManager.port")
+ .doc("Port to use for the block manager when a more specific setting is not provided.")
+ .intConf
+ .createWithDefault(0)
+
+ private[spark] val DRIVER_BLOCK_MANAGER_PORT = ConfigBuilder("spark.driver.blockManager.port")
+ .doc("Port to use for the block manager on the driver.")
+ .fallbackConf(BLOCK_MANAGER_PORT)
+
+ private[spark] val IGNORE_CORRUPT_FILES = ConfigBuilder("spark.files.ignoreCorruptFiles")
+ .doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " +
+ "encountering corrupted or non-existing files and contents that have been read will still " +
+ "be returned.")
+ .booleanConf
+ .createWithDefault(false)
+
+ private[spark] val APP_CALLER_CONTEXT = ConfigBuilder("spark.log.callerContext")
+ .stringConf
+ .createOptional
+
+ private[spark] val FILES_MAX_PARTITION_BYTES = ConfigBuilder("spark.files.maxPartitionBytes")
+ .doc("The maximum number of bytes to pack into a single partition when reading files.")
+ .longConf
+ .createWithDefault(128 * 1024 * 1024)
+
+ private[spark] val FILES_OPEN_COST_IN_BYTES = ConfigBuilder("spark.files.openCostInBytes")
+ .doc("The estimated cost to open a file, measured by the number of bytes could be scanned in" +
+ " the same time. This is used when putting multiple files into a partition. It's better to" +
+ " over estimate, then the partitions with small files will be faster than partitions with" +
+ " bigger files.")
+ .longConf
+ .createWithDefault(4 * 1024 * 1024)
+
+ private[spark] val SECRET_REDACTION_PATTERN =
+ ConfigBuilder("spark.redaction.regex")
+ .doc("Regex to decide which Spark configuration properties and environment variables in " +
+ "driver and executor environments contain sensitive information. When this regex matches " +
+ "a property key or value, the value is redacted from the environment UI and various logs " +
+ "like YARN and event logs.")
+ .regexConf
+ .createWithDefault("(?i)secret|password".r)
+
+ private[spark] val STRING_REDACTION_PATTERN =
+ ConfigBuilder("spark.redaction.string.regex")
+ .doc("Regex to decide which parts of strings produced by Spark contain sensitive " +
+ "information. When this regex matches a string part, that string part is replaced by a " +
+ "dummy value. This is currently used to redact the output of SQL explain commands.")
+ .regexConf
+ .createOptional
+
+ private[spark] val AUTH_SECRET_BIT_LENGTH =
+ ConfigBuilder("spark.authenticate.secretBitLength")
+ .intConf
+ .createWithDefault(256)
+
+ private[spark] val NETWORK_AUTH_ENABLED =
+ ConfigBuilder("spark.authenticate")
+ .booleanConf
+ .createWithDefault(false)
+
+ private[spark] val SASL_ENCRYPTION_ENABLED =
+ ConfigBuilder("spark.authenticate.enableSaslEncryption")
+ .booleanConf
+ .createWithDefault(false)
+
+ private[spark] val NETWORK_ENCRYPTION_ENABLED =
+ ConfigBuilder("spark.network.crypto.enabled")
+ .booleanConf
+ .createWithDefault(false)
+
+ private[spark] val CHECKPOINT_COMPRESS =
+ ConfigBuilder("spark.checkpoint.compress")
+ .doc("Whether to compress RDD checkpoints. Generally a good idea. Compression will use " +
+ "spark.io.compression.codec.")
+ .booleanConf
+ .createWithDefault(false)
+
+ private[spark] val SHUFFLE_ACCURATE_BLOCK_THRESHOLD =
+ ConfigBuilder("spark.shuffle.accurateBlockThreshold")
+ .doc("When we compress the size of shuffle blocks in HighlyCompressedMapStatus, we will " +
+ "record the size accurately if it's above this config. This helps to prevent OOM by " +
+ "avoiding underestimating shuffle block size when fetch shuffle blocks.")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefault(100 * 1024 * 1024)
+
+ private[spark] val REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS =
+ ConfigBuilder("spark.reducer.maxBlocksInFlightPerAddress")
+ .doc("This configuration limits the number of remote blocks being fetched per reduce task" +
+ " from a given host port. When a large number of blocks are being requested from a given" +
+ " address in a single fetch or simultaneously, this could crash the serving executor or" +
+ " Node Manager. This is especially useful to reduce the load on the Node Manager when" +
+ " external shuffle is enabled. You can mitigate the issue by setting it to a lower value.")
+ .intConf
+ .checkValue(_ > 0, "The max no. of blocks in flight cannot be non-positive.")
+ .createWithDefault(Int.MaxValue)
+
+ private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM =
+ ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem")
+ .doc("The blocks of a shuffle request will be fetched to disk when size of the request is " +
+ "above this threshold. This is to avoid a giant request takes too much memory. We can " +
+ "enable this config by setting a specific value(e.g. 200m). Note that this config can " +
+ "be enabled only when the shuffle shuffle service is newer than Spark-2.2 or the shuffle" +
+ " service is disabled.")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefault(Long.MaxValue)
+}
diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
new file mode 100644
index 000000000000..7efa9416362a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.internal.io
+
+import org.apache.hadoop.fs._
+import org.apache.hadoop.mapreduce._
+
+import org.apache.spark.util.Utils
+
+
+/**
+ * An interface to define how a single Spark job commits its outputs. Two notes:
+ *
+ * 1. Implementations must be serializable, as the committer instance instantiated on the driver
+ * will be used for tasks on executors.
+ * 2. Implementations should have a constructor with either 2 or 3 arguments:
+ * (jobId: String, path: String) or (jobId: String, path: String, isAppend: Boolean).
+ * 3. A committer should not be reused across multiple Spark jobs.
+ *
+ * The proper call sequence is:
+ *
+ * 1. Driver calls setupJob.
+ * 2. As part of each task's execution, executor calls setupTask and then commitTask
+ * (or abortTask if task failed).
+ * 3. When all necessary tasks completed successfully, the driver calls commitJob. If the job
+ * failed to execute (e.g. too many failed tasks), the job should call abortJob.
+ */
+abstract class FileCommitProtocol {
+ import FileCommitProtocol._
+
+ /**
+ * Setups up a job. Must be called on the driver before any other methods can be invoked.
+ */
+ def setupJob(jobContext: JobContext): Unit
+
+ /**
+ * Commits a job after the writes succeed. Must be called on the driver.
+ */
+ def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit
+
+ /**
+ * Aborts a job after the writes fail. Must be called on the driver.
+ *
+ * Calling this function is a best-effort attempt, because it is possible that the driver
+ * just crashes (or killed) before it can call abort.
+ */
+ def abortJob(jobContext: JobContext): Unit
+
+ /**
+ * Sets up a task within a job.
+ * Must be called before any other task related methods can be invoked.
+ */
+ def setupTask(taskContext: TaskAttemptContext): Unit
+
+ /**
+ * Notifies the commit protocol to add a new file, and gets back the full path that should be
+ * used. Must be called on the executors when running tasks.
+ *
+ * Note that the returned temp file may have an arbitrary path. The commit protocol only
+ * promises that the file will be at the location specified by the arguments after job commit.
+ *
+ * A full file path consists of the following parts:
+ * 1. the base path
+ * 2. some sub-directory within the base path, used to specify partitioning
+ * 3. file prefix, usually some unique job id with the task id
+ * 4. bucket id
+ * 5. source specific file extension, e.g. ".snappy.parquet"
+ *
+ * The "dir" parameter specifies 2, and "ext" parameter specifies both 4 and 5, and the rest
+ * are left to the commit protocol implementation to decide.
+ *
+ * Important: it is the caller's responsibility to add uniquely identifying content to "ext"
+ * if a task is going to write out multiple files to the same dir. The file commit protocol only
+ * guarantees that files written by different tasks will not conflict.
+ */
+ def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String
+
+ /**
+ * Similar to newTaskTempFile(), but allows files to committed to an absolute output location.
+ * Depending on the implementation, there may be weaker guarantees around adding files this way.
+ *
+ * Important: it is the caller's responsibility to add uniquely identifying content to "ext"
+ * if a task is going to write out multiple files to the same dir. The file commit protocol only
+ * guarantees that files written by different tasks will not conflict.
+ */
+ def newTaskTempFileAbsPath(
+ taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String
+
+ /**
+ * Commits a task after the writes succeed. Must be called on the executors when running tasks.
+ */
+ def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage
+
+ /**
+ * Aborts a task after the writes have failed. Must be called on the executors when running tasks.
+ *
+ * Calling this function is a best-effort attempt, because it is possible that the executor
+ * just crashes (or killed) before it can call abort.
+ */
+ def abortTask(taskContext: TaskAttemptContext): Unit
+
+ /**
+ * Specifies that a file should be deleted with the commit of this job. The default
+ * implementation deletes the file immediately.
+ */
+ def deleteWithJob(fs: FileSystem, path: Path, recursive: Boolean): Boolean = {
+ fs.delete(path, recursive)
+ }
+
+ /**
+ * Called on the driver after a task commits. This can be used to access task commit messages
+ * before the job has finished. These same task commit messages will be passed to commitJob()
+ * if the entire job succeeds.
+ */
+ def onTaskCommit(taskCommit: TaskCommitMessage): Unit = {}
+}
+
+
+object FileCommitProtocol {
+ class TaskCommitMessage(val obj: Any) extends Serializable
+
+ object EmptyTaskCommitMessage extends TaskCommitMessage(null)
+
+ /**
+ * Instantiates a FileCommitProtocol using the given className.
+ */
+ def instantiate(className: String, jobId: String, outputPath: String, isAppend: Boolean)
+ : FileCommitProtocol = {
+ val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]]
+
+ // First try the one with argument (jobId: String, outputPath: String, isAppend: Boolean).
+ // If that doesn't exist, try the one with (jobId: string, outputPath: String).
+ try {
+ val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String], classOf[Boolean])
+ ctor.newInstance(jobId, outputPath, isAppend.asInstanceOf[java.lang.Boolean])
+ } catch {
+ case _: NoSuchMethodException =>
+ val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String])
+ ctor.newInstance(jobId, outputPath)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
new file mode 100644
index 000000000000..bc777eba402c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.internal.io
+
+import java.util.{Date, UUID}
+
+import scala.collection.mutable
+import scala.util.Try
+
+import org.apache.hadoop.conf.Configurable
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapreduce._
+import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter
+import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.mapred.SparkHadoopMapRedUtil
+
+/**
+ * An [[FileCommitProtocol]] implementation backed by an underlying Hadoop OutputCommitter
+ * (from the newer mapreduce API, not the old mapred API).
+ *
+ * Unlike Hadoop's OutputCommitter, this implementation is serializable.
+ *
+ * @param jobId the job's or stage's id
+ * @param path the job's output path, or null if committer acts as a noop
+ */
+class HadoopMapReduceCommitProtocol(jobId: String, path: String)
+ extends FileCommitProtocol with Serializable with Logging {
+
+ import FileCommitProtocol._
+
+ /** OutputCommitter from Hadoop is not serializable so marking it transient. */
+ @transient private var committer: OutputCommitter = _
+
+ /**
+ * Checks whether there are files to be committed to a valid output location.
+ *
+ * As committing and aborting a job occurs on driver, where `addedAbsPathFiles` is always null,
+ * it is necessary to check whether a valid output path is specified.
+ * [[HadoopMapReduceCommitProtocol#path]] need not be a valid [[org.apache.hadoop.fs.Path]] for
+ * committers not writing to distributed file systems.
+ */
+ private val hasValidPath = Try { new Path(path) }.isSuccess
+
+ /**
+ * Tracks files staged by this task for absolute output paths. These outputs are not managed by
+ * the Hadoop OutputCommitter, so we must move these to their final locations on job commit.
+ *
+ * The mapping is from the temp output path to the final desired output path of the file.
+ */
+ @transient private var addedAbsPathFiles: mutable.Map[String, String] = null
+
+ /**
+ * The staging directory for all files committed with absolute output paths.
+ */
+ private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId)
+
+ protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = {
+ val format = context.getOutputFormatClass.newInstance()
+ // If OutputFormat is Configurable, we should set conf to it.
+ format match {
+ case c: Configurable => c.setConf(context.getConfiguration)
+ case _ => ()
+ }
+ format.getOutputCommitter(context)
+ }
+
+ override def newTaskTempFile(
+ taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = {
+ val filename = getFilename(taskContext, ext)
+
+ val stagingDir: String = committer match {
+ // For FileOutputCommitter it has its own staging path called "work path".
+ case f: FileOutputCommitter => Option(f.getWorkPath.toString).getOrElse(path)
+ case _ => path
+ }
+
+ dir.map { d =>
+ new Path(new Path(stagingDir, d), filename).toString
+ }.getOrElse {
+ new Path(stagingDir, filename).toString
+ }
+ }
+
+ override def newTaskTempFileAbsPath(
+ taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = {
+ val filename = getFilename(taskContext, ext)
+ val absOutputPath = new Path(absoluteDir, filename).toString
+
+ // Include a UUID here to prevent file collisions for one task writing to different dirs.
+ // In principle we could include hash(absoluteDir) instead but this is simpler.
+ val tmpOutputPath = new Path(
+ absPathStagingDir, UUID.randomUUID().toString() + "-" + filename).toString
+
+ addedAbsPathFiles(tmpOutputPath) = absOutputPath
+ tmpOutputPath
+ }
+
+ private def getFilename(taskContext: TaskAttemptContext, ext: String): String = {
+ // The file name looks like part-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003-c000.parquet
+ // Note that %05d does not truncate the split number, so if we have more than 100000 tasks,
+ // the file name is fine and won't overflow.
+ val split = taskContext.getTaskAttemptID.getTaskID.getId
+ f"part-$split%05d-$jobId$ext"
+ }
+
+ override def setupJob(jobContext: JobContext): Unit = {
+ // Setup IDs
+ val jobId = SparkHadoopWriterUtils.createJobID(new Date, 0)
+ val taskId = new TaskID(jobId, TaskType.MAP, 0)
+ val taskAttemptId = new TaskAttemptID(taskId, 0)
+
+ // Set up the configuration object
+ jobContext.getConfiguration.set("mapreduce.job.id", jobId.toString)
+ jobContext.getConfiguration.set("mapreduce.task.id", taskAttemptId.getTaskID.toString)
+ jobContext.getConfiguration.set("mapreduce.task.attempt.id", taskAttemptId.toString)
+ jobContext.getConfiguration.setBoolean("mapreduce.task.ismap", true)
+ jobContext.getConfiguration.setInt("mapreduce.task.partition", 0)
+
+ val taskAttemptContext = new TaskAttemptContextImpl(jobContext.getConfiguration, taskAttemptId)
+ committer = setupCommitter(taskAttemptContext)
+ committer.setupJob(jobContext)
+ }
+
+ override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = {
+ committer.commitJob(jobContext)
+ val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]])
+ .foldLeft(Map[String, String]())(_ ++ _)
+ logDebug(s"Committing files staged for absolute locations $filesToMove")
+ if (hasValidPath) {
+ val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration)
+ for ((src, dst) <- filesToMove) {
+ fs.rename(new Path(src), new Path(dst))
+ }
+ fs.delete(absPathStagingDir, true)
+ }
+ }
+
+ override def abortJob(jobContext: JobContext): Unit = {
+ committer.abortJob(jobContext, JobStatus.State.FAILED)
+ if (hasValidPath) {
+ val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration)
+ fs.delete(absPathStagingDir, true)
+ }
+ }
+
+ override def setupTask(taskContext: TaskAttemptContext): Unit = {
+ committer = setupCommitter(taskContext)
+ committer.setupTask(taskContext)
+ addedAbsPathFiles = mutable.Map[String, String]()
+ }
+
+ override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = {
+ val attemptId = taskContext.getTaskAttemptID
+ SparkHadoopMapRedUtil.commitTask(
+ committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId)
+ new TaskCommitMessage(addedAbsPathFiles.toMap)
+ }
+
+ override def abortTask(taskContext: TaskAttemptContext): Unit = {
+ committer.abortTask(taskContext)
+ // best effort cleanup of other staged files
+ for ((src, _) <- addedAbsPathFiles) {
+ val tmp = new Path(src)
+ tmp.getFileSystem(taskContext.getConfiguration).delete(tmp, false)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala
new file mode 100644
index 000000000000..dd72f9430366
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.internal.io
+
+import java.text.SimpleDateFormat
+import java.util.{Date, Locale}
+
+import scala.reflect.ClassTag
+import scala.util.DynamicVariable
+
+import org.apache.hadoop.conf.{Configurable, Configuration}
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapred.{JobConf, JobID}
+import org.apache.hadoop.mapreduce._
+import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
+
+import org.apache.spark.{SparkConf, SparkException, TaskContext}
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.executor.OutputMetrics
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.{SerializableConfiguration, Utils}
+
+/**
+ * A helper object that saves an RDD using a Hadoop OutputFormat
+ * (from the newer mapreduce API, not the old mapred API).
+ */
+private[spark]
+object SparkHadoopMapReduceWriter extends Logging {
+
+ /**
+ * Basic work flow of this command is:
+ * 1. Driver side setup, prepare the data source and hadoop configuration for the write job to
+ * be issued.
+ * 2. Issues a write job consists of one or more executor side tasks, each of which writes all
+ * rows within an RDD partition.
+ * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any
+ * exception is thrown during task commitment, also aborts that task.
+ * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is
+ * thrown during job commitment, also aborts the job.
+ */
+ def write[K, V: ClassTag](
+ rdd: RDD[(K, V)],
+ hadoopConf: Configuration): Unit = {
+ // Extract context and configuration from RDD.
+ val sparkContext = rdd.context
+ val commitJobId = rdd.id
+ val sparkConf = rdd.conf
+ val conf = new SerializableConfiguration(hadoopConf)
+
+ // Set up a job.
+ val jobTrackerId = SparkHadoopWriterUtils.createJobTrackerID(new Date())
+ val jobAttemptId = new TaskAttemptID(jobTrackerId, commitJobId, TaskType.MAP, 0, 0)
+ val jobContext = new TaskAttemptContextImpl(conf.value, jobAttemptId)
+ val format = jobContext.getOutputFormatClass
+
+ if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(sparkConf)) {
+ // FileOutputFormat ignores the filesystem parameter
+ val jobFormat = format.newInstance
+ jobFormat.checkOutputSpecs(jobContext)
+ }
+
+ val committer = FileCommitProtocol.instantiate(
+ className = classOf[HadoopMapReduceCommitProtocol].getName,
+ jobId = commitJobId.toString,
+ outputPath = conf.value.get("mapreduce.output.fileoutputformat.outputdir"),
+ isAppend = false).asInstanceOf[HadoopMapReduceCommitProtocol]
+ committer.setupJob(jobContext)
+
+ // Try to write all RDD partitions as a Hadoop OutputFormat.
+ try {
+ val ret = sparkContext.runJob(rdd, (context: TaskContext, iter: Iterator[(K, V)]) => {
+ // SPARK-24552: Generate a unique "attempt ID" based on the stage and task attempt numbers.
+ // Assumes that there won't be more than Short.MaxValue attempts, at least not concurrently.
+ val attemptId = (context.stageAttemptNumber << 16) | context.attemptNumber
+
+ executeTask(
+ context = context,
+ jobTrackerId = jobTrackerId,
+ commitJobId = commitJobId,
+ sparkPartitionId = context.partitionId,
+ sparkAttemptNumber = attemptId,
+ committer = committer,
+ hadoopConf = conf.value,
+ outputFormat = format.asInstanceOf[Class[OutputFormat[K, V]]],
+ iterator = iter)
+ })
+
+ committer.commitJob(jobContext, ret)
+ logInfo(s"Job ${jobContext.getJobID} committed.")
+ } catch {
+ case cause: Throwable =>
+ logError(s"Aborting job ${jobContext.getJobID}.", cause)
+ committer.abortJob(jobContext)
+ throw new SparkException("Job aborted.", cause)
+ }
+ }
+
+ /** Write an RDD partition out in a single Spark task. */
+ private def executeTask[K, V: ClassTag](
+ context: TaskContext,
+ jobTrackerId: String,
+ commitJobId: Int,
+ sparkPartitionId: Int,
+ sparkAttemptNumber: Int,
+ committer: FileCommitProtocol,
+ hadoopConf: Configuration,
+ outputFormat: Class[_ <: OutputFormat[K, V]],
+ iterator: Iterator[(K, V)]): TaskCommitMessage = {
+ // Set up a task.
+ val attemptId = new TaskAttemptID(jobTrackerId, commitJobId, TaskType.REDUCE,
+ sparkPartitionId, sparkAttemptNumber)
+ val taskContext = new TaskAttemptContextImpl(hadoopConf, attemptId)
+ committer.setupTask(taskContext)
+
+ val (outputMetrics, callback) = SparkHadoopWriterUtils.initHadoopOutputMetrics(context)
+
+ // Initiate the writer.
+ val taskFormat = outputFormat.newInstance()
+ // If OutputFormat is Configurable, we should set conf to it.
+ taskFormat match {
+ case c: Configurable => c.setConf(hadoopConf)
+ case _ => ()
+ }
+ var writer = taskFormat.getRecordWriter(taskContext)
+ .asInstanceOf[RecordWriter[K, V]]
+ require(writer != null, "Unable to obtain RecordWriter")
+ var recordsWritten = 0L
+
+ // Write all rows in RDD partition.
+ try {
+ val ret = Utils.tryWithSafeFinallyAndFailureCallbacks {
+ // Write rows out, release resource and commit the task.
+ while (iterator.hasNext) {
+ val pair = iterator.next()
+ writer.write(pair._1, pair._2)
+
+ // Update bytes written metric every few records
+ SparkHadoopWriterUtils.maybeUpdateOutputMetrics(outputMetrics, callback, recordsWritten)
+ recordsWritten += 1
+ }
+ if (writer != null) {
+ writer.close(taskContext)
+ writer = null
+ }
+ committer.commitTask(taskContext)
+ }(catchBlock = {
+ // If there is an error, release resource and then abort the task.
+ try {
+ if (writer != null) {
+ writer.close(taskContext)
+ writer = null
+ }
+ } finally {
+ committer.abortTask(taskContext)
+ logError(s"Task ${taskContext.getTaskAttemptID} aborted.")
+ }
+ })
+
+ outputMetrics.setBytesWritten(callback())
+ outputMetrics.setRecordsWritten(recordsWritten)
+
+ ret
+ } catch {
+ case t: Throwable =>
+ throw new SparkException("Task failed while writing rows", t)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala
similarity index 78%
rename from core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
rename to core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala
index ac6eaab20d8d..acc9c3857100 100644
--- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
+++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala
@@ -15,17 +15,18 @@
* limitations under the License.
*/
-package org.apache.spark
+package org.apache.spark.internal.io
import java.io.IOException
-import java.text.NumberFormat
-import java.text.SimpleDateFormat
-import java.util.Date
+import java.text.{NumberFormat, SimpleDateFormat}
+import java.util.{Date, Locale}
-import org.apache.hadoop.mapred._
import org.apache.hadoop.fs.FileSystem
-import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapred._
+import org.apache.hadoop.mapreduce.TaskType
+import org.apache.spark.SerializableWritable
+import org.apache.spark.internal.Logging
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.rdd.HadoopRDD
import org.apache.spark.util.SerializableJobConf
@@ -37,10 +38,7 @@ import org.apache.spark.util.SerializableJobConf
* a filename to write to, etc, exactly like in a Hadoop MapReduce job.
*/
private[spark]
-class SparkHadoopWriter(jobConf: JobConf)
- extends Logging
- with SparkHadoopMapRedUtil
- with Serializable {
+class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable {
private val now = new Date()
private val conf = new SerializableJobConf(jobConf)
@@ -68,12 +66,12 @@ class SparkHadoopWriter(jobConf: JobConf)
def setup(jobid: Int, splitid: Int, attemptid: Int) {
setIDs(jobid, splitid, attemptid)
- HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(now),
+ HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(now),
jobid, splitID, attemptID, conf.value)
}
def open() {
- val numfmt = NumberFormat.getInstance()
+ val numfmt = NumberFormat.getInstance(Locale.US)
numfmt.setMinimumIntegerDigits(5)
numfmt.setGroupingUsed(false)
@@ -131,7 +129,7 @@ class SparkHadoopWriter(jobConf: JobConf)
private def getJobContext(): JobContext = {
if (jobContext == null) {
- jobContext = newJobContext(conf.value, jID.value)
+ jobContext = new JobContextImpl(conf.value, jID.value)
}
jobContext
}
@@ -143,34 +141,19 @@ class SparkHadoopWriter(jobConf: JobConf)
taskContext
}
+ protected def newTaskAttemptContext(
+ conf: JobConf,
+ attemptId: TaskAttemptID): TaskAttemptContext = {
+ new TaskAttemptContextImpl(conf, attemptId)
+ }
+
private def setIDs(jobid: Int, splitid: Int, attemptid: Int) {
jobID = jobid
splitID = splitid
attemptID = attemptid
- jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobid))
+ jID = new SerializableWritable[JobID](SparkHadoopWriterUtils.createJobID(now, jobid))
taID = new SerializableWritable[TaskAttemptID](
- new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID))
- }
-}
-
-private[spark]
-object SparkHadoopWriter {
- def createJobID(time: Date, id: Int): JobID = {
- val formatter = new SimpleDateFormat("yyyyMMddHHmm")
- val jobtrackerID = formatter.format(time)
- new JobID(jobtrackerID, id)
- }
-
- def createPathFromString(path: String, conf: JobConf): Path = {
- if (path == null) {
- throw new IllegalArgumentException("Output path is null")
- }
- val outputPath = new Path(path)
- val fs = outputPath.getFileSystem(conf)
- if (outputPath == null || fs == null) {
- throw new IllegalArgumentException("Incorrectly formatted output path")
- }
- outputPath.makeQualified(fs)
+ new TaskAttemptID(new TaskID(jID.value, TaskType.MAP, splitID), attemptID))
}
}
diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala
new file mode 100644
index 000000000000..de828a6d6156
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.internal.io
+
+import java.text.SimpleDateFormat
+import java.util.{Date, Locale}
+
+import scala.util.DynamicVariable
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapred.{JobConf, JobID}
+
+import org.apache.spark.{SparkConf, TaskContext}
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.executor.OutputMetrics
+
+/**
+ * A helper object that provide common utils used during saving an RDD using a Hadoop OutputFormat
+ * (both from the old mapred API and the new mapreduce API)
+ */
+private[spark]
+object SparkHadoopWriterUtils {
+
+ private val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256
+
+ def createJobID(time: Date, id: Int): JobID = {
+ val jobtrackerID = createJobTrackerID(time)
+ new JobID(jobtrackerID, id)
+ }
+
+ def createJobTrackerID(time: Date): String = {
+ new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(time)
+ }
+
+ def createPathFromString(path: String, conf: JobConf): Path = {
+ if (path == null) {
+ throw new IllegalArgumentException("Output path is null")
+ }
+ val outputPath = new Path(path)
+ val fs = outputPath.getFileSystem(conf)
+ if (fs == null) {
+ throw new IllegalArgumentException("Incorrectly formatted output path")
+ }
+ outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
+ }
+
+ // Note: this needs to be a function instead of a 'val' so that the disableOutputSpecValidation
+ // setting can take effect:
+ def isOutputSpecValidationEnabled(conf: SparkConf): Boolean = {
+ val validationDisabled = disableOutputSpecValidation.value
+ val enabledInConf = conf.getBoolean("spark.hadoop.validateOutputSpecs", true)
+ enabledInConf && !validationDisabled
+ }
+
+ // TODO: these don't seem like the right abstractions.
+ // We should abstract the duplicate code in a less awkward way.
+
+ def initHadoopOutputMetrics(context: TaskContext): (OutputMetrics, () => Long) = {
+ val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback()
+ (context.taskMetrics().outputMetrics, bytesWrittenCallback)
+ }
+
+ def maybeUpdateOutputMetrics(
+ outputMetrics: OutputMetrics,
+ callback: () => Long,
+ recordsWritten: Long): Unit = {
+ if (recordsWritten % RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0) {
+ outputMetrics.setBytesWritten(callback())
+ outputMetrics.setRecordsWritten(recordsWritten)
+ }
+ }
+
+ /**
+ * Allows for the `spark.hadoop.validateOutputSpecs` checks to be disabled on a case-by-case
+ * basis; see SPARK-4835 for more details.
+ */
+ val disableOutputSpecValidation: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false)
+}
diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
index ca74eedf89be..0cb16f0627b7 100644
--- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
+++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
@@ -17,10 +17,11 @@
package org.apache.spark.io
-import java.io.{IOException, InputStream, OutputStream}
+import java.io._
+import java.util.Locale
import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
-import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream}
+import net.jpountz.lz4.LZ4BlockOutputStream
import org.xerial.snappy.{Snappy, SnappyInputStream, SnappyOutputStream}
import org.apache.spark.SparkConf
@@ -32,9 +33,8 @@ import org.apache.spark.util.Utils
* CompressionCodec allows the customization of choosing different compression implementations
* to be used in block storage.
*
- * Note: The wire protocol for a codec is not guaranteed compatible across versions of Spark.
- * This is intended for use as an internal compression utility within a single
- * Spark application.
+ * @note The wire protocol for a codec is not guaranteed compatible across versions of Spark.
+ * This is intended for use as an internal compression utility within a single Spark application.
*/
@DeveloperApi
trait CompressionCodec {
@@ -49,7 +49,8 @@ private[spark] object CompressionCodec {
private val configKey = "spark.io.compression.codec"
private[spark] def supportsConcatenationOfSerializedStreams(codec: CompressionCodec): Boolean = {
- codec.isInstanceOf[SnappyCompressionCodec] || codec.isInstanceOf[LZFCompressionCodec]
+ (codec.isInstanceOf[SnappyCompressionCodec] || codec.isInstanceOf[LZFCompressionCodec]
+ || codec.isInstanceOf[LZ4CompressionCodec])
}
private val shortCompressionCodecNames = Map(
@@ -66,13 +67,13 @@ private[spark] object CompressionCodec {
}
def createCodec(conf: SparkConf, codecName: String): CompressionCodec = {
- val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName)
+ val codecClass =
+ shortCompressionCodecNames.getOrElse(codecName.toLowerCase(Locale.ROOT), codecName)
val codec = try {
val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf])
Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec])
} catch {
- case e: ClassNotFoundException => None
- case e: IllegalArgumentException => None
+ case _: ClassNotFoundException | _: IllegalArgumentException => None
}
codec.getOrElse(throw new IllegalArgumentException(s"Codec [$codecName] is not available. " +
s"Consider setting $configKey=$FALLBACK_COMPRESSION_CODEC"))
@@ -92,20 +93,19 @@ private[spark] object CompressionCodec {
}
}
- val FALLBACK_COMPRESSION_CODEC = "lzf"
- val DEFAULT_COMPRESSION_CODEC = "snappy"
+ val FALLBACK_COMPRESSION_CODEC = "snappy"
+ val DEFAULT_COMPRESSION_CODEC = "lz4"
val ALL_COMPRESSION_CODECS = shortCompressionCodecNames.values.toSeq
}
-
/**
* :: DeveloperApi ::
* LZ4 implementation of [[org.apache.spark.io.CompressionCodec]].
* Block size can be configured by `spark.io.compression.lz4.blockSize`.
*
- * Note: The wire protocol for this codec is not guaranteed to be compatible across versions
- * of Spark. This is intended for use as an internal compression utility within a single Spark
- * application.
+ * @note The wire protocol for this codec is not guaranteed to be compatible across versions
+ * of Spark. This is intended for use as an internal compression utility within a single Spark
+ * application.
*/
@DeveloperApi
class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec {
@@ -123,9 +123,9 @@ class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec {
* :: DeveloperApi ::
* LZF implementation of [[org.apache.spark.io.CompressionCodec]].
*
- * Note: The wire protocol for this codec is not guaranteed to be compatible across versions
- * of Spark. This is intended for use as an internal compression utility within a single Spark
- * application.
+ * @note The wire protocol for this codec is not guaranteed to be compatible across versions
+ * of Spark. This is intended for use as an internal compression utility within a single Spark
+ * application.
*/
@DeveloperApi
class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec {
@@ -143,18 +143,13 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec {
* Snappy implementation of [[org.apache.spark.io.CompressionCodec]].
* Block size can be configured by `spark.io.compression.snappy.blockSize`.
*
- * Note: The wire protocol for this codec is not guaranteed to be compatible across versions
- * of Spark. This is intended for use as an internal compression utility within a single Spark
- * application.
+ * @note The wire protocol for this codec is not guaranteed to be compatible across versions
+ * of Spark. This is intended for use as an internal compression utility within a single Spark
+ * application.
*/
@DeveloperApi
class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec {
-
- try {
- Snappy.getNativeLibraryVersion
- } catch {
- case e: Error => throw new IllegalArgumentException(e)
- }
+ val version = SnappyCompressionCodec.version
override def compressedOutputStream(s: OutputStream): OutputStream = {
val blockSize = conf.getSizeAsBytes("spark.io.compression.snappy.blockSize", "32k").toInt
@@ -165,7 +160,20 @@ class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec {
}
/**
- * Wrapper over [[SnappyOutputStream]] which guards against write-after-close and double-close
+ * Object guards against memory leak bug in snappy-java library:
+ * (https://github.com/xerial/snappy-java/issues/131).
+ * Before a new version of the library, we only call the method once and cache the result.
+ */
+private final object SnappyCompressionCodec {
+ private lazy val version: String = try {
+ Snappy.getNativeLibraryVersion
+ } catch {
+ case e: Error => throw new IllegalArgumentException(e)
+ }
+}
+
+/**
+ * Wrapper over `SnappyOutputStream` which guards against write-after-close and double-close
* issues. See SPARK-7660 for more details. This wrapping can be removed if we upgrade to a version
* of snappy-java that contains the fix for https://github.com/xerial/snappy-java/issues/107.
*/
diff --git a/core/src/main/scala/org/apache/spark/io/package-info.java b/core/src/main/scala/org/apache/spark/io/package-info.java
index bea1bfdb6375..1a466602806e 100644
--- a/core/src/main/scala/org/apache/spark/io/package-info.java
+++ b/core/src/main/scala/org/apache/spark/io/package-info.java
@@ -18,4 +18,4 @@
/**
* IO codecs used for compression.
*/
-package org.apache.spark.io;
\ No newline at end of file
+package org.apache.spark.io;
diff --git a/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala
index 3ea984c501e0..a5d41a1eeb47 100644
--- a/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala
+++ b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala
@@ -21,7 +21,7 @@ import java.net.{InetAddress, Socket}
import org.apache.spark.SPARK_VERSION
import org.apache.spark.launcher.LauncherProtocol._
-import org.apache.spark.util.ThreadUtils
+import org.apache.spark.util.{ThreadUtils, Utils}
/**
* A class that can be used to talk to a launcher server. Users should extend this class to
@@ -88,12 +88,20 @@ private[spark] abstract class LauncherBackend {
*/
protected def onDisconnected() : Unit = { }
+ private def fireStopRequest(): Unit = {
+ val thread = LauncherBackend.threadFactory.newThread(new Runnable() {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ onStopRequest()
+ }
+ })
+ thread.start()
+ }
private class BackendConnection(s: Socket) extends LauncherConnection(s) {
override protected def handle(m: Message): Unit = m match {
case _: Stop =>
- onStopRequest()
+ fireStopRequest()
case _ =>
throw new IllegalArgumentException(s"Unexpected message type: ${m.getClass().getName()}")
diff --git a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala
index a2add6161728..4216b2627309 100644
--- a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala
+++ b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala
@@ -37,11 +37,8 @@ private[spark] class WorkerCommandBuilder(sparkHome: String, memoryMb: Int, comm
override def buildCommand(env: JMap[String, String]): JList[String] = {
val cmd = buildJavaCommand(command.classPathEntries.mkString(File.pathSeparator))
- cmd.add(s"-Xms${memoryMb}M")
cmd.add(s"-Xmx${memoryMb}M")
command.javaOpts.foreach(cmd.add)
- CommandBuilderUtils.addPermGenSizeOpt(cmd)
- addOptionString(cmd, getenv("SPARK_JAVA_OPTS"))
cmd
}
diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
index f7298e8d5c62..db8aff94ea1e 100644
--- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
+++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
@@ -18,61 +18,13 @@
package org.apache.spark.mapred
import java.io.IOException
-import java.lang.reflect.Modifier
-import org.apache.hadoop.mapred._
import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext}
import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter}
-import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.executor.CommitDeniedException
-import org.apache.spark.{Logging, SparkEnv, TaskContext}
-import org.apache.spark.util.{Utils => SparkUtils}
-
-private[spark]
-trait SparkHadoopMapRedUtil {
- def newJobContext(conf: JobConf, jobId: JobID): JobContext = {
- val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl",
- "org.apache.hadoop.mapred.JobContext")
- val ctor = klass.getDeclaredConstructor(classOf[JobConf],
- classOf[org.apache.hadoop.mapreduce.JobID])
- // In Hadoop 1.0.x, JobContext is an interface, and JobContextImpl is package private.
- // Make it accessible if it's not in order to access it.
- if (!Modifier.isPublic(ctor.getModifiers)) {
- ctor.setAccessible(true)
- }
- ctor.newInstance(conf, jobId).asInstanceOf[JobContext]
- }
-
- def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = {
- val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl",
- "org.apache.hadoop.mapred.TaskAttemptContext")
- val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID])
- // See above
- if (!Modifier.isPublic(ctor.getModifiers)) {
- ctor.setAccessible(true)
- }
- ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
- }
-
- def newTaskAttemptID(
- jtIdentifier: String,
- jobId: Int,
- isMap: Boolean,
- taskId: Int,
- attemptId: Int): TaskAttemptID = {
- new TaskAttemptID(jtIdentifier, jobId, isMap, taskId, attemptId)
- }
-
- private def firstAvailableClass(first: String, second: String): Class[_] = {
- try {
- SparkUtils.classForName(first)
- } catch {
- case e: ClassNotFoundException =>
- SparkUtils.classForName(second)
- }
- }
-}
+import org.apache.spark.internal.Logging
object SparkHadoopMapRedUtil extends Logging {
/**
@@ -81,11 +33,8 @@ object SparkHadoopMapRedUtil extends Logging {
* the driver in order to determine whether this attempt can commit (please see SPARK-4879 for
* details).
*
- * Output commit coordinator is only contacted when the following two configurations are both set
- * to `true`:
- *
- * - `spark.speculation`
- * - `spark.hadoop.outputCommitCoordination.enabled`
+ * Output commit coordinator is only used when `spark.hadoop.outputCommitCoordination.enabled`
+ * is set to true (which is the default).
*/
def commitTask(
committer: MapReduceOutputCommitter,
@@ -93,7 +42,7 @@ object SparkHadoopMapRedUtil extends Logging {
jobId: Int,
splitId: Int): Unit = {
- val mrTaskAttemptID = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(mrTaskContext)
+ val mrTaskAttemptID = mrTaskContext.getTaskAttemptID
// Called after we have decided to commit
def performCommit(): Unit = {
@@ -112,17 +61,17 @@ object SparkHadoopMapRedUtil extends Logging {
if (committer.needsTaskCommit(mrTaskContext)) {
val shouldCoordinateWithDriver: Boolean = {
val sparkConf = SparkEnv.get.conf
- // We only need to coordinate with the driver if there are multiple concurrent task
- // attempts, which should only occur if speculation is enabled
- val speculationEnabled = sparkConf.getBoolean("spark.speculation", defaultValue = false)
- // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs
- sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", speculationEnabled)
+ // We only need to coordinate with the driver if there are concurrent task attempts.
+ // Note that this could happen even when speculation is not enabled (e.g. see SPARK-8029).
+ // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs.
+ sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", defaultValue = true)
}
if (shouldCoordinateWithDriver) {
val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator
- val taskAttemptNumber = TaskContext.get().attemptNumber()
- val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, taskAttemptNumber)
+ val ctx = TaskContext.get()
+ val canCommit = outputCommitCoordinator.canCommit(ctx.stageId(), ctx.stageAttemptNumber(),
+ splitId, ctx.attemptNumber())
if (canCommit) {
performCommit()
@@ -132,7 +81,7 @@ object SparkHadoopMapRedUtil extends Logging {
logInfo(message)
// We need to abort the task so that the driver can reschedule new attempts, if necessary
committer.abortTask(mrTaskContext)
- throw new CommitDeniedException(message, jobId, splitId, taskAttemptNumber)
+ throw new CommitDeniedException(message, ctx.stageId(), splitId, ctx.attemptNumber())
}
} else {
// Speculation is disabled or a user has chosen to manually bypass the commit coordination
diff --git a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala
deleted file mode 100644
index 943ebcb7bd0a..000000000000
--- a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala
+++ /dev/null
@@ -1,81 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mapreduce
-
-import java.lang.{Boolean => JBoolean, Integer => JInteger}
-
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.mapreduce.{JobContext, JobID, TaskAttemptContext, TaskAttemptID}
-import org.apache.spark.util.Utils
-
-private[spark]
-trait SparkHadoopMapReduceUtil {
- def newJobContext(conf: Configuration, jobId: JobID): JobContext = {
- val klass = firstAvailableClass(
- "org.apache.hadoop.mapreduce.task.JobContextImpl", // hadoop2, hadoop2-yarn
- "org.apache.hadoop.mapreduce.JobContext") // hadoop1
- val ctor = klass.getDeclaredConstructor(classOf[Configuration], classOf[JobID])
- ctor.newInstance(conf, jobId).asInstanceOf[JobContext]
- }
-
- def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = {
- val klass = firstAvailableClass(
- "org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl", // hadoop2, hadoop2-yarn
- "org.apache.hadoop.mapreduce.TaskAttemptContext") // hadoop1
- val ctor = klass.getDeclaredConstructor(classOf[Configuration], classOf[TaskAttemptID])
- ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
- }
-
- def newTaskAttemptID(
- jtIdentifier: String,
- jobId: Int,
- isMap: Boolean,
- taskId: Int,
- attemptId: Int): TaskAttemptID = {
- val klass = Utils.classForName("org.apache.hadoop.mapreduce.TaskAttemptID")
- try {
- // First, attempt to use the old-style constructor that takes a boolean isMap
- // (not available in YARN)
- val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], classOf[Boolean],
- classOf[Int], classOf[Int])
- ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId),
- new JInteger(attemptId)).asInstanceOf[TaskAttemptID]
- } catch {
- case exc: NoSuchMethodException => {
- // If that failed, look for the new constructor that takes a TaskType (not available in 1.x)
- val taskTypeClass = Utils.classForName("org.apache.hadoop.mapreduce.TaskType")
- .asInstanceOf[Class[Enum[_]]]
- val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(
- taskTypeClass, if (isMap) "MAP" else "REDUCE")
- val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass,
- classOf[Int], classOf[Int])
- ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId),
- new JInteger(attemptId)).asInstanceOf[TaskAttemptID]
- }
- }
- }
-
- private def firstAvailableClass(first: String, second: String): Class[_] = {
- try {
- Utils.classForName(first)
- } catch {
- case e: ClassNotFoundException =>
- Utils.classForName(second)
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala
new file mode 100644
index 000000000000..f1915857ea43
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.memory
+
+import javax.annotation.concurrent.GuardedBy
+
+import scala.collection.mutable
+
+import org.apache.spark.internal.Logging
+
+/**
+ * Implements policies and bookkeeping for sharing an adjustable-sized pool of memory between tasks.
+ *
+ * Tries to ensure that each task gets a reasonable share of memory, instead of some task ramping up
+ * to a large amount first and then causing others to spill to disk repeatedly.
+ *
+ * If there are N tasks, it ensures that each task can acquire at least 1 / 2N of the memory
+ * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the
+ * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever this
+ * set changes. This is all done by synchronizing access to mutable state and using wait() and
+ * notifyAll() to signal changes to callers. Prior to Spark 1.6, this arbitration of memory across
+ * tasks was performed by the ShuffleMemoryManager.
+ *
+ * @param lock a [[MemoryManager]] instance to synchronize on
+ * @param memoryMode the type of memory tracked by this pool (on- or off-heap)
+ */
+private[memory] class ExecutionMemoryPool(
+ lock: Object,
+ memoryMode: MemoryMode
+ ) extends MemoryPool(lock) with Logging {
+
+ private[this] val poolName: String = memoryMode match {
+ case MemoryMode.ON_HEAP => "on-heap execution"
+ case MemoryMode.OFF_HEAP => "off-heap execution"
+ }
+
+ /**
+ * Map from taskAttemptId -> memory consumption in bytes
+ */
+ @GuardedBy("lock")
+ private val memoryForTask = new mutable.HashMap[Long, Long]()
+
+ override def memoryUsed: Long = lock.synchronized {
+ memoryForTask.values.sum
+ }
+
+ /**
+ * Returns the memory consumption, in bytes, for the given task.
+ */
+ def getMemoryUsageForTask(taskAttemptId: Long): Long = lock.synchronized {
+ memoryForTask.getOrElse(taskAttemptId, 0L)
+ }
+
+ /**
+ * Try to acquire up to `numBytes` of memory for the given task and return the number of bytes
+ * obtained, or 0 if none can be allocated.
+ *
+ * This call may block until there is enough free memory in some situations, to make sure each
+ * task has a chance to ramp up to at least 1 / 2N of the total memory pool (where N is the # of
+ * active tasks) before it is forced to spill. This can happen if the number of tasks increase
+ * but an older task had a lot of memory already.
+ *
+ * @param numBytes number of bytes to acquire
+ * @param taskAttemptId the task attempt acquiring memory
+ * @param maybeGrowPool a callback that potentially grows the size of this pool. It takes in
+ * one parameter (Long) that represents the desired amount of memory by
+ * which this pool should be expanded.
+ * @param computeMaxPoolSize a callback that returns the maximum allowable size of this pool
+ * at this given moment. This is not a field because the max pool
+ * size is variable in certain cases. For instance, in unified
+ * memory management, the execution pool can be expanded by evicting
+ * cached blocks, thereby shrinking the storage pool.
+ *
+ * @return the number of bytes granted to the task.
+ */
+ private[memory] def acquireMemory(
+ numBytes: Long,
+ taskAttemptId: Long,
+ maybeGrowPool: Long => Unit = (additionalSpaceNeeded: Long) => Unit,
+ computeMaxPoolSize: () => Long = () => poolSize): Long = lock.synchronized {
+ assert(numBytes > 0, s"invalid number of bytes requested: $numBytes")
+
+ // TODO: clean up this clunky method signature
+
+ // Add this task to the taskMemory map just so we can keep an accurate count of the number
+ // of active tasks, to let other tasks ramp down their memory in calls to `acquireMemory`
+ if (!memoryForTask.contains(taskAttemptId)) {
+ memoryForTask(taskAttemptId) = 0L
+ // This will later cause waiting tasks to wake up and check numTasks again
+ lock.notifyAll()
+ }
+
+ // Keep looping until we're either sure that we don't want to grant this request (because this
+ // task would have more than 1 / numActiveTasks of the memory) or we have enough free
+ // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)).
+ // TODO: simplify this to limit each task to its own slot
+ while (true) {
+ val numActiveTasks = memoryForTask.keys.size
+ val curMem = memoryForTask(taskAttemptId)
+
+ // In every iteration of this loop, we should first try to reclaim any borrowed execution
+ // space from storage. This is necessary because of the potential race condition where new
+ // storage blocks may steal the free execution memory that this task was waiting for.
+ maybeGrowPool(numBytes - memoryFree)
+
+ // Maximum size the pool would have after potentially growing the pool.
+ // This is used to compute the upper bound of how much memory each task can occupy. This
+ // must take into account potential free memory as well as the amount this pool currently
+ // occupies. Otherwise, we may run into SPARK-12155 where, in unified memory management,
+ // we did not take into account space that could have been freed by evicting cached blocks.
+ val maxPoolSize = computeMaxPoolSize()
+ val maxMemoryPerTask = maxPoolSize / numActiveTasks
+ val minMemoryPerTask = poolSize / (2 * numActiveTasks)
+
+ // How much we can grant this task; keep its share within 0 <= X <= 1 / numActiveTasks
+ val maxToGrant = math.min(numBytes, math.max(0, maxMemoryPerTask - curMem))
+ // Only give it as much memory as is free, which might be none if it reached 1 / numTasks
+ val toGrant = math.min(maxToGrant, memoryFree)
+
+ // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking;
+ // if we can't give it this much now, wait for other tasks to free up memory
+ // (this happens if older tasks allocated lots of memory before N grew)
+ if (toGrant < numBytes && curMem + toGrant < minMemoryPerTask) {
+ logInfo(s"TID $taskAttemptId waiting for at least 1/2N of $poolName pool to be free")
+ lock.wait()
+ } else {
+ memoryForTask(taskAttemptId) += toGrant
+ return toGrant
+ }
+ }
+ 0L // Never reached
+ }
+
+ /**
+ * Release `numBytes` of memory acquired by the given task.
+ */
+ def releaseMemory(numBytes: Long, taskAttemptId: Long): Unit = lock.synchronized {
+ val curMem = memoryForTask.getOrElse(taskAttemptId, 0L)
+ var memoryToFree = if (curMem < numBytes) {
+ logWarning(
+ s"Internal error: release called on $numBytes bytes but task only has $curMem bytes " +
+ s"of memory from the $poolName pool")
+ curMem
+ } else {
+ numBytes
+ }
+ if (memoryForTask.contains(taskAttemptId)) {
+ memoryForTask(taskAttemptId) -= memoryToFree
+ if (memoryForTask(taskAttemptId) <= 0) {
+ memoryForTask.remove(taskAttemptId)
+ }
+ }
+ lock.notifyAll() // Notify waiters in acquireMemory() that memory has been freed
+ }
+
+ /**
+ * Release all memory for the given task and mark it as inactive (e.g. when a task ends).
+ * @return the number of bytes freed.
+ */
+ def releaseAllMemoryForTask(taskAttemptId: Long): Long = lock.synchronized {
+ val numBytesToFree = getMemoryUsageForTask(taskAttemptId)
+ releaseMemory(numBytesToFree, taskAttemptId)
+ numBytesToFree
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
index b0cf2696a397..82442cf56154 100644
--- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
@@ -19,14 +19,11 @@ package org.apache.spark.memory
import javax.annotation.concurrent.GuardedBy
-import scala.collection.mutable
-import scala.collection.mutable.ArrayBuffer
-
-import com.google.common.annotations.VisibleForTesting
-
-import org.apache.spark.util.Utils
-import org.apache.spark.{SparkException, TaskContext, SparkConf, Logging}
-import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore}
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.Logging
+import org.apache.spark.storage.BlockId
+import org.apache.spark.storage.memory.MemoryStore
+import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.memory.MemoryAllocator
@@ -36,65 +33,62 @@ import org.apache.spark.unsafe.memory.MemoryAllocator
* In this context, execution memory refers to that used for computation in shuffles, joins,
* sorts and aggregations, while storage memory refers to that used for caching and propagating
* internal data across the cluster. There exists one MemoryManager per JVM.
- *
- * The MemoryManager abstract base class itself implements policies for sharing execution memory
- * between tasks; it tries to ensure that each task gets a reasonable share of memory, instead of
- * some task ramping up to a large amount first and then causing others to spill to disk repeatedly.
- * If there are N tasks, it ensures that each task can acquire at least 1 / 2N of the memory
- * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the
- * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever
- * this set changes. This is all done by synchronizing access to mutable state and using wait() and
- * notifyAll() to signal changes to callers. Prior to Spark 1.6, this arbitration of memory across
- * tasks was performed by the ShuffleMemoryManager.
*/
-private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) extends Logging {
+private[spark] abstract class MemoryManager(
+ conf: SparkConf,
+ numCores: Int,
+ onHeapStorageMemory: Long,
+ onHeapExecutionMemory: Long) extends Logging {
// -- Methods related to memory allocation policies and bookkeeping ------------------------------
- // The memory store used to evict cached blocks
- private var _memoryStore: MemoryStore = _
- protected def memoryStore: MemoryStore = {
- if (_memoryStore == null) {
- throw new IllegalArgumentException("memory store not initialized yet")
- }
- _memoryStore
- }
+ @GuardedBy("this")
+ protected val onHeapStorageMemoryPool = new StorageMemoryPool(this, MemoryMode.ON_HEAP)
+ @GuardedBy("this")
+ protected val offHeapStorageMemoryPool = new StorageMemoryPool(this, MemoryMode.OFF_HEAP)
+ @GuardedBy("this")
+ protected val onHeapExecutionMemoryPool = new ExecutionMemoryPool(this, MemoryMode.ON_HEAP)
+ @GuardedBy("this")
+ protected val offHeapExecutionMemoryPool = new ExecutionMemoryPool(this, MemoryMode.OFF_HEAP)
+
+ onHeapStorageMemoryPool.incrementPoolSize(onHeapStorageMemory)
+ onHeapExecutionMemoryPool.incrementPoolSize(onHeapExecutionMemory)
- // Amount of execution/storage memory in use, accesses must be synchronized on `this`
- @GuardedBy("this") protected var _executionMemoryUsed: Long = 0
- @GuardedBy("this") protected var _storageMemoryUsed: Long = 0
- // Map from taskAttemptId -> memory consumption in bytes
- @GuardedBy("this") private val executionMemoryForTask = new mutable.HashMap[Long, Long]()
+ protected[this] val maxOffHeapMemory = conf.getSizeAsBytes("spark.memory.offHeap.size", 0)
+ protected[this] val offHeapStorageMemory =
+ (maxOffHeapMemory * conf.getDouble("spark.memory.storageFraction", 0.5)).toLong
+
+ offHeapExecutionMemoryPool.incrementPoolSize(maxOffHeapMemory - offHeapStorageMemory)
+ offHeapStorageMemoryPool.incrementPoolSize(offHeapStorageMemory)
/**
- * Set the [[MemoryStore]] used by this manager to evict cached blocks.
- * This must be set after construction due to initialization ordering constraints.
+ * Total available on heap memory for storage, in bytes. This amount can vary over time,
+ * depending on the MemoryManager implementation.
+ * In this model, this is equivalent to the amount of memory not occupied by execution.
*/
- final def setMemoryStore(store: MemoryStore): Unit = {
- _memoryStore = store
- }
+ def maxOnHeapStorageMemory: Long
/**
- * Total available memory for execution, in bytes.
+ * Total available off heap memory for storage, in bytes. This amount can vary over time,
+ * depending on the MemoryManager implementation.
*/
- def maxExecutionMemory: Long
+ def maxOffHeapStorageMemory: Long
/**
- * Total available memory for storage, in bytes.
+ * Set the [[MemoryStore]] used by this manager to evict cached blocks.
+ * This must be set after construction due to initialization ordering constraints.
*/
- def maxStorageMemory: Long
-
- // TODO: avoid passing evicted blocks around to simplify method signatures (SPARK-10985)
+ final def setMemoryStore(store: MemoryStore): Unit = synchronized {
+ onHeapStorageMemoryPool.setMemoryStore(store)
+ offHeapStorageMemoryPool.setMemoryStore(store)
+ }
/**
* Acquire N bytes of memory to cache the given block, evicting existing ones if necessary.
- * Blocks evicted in the process, if any, are added to `evictedBlocks`.
+ *
* @return whether all N bytes were successfully granted.
*/
- def acquireStorageMemory(
- blockId: BlockId,
- numBytes: Long,
- evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean
+ def acquireStorageMemory(blockId: BlockId, numBytes: Long, memoryMode: MemoryMode): Boolean
/**
* Acquire N bytes of memory to unroll the given block, evicting existing ones if necessary.
@@ -102,197 +96,115 @@ private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) exte
* This extra method allows subclasses to differentiate behavior between acquiring storage
* memory and acquiring unroll memory. For instance, the memory management model in Spark
* 1.5 and before places a limit on the amount of space that can be freed from unrolling.
- * Blocks evicted in the process, if any, are added to `evictedBlocks`.
*
* @return whether all N bytes were successfully granted.
*/
- def acquireUnrollMemory(
- blockId: BlockId,
- numBytes: Long,
- evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized {
- acquireStorageMemory(blockId, numBytes, evictedBlocks)
- }
-
- /**
- * Acquire N bytes of memory for execution, evicting cached blocks if necessary.
- * Blocks evicted in the process, if any, are added to `evictedBlocks`.
- * @return number of bytes successfully granted (<= N).
- */
- @VisibleForTesting
- private[memory] def doAcquireExecutionMemory(
- numBytes: Long,
- evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long
+ def acquireUnrollMemory(blockId: BlockId, numBytes: Long, memoryMode: MemoryMode): Boolean
/**
- * Try to acquire up to `numBytes` of execution memory for the current task and return the number
- * of bytes obtained, or 0 if none can be allocated.
+ * Try to acquire up to `numBytes` of execution memory for the current task and return the
+ * number of bytes obtained, or 0 if none can be allocated.
*
* This call may block until there is enough free memory in some situations, to make sure each
* task has a chance to ramp up to at least 1 / 2N of the total memory pool (where N is the # of
* active tasks) before it is forced to spill. This can happen if the number of tasks increase
* but an older task had a lot of memory already.
- *
- * Subclasses should override `doAcquireExecutionMemory` in order to customize the policies
- * that control global sharing of memory between execution and storage.
*/
private[memory]
- final def acquireExecutionMemory(numBytes: Long, taskAttemptId: Long): Long = synchronized {
- assert(numBytes > 0, "invalid number of bytes requested: " + numBytes)
-
- // Add this task to the taskMemory map just so we can keep an accurate count of the number
- // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire
- if (!executionMemoryForTask.contains(taskAttemptId)) {
- executionMemoryForTask(taskAttemptId) = 0L
- // This will later cause waiting tasks to wake up and check numTasks again
- notifyAll()
- }
-
- // Once the cross-task memory allocation policy has decided to grant more memory to a task,
- // this method is called in order to actually obtain that execution memory, potentially
- // triggering eviction of storage memory:
- def acquire(toGrant: Long): Long = synchronized {
- val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
- val acquired = doAcquireExecutionMemory(toGrant, evictedBlocks)
- // Register evicted blocks, if any, with the active task metrics
- Option(TaskContext.get()).foreach { tc =>
- val metrics = tc.taskMetrics()
- val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]())
- metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq)
- }
- executionMemoryForTask(taskAttemptId) += acquired
- acquired
- }
-
- // Keep looping until we're either sure that we don't want to grant this request (because this
- // task would have more than 1 / numActiveTasks of the memory) or we have enough free
- // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)).
- // TODO: simplify this to limit each task to its own slot
- while (true) {
- val numActiveTasks = executionMemoryForTask.keys.size
- val curMem = executionMemoryForTask(taskAttemptId)
- val freeMemory = maxExecutionMemory - executionMemoryForTask.values.sum
-
- // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks;
- // don't let it be negative
- val maxToGrant =
- math.min(numBytes, math.max(0, (maxExecutionMemory / numActiveTasks) - curMem))
- // Only give it as much memory as is free, which might be none if it reached 1 / numTasks
- val toGrant = math.min(maxToGrant, freeMemory)
-
- if (curMem < maxExecutionMemory / (2 * numActiveTasks)) {
- // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking;
- // if we can't give it this much now, wait for other tasks to free up memory
- // (this happens if older tasks allocated lots of memory before N grew)
- if (
- freeMemory >= math.min(maxToGrant, maxExecutionMemory / (2 * numActiveTasks) - curMem)) {
- return acquire(toGrant)
- } else {
- logInfo(
- s"TID $taskAttemptId waiting for at least 1/2N of execution memory pool to be free")
- wait()
- }
- } else {
- return acquire(toGrant)
- }
- }
- 0L // Never reached
- }
-
- @VisibleForTesting
- private[memory] def releaseExecutionMemory(numBytes: Long): Unit = synchronized {
- if (numBytes > _executionMemoryUsed) {
- logWarning(s"Attempted to release $numBytes bytes of execution " +
- s"memory when we only have ${_executionMemoryUsed} bytes")
- _executionMemoryUsed = 0
- } else {
- _executionMemoryUsed -= numBytes
- }
- }
+ def acquireExecutionMemory(
+ numBytes: Long,
+ taskAttemptId: Long,
+ memoryMode: MemoryMode): Long
/**
* Release numBytes of execution memory belonging to the given task.
*/
private[memory]
- final def releaseExecutionMemory(numBytes: Long, taskAttemptId: Long): Unit = synchronized {
- val curMem = executionMemoryForTask.getOrElse(taskAttemptId, 0L)
- if (curMem < numBytes) {
- if (Utils.isTesting) {
- throw new SparkException(
- s"Internal error: release called on $numBytes bytes but task only has $curMem")
- } else {
- logWarning(s"Internal error: release called on $numBytes bytes but task only has $curMem")
- }
- }
- if (executionMemoryForTask.contains(taskAttemptId)) {
- executionMemoryForTask(taskAttemptId) -= numBytes
- if (executionMemoryForTask(taskAttemptId) <= 0) {
- executionMemoryForTask.remove(taskAttemptId)
- }
- releaseExecutionMemory(numBytes)
+ def releaseExecutionMemory(
+ numBytes: Long,
+ taskAttemptId: Long,
+ memoryMode: MemoryMode): Unit = synchronized {
+ memoryMode match {
+ case MemoryMode.ON_HEAP => onHeapExecutionMemoryPool.releaseMemory(numBytes, taskAttemptId)
+ case MemoryMode.OFF_HEAP => offHeapExecutionMemoryPool.releaseMemory(numBytes, taskAttemptId)
}
- notifyAll() // Notify waiters in acquireExecutionMemory() that memory has been freed
}
/**
* Release all memory for the given task and mark it as inactive (e.g. when a task ends).
+ *
* @return the number of bytes freed.
*/
private[memory] def releaseAllExecutionMemoryForTask(taskAttemptId: Long): Long = synchronized {
- val numBytesToFree = getExecutionMemoryUsageForTask(taskAttemptId)
- releaseExecutionMemory(numBytesToFree, taskAttemptId)
- numBytesToFree
+ onHeapExecutionMemoryPool.releaseAllMemoryForTask(taskAttemptId) +
+ offHeapExecutionMemoryPool.releaseAllMemoryForTask(taskAttemptId)
}
/**
* Release N bytes of storage memory.
*/
- def releaseStorageMemory(numBytes: Long): Unit = synchronized {
- if (numBytes > _storageMemoryUsed) {
- logWarning(s"Attempted to release $numBytes bytes of storage " +
- s"memory when we only have ${_storageMemoryUsed} bytes")
- _storageMemoryUsed = 0
- } else {
- _storageMemoryUsed -= numBytes
+ def releaseStorageMemory(numBytes: Long, memoryMode: MemoryMode): Unit = synchronized {
+ memoryMode match {
+ case MemoryMode.ON_HEAP => onHeapStorageMemoryPool.releaseMemory(numBytes)
+ case MemoryMode.OFF_HEAP => offHeapStorageMemoryPool.releaseMemory(numBytes)
}
}
/**
* Release all storage memory acquired.
*/
- def releaseAllStorageMemory(): Unit = synchronized {
- _storageMemoryUsed = 0
+ final def releaseAllStorageMemory(): Unit = synchronized {
+ onHeapStorageMemoryPool.releaseAllMemory()
+ offHeapStorageMemoryPool.releaseAllMemory()
}
/**
* Release N bytes of unroll memory.
*/
- def releaseUnrollMemory(numBytes: Long): Unit = synchronized {
- releaseStorageMemory(numBytes)
+ final def releaseUnrollMemory(numBytes: Long, memoryMode: MemoryMode): Unit = synchronized {
+ releaseStorageMemory(numBytes, memoryMode)
}
/**
* Execution memory currently in use, in bytes.
*/
final def executionMemoryUsed: Long = synchronized {
- _executionMemoryUsed
+ onHeapExecutionMemoryPool.memoryUsed + offHeapExecutionMemoryPool.memoryUsed
}
/**
* Storage memory currently in use, in bytes.
*/
final def storageMemoryUsed: Long = synchronized {
- _storageMemoryUsed
+ onHeapStorageMemoryPool.memoryUsed + offHeapStorageMemoryPool.memoryUsed
}
/**
* Returns the execution memory consumption, in bytes, for the given task.
*/
private[memory] def getExecutionMemoryUsageForTask(taskAttemptId: Long): Long = synchronized {
- executionMemoryForTask.getOrElse(taskAttemptId, 0L)
+ onHeapExecutionMemoryPool.getMemoryUsageForTask(taskAttemptId) +
+ offHeapExecutionMemoryPool.getMemoryUsageForTask(taskAttemptId)
}
// -- Fields related to Tungsten managed memory -------------------------------------------------
+ /**
+ * Tracks whether Tungsten memory will be allocated on the JVM heap or off-heap using
+ * sun.misc.Unsafe.
+ */
+ final val tungstenMemoryMode: MemoryMode = {
+ if (conf.getBoolean("spark.memory.offHeap.enabled", false)) {
+ require(conf.getSizeAsBytes("spark.memory.offHeap.size", 0) > 0,
+ "spark.memory.offHeap.size must be > 0 when spark.memory.offHeap.enabled == true")
+ require(Platform.unaligned(),
+ "No support for unaligned Unsafe. Set spark.memory.offHeap.enabled to false.")
+ MemoryMode.OFF_HEAP
+ } else {
+ MemoryMode.ON_HEAP
+ }
+ }
+
/**
* The default page size, in bytes.
*
@@ -306,21 +218,22 @@ private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) exte
val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors()
// Because of rounding to next power of 2, we may have safetyFactor as 8 in worst case
val safetyFactor = 16
- val size = ByteArrayMethods.nextPowerOf2(maxExecutionMemory / cores / safetyFactor)
+ val maxTungstenMemory: Long = tungstenMemoryMode match {
+ case MemoryMode.ON_HEAP => onHeapExecutionMemoryPool.poolSize
+ case MemoryMode.OFF_HEAP => offHeapExecutionMemoryPool.poolSize
+ }
+ val size = ByteArrayMethods.nextPowerOf2(maxTungstenMemory / cores / safetyFactor)
val default = math.min(maxPageSize, math.max(minPageSize, size))
conf.getSizeAsBytes("spark.buffer.pageSize", default)
}
- /**
- * Tracks whether Tungsten memory will be allocated on the JVM heap or off-heap using
- * sun.misc.Unsafe.
- */
- final val tungstenMemoryIsAllocatedInHeap: Boolean =
- !conf.getBoolean("spark.unsafe.offHeap", false)
-
/**
* Allocates memory for use by Unsafe/Tungsten code.
*/
- private[memory] final val tungstenMemoryAllocator: MemoryAllocator =
- if (tungstenMemoryIsAllocatedInHeap) MemoryAllocator.HEAP else MemoryAllocator.UNSAFE
+ private[memory] final val tungstenMemoryAllocator: MemoryAllocator = {
+ tungstenMemoryMode match {
+ case MemoryMode.ON_HEAP => MemoryAllocator.HEAP
+ case MemoryMode.OFF_HEAP => MemoryAllocator.UNSAFE
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala
new file mode 100644
index 000000000000..1b9edf9c43bd
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.memory
+
+import javax.annotation.concurrent.GuardedBy
+
+/**
+ * Manages bookkeeping for an adjustable-sized region of memory. This class is internal to
+ * the [[MemoryManager]]. See subclasses for more details.
+ *
+ * @param lock a [[MemoryManager]] instance, used for synchronization. We purposely erase the type
+ * to `Object` to avoid programming errors, since this object should only be used for
+ * synchronization purposes.
+ */
+private[memory] abstract class MemoryPool(lock: Object) {
+
+ @GuardedBy("lock")
+ private[this] var _poolSize: Long = 0
+
+ /**
+ * Returns the current size of the pool, in bytes.
+ */
+ final def poolSize: Long = lock.synchronized {
+ _poolSize
+ }
+
+ /**
+ * Returns the amount of free memory in the pool, in bytes.
+ */
+ final def memoryFree: Long = lock.synchronized {
+ _poolSize - memoryUsed
+ }
+
+ /**
+ * Expands the pool by `delta` bytes.
+ */
+ final def incrementPoolSize(delta: Long): Unit = lock.synchronized {
+ require(delta >= 0)
+ _poolSize += delta
+ }
+
+ /**
+ * Shrinks the pool by `delta` bytes.
+ */
+ final def decrementPoolSize(delta: Long): Unit = lock.synchronized {
+ require(delta >= 0)
+ require(delta <= _poolSize)
+ require(_poolSize - delta >= memoryUsed)
+ _poolSize -= delta
+ }
+
+ /**
+ * Returns the amount of used memory in this pool (in bytes).
+ */
+ def memoryUsed: Long
+}
diff --git a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala
index 9c2c2e90a228..a6f7db0600e6 100644
--- a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala
@@ -17,11 +17,8 @@
package org.apache.spark.memory
-import scala.collection.mutable
-
import org.apache.spark.SparkConf
-import org.apache.spark.storage.{BlockId, BlockStatus}
-
+import org.apache.spark.storage.BlockId
/**
* A [[MemoryManager]] that statically partitions the heap space into disjoint regions.
@@ -32,10 +29,14 @@ import org.apache.spark.storage.{BlockId, BlockStatus}
*/
private[spark] class StaticMemoryManager(
conf: SparkConf,
- override val maxExecutionMemory: Long,
- override val maxStorageMemory: Long,
+ maxOnHeapExecutionMemory: Long,
+ override val maxOnHeapStorageMemory: Long,
numCores: Int)
- extends MemoryManager(conf, numCores) {
+ extends MemoryManager(
+ conf,
+ numCores,
+ maxOnHeapStorageMemory,
+ maxOnHeapExecutionMemory) {
def this(conf: SparkConf, numCores: Int) {
this(
@@ -45,86 +46,68 @@ private[spark] class StaticMemoryManager(
numCores)
}
+ // The StaticMemoryManager does not support off-heap storage memory:
+ offHeapExecutionMemoryPool.incrementPoolSize(offHeapStorageMemoryPool.poolSize)
+ offHeapStorageMemoryPool.decrementPoolSize(offHeapStorageMemoryPool.poolSize)
+
// Max number of bytes worth of blocks to evict when unrolling
- private val maxMemoryToEvictForUnroll: Long = {
- (maxStorageMemory * conf.getDouble("spark.storage.unrollFraction", 0.2)).toLong
+ private val maxUnrollMemory: Long = {
+ (maxOnHeapStorageMemory * conf.getDouble("spark.storage.unrollFraction", 0.2)).toLong
}
- /**
- * Acquire N bytes of memory for execution.
- * @return number of bytes successfully granted (<= N).
- */
- override def doAcquireExecutionMemory(
- numBytes: Long,
- evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized {
- assert(numBytes >= 0)
- assert(_executionMemoryUsed <= maxExecutionMemory)
- val bytesToGrant = math.min(numBytes, maxExecutionMemory - _executionMemoryUsed)
- _executionMemoryUsed += bytesToGrant
- bytesToGrant
- }
+ override def maxOffHeapStorageMemory: Long = 0L
- /**
- * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary.
- * Blocks evicted in the process, if any, are added to `evictedBlocks`.
- * @return whether all N bytes were successfully granted.
- */
override def acquireStorageMemory(
blockId: BlockId,
numBytes: Long,
- evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized {
- acquireStorageMemory(blockId, numBytes, numBytes, evictedBlocks)
+ memoryMode: MemoryMode): Boolean = synchronized {
+ require(memoryMode != MemoryMode.OFF_HEAP,
+ "StaticMemoryManager does not support off-heap storage memory")
+ if (numBytes > maxOnHeapStorageMemory) {
+ // Fail fast if the block simply won't fit
+ logInfo(s"Will not store $blockId as the required space ($numBytes bytes) exceeds our " +
+ s"memory limit ($maxOnHeapStorageMemory bytes)")
+ false
+ } else {
+ onHeapStorageMemoryPool.acquireMemory(blockId, numBytes)
+ }
}
- /**
- * Acquire N bytes of memory to unroll the given block, evicting existing ones if necessary.
- *
- * This evicts at most M bytes worth of existing blocks, where M is a fraction of the storage
- * space specified by `spark.storage.unrollFraction`. Blocks evicted in the process, if any,
- * are added to `evictedBlocks`.
- *
- * @return whether all N bytes were successfully granted.
- */
override def acquireUnrollMemory(
blockId: BlockId,
numBytes: Long,
- evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized {
- val currentUnrollMemory = memoryStore.currentUnrollMemory
- val maxNumBytesToFree = math.max(0, maxMemoryToEvictForUnroll - currentUnrollMemory)
- val numBytesToFree = math.min(numBytes, maxNumBytesToFree)
- acquireStorageMemory(blockId, numBytes, numBytesToFree, evictedBlocks)
+ memoryMode: MemoryMode): Boolean = synchronized {
+ require(memoryMode != MemoryMode.OFF_HEAP,
+ "StaticMemoryManager does not support off-heap unroll memory")
+ val currentUnrollMemory = onHeapStorageMemoryPool.memoryStore.currentUnrollMemory
+ val freeMemory = onHeapStorageMemoryPool.memoryFree
+ // When unrolling, we will use all of the existing free memory, and, if necessary,
+ // some extra space freed from evicting cached blocks. We must place a cap on the
+ // amount of memory to be evicted by unrolling, however, otherwise unrolling one
+ // big block can blow away the entire cache.
+ val maxNumBytesToFree = math.max(0, maxUnrollMemory - currentUnrollMemory - freeMemory)
+ // Keep it within the range 0 <= X <= maxNumBytesToFree
+ val numBytesToFree = math.max(0, math.min(maxNumBytesToFree, numBytes - freeMemory))
+ onHeapStorageMemoryPool.acquireMemory(blockId, numBytes, numBytesToFree)
}
- /**
- * Acquire N bytes of storage memory for the given block, evicting existing ones if necessary.
- *
- * @param blockId the ID of the block we are acquiring storage memory for
- * @param numBytesToAcquire the size of this block
- * @param numBytesToFree the size of space to be freed through evicting blocks
- * @param evictedBlocks a holder for blocks evicted in the process
- * @return whether all N bytes were successfully granted.
- */
- private def acquireStorageMemory(
- blockId: BlockId,
- numBytesToAcquire: Long,
- numBytesToFree: Long,
- evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized {
- assert(numBytesToAcquire >= 0)
- assert(numBytesToFree >= 0)
- memoryStore.ensureFreeSpace(blockId, numBytesToFree, evictedBlocks)
- assert(_storageMemoryUsed <= maxStorageMemory)
- val enoughMemory = _storageMemoryUsed + numBytesToAcquire <= maxStorageMemory
- if (enoughMemory) {
- _storageMemoryUsed += numBytesToAcquire
+ private[memory]
+ override def acquireExecutionMemory(
+ numBytes: Long,
+ taskAttemptId: Long,
+ memoryMode: MemoryMode): Long = synchronized {
+ memoryMode match {
+ case MemoryMode.ON_HEAP => onHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId)
+ case MemoryMode.OFF_HEAP => offHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId)
}
- enoughMemory
}
-
}
private[spark] object StaticMemoryManager {
+ private val MIN_MEMORY_BYTES = 32 * 1024 * 1024
+
/**
* Return the total amount of memory available for the storage region, in bytes.
*/
@@ -135,12 +118,25 @@ private[spark] object StaticMemoryManager {
(systemMaxMemory * memoryFraction * safetyFraction).toLong
}
-
/**
* Return the total amount of memory available for the execution region, in bytes.
*/
private def getMaxExecutionMemory(conf: SparkConf): Long = {
val systemMaxMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory)
+
+ if (systemMaxMemory < MIN_MEMORY_BYTES) {
+ throw new IllegalArgumentException(s"System memory $systemMaxMemory must " +
+ s"be at least $MIN_MEMORY_BYTES. Please increase heap size using the --driver-memory " +
+ s"option or spark.driver.memory in Spark configuration.")
+ }
+ if (conf.contains("spark.executor.memory")) {
+ val executorMemory = conf.getSizeAsBytes("spark.executor.memory")
+ if (executorMemory < MIN_MEMORY_BYTES) {
+ throw new IllegalArgumentException(s"Executor memory $executorMemory must be at least " +
+ s"$MIN_MEMORY_BYTES. Please increase executor memory using the " +
+ s"--executor-memory option or spark.executor.memory in Spark configuration.")
+ }
+ }
val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2)
val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8)
(systemMaxMemory * memoryFraction * safetyFraction).toLong
diff --git a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala
new file mode 100644
index 000000000000..4c6b639015a9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.memory
+
+import javax.annotation.concurrent.GuardedBy
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.storage.BlockId
+import org.apache.spark.storage.memory.MemoryStore
+
+/**
+ * Performs bookkeeping for managing an adjustable-size pool of memory that is used for storage
+ * (caching).
+ *
+ * @param lock a [[MemoryManager]] instance to synchronize on
+ * @param memoryMode the type of memory tracked by this pool (on- or off-heap)
+ */
+private[memory] class StorageMemoryPool(
+ lock: Object,
+ memoryMode: MemoryMode
+ ) extends MemoryPool(lock) with Logging {
+
+ private[this] val poolName: String = memoryMode match {
+ case MemoryMode.ON_HEAP => "on-heap storage"
+ case MemoryMode.OFF_HEAP => "off-heap storage"
+ }
+
+ @GuardedBy("lock")
+ private[this] var _memoryUsed: Long = 0L
+
+ override def memoryUsed: Long = lock.synchronized {
+ _memoryUsed
+ }
+
+ private var _memoryStore: MemoryStore = _
+ def memoryStore: MemoryStore = {
+ if (_memoryStore == null) {
+ throw new IllegalStateException("memory store not initialized yet")
+ }
+ _memoryStore
+ }
+
+ /**
+ * Set the [[MemoryStore]] used by this manager to evict cached blocks.
+ * This must be set after construction due to initialization ordering constraints.
+ */
+ final def setMemoryStore(store: MemoryStore): Unit = {
+ _memoryStore = store
+ }
+
+ /**
+ * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary.
+ *
+ * @return whether all N bytes were successfully granted.
+ */
+ def acquireMemory(blockId: BlockId, numBytes: Long): Boolean = lock.synchronized {
+ val numBytesToFree = math.max(0, numBytes - memoryFree)
+ acquireMemory(blockId, numBytes, numBytesToFree)
+ }
+
+ /**
+ * Acquire N bytes of storage memory for the given block, evicting existing ones if necessary.
+ *
+ * @param blockId the ID of the block we are acquiring storage memory for
+ * @param numBytesToAcquire the size of this block
+ * @param numBytesToFree the amount of space to be freed through evicting blocks
+ * @return whether all N bytes were successfully granted.
+ */
+ def acquireMemory(
+ blockId: BlockId,
+ numBytesToAcquire: Long,
+ numBytesToFree: Long): Boolean = lock.synchronized {
+ assert(numBytesToAcquire >= 0)
+ assert(numBytesToFree >= 0)
+ assert(memoryUsed <= poolSize)
+ if (numBytesToFree > 0) {
+ memoryStore.evictBlocksToFreeSpace(Some(blockId), numBytesToFree, memoryMode)
+ }
+ // NOTE: If the memory store evicts blocks, then those evictions will synchronously call
+ // back into this StorageMemoryPool in order to free memory. Therefore, these variables
+ // should have been updated.
+ val enoughMemory = numBytesToAcquire <= memoryFree
+ if (enoughMemory) {
+ _memoryUsed += numBytesToAcquire
+ }
+ enoughMemory
+ }
+
+ def releaseMemory(size: Long): Unit = lock.synchronized {
+ if (size > _memoryUsed) {
+ logWarning(s"Attempted to release $size bytes of storage " +
+ s"memory when we only have ${_memoryUsed} bytes")
+ _memoryUsed = 0
+ } else {
+ _memoryUsed -= size
+ }
+ }
+
+ def releaseAllMemory(): Unit = lock.synchronized {
+ _memoryUsed = 0
+ }
+
+ /**
+ * Free space to shrink the size of this storage memory pool by `spaceToFree` bytes.
+ * Note: this method doesn't actually reduce the pool size but relies on the caller to do so.
+ *
+ * @return number of bytes to be removed from the pool's capacity.
+ */
+ def freeSpaceToShrinkPool(spaceToFree: Long): Long = lock.synchronized {
+ val spaceFreedByReleasingUnusedMemory = math.min(spaceToFree, memoryFree)
+ val remainingSpaceToFree = spaceToFree - spaceFreedByReleasingUnusedMemory
+ if (remainingSpaceToFree > 0) {
+ // If reclaiming free memory did not adequately shrink the pool, begin evicting blocks:
+ val spaceFreedByEviction =
+ memoryStore.evictBlocksToFreeSpace(None, remainingSpaceToFree, memoryMode)
+ // When a block is released, BlockManager.dropFromMemory() calls releaseMemory(), so we do
+ // not need to decrement _memoryUsed here. However, we do need to decrement the pool size.
+ spaceFreedByReleasingUnusedMemory + spaceFreedByEviction
+ } else {
+ spaceFreedByReleasingUnusedMemory
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala
index a3093030a0f9..df193552bed3 100644
--- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala
@@ -17,20 +17,17 @@
package org.apache.spark.memory
-import scala.collection.mutable
-
import org.apache.spark.SparkConf
-import org.apache.spark.storage.{BlockStatus, BlockId}
-
+import org.apache.spark.storage.BlockId
/**
* A [[MemoryManager]] that enforces a soft boundary between execution and storage such that
* either side can borrow memory from the other.
*
- * The region shared between execution and storage is a fraction of the total heap space
- * configurable through `spark.memory.fraction` (default 0.75). The position of the boundary
+ * The region shared between execution and storage is a fraction of (the total heap space - 300MB)
+ * configurable through `spark.memory.fraction` (default 0.6). The position of the boundary
* within this space is further determined by `spark.memory.storageFraction` (default 0.5).
- * This means the size of the storage region is 0.75 * 0.5 = 0.375 of the heap space by default.
+ * This means the size of the storage region is 0.6 * 0.5 = 0.3 of the heap space by default.
*
* Storage can borrow as much execution memory as is free until execution reclaims its space.
* When this happens, cached blocks will be evicted from memory until sufficient borrowed
@@ -41,105 +38,197 @@ import org.apache.spark.storage.{BlockStatus, BlockId}
* The implication is that attempts to cache blocks may fail if execution has already eaten
* up most of the storage space, in which case the new blocks will be evicted immediately
* according to their respective storage levels.
+ *
+ * @param onHeapStorageRegionSize Size of the storage region, in bytes.
+ * This region is not statically reserved; execution can borrow from
+ * it if necessary. Cached blocks can be evicted only if actual
+ * storage memory usage exceeds this region.
*/
-private[spark] class UnifiedMemoryManager(
+private[spark] class UnifiedMemoryManager private[memory] (
conf: SparkConf,
- maxMemory: Long,
+ val maxHeapMemory: Long,
+ onHeapStorageRegionSize: Long,
numCores: Int)
- extends MemoryManager(conf, numCores) {
-
- def this(conf: SparkConf, numCores: Int) {
- this(conf, UnifiedMemoryManager.getMaxMemory(conf), numCores)
- }
+ extends MemoryManager(
+ conf,
+ numCores,
+ onHeapStorageRegionSize,
+ maxHeapMemory - onHeapStorageRegionSize) {
- /**
- * Size of the storage region, in bytes.
- *
- * This region is not statically reserved; execution can borrow from it if necessary.
- * Cached blocks can be evicted only if actual storage memory usage exceeds this region.
- */
- private val storageRegionSize: Long = {
- (maxMemory * conf.getDouble("spark.memory.storageFraction", 0.5)).toLong
+ private def assertInvariants(): Unit = {
+ assert(onHeapExecutionMemoryPool.poolSize + onHeapStorageMemoryPool.poolSize == maxHeapMemory)
+ assert(
+ offHeapExecutionMemoryPool.poolSize + offHeapStorageMemoryPool.poolSize == maxOffHeapMemory)
}
- /**
- * Total amount of memory, in bytes, not currently occupied by either execution or storage.
- */
- private def totalFreeMemory: Long = synchronized {
- assert(_executionMemoryUsed <= maxMemory)
- assert(_storageMemoryUsed <= maxMemory)
- assert(_executionMemoryUsed + _storageMemoryUsed <= maxMemory)
- maxMemory - _executionMemoryUsed - _storageMemoryUsed
- }
+ assertInvariants()
- /**
- * Total available memory for execution, in bytes.
- * In this model, this is equivalent to the amount of memory not occupied by storage.
- */
- override def maxExecutionMemory: Long = synchronized {
- maxMemory - _storageMemoryUsed
+ override def maxOnHeapStorageMemory: Long = synchronized {
+ maxHeapMemory - onHeapExecutionMemoryPool.memoryUsed
}
- /**
- * Total available memory for storage, in bytes.
- * In this model, this is equivalent to the amount of memory not occupied by execution.
- */
- override def maxStorageMemory: Long = synchronized {
- maxMemory - _executionMemoryUsed
+ override def maxOffHeapStorageMemory: Long = synchronized {
+ maxOffHeapMemory - offHeapExecutionMemoryPool.memoryUsed
}
/**
- * Acquire N bytes of memory for execution, evicting cached blocks if necessary.
+ * Try to acquire up to `numBytes` of execution memory for the current task and return the
+ * number of bytes obtained, or 0 if none can be allocated.
*
- * This method evicts blocks only up to the amount of memory borrowed by storage.
- * Blocks evicted in the process, if any, are added to `evictedBlocks`.
- * @return number of bytes successfully granted (<= N).
+ * This call may block until there is enough free memory in some situations, to make sure each
+ * task has a chance to ramp up to at least 1 / 2N of the total memory pool (where N is the # of
+ * active tasks) before it is forced to spill. This can happen if the number of tasks increase
+ * but an older task had a lot of memory already.
*/
- private[memory] override def doAcquireExecutionMemory(
+ override private[memory] def acquireExecutionMemory(
numBytes: Long,
- evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized {
+ taskAttemptId: Long,
+ memoryMode: MemoryMode): Long = synchronized {
+ assertInvariants()
assert(numBytes >= 0)
- val memoryBorrowedByStorage = math.max(0, _storageMemoryUsed - storageRegionSize)
- // If there is not enough free memory AND storage has borrowed some execution memory,
- // then evict as much memory borrowed by storage as needed to grant this request
- val shouldEvictStorage = totalFreeMemory < numBytes && memoryBorrowedByStorage > 0
- if (shouldEvictStorage) {
- val spaceToEnsure = math.min(numBytes, memoryBorrowedByStorage)
- memoryStore.ensureFreeSpace(spaceToEnsure, evictedBlocks)
+ val (executionPool, storagePool, storageRegionSize, maxMemory) = memoryMode match {
+ case MemoryMode.ON_HEAP => (
+ onHeapExecutionMemoryPool,
+ onHeapStorageMemoryPool,
+ onHeapStorageRegionSize,
+ maxHeapMemory)
+ case MemoryMode.OFF_HEAP => (
+ offHeapExecutionMemoryPool,
+ offHeapStorageMemoryPool,
+ offHeapStorageMemory,
+ maxOffHeapMemory)
+ }
+
+ /**
+ * Grow the execution pool by evicting cached blocks, thereby shrinking the storage pool.
+ *
+ * When acquiring memory for a task, the execution pool may need to make multiple
+ * attempts. Each attempt must be able to evict storage in case another task jumps in
+ * and caches a large block between the attempts. This is called once per attempt.
+ */
+ def maybeGrowExecutionPool(extraMemoryNeeded: Long): Unit = {
+ if (extraMemoryNeeded > 0) {
+ // There is not enough free memory in the execution pool, so try to reclaim memory from
+ // storage. We can reclaim any free memory from the storage pool. If the storage pool
+ // has grown to become larger than `storageRegionSize`, we can evict blocks and reclaim
+ // the memory that storage has borrowed from execution.
+ val memoryReclaimableFromStorage = math.max(
+ storagePool.memoryFree,
+ storagePool.poolSize - storageRegionSize)
+ if (memoryReclaimableFromStorage > 0) {
+ // Only reclaim as much space as is necessary and available:
+ val spaceToReclaim = storagePool.freeSpaceToShrinkPool(
+ math.min(extraMemoryNeeded, memoryReclaimableFromStorage))
+ storagePool.decrementPoolSize(spaceToReclaim)
+ executionPool.incrementPoolSize(spaceToReclaim)
+ }
+ }
+ }
+
+ /**
+ * The size the execution pool would have after evicting storage memory.
+ *
+ * The execution memory pool divides this quantity among the active tasks evenly to cap
+ * the execution memory allocation for each task. It is important to keep this greater
+ * than the execution pool size, which doesn't take into account potential memory that
+ * could be freed by evicting storage. Otherwise we may hit SPARK-12155.
+ *
+ * Additionally, this quantity should be kept below `maxMemory` to arbitrate fairness
+ * in execution memory allocation across tasks, Otherwise, a task may occupy more than
+ * its fair share of execution memory, mistakenly thinking that other tasks can acquire
+ * the portion of storage memory that cannot be evicted.
+ */
+ def computeMaxExecutionPoolSize(): Long = {
+ maxMemory - math.min(storagePool.memoryUsed, storageRegionSize)
}
- val bytesToGrant = math.min(numBytes, totalFreeMemory)
- _executionMemoryUsed += bytesToGrant
- bytesToGrant
+
+ executionPool.acquireMemory(
+ numBytes, taskAttemptId, maybeGrowExecutionPool, computeMaxExecutionPoolSize)
}
- /**
- * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary.
- * Blocks evicted in the process, if any, are added to `evictedBlocks`.
- * @return whether all N bytes were successfully granted.
- */
override def acquireStorageMemory(
blockId: BlockId,
numBytes: Long,
- evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized {
+ memoryMode: MemoryMode): Boolean = synchronized {
+ assertInvariants()
assert(numBytes >= 0)
- memoryStore.ensureFreeSpace(blockId, numBytes, evictedBlocks)
- val enoughMemory = totalFreeMemory >= numBytes
- if (enoughMemory) {
- _storageMemoryUsed += numBytes
+ val (executionPool, storagePool, maxMemory) = memoryMode match {
+ case MemoryMode.ON_HEAP => (
+ onHeapExecutionMemoryPool,
+ onHeapStorageMemoryPool,
+ maxOnHeapStorageMemory)
+ case MemoryMode.OFF_HEAP => (
+ offHeapExecutionMemoryPool,
+ offHeapStorageMemoryPool,
+ maxOffHeapStorageMemory)
+ }
+ if (numBytes > maxMemory) {
+ // Fail fast if the block simply won't fit
+ logInfo(s"Will not store $blockId as the required space ($numBytes bytes) exceeds our " +
+ s"memory limit ($maxMemory bytes)")
+ return false
+ }
+ if (numBytes > storagePool.memoryFree) {
+ // There is not enough free memory in the storage pool, so try to borrow free memory from
+ // the execution pool.
+ val memoryBorrowedFromExecution = Math.min(executionPool.memoryFree,
+ numBytes - storagePool.memoryFree)
+ executionPool.decrementPoolSize(memoryBorrowedFromExecution)
+ storagePool.incrementPoolSize(memoryBorrowedFromExecution)
}
- enoughMemory
+ storagePool.acquireMemory(blockId, numBytes)
}
+ override def acquireUnrollMemory(
+ blockId: BlockId,
+ numBytes: Long,
+ memoryMode: MemoryMode): Boolean = synchronized {
+ acquireStorageMemory(blockId, numBytes, memoryMode)
+ }
}
-private object UnifiedMemoryManager {
+object UnifiedMemoryManager {
+
+ // Set aside a fixed amount of memory for non-storage, non-execution purposes.
+ // This serves a function similar to `spark.memory.fraction`, but guarantees that we reserve
+ // sufficient memory for the system even for small heaps. E.g. if we have a 1GB JVM, then
+ // the memory used for execution and storage will be (1024 - 300) * 0.6 = 434MB by default.
+ private val RESERVED_SYSTEM_MEMORY_BYTES = 300 * 1024 * 1024
+
+ def apply(conf: SparkConf, numCores: Int): UnifiedMemoryManager = {
+ val maxMemory = getMaxMemory(conf)
+ new UnifiedMemoryManager(
+ conf,
+ maxHeapMemory = maxMemory,
+ onHeapStorageRegionSize =
+ (maxMemory * conf.getDouble("spark.memory.storageFraction", 0.5)).toLong,
+ numCores = numCores)
+ }
/**
* Return the total amount of memory shared between execution and storage, in bytes.
*/
private def getMaxMemory(conf: SparkConf): Long = {
- val systemMaxMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory)
- val memoryFraction = conf.getDouble("spark.memory.fraction", 0.75)
- (systemMaxMemory * memoryFraction).toLong
+ val systemMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory)
+ val reservedMemory = conf.getLong("spark.testing.reservedMemory",
+ if (conf.contains("spark.testing")) 0 else RESERVED_SYSTEM_MEMORY_BYTES)
+ val minSystemMemory = (reservedMemory * 1.5).ceil.toLong
+ if (systemMemory < minSystemMemory) {
+ throw new IllegalArgumentException(s"System memory $systemMemory must " +
+ s"be at least $minSystemMemory. Please increase heap size using the --driver-memory " +
+ s"option or spark.driver.memory in Spark configuration.")
+ }
+ // SPARK-12759 Check executor memory to fail fast if memory is insufficient
+ if (conf.contains("spark.executor.memory")) {
+ val executorMemory = conf.getSizeAsBytes("spark.executor.memory")
+ if (executorMemory < minSystemMemory) {
+ throw new IllegalArgumentException(s"Executor memory $executorMemory must be at least " +
+ s"$minSystemMemory. Please increase executor memory using the " +
+ s"--executor-memory option or spark.executor.memory in Spark configuration.")
+ }
+ }
+ val usableMemory = systemMemory - reservedMemory
+ val memoryFraction = conf.getDouble("spark.memory.fraction", 0.6)
+ (usableMemory * memoryFraction).toLong
}
}
diff --git a/core/src/main/scala/org/apache/spark/memory/package.scala b/core/src/main/scala/org/apache/spark/memory/package.scala
new file mode 100644
index 000000000000..3d00cd9cb637
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/memory/package.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+/**
+ * This package implements Spark's memory management system. This system consists of two main
+ * components, a JVM-wide memory manager and a per-task manager:
+ *
+ * - [[org.apache.spark.memory.MemoryManager]] manages Spark's overall memory usage within a JVM.
+ * This component implements the policies for dividing the available memory across tasks and for
+ * allocating memory between storage (memory used caching and data transfer) and execution
+ * (memory used by computations, such as shuffles, joins, sorts, and aggregations).
+ * - [[org.apache.spark.memory.TaskMemoryManager]] manages the memory allocated by individual
+ * tasks. Tasks interact with TaskMemoryManager and never directly interact with the JVM-wide
+ * MemoryManager.
+ *
+ * Internally, each of these components have additional abstractions for memory bookkeeping:
+ *
+ * - [[org.apache.spark.memory.MemoryConsumer]]s are clients of the TaskMemoryManager and
+ * correspond to individual operators and data structures within a task. The TaskMemoryManager
+ * receives memory allocation requests from MemoryConsumers and issues callbacks to consumers
+ * in order to trigger spilling when running low on memory.
+ * - [[org.apache.spark.memory.MemoryPool]]s are a bookkeeping abstraction used by the
+ * MemoryManager to track the division of memory between storage and execution.
+ *
+ * Diagrammatically:
+ *
+ * {{{
+ * +-------------+
+ * | MemConsumer |----+ +------------------------+
+ * +-------------+ | +-------------------+ | MemoryManager |
+ * +--->| TaskMemoryManager |----+ | |
+ * +-------------+ | +-------------------+ | | +------------------+ |
+ * | MemConsumer |----+ | | | StorageMemPool | |
+ * +-------------+ +-------------------+ | | +------------------+ |
+ * | TaskMemoryManager |----+ | |
+ * +-------------------+ | | +------------------+ |
+ * +---->| |OnHeapExecMemPool | |
+ * * | | +------------------+ |
+ * * | | |
+ * +-------------+ * | | +------------------+ |
+ * | MemConsumer |----+ | | |OffHeapExecMemPool| |
+ * +-------------+ | +-------------------+ | | +------------------+ |
+ * +--->| TaskMemoryManager |----+ | |
+ * +-------------------+ +------------------------+
+ * }}}
+ *
+ *
+ * There are two implementations of [[org.apache.spark.memory.MemoryManager]] which vary in how
+ * they handle the sizing of their memory pools:
+ *
+ * - [[org.apache.spark.memory.UnifiedMemoryManager]], the default in Spark 1.6+, enforces soft
+ * boundaries between storage and execution memory, allowing requests for memory in one region
+ * to be fulfilled by borrowing memory from the other.
+ * - [[org.apache.spark.memory.StaticMemoryManager]] enforces hard boundaries between storage
+ * and execution memory by statically partitioning Spark's memory and preventing storage and
+ * execution from borrowing memory from each other. This mode is retained only for legacy
+ * compatibility purposes.
+ */
+package object memory
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala
index dd2d325d8703..a4056508c181 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala
@@ -24,8 +24,9 @@ import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.matching.Regex
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils
-import org.apache.spark.{Logging, SparkConf}
private[spark] class MetricsConfig(conf: SparkConf) extends Logging {
@@ -34,7 +35,7 @@ private[spark] class MetricsConfig(conf: SparkConf) extends Logging {
private val DEFAULT_METRICS_CONF_FILENAME = "metrics.properties"
private[metrics] val properties = new Properties()
- private[metrics] var propertyCategories: mutable.HashMap[String, Properties] = null
+ private[metrics] var perInstanceSubProperties: mutable.HashMap[String, Properties] = null
private def setDefaultProperties(prop: Properties) {
prop.setProperty("*.sink.servlet.class", "org.apache.spark.metrics.sink.MetricsServlet")
@@ -43,6 +44,10 @@ private[spark] class MetricsConfig(conf: SparkConf) extends Logging {
prop.setProperty("applications.sink.servlet.path", "/metrics/applications/json")
}
+ /**
+ * Load properties from various places, based on precedence
+ * If the same property is set again latter on in the method, it overwrites the previous value
+ */
def initialize() {
// Add default properties in case there's no properties file
setDefaultProperties(properties)
@@ -57,16 +62,47 @@ private[spark] class MetricsConfig(conf: SparkConf) extends Logging {
case _ =>
}
- propertyCategories = subProperties(properties, INSTANCE_REGEX)
- if (propertyCategories.contains(DEFAULT_PREFIX)) {
- val defaultProperty = propertyCategories(DEFAULT_PREFIX).asScala
- for((inst, prop) <- propertyCategories if (inst != DEFAULT_PREFIX);
- (k, v) <- defaultProperty if (prop.get(k) == null)) {
+ // Now, let's populate a list of sub-properties per instance, instance being the prefix that
+ // appears before the first dot in the property name.
+ // Add to the sub-properties per instance, the default properties (those with prefix "*"), if
+ // they don't have that exact same sub-property already defined.
+ //
+ // For example, if properties has ("*.class"->"default_class", "*.path"->"default_path,
+ // "driver.path"->"driver_path"), for driver specific sub-properties, we'd like the output to be
+ // ("driver"->Map("path"->"driver_path", "class"->"default_class")
+ // Note how class got added to based on the default property, but path remained the same
+ // since "driver.path" already existed and took precedence over "*.path"
+ //
+ perInstanceSubProperties = subProperties(properties, INSTANCE_REGEX)
+ if (perInstanceSubProperties.contains(DEFAULT_PREFIX)) {
+ val defaultSubProperties = perInstanceSubProperties(DEFAULT_PREFIX).asScala
+ for ((instance, prop) <- perInstanceSubProperties if (instance != DEFAULT_PREFIX);
+ (k, v) <- defaultSubProperties if (prop.get(k) == null)) {
prop.put(k, v)
}
}
}
+ /**
+ * Take a simple set of properties and a regex that the instance names (part before the first dot)
+ * have to conform to. And, return a map of the first order prefix (before the first dot) to the
+ * sub-properties under that prefix.
+ *
+ * For example, if the properties sent were Properties("*.sink.servlet.class"->"class1",
+ * "*.sink.servlet.path"->"path1"), the returned map would be
+ * Map("*" -> Properties("sink.servlet.class" -> "class1", "sink.servlet.path" -> "path1"))
+ * Note in the subProperties (value of the returned Map), only the suffixes are used as property
+ * keys.
+ * If, in the passed properties, there is only one property with a given prefix, it is still
+ * "unflattened". For example, if the input was Properties("*.sink.servlet.class" -> "class1"
+ * the returned Map would contain one key-value pair
+ * Map("*" -> Properties("sink.servlet.class" -> "class1"))
+ * Any passed in properties, not complying with the regex are ignored.
+ *
+ * @param prop the flat list of properties to "unflatten" based on prefixes
+ * @param regex the regex that the prefix has to comply with
+ * @return an unflatted map, mapping prefix with sub-properties under that prefix
+ */
def subProperties(prop: Properties, regex: Regex): mutable.HashMap[String, Properties] = {
val subProperties = new mutable.HashMap[String, Properties]
prop.asScala.foreach { kv =>
@@ -79,9 +115,9 @@ private[spark] class MetricsConfig(conf: SparkConf) extends Logging {
}
def getInstance(inst: String): Properties = {
- propertyCategories.get(inst) match {
+ perInstanceSubProperties.get(inst) match {
case Some(s) => s
- case None => propertyCategories.getOrElse(DEFAULT_PREFIX, new Properties)
+ case None => perInstanceSubProperties.getOrElse(DEFAULT_PREFIX, new Properties)
}
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
index fdf76d312db3..1d494500cdb5 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
@@ -20,36 +20,37 @@ package org.apache.spark.metrics
import java.util.Properties
import java.util.concurrent.TimeUnit
-import org.apache.spark.util.Utils
-
import scala.collection.mutable
import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry}
import org.eclipse.jetty.servlet.ServletContextHandler
-import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.{SecurityManager, SparkConf}
+import org.apache.spark.internal.config._
+import org.apache.spark.internal.Logging
import org.apache.spark.metrics.sink.{MetricsServlet, Sink}
-import org.apache.spark.metrics.source.Source
+import org.apache.spark.metrics.source.{Source, StaticSources}
+import org.apache.spark.util.Utils
/**
- * Spark Metrics System, created by specific "instance", combined by source,
- * sink, periodically poll source metrics data to sink destinations.
+ * Spark Metrics System, created by a specific "instance", combined by source,
+ * sink, periodically polls source metrics data to sink destinations.
*
- * "instance" specify "who" (the role) use metrics system. In spark there are several roles
- * like master, worker, executor, client driver, these roles will create metrics system
- * for monitoring. So instance represents these roles. Currently in Spark, several instances
+ * "instance" specifies "who" (the role) uses the metrics system. In Spark, there are several roles
+ * like master, worker, executor, client driver. These roles will create metrics system
+ * for monitoring. So, "instance" represents these roles. Currently in Spark, several instances
* have already implemented: master, worker, executor, driver, applications.
*
- * "source" specify "where" (source) to collect metrics data. In metrics system, there exists
+ * "source" specifies "where" (source) to collect metrics data from. In metrics system, there exists
* two kinds of source:
* 1. Spark internal source, like MasterSource, WorkerSource, etc, which will collect
* Spark component's internal state, these sources are related to instance and will be
- * added after specific metrics system is created.
+ * added after a specific metrics system is created.
* 2. Common source, like JvmSource, which will collect low level state, is configured by
* configuration and loaded through reflection.
*
- * "sink" specify "where" (destination) to output metrics data to. Several sinks can be
- * coexisted and flush metrics to all these sinks.
+ * "sink" specifies "where" (destination) to output metrics data to. Several sinks can
+ * coexist and metrics can be flushed to all these sinks.
*
* Metrics configuration format is like below:
* [instance].[sink|source].[name].[options] = xxxx
@@ -62,9 +63,9 @@ import org.apache.spark.metrics.source.Source
* [sink|source] means this property belongs to source or sink. This field can only be
* source or sink.
*
- * [name] specify the name of sink or source, it is custom defined.
+ * [name] specify the name of sink or source, if it is custom defined.
*
- * [options] is the specific property of this source or sink.
+ * [options] represent the specific property of this source or sink.
*/
private[spark] class MetricsSystem private (
val instance: String,
@@ -96,6 +97,7 @@ private[spark] class MetricsSystem private (
def start() {
require(!running, "Attempting to start a MetricsSystem that is already running")
running = true
+ StaticSources.allSources.foreach(registerSource)
registerSources()
registerSinks()
sinks.foreach(_.start)
@@ -124,19 +126,25 @@ private[spark] class MetricsSystem private (
* application, executor/driver and metric source.
*/
private[spark] def buildRegistryName(source: Source): String = {
- val appId = conf.getOption("spark.app.id")
+ val metricsNamespace = conf.get(METRICS_NAMESPACE).orElse(conf.getOption("spark.app.id"))
+
val executorId = conf.getOption("spark.executor.id")
val defaultName = MetricRegistry.name(source.sourceName)
if (instance == "driver" || instance == "executor") {
- if (appId.isDefined && executorId.isDefined) {
- MetricRegistry.name(appId.get, executorId.get, source.sourceName)
+ if (metricsNamespace.isDefined && executorId.isDefined) {
+ MetricRegistry.name(metricsNamespace.get, executorId.get, source.sourceName)
} else {
// Only Driver and Executor set spark.app.id and spark.executor.id.
// Other instance types, e.g. Master and Worker, are not related to a specific application.
- val warningMsg = s"Using default name $defaultName for source because %s is not set."
- if (appId.isEmpty) { logWarning(warningMsg.format("spark.app.id")) }
- if (executorId.isEmpty) { logWarning(warningMsg.format("spark.executor.id")) }
+ if (metricsNamespace.isEmpty) {
+ logWarning(s"Using default name $defaultName for source because neither " +
+ s"${METRICS_NAMESPACE.key} nor spark.app.id is set.")
+ }
+ if (executorId.isEmpty) {
+ logWarning(s"Using default name $defaultName for source because spark.executor.id is " +
+ s"not set.")
+ }
defaultName
}
} else { defaultName }
@@ -196,10 +204,9 @@ private[spark] class MetricsSystem private (
sinks += sink.asInstanceOf[Sink]
}
} catch {
- case e: Exception => {
+ case e: Exception =>
logError("Sink class " + classPath + " cannot be instantiated")
throw e
- }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
index 81b9056b40fb..fce556fd0382 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
@@ -17,7 +17,7 @@
package org.apache.spark.metrics.sink
-import java.util.Properties
+import java.util.{Locale, Properties}
import java.util.concurrent.TimeUnit
import com.codahale.metrics.{ConsoleReporter, MetricRegistry}
@@ -39,7 +39,7 @@ private[spark] class ConsoleSink(val property: Properties, val registry: MetricR
}
val pollUnit: TimeUnit = Option(property.getProperty(CONSOLE_KEY_UNIT)) match {
- case Some(s) => TimeUnit.valueOf(s.toUpperCase())
+ case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT))
case None => TimeUnit.valueOf(CONSOLE_DEFAULT_UNIT)
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
index 9d5f2ae9328a..88bba2fdbd1c 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
@@ -42,7 +42,7 @@ private[spark] class CsvSink(val property: Properties, val registry: MetricRegis
}
val pollUnit: TimeUnit = Option(property.getProperty(CSV_KEY_UNIT)) match {
- case Some(s) => TimeUnit.valueOf(s.toUpperCase())
+ case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT))
case None => TimeUnit.valueOf(CSV_DEFAULT_UNIT)
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
index 2d25ebd66159..23e31823f493 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
@@ -18,11 +18,11 @@
package org.apache.spark.metrics.sink
import java.net.InetSocketAddress
-import java.util.Properties
+import java.util.{Locale, Properties}
import java.util.concurrent.TimeUnit
import com.codahale.metrics.MetricRegistry
-import com.codahale.metrics.graphite.{GraphiteUDP, Graphite, GraphiteReporter}
+import com.codahale.metrics.graphite.{Graphite, GraphiteReporter, GraphiteUDP}
import org.apache.spark.SecurityManager
import org.apache.spark.metrics.MetricsSystem
@@ -59,7 +59,7 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric
}
val pollUnit: TimeUnit = propertyToOption(GRAPHITE_KEY_UNIT) match {
- case Some(s) => TimeUnit.valueOf(s.toUpperCase())
+ case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT))
case None => TimeUnit.valueOf(GRAPHITE_DEFAULT_UNIT)
}
@@ -67,7 +67,7 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric
MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod)
- val graphite = propertyToOption(GRAPHITE_KEY_PROTOCOL).map(_.toLowerCase) match {
+ val graphite = propertyToOption(GRAPHITE_KEY_PROTOCOL).map(_.toLowerCase(Locale.ROOT)) match {
case Some("udp") => new GraphiteUDP(new InetSocketAddress(host, port))
case Some("tcp") | None => new Graphite(new InetSocketAddress(host, port))
case Some(p) => throw new Exception(s"Invalid Graphite protocol: $p")
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala
index 2588fe2c9edb..1992b42ac7f6 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala
@@ -20,6 +20,7 @@ package org.apache.spark.metrics.sink
import java.util.Properties
import com.codahale.metrics.{JmxReporter, MetricRegistry}
+
import org.apache.spark.SecurityManager
private[spark] class JmxSink(val property: Properties, val registry: MetricRegistry,
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
index 4193e1d21d3c..68b58b849064 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
@@ -19,7 +19,6 @@ package org.apache.spark.metrics.sink
import java.util.Properties
import java.util.concurrent.TimeUnit
-
import javax.servlet.http.HttpServletRequest
import com.codahale.metrics.MetricRegistry
@@ -27,7 +26,7 @@ import com.codahale.metrics.json.MetricsModule
import com.fasterxml.jackson.databind.ObjectMapper
import org.eclipse.jetty.servlet.ServletContextHandler
-import org.apache.spark.{SparkConf, SecurityManager}
+import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.ui.JettyUtils._
private[spark] class MetricsServlet(
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala
index 11dfcfe2f04e..7fa4ba762298 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala
@@ -17,10 +17,10 @@
package org.apache.spark.metrics.sink
-import java.util.Properties
+import java.util.{Locale, Properties}
import java.util.concurrent.TimeUnit
-import com.codahale.metrics.{Slf4jReporter, MetricRegistry}
+import com.codahale.metrics.{MetricRegistry, Slf4jReporter}
import org.apache.spark.SecurityManager
import org.apache.spark.metrics.MetricsSystem
@@ -42,7 +42,7 @@ private[spark] class Slf4jSink(
}
val pollUnit: TimeUnit = Option(property.getProperty(SLF4J_KEY_UNIT)) match {
- case Some(s) => TimeUnit.valueOf(s.toUpperCase())
+ case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT))
case None => TimeUnit.valueOf(SLF4J_DEFAULT_UNIT)
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala
new file mode 100644
index 000000000000..99ec78633ab7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.metrics.source
+
+import com.codahale.metrics.MetricRegistry
+
+import org.apache.spark.annotation.Experimental
+
+private[spark] object StaticSources {
+ /**
+ * The set of all static sources. These sources may be reported to from any class, including
+ * static classes, without requiring reference to a SparkEnv.
+ */
+ val allSources = Seq(CodegenMetrics, HiveCatalogMetrics)
+}
+
+/**
+ * :: Experimental ::
+ * Metrics for code generation.
+ */
+@Experimental
+object CodegenMetrics extends Source {
+ override val sourceName: String = "CodeGenerator"
+ override val metricRegistry: MetricRegistry = new MetricRegistry()
+
+ /**
+ * Histogram of the length of source code text compiled by CodeGenerator (in characters).
+ */
+ val METRIC_SOURCE_CODE_SIZE = metricRegistry.histogram(MetricRegistry.name("sourceCodeSize"))
+
+ /**
+ * Histogram of the time it took to compile source code text (in milliseconds).
+ */
+ val METRIC_COMPILATION_TIME = metricRegistry.histogram(MetricRegistry.name("compilationTime"))
+
+ /**
+ * Histogram of the bytecode size of each class generated by CodeGenerator.
+ */
+ val METRIC_GENERATED_CLASS_BYTECODE_SIZE =
+ metricRegistry.histogram(MetricRegistry.name("generatedClassSize"))
+
+ /**
+ * Histogram of the bytecode size of each method in classes generated by CodeGenerator.
+ */
+ val METRIC_GENERATED_METHOD_BYTECODE_SIZE =
+ metricRegistry.histogram(MetricRegistry.name("generatedMethodSize"))
+}
+
+/**
+ * :: Experimental ::
+ * Metrics for access to the hive external catalog.
+ */
+@Experimental
+object HiveCatalogMetrics extends Source {
+ override val sourceName: String = "HiveExternalCatalog"
+ override val metricRegistry: MetricRegistry = new MetricRegistry()
+
+ /**
+ * Tracks the total number of partition metadata entries fetched via the client api.
+ */
+ val METRIC_PARTITIONS_FETCHED = metricRegistry.counter(MetricRegistry.name("partitionsFetched"))
+
+ /**
+ * Tracks the total number of files discovered off of the filesystem by InMemoryFileIndex.
+ */
+ val METRIC_FILES_DISCOVERED = metricRegistry.counter(MetricRegistry.name("filesDiscovered"))
+
+ /**
+ * Tracks the total number of files served from the file status cache instead of discovered.
+ */
+ val METRIC_FILE_CACHE_HITS = metricRegistry.counter(MetricRegistry.name("fileCacheHits"))
+
+ /**
+ * Tracks the total number of Hive client calls (e.g. to lookup a table).
+ */
+ val METRIC_HIVE_CLIENT_CALLS = metricRegistry.counter(MetricRegistry.name("hiveClientCalls"))
+
+ /**
+ * Tracks the total number of Spark jobs launched for parallel file listing.
+ */
+ val METRIC_PARALLEL_LISTING_JOB_COUNT = metricRegistry.counter(
+ MetricRegistry.name("parallelListingJobCount"))
+
+ /**
+ * Resets the values of all metrics to zero. This is useful in tests.
+ */
+ def reset(): Unit = {
+ METRIC_PARTITIONS_FETCHED.dec(METRIC_PARTITIONS_FETCHED.getCount())
+ METRIC_FILES_DISCOVERED.dec(METRIC_FILES_DISCOVERED.getCount())
+ METRIC_FILE_CACHE_HITS.dec(METRIC_FILE_CACHE_HITS.getCount())
+ METRIC_HIVE_CLIENT_CALLS.dec(METRIC_HIVE_CLIENT_CALLS.getCount())
+ METRIC_PARALLEL_LISTING_JOB_COUNT.dec(METRIC_PARALLEL_LISTING_JOB_COUNT.getCount())
+ }
+
+ // clients can use these to avoid classloader issues with the codahale classes
+ def incrementFetchedPartitions(n: Int): Unit = METRIC_PARTITIONS_FETCHED.inc(n)
+ def incrementFilesDiscovered(n: Int): Unit = METRIC_FILES_DISCOVERED.inc(n)
+ def incrementFileCacheHits(n: Int): Unit = METRIC_FILE_CACHE_HITS.inc(n)
+ def incrementHiveClientCalls(n: Int): Unit = METRIC_HIVE_CLIENT_CALLS.inc(n)
+ def incrementParallelListingJobCount(n: Int): Unit = METRIC_PARALLEL_LISTING_JOB_COUNT.inc(n)
+}
diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
index 1745d52c8192..b3f8bfe8b1d4 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
@@ -17,6 +17,8 @@
package org.apache.spark.network
+import scala.reflect.ClassTag
+
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.storage.{BlockId, StorageLevel}
@@ -31,6 +33,18 @@ trait BlockDataManager {
/**
* Put the block locally, using the given storage level.
+ *
+ * Returns true if the block was stored and false if the put operation failed or the block
+ * already existed.
+ */
+ def putBlockData(
+ blockId: BlockId,
+ data: ManagedBuffer,
+ level: StorageLevel,
+ classTag: ClassTag[_]): Boolean
+
+ /**
+ * Release locks acquired by [[putBlockData()]] and [[getBlockData()]].
*/
- def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit
+ def releaseLock(blockId: BlockId, taskAttemptId: Option[Long]): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
index dcbda5a8515d..fe5fd2da039b 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
@@ -20,13 +20,15 @@ package org.apache.spark.network
import java.io.Closeable
import java.nio.ByteBuffer
-import scala.concurrent.{Promise, Await, Future}
+import scala.concurrent.{Future, Promise}
import scala.concurrent.duration.Duration
+import scala.reflect.ClassTag
-import org.apache.spark.Logging
-import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer}
-import org.apache.spark.network.shuffle.{ShuffleClient, BlockFetchingListener}
-import org.apache.spark.storage.{BlockManagerId, BlockId, StorageLevel}
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager}
+import org.apache.spark.storage.{BlockId, StorageLevel}
+import org.apache.spark.util.ThreadUtils
private[spark]
abstract class BlockTransferService extends ShuffleClient with Closeable with Logging {
@@ -35,7 +37,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
* Initialize the transfer service by giving it the BlockDataManager that can be used to fetch
* local blocks or put local blocks.
*/
- def init(blockDataManager: BlockDataManager)
+ def init(blockDataManager: BlockDataManager): Unit
/**
* Tear down the transfer service.
@@ -65,7 +67,8 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
port: Int,
execId: String,
blockIds: Array[String],
- listener: BlockFetchingListener): Unit
+ listener: BlockFetchingListener,
+ tempShuffleFileManager: TempShuffleFileManager): Unit
/**
* Upload a single block to a remote node, available only after [[init]] is invoked.
@@ -76,7 +79,8 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
execId: String,
blockId: BlockId,
blockData: ManagedBuffer,
- level: StorageLevel): Future[Unit]
+ level: StorageLevel,
+ classTag: ClassTag[_]): Future[Unit]
/**
* A special case of [[fetchBlocks]], as it fetches only one block and is blocking.
@@ -97,9 +101,8 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
ret.flip()
result.success(new NioManagedBuffer(ret))
}
- })
-
- Await.result(result.future, Duration.Inf)
+ }, tempShuffleFileManager = null)
+ ThreadUtils.awaitResult(result.future, Duration.Inf)
}
/**
@@ -114,7 +117,9 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
execId: String,
blockId: BlockId,
blockData: ManagedBuffer,
- level: StorageLevel): Unit = {
- Await.result(uploadBlock(hostname, port, execId, blockId, blockData, level), Duration.Inf)
+ level: StorageLevel,
+ classTag: ClassTag[_]): Unit = {
+ val future = uploadBlock(hostname, port, execId, blockId, blockData, level, classTag)
+ ThreadUtils.awaitResult(future, Duration.Inf)
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
index 76968249fb62..305fd9a6de10 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
@@ -20,8 +20,10 @@ package org.apache.spark.network.netty
import java.nio.ByteBuffer
import scala.collection.JavaConverters._
+import scala.language.existentials
+import scala.reflect.ClassTag
-import org.apache.spark.Logging
+import org.apache.spark.internal.Logging
import org.apache.spark.network.BlockDataManager
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
@@ -47,26 +49,32 @@ class NettyBlockRpcServer(
override def receive(
client: TransportClient,
- messageBytes: Array[Byte],
+ rpcMessage: ByteBuffer,
responseContext: RpcResponseCallback): Unit = {
- val message = BlockTransferMessage.Decoder.fromByteArray(messageBytes)
+ val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage)
logTrace(s"Received request: $message")
message match {
case openBlocks: OpenBlocks =>
- val blocks: Seq[ManagedBuffer] =
- openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
+ val blocksNum = openBlocks.blockIds.length
+ val blocks = for (i <- (0 until blocksNum).view)
+ yield blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i)))
val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
- logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
- responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray)
+ logTrace(s"Registered streamId $streamId with $blocksNum buffers")
+ responseContext.onSuccess(new StreamHandle(streamId, blocksNum).toByteBuffer)
case uploadBlock: UploadBlock =>
- // StorageLevel is serialized as bytes using our JavaSerializer.
- val level: StorageLevel =
- serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata))
+ // StorageLevel and ClassTag are serialized as bytes using our JavaSerializer.
+ val (level: StorageLevel, classTag: ClassTag[_]) = {
+ serializer
+ .newInstance()
+ .deserialize(ByteBuffer.wrap(uploadBlock.metadata))
+ .asInstanceOf[(StorageLevel, ClassTag[_])]
+ }
val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
- blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level)
- responseContext.onSuccess(new Array[Byte](0))
+ val blockId = BlockId(uploadBlock.blockId)
+ blockManager.putBlockData(blockId, data, level, classTag)
+ responseContext.onSuccess(ByteBuffer.allocate(0))
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index 70a42f9045e6..30ff93897f98 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -17,31 +17,41 @@
package org.apache.spark.network.netty
+import java.nio.ByteBuffer
+
import scala.collection.JavaConverters._
import scala.concurrent.{Future, Promise}
+import scala.reflect.ClassTag
import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.network._
import org.apache.spark.network.buffer.ManagedBuffer
-import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory}
-import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap}
+import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory}
+import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap}
import org.apache.spark.network.server._
-import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, TempShuffleFileManager}
import org.apache.spark.network.shuffle.protocol.UploadBlock
+import org.apache.spark.network.util.JavaUtils
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils
/**
- * A BlockTransferService that uses Netty to fetch a set of blocks at at time.
+ * A BlockTransferService that uses Netty to fetch a set of blocks at time.
*/
-class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManager, numCores: Int)
+private[spark] class NettyBlockTransferService(
+ conf: SparkConf,
+ securityManager: SecurityManager,
+ bindAddress: String,
+ override val hostName: String,
+ _port: Int,
+ numCores: Int)
extends BlockTransferService {
// TODO: Don't use Java serialization, use a more cross-version compatible serialization format.
private val serializer = new JavaSerializer(conf)
private val authEnabled = securityManager.isAuthenticationEnabled()
- private val transportConf = SparkTransportConf.fromSparkConf(conf, numCores)
+ private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores)
private[this] var transportContext: TransportContext = _
private[this] var server: TransportServer = _
@@ -53,26 +63,24 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
var serverBootstrap: Option[TransportServerBootstrap] = None
var clientBootstrap: Option[TransportClientBootstrap] = None
if (authEnabled) {
- serverBootstrap = Some(new SaslServerBootstrap(transportConf, securityManager))
- clientBootstrap = Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager,
- securityManager.isSaslEncryptionEnabled()))
+ serverBootstrap = Some(new AuthServerBootstrap(transportConf, securityManager))
+ clientBootstrap = Some(new AuthClientBootstrap(transportConf, conf.getAppId, securityManager))
}
transportContext = new TransportContext(transportConf, rpcHandler)
clientFactory = transportContext.createClientFactory(clientBootstrap.toSeq.asJava)
server = createServer(serverBootstrap.toList)
appId = conf.getAppId
- logInfo("Server created on " + server.getPort)
+ logInfo(s"Server created on ${hostName}:${server.getPort}")
}
/** Creates and binds the TransportServer, possibly trying multiple ports. */
private def createServer(bootstraps: List[TransportServerBootstrap]): TransportServer = {
def startService(port: Int): (TransportServer, Int) = {
- val server = transportContext.createServer(port, bootstraps.asJava)
+ val server = transportContext.createServer(bindAddress, port, bootstraps.asJava)
(server, server.getPort)
}
- val portToTry = conf.getInt("spark.blockManager.port", 0)
- Utils.startServiceOnPort(portToTry, startService, conf, getClass.getName)._1
+ Utils.startServiceOnPort(_port, startService, conf, getClass.getName)._1
}
override def fetchBlocks(
@@ -80,13 +88,15 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
port: Int,
execId: String,
blockIds: Array[String],
- listener: BlockFetchingListener): Unit = {
+ listener: BlockFetchingListener,
+ tempShuffleFileManager: TempShuffleFileManager): Unit = {
logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
try {
val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
val client = clientFactory.createClient(host, port)
- new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start()
+ new OneForOneBlockFetcher(client, appId, execId, blockIds, listener,
+ transportConf, tempShuffleFileManager).start()
}
}
@@ -105,8 +115,6 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
}
}
- override def hostName: String = Utils.localHostName()
-
override def port: Int = server.getPort
override def uploadBlock(
@@ -115,27 +123,21 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
execId: String,
blockId: BlockId,
blockData: ManagedBuffer,
- level: StorageLevel): Future[Unit] = {
+ level: StorageLevel,
+ classTag: ClassTag[_]): Future[Unit] = {
val result = Promise[Unit]()
val client = clientFactory.createClient(hostname, port)
- // StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded
- // using our binary protocol.
- val levelBytes = serializer.newInstance().serialize(level).array()
+ // StorageLevel and ClassTag are serialized as bytes using our JavaSerializer.
+ // Everything else is encoded using our binary protocol.
+ val metadata = JavaUtils.bufferToArray(serializer.newInstance().serialize((level, classTag)))
// Convert or copy nio buffer into array in order to serialize it.
- val nioBuffer = blockData.nioByteBuffer()
- val array = if (nioBuffer.hasArray) {
- nioBuffer.array()
- } else {
- val data = new Array[Byte](nioBuffer.remaining())
- nioBuffer.get(data)
- data
- }
+ val array = JavaUtils.bufferToArray(blockData.nioByteBuffer())
- client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteArray,
+ client.sendRpc(new UploadBlock(appId, execId, blockId.toString, metadata, array).toByteBuffer,
new RpcResponseCallback {
- override def onSuccess(response: Array[Byte]): Unit = {
+ override def onSuccess(response: ByteBuffer): Unit = {
logTrace(s"Successfully uploaded block $blockId")
result.success((): Unit)
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala
index cef203006d68..25f7bcb9801b 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala
@@ -17,8 +17,10 @@
package org.apache.spark.network.netty
+import scala.collection.JavaConverters._
+
import org.apache.spark.SparkConf
-import org.apache.spark.network.util.{TransportConf, ConfigProvider}
+import org.apache.spark.network.util.{ConfigProvider, TransportConf}
/**
* Provides a utility for transforming from a SparkConf inside a Spark JVM (e.g., Executor,
@@ -40,24 +42,28 @@ object SparkTransportConf {
/**
* Utility for creating a [[TransportConf]] from a [[SparkConf]].
+ * @param _conf the [[SparkConf]]
+ * @param module the module name
* @param numUsableCores if nonzero, this will restrict the server and client threads to only
* use the given number of cores, rather than all of the machine's cores.
* This restriction will only occur if these properties are not already set.
*/
- def fromSparkConf(_conf: SparkConf, numUsableCores: Int = 0): TransportConf = {
+ def fromSparkConf(_conf: SparkConf, module: String, numUsableCores: Int = 0): TransportConf = {
val conf = _conf.clone
// Specify thread configuration based on our JVM's allocation of cores (rather than necessarily
// assuming we have all the machine's cores).
// NB: Only set if serverThreads/clientThreads not already set.
val numThreads = defaultNumThreads(numUsableCores)
- conf.set("spark.shuffle.io.serverThreads",
- conf.get("spark.shuffle.io.serverThreads", numThreads.toString))
- conf.set("spark.shuffle.io.clientThreads",
- conf.get("spark.shuffle.io.clientThreads", numThreads.toString))
+ conf.setIfMissing(s"spark.$module.io.serverThreads", numThreads.toString)
+ conf.setIfMissing(s"spark.$module.io.clientThreads", numThreads.toString)
- new TransportConf(new ConfigProvider {
+ new TransportConf(module, new ConfigProvider {
override def get(name: String): String = conf.get(name)
+ override def get(name: String, defaultValue: String): String = conf.get(name, defaultValue)
+ override def getAll(): java.lang.Iterable[java.util.Map.Entry[String, String]] = {
+ conf.getAll.toMap.asJava.entrySet()
+ }
})
}
diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala
index 7515aad09db7..2610d6f6e45a 100644
--- a/core/src/main/scala/org/apache/spark/package.scala
+++ b/core/src/main/scala/org/apache/spark/package.scala
@@ -41,7 +41,58 @@ package org.apache
* level interfaces. These are subject to changes or removal in minor releases.
*/
+import java.util.Properties
+
package object spark {
- // For package docs only
- val SPARK_VERSION = "1.6.0-SNAPSHOT"
+
+ private object SparkBuildInfo {
+
+ val (
+ spark_version: String,
+ spark_branch: String,
+ spark_revision: String,
+ spark_build_user: String,
+ spark_repo_url: String,
+ spark_build_date: String) = {
+
+ val resourceStream = Thread.currentThread().getContextClassLoader.
+ getResourceAsStream("spark-version-info.properties")
+
+ try {
+ val unknownProp = ""
+ val props = new Properties()
+ props.load(resourceStream)
+ (
+ props.getProperty("version", unknownProp),
+ props.getProperty("branch", unknownProp),
+ props.getProperty("revision", unknownProp),
+ props.getProperty("user", unknownProp),
+ props.getProperty("url", unknownProp),
+ props.getProperty("date", unknownProp)
+ )
+ } catch {
+ case npe: NullPointerException =>
+ throw new SparkException("Error while locating file spark-version-info.properties", npe)
+ case e: Exception =>
+ throw new SparkException("Error loading properties from spark-version-info.properties", e)
+ } finally {
+ if (resourceStream != null) {
+ try {
+ resourceStream.close()
+ } catch {
+ case e: Exception =>
+ throw new SparkException("Error closing spark build info resource stream", e)
+ }
+ }
+ }
+ }
+ }
+
+ val SPARK_VERSION = SparkBuildInfo.spark_version
+ val SPARK_BRANCH = SparkBuildInfo.spark_branch
+ val SPARK_REVISION = SparkBuildInfo.spark_revision
+ val SPARK_BUILD_USER = SparkBuildInfo.spark_build_user
+ val SPARK_REPO_URL = SparkBuildInfo.spark_repo_url
+ val SPARK_BUILD_DATE = SparkBuildInfo.spark_build_date
}
+
diff --git a/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala
index d25452daf760..b089bbd7e972 100644
--- a/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala
+++ b/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala
@@ -38,7 +38,7 @@ private[spark] class ApproximateActionListener[T, U, R](
extends JobListener {
val startTime = System.currentTimeMillis()
- val totalTasks = rdd.partitions.size
+ val totalTasks = rdd.partitions.length
var finishedTasks = 0
var failure: Option[Exception] = None // Set if the job has failed (permanently)
var resultObject: Option[PartialResult[R]] = None // Set if we've already returned a PartialResult
diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala
index 48b943415317..8f579c5a3033 100644
--- a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala
+++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala
@@ -21,5 +21,22 @@ package org.apache.spark.partial
* A Double value with error bars and associated confidence.
*/
class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) {
+
override def toString(): String = "[%.3f, %.3f]".format(low, high)
+
+ override def hashCode: Int =
+ this.mean.hashCode ^ this.confidence.hashCode ^ this.low.hashCode ^ this.high.hashCode
+
+ /**
+ * @note Consistent with Double, any NaN value will make equality false
+ */
+ override def equals(that: Any): Boolean =
+ that match {
+ case that: BoundedDouble =>
+ this.mean == that.mean &&
+ this.confidence == that.confidence &&
+ this.low == that.low &&
+ this.high == that.high
+ case _ => false
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala
index 637492a97551..5a5bd7fbbe2f 100644
--- a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala
+++ b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala
@@ -17,21 +17,18 @@
package org.apache.spark.partial
-import org.apache.commons.math3.distribution.NormalDistribution
+import org.apache.commons.math3.distribution.{PascalDistribution, PoissonDistribution}
/**
* An ApproximateEvaluator for counts.
- *
- * TODO: There's currently a lot of shared code between this and GroupedCountEvaluator. It might
- * be best to make this a special case of GroupedCountEvaluator with one group.
*/
private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[Long, BoundedDouble] {
- var outputsMerged = 0
- var sum: Long = 0
+ private var outputsMerged = 0
+ private var sum: Long = 0
- override def merge(outputId: Int, taskResult: Long) {
+ override def merge(outputId: Int, taskResult: Long): Unit = {
outputsMerged += 1
sum += taskResult
}
@@ -39,18 +36,40 @@ private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double)
override def currentResult(): BoundedDouble = {
if (outputsMerged == totalOutputs) {
new BoundedDouble(sum, 1.0, sum, sum)
- } else if (outputsMerged == 0) {
- new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
+ } else if (outputsMerged == 0 || sum == 0) {
+ new BoundedDouble(0, 0.0, 0.0, Double.PositiveInfinity)
} else {
val p = outputsMerged.toDouble / totalOutputs
- val mean = (sum + 1 - p) / p
- val variance = (sum + 1) * (1 - p) / (p * p)
- val stdev = math.sqrt(variance)
- val confFactor = new NormalDistribution().
- inverseCumulativeProbability(1 - (1 - confidence) / 2)
- val low = mean - confFactor * stdev
- val high = mean + confFactor * stdev
- new BoundedDouble(mean, confidence, low, high)
+ CountEvaluator.bound(confidence, sum, p)
}
}
}
+
+private[partial] object CountEvaluator {
+
+ def bound(confidence: Double, sum: Long, p: Double): BoundedDouble = {
+ // Let the total count be N. A fraction p has been counted already, with sum 'sum',
+ // as if each element from the total data set had been seen with probability p.
+ val dist =
+ if (sum <= 10000) {
+ // The remaining count, k=N-sum, may be modeled as negative binomial (aka Pascal),
+ // where there have been 'sum' successes of probability p already. (There are several
+ // conventions, but this is the one followed by Commons Math3.)
+ new PascalDistribution(sum.toInt, p)
+ } else {
+ // For large 'sum' (certainly, > Int.MaxValue!), use a Poisson approximation, which has
+ // a different interpretation. "sum" elements have been observed having scanned a fraction
+ // p of the data. This suggests data is counted at a rate of sum / p across the whole data
+ // set. The total expected count from the rest is distributed as
+ // (1-p) Poisson(sum / p) = Poisson(sum*(1-p)/p)
+ new PoissonDistribution(sum * (1 - p) / p)
+ }
+ // Not quite symmetric; calculate interval straight from discrete distribution
+ val low = dist.inverseCumulativeProbability((1 - confidence) / 2)
+ val high = dist.inverseCumulativeProbability((1 + confidence) / 2)
+ // Add 'sum' to each because distribution is just of remaining count, not observed
+ new BoundedDouble(sum + dist.getNumericalMean, confidence, sum + low, sum + high)
+ }
+
+
+}
diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala
index 5afce75680f9..d2b4187df5d5 100644
--- a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala
+++ b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala
@@ -17,15 +17,10 @@
package org.apache.spark.partial
-import java.util.{HashMap => JHashMap}
-
-import scala.collection.JavaConverters._
import scala.collection.Map
import scala.collection.mutable.HashMap
import scala.reflect.ClassTag
-import org.apache.commons.math3.distribution.NormalDistribution
-
import org.apache.spark.util.collection.OpenHashMap
/**
@@ -34,10 +29,10 @@ import org.apache.spark.util.collection.OpenHashMap
private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[OpenHashMap[T, Long], Map[T, BoundedDouble]] {
- var outputsMerged = 0
- var sums = new OpenHashMap[T, Long]() // Sum of counts for each key
+ private var outputsMerged = 0
+ private val sums = new OpenHashMap[T, Long]() // Sum of counts for each key
- override def merge(outputId: Int, taskResult: OpenHashMap[T, Long]) {
+ override def merge(outputId: Int, taskResult: OpenHashMap[T, Long]): Unit = {
outputsMerged += 1
taskResult.foreach { case (key, value) =>
sums.changeValue(key, value, _ + value)
@@ -46,27 +41,12 @@ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, conf
override def currentResult(): Map[T, BoundedDouble] = {
if (outputsMerged == totalOutputs) {
- val result = new JHashMap[T, BoundedDouble](sums.size)
- sums.foreach { case (key, sum) =>
- result.put(key, new BoundedDouble(sum, 1.0, sum, sum))
- }
- result.asScala
+ sums.map { case (key, sum) => (key, new BoundedDouble(sum, 1.0, sum, sum)) }.toMap
} else if (outputsMerged == 0) {
new HashMap[T, BoundedDouble]
} else {
val p = outputsMerged.toDouble / totalOutputs
- val confFactor = new NormalDistribution().
- inverseCumulativeProbability(1 - (1 - confidence) / 2)
- val result = new JHashMap[T, BoundedDouble](sums.size)
- sums.foreach { case (key, sum) =>
- val mean = (sum + 1 - p) / p
- val variance = (sum + 1) * (1 - p) / (p * p)
- val stdev = math.sqrt(variance)
- val low = mean - confFactor * stdev
- val high = mean + confFactor * stdev
- result.put(key, new BoundedDouble(mean, confidence, low, high))
- }
- result.asScala
+ sums.map { case (key, sum) => (key, CountEvaluator.bound(confidence, sum, p)) }.toMap
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala
deleted file mode 100644
index a16404068480..000000000000
--- a/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala
+++ /dev/null
@@ -1,80 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.partial
-
-import java.util.{HashMap => JHashMap}
-
-import scala.collection.JavaConverters._
-import scala.collection.Map
-import scala.collection.mutable.HashMap
-
-import org.apache.spark.util.StatCounter
-
-/**
- * An ApproximateEvaluator for means by key. Returns a map of key to confidence interval.
- */
-private[spark] class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Double)
- extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] {
-
- var outputsMerged = 0
- var sums = new JHashMap[T, StatCounter] // Sum of counts for each key
-
- override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) {
- outputsMerged += 1
- val iter = taskResult.entrySet.iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- val old = sums.get(entry.getKey)
- if (old != null) {
- old.merge(entry.getValue)
- } else {
- sums.put(entry.getKey, entry.getValue)
- }
- }
- }
-
- override def currentResult(): Map[T, BoundedDouble] = {
- if (outputsMerged == totalOutputs) {
- val result = new JHashMap[T, BoundedDouble](sums.size)
- val iter = sums.entrySet.iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- val mean = entry.getValue.mean
- result.put(entry.getKey, new BoundedDouble(mean, 1.0, mean, mean))
- }
- result.asScala
- } else if (outputsMerged == 0) {
- new HashMap[T, BoundedDouble]
- } else {
- val studentTCacher = new StudentTCacher(confidence)
- val result = new JHashMap[T, BoundedDouble](sums.size)
- val iter = sums.entrySet.iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- val counter = entry.getValue
- val mean = counter.mean
- val stdev = math.sqrt(counter.sampleVariance / counter.count)
- val confFactor = studentTCacher.get(counter.count)
- val low = mean - confFactor * stdev
- val high = mean + confFactor * stdev
- result.put(entry.getKey, new BoundedDouble(mean, confidence, low, high))
- }
- result.asScala
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala
deleted file mode 100644
index 54a1beab3514..000000000000
--- a/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala
+++ /dev/null
@@ -1,88 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.partial
-
-import java.util.{HashMap => JHashMap}
-
-import scala.collection.JavaConverters._
-import scala.collection.Map
-import scala.collection.mutable.HashMap
-
-import org.apache.spark.util.StatCounter
-
-/**
- * An ApproximateEvaluator for sums by key. Returns a map of key to confidence interval.
- */
-private[spark] class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Double)
- extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] {
-
- var outputsMerged = 0
- var sums = new JHashMap[T, StatCounter] // Sum of counts for each key
-
- override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) {
- outputsMerged += 1
- val iter = taskResult.entrySet.iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- val old = sums.get(entry.getKey)
- if (old != null) {
- old.merge(entry.getValue)
- } else {
- sums.put(entry.getKey, entry.getValue)
- }
- }
- }
-
- override def currentResult(): Map[T, BoundedDouble] = {
- if (outputsMerged == totalOutputs) {
- val result = new JHashMap[T, BoundedDouble](sums.size)
- val iter = sums.entrySet.iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- val sum = entry.getValue.sum
- result.put(entry.getKey, new BoundedDouble(sum, 1.0, sum, sum))
- }
- result.asScala
- } else if (outputsMerged == 0) {
- new HashMap[T, BoundedDouble]
- } else {
- val p = outputsMerged.toDouble / totalOutputs
- val studentTCacher = new StudentTCacher(confidence)
- val result = new JHashMap[T, BoundedDouble](sums.size)
- val iter = sums.entrySet.iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- val counter = entry.getValue
- val meanEstimate = counter.mean
- val meanVar = counter.sampleVariance / counter.count
- val countEstimate = (counter.count + 1 - p) / p
- val countVar = (counter.count + 1) * (1 - p) / (p * p)
- val sumEstimate = meanEstimate * countEstimate
- val sumVar = (meanEstimate * meanEstimate * countVar) +
- (countEstimate * countEstimate * meanVar) +
- (meanVar * countVar)
- val sumStdev = math.sqrt(sumVar)
- val confFactor = studentTCacher.get(counter.count)
- val low = sumEstimate - confFactor * sumStdev
- val high = sumEstimate + confFactor * sumStdev
- result.put(entry.getKey, new BoundedDouble(sumEstimate, confidence, low, high))
- }
- result.asScala
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala
index 787a21a61fdc..3fb2d30a800b 100644
--- a/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala
+++ b/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala
@@ -27,10 +27,10 @@ import org.apache.spark.util.StatCounter
private[spark] class MeanEvaluator(totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[StatCounter, BoundedDouble] {
- var outputsMerged = 0
- var counter = new StatCounter
+ private var outputsMerged = 0
+ private val counter = new StatCounter()
- override def merge(outputId: Int, taskResult: StatCounter) {
+ override def merge(outputId: Int, taskResult: StatCounter): Unit = {
outputsMerged += 1
counter.merge(taskResult)
}
@@ -38,19 +38,24 @@ private[spark] class MeanEvaluator(totalOutputs: Int, confidence: Double)
override def currentResult(): BoundedDouble = {
if (outputsMerged == totalOutputs) {
new BoundedDouble(counter.mean, 1.0, counter.mean, counter.mean)
- } else if (outputsMerged == 0) {
+ } else if (outputsMerged == 0 || counter.count == 0) {
new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
+ } else if (counter.count == 1) {
+ new BoundedDouble(counter.mean, confidence, Double.NegativeInfinity, Double.PositiveInfinity)
} else {
val mean = counter.mean
val stdev = math.sqrt(counter.sampleVariance / counter.count)
- val confFactor = {
- if (counter.count > 100) {
- new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2)
+ val confFactor = if (counter.count > 100) {
+ // For large n, the normal distribution is a good approximation to t-distribution
+ new NormalDistribution().inverseCumulativeProbability((1 + confidence) / 2)
} else {
+ // t-distribution describes distribution of actual population mean
+ // note that if this goes to 0, TDistribution will throw an exception.
+ // Hence special casing 1 above.
val degreesOfFreedom = (counter.count - 1).toInt
- new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2)
+ new TDistribution(degreesOfFreedom).inverseCumulativeProbability((1 + confidence) / 2)
}
- }
+ // Symmetric, so confidence interval is symmetric about mean of distribution
val low = mean - confFactor * stdev
val high = mean + confFactor * stdev
new BoundedDouble(mean, confidence, low, high)
diff --git a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala b/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala
deleted file mode 100644
index 828bf96c2c0b..000000000000
--- a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala
+++ /dev/null
@@ -1,46 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.partial
-
-import org.apache.commons.math3.distribution.{TDistribution, NormalDistribution}
-
-/**
- * A utility class for caching Student's T distribution values for a given confidence level
- * and various sample sizes. This is used by the MeanEvaluator to efficiently calculate
- * confidence intervals for many keys.
- */
-private[spark] class StudentTCacher(confidence: Double) {
-
- val NORMAL_APPROX_SAMPLE_SIZE = 100 // For samples bigger than this, use Gaussian approximation
-
- val normalApprox = new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2)
- val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0)
-
- def get(sampleSize: Long): Double = {
- if (sampleSize >= NORMAL_APPROX_SAMPLE_SIZE) {
- normalApprox
- } else {
- val size = sampleSize.toInt
- if (cache(size) < 0) {
- val tDist = new TDistribution(size - 1)
- cache(size) = tDist.inverseCumulativeProbability(1 - (1 - confidence) / 2)
- }
- cache(size)
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala
index 1753c2561b67..1988052b733e 100644
--- a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala
+++ b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala
@@ -17,7 +17,7 @@
package org.apache.spark.partial
-import org.apache.commons.math3.distribution.{TDistribution, NormalDistribution}
+import org.apache.commons.math3.distribution.{NormalDistribution, TDistribution}
import org.apache.spark.util.StatCounter
@@ -29,10 +29,11 @@ import org.apache.spark.util.StatCounter
private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[StatCounter, BoundedDouble] {
- var outputsMerged = 0
- var counter = new StatCounter
+ // modified in merge
+ private var outputsMerged = 0
+ private val counter = new StatCounter()
- override def merge(outputId: Int, taskResult: StatCounter) {
+ override def merge(outputId: Int, taskResult: StatCounter): Unit = {
outputsMerged += 1
counter.merge(taskResult)
}
@@ -40,30 +41,50 @@ private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double)
override def currentResult(): BoundedDouble = {
if (outputsMerged == totalOutputs) {
new BoundedDouble(counter.sum, 1.0, counter.sum, counter.sum)
- } else if (outputsMerged == 0) {
+ } else if (outputsMerged == 0 || counter.count == 0) {
new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
} else {
val p = outputsMerged.toDouble / totalOutputs
+ // Expected value of unobserved is presumed equal to that of the observed data
val meanEstimate = counter.mean
- val meanVar = counter.sampleVariance / counter.count
- val countEstimate = (counter.count + 1 - p) / p
- val countVar = (counter.count + 1) * (1 - p) / (p * p)
+ // Expected size of rest of the data is proportional
+ val countEstimate = counter.count * (1 - p) / p
+ // Expected sum is simply their product
val sumEstimate = meanEstimate * countEstimate
- val sumVar = (meanEstimate * meanEstimate * countVar) +
- (countEstimate * countEstimate * meanVar) +
- (meanVar * countVar)
- val sumStdev = math.sqrt(sumVar)
- val confFactor = {
- if (counter.count > 100) {
- new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2)
+
+ // Variance of unobserved data is presumed equal to that of the observed data
+ val meanVar = counter.sampleVariance / counter.count
+
+ // branch at this point because count == 1 implies counter.sampleVariance == Nan
+ // and we don't want to ever return a bound of NaN
+ if (meanVar.isNaN || counter.count == 1) {
+ // add sum because estimate is of unobserved data sum
+ new BoundedDouble(
+ counter.sum + sumEstimate, confidence, Double.NegativeInfinity, Double.PositiveInfinity)
+ } else {
+ // See CountEvaluator. Variance of population count here follows from negative binomial
+ val countVar = counter.count * (1 - p) / (p * p)
+ // Var(Sum) = Var(Mean*Count) =
+ // [E(Mean)]^2 * Var(Count) + [E(Count)]^2 * Var(Mean) + Var(Mean) * Var(Count)
+ val sumVar = (meanEstimate * meanEstimate * countVar) +
+ (countEstimate * countEstimate * meanVar) +
+ (meanVar * countVar)
+ val sumStdev = math.sqrt(sumVar)
+ val confFactor = if (counter.count > 100) {
+ new NormalDistribution().inverseCumulativeProbability((1 + confidence) / 2)
} else {
+ // note that if this goes to 0, TDistribution will throw an exception.
+ // Hence special casing 1 above.
val degreesOfFreedom = (counter.count - 1).toInt
- new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2)
+ new TDistribution(degreesOfFreedom).inverseCumulativeProbability((1 + confidence) / 2)
}
+ // Symmetric, so confidence interval is symmetric about mean of distribution
+ val low = sumEstimate - confFactor * sumStdev
+ val high = sumEstimate + confFactor * sumStdev
+ // add sum because estimate is of unobserved data sum
+ new BoundedDouble(
+ counter.sum + sumEstimate, confidence, counter.sum + low, counter.sum + high)
}
- val low = sumEstimate - confFactor * sumStdev
- val high = sumEstimate + confFactor * sumStdev
- new BoundedDouble(sumEstimate, confidence, low, high)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
index ca1eb1f4e4a9..c9ed12f4e1bd 100644
--- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -19,13 +19,13 @@ package org.apache.spark.rdd
import java.util.concurrent.atomic.AtomicLong
-import org.apache.spark.util.ThreadUtils
-
import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.ExecutionContext
+import scala.concurrent.{ExecutionContext, Future}
import scala.reflect.ClassTag
-import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
+import org.apache.spark.{ComplexFutureAction, FutureAction, JobSubmitter}
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.ThreadUtils
/**
* A set of asynchronous RDD actions available through an implicit conversion.
@@ -65,18 +65,26 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
* Returns a future for retrieving the first num elements of the RDD.
*/
def takeAsync(num: Int): FutureAction[Seq[T]] = self.withScope {
- val f = new ComplexFutureAction[Seq[T]]
-
- f.run {
- // This is a blocking action so we should use "AsyncRDDActions.futureExecutionContext" which
- // is a cached thread pool.
- val results = new ArrayBuffer[T](num)
- val totalParts = self.partitions.length
- var partsScanned = 0
- while (results.size < num && partsScanned < totalParts) {
+ val callSite = self.context.getCallSite
+ val localProperties = self.context.getLocalProperties
+ // Cached thread pool to handle aggregation of subtasks.
+ implicit val executionContext = AsyncRDDActions.futureExecutionContext
+ val results = new ArrayBuffer[T]
+ val totalParts = self.partitions.length
+
+ /*
+ Recursively triggers jobs to scan partitions until either the requested
+ number of elements are retrieved, or the partitions to scan are exhausted.
+ This implementation is non-blocking, asynchronously handling the
+ results of each job and triggering the next job using callbacks on futures.
+ */
+ def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter): Future[Seq[T]] =
+ if (results.size >= num || partsScanned >= totalParts) {
+ Future.successful(results.toSeq)
+ } else {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
- var numPartsToTry = 1
+ var numPartsToTry = 1L
if (partsScanned > 0) {
// If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
@@ -92,22 +100,23 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
}
val left = num - results.size
- val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
+ val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
val buf = new Array[Array[T]](p.size)
- f.runJob(self,
+ self.context.setCallSite(callSite)
+ self.context.setLocalProperties(localProperties)
+ val job = jobSubmitter.submitJob(self,
(it: Iterator[T]) => it.take(left).toArray,
p,
(index: Int, data: Array[T]) => buf(index) = data,
Unit)
-
- buf.foreach(results ++= _.take(num - results.size))
- partsScanned += numPartsToTry
+ job.flatMap { _ =>
+ buf.foreach(results ++= _.take(num - results.size))
+ continue(partsScanned + p.size)
+ }
}
- results.toSeq
- }(AsyncRDDActions.futureExecutionContext)
- f
+ new ComplexFutureAction[Seq[T]](continue(0)(_))
}
/**
diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala
index aedced7408cd..50d977a92da5 100644
--- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala
@@ -17,14 +17,16 @@
package org.apache.spark.rdd
-import org.apache.hadoop.conf.{ Configurable, Configuration }
+import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
+import org.apache.hadoop.mapreduce.task.JobContextImpl
+
+import org.apache.spark.{Partition, SparkContext}
import org.apache.spark.input.StreamFileInputFormat
-import org.apache.spark.{ Partition, SparkContext }
private[spark] class BinaryFileRDD[T](
- sc: SparkContext,
+ @transient private val sc: SparkContext,
inputFormatClass: Class[_ <: StreamFileInputFormat[T]],
keyClass: Class[String],
valueClass: Class[T],
@@ -40,8 +42,8 @@ private[spark] class BinaryFileRDD[T](
configurable.setConf(conf)
case _ =>
}
- val jobContext = newJobContext(conf, jobId)
- inputFormat.setMinPartitions(jobContext, minPartitions)
+ val jobContext = new JobContextImpl(conf, jobId)
+ inputFormat.setMinPartitions(sc, jobContext, minPartitions)
val rawSplits = inputFormat.getSplits(jobContext).toArray
val result = new Array[Partition](rawSplits.size)
for (i <- 0 until rawSplits.size) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
index fc1710fbad0a..4e036c2ed49b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
@@ -21,7 +21,6 @@ import scala.reflect.ClassTag
import org.apache.spark._
import org.apache.spark.storage.{BlockId, BlockManager}
-import scala.Some
private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends Partition {
val index = idx
@@ -36,19 +35,19 @@ class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[Blo
override def getPartitions: Array[Partition] = {
assertValid()
- (0 until blockIds.length).map(i => {
+ (0 until blockIds.length).map { i =>
new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition]
- }).toArray
+ }.toArray
}
override def compute(split: Partition, context: TaskContext): Iterator[T] = {
assertValid()
val blockManager = SparkEnv.get.blockManager
val blockId = split.asInstanceOf[BlockRDDPartition].blockId
- blockManager.get(blockId) match {
+ blockManager.get[T](blockId) match {
case Some(block) => block.data.asInstanceOf[Iterator[T]]
case None =>
- throw new Exception("Could not compute split, block " + blockId + " not found")
+ throw new Exception(s"Could not compute split, block $blockId of RDD $id not found")
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
index 18e8cddbc40d..57108dcedcf0 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
@@ -50,7 +50,7 @@ class CartesianRDD[T: ClassTag, U: ClassTag](
sc: SparkContext,
var rdd1 : RDD[T],
var rdd2 : RDD[U])
- extends RDD[Pair[T, U]](sc, Nil)
+ extends RDD[(T, U)](sc, Nil)
with Serializable {
val numPartitionsInRdd2 = rdd2.partitions.length
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 935c3babd8ea..a091f06b4ed7 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -17,23 +17,24 @@
package org.apache.spark.rdd
-import scala.language.existentials
-
import java.io.{IOException, ObjectOutputStream}
import scala.collection.mutable.ArrayBuffer
+import scala.language.existentials
import scala.reflect.ClassTag
import org.apache.spark._
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.serializer.Serializer
import org.apache.spark.util.collection.{CompactBuffer, ExternalAppendOnlyMap}
import org.apache.spark.util.Utils
-import org.apache.spark.serializer.Serializer
-/** The references to rdd and splitIndex are transient because redundant information is stored
- * in the CoGroupedRDD object. Because CoGroupedRDD is serialized separately from
- * CoGroupPartition, if rdd and splitIndex aren't transient, they'll be included twice in the
- * task closure. */
+/**
+ * The references to rdd and splitIndex are transient because redundant information is stored
+ * in the CoGroupedRDD object. Because CoGroupedRDD is serialized separately from
+ * CoGroupPartition, if rdd and splitIndex aren't transient, they'll be included twice in the
+ * task closure.
+ */
private[spark] case class NarrowCoGroupSplitDep(
@transient rdd: RDD[_],
@transient splitIndex: Int,
@@ -57,22 +58,22 @@ private[spark] case class NarrowCoGroupSplitDep(
* narrowDeps should always be equal to the number of parents.
*/
private[spark] class CoGroupPartition(
- idx: Int, val narrowDeps: Array[Option[NarrowCoGroupSplitDep]])
+ override val index: Int, val narrowDeps: Array[Option[NarrowCoGroupSplitDep]])
extends Partition with Serializable {
- override val index: Int = idx
- override def hashCode(): Int = idx
+ override def hashCode(): Int = index
+ override def equals(other: Any): Boolean = super.equals(other)
}
/**
* :: DeveloperApi ::
- * A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a
+ * An RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a
* tuple with the list of values for that key.
*
- * Note: This is an internal API. We recommend users use RDD.cogroup(...) instead of
- * instantiating this directly.
-
* @param rdds parent RDDs.
* @param part partitioner used to partition the shuffle output
+ *
+ * @note This is an internal API. We recommend users use RDD.cogroup(...) instead of
+ * instantiating this directly.
*/
@DeveloperApi
class CoGroupedRDD[K: ClassTag](
@@ -87,11 +88,11 @@ class CoGroupedRDD[K: ClassTag](
private type CoGroupValue = (Any, Int) // Int is dependency number
private type CoGroupCombiner = Array[CoGroup]
- private var serializer: Option[Serializer] = None
+ private var serializer: Serializer = SparkEnv.get.serializer
/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): CoGroupedRDD[K] = {
- this.serializer = Option(serializer)
+ this.serializer = serializer
this
}
@@ -154,8 +155,7 @@ class CoGroupedRDD[K: ClassTag](
}
context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled)
- context.internalMetricsToAccumulators(
- InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes)
+ context.taskMetrics().incPeakExecutionMemory(map.peakMemoryUsedBytes)
new InterruptibleIterator(context,
map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]])
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
index 90d9735cb3f6..2cba1febe875 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
@@ -70,23 +70,27 @@ private[spark] case class CoalescedRDDPartition(
* parent partitions
* @param prev RDD to be coalesced
* @param maxPartitions number of desired partitions in the coalesced RDD (must be positive)
- * @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance
+ * @param partitionCoalescer [[PartitionCoalescer]] implementation to use for coalescing
*/
private[spark] class CoalescedRDD[T: ClassTag](
@transient var prev: RDD[T],
maxPartitions: Int,
- balanceSlack: Double = 0.10)
+ partitionCoalescer: Option[PartitionCoalescer] = None)
extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies
require(maxPartitions > 0 || maxPartitions == prev.partitions.length,
s"Number of partitions ($maxPartitions) must be positive.")
+ if (partitionCoalescer.isDefined) {
+ require(partitionCoalescer.get.isInstanceOf[Serializable],
+ "The partition coalescer passed in must be serializable.")
+ }
override def getPartitions: Array[Partition] = {
- val pc = new PartitionCoalescer(maxPartitions, prev, balanceSlack)
+ val pc = partitionCoalescer.getOrElse(new DefaultPartitionCoalescer())
- pc.run().zipWithIndex.map {
+ pc.coalesce(maxPartitions, prev).zipWithIndex.map {
case (pg, i) =>
- val ids = pg.arr.map(_.index).toArray
+ val ids = pg.partitions.map(_.index).toArray
new CoalescedRDDPartition(i, prev, ids, pg.prefLoc)
}
}
@@ -144,15 +148,15 @@ private[spark] class CoalescedRDD[T: ClassTag](
* desired partitions is greater than the number of preferred machines (can happen), it needs to
* start picking duplicate preferred machines. This is determined using coupon collector estimation
* (2n log(n)). The load balancing is done using power-of-two randomized bins-balls with one twist:
- * it tries to also achieve locality. This is done by allowing a slack (balanceSlack) between two
- * bins. If two bins are within the slack in terms of balance, the algorithm will assign partitions
- * according to locality. (contact alig for questions)
- *
+ * it tries to also achieve locality. This is done by allowing a slack (balanceSlack, where
+ * 1.0 is all locality, 0 is all balance) between two bins. If two bins are within the slack
+ * in terms of balance, the algorithm will assign partitions according to locality.
+ * (contact alig for questions)
*/
-private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: Double) {
-
- def compare(o1: PartitionGroup, o2: PartitionGroup): Boolean = o1.size < o2.size
+private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10)
+ extends PartitionCoalescer {
+ def compare(o1: PartitionGroup, o2: PartitionGroup): Boolean = o1.numPartitions < o2.numPartitions
def compare(o1: Option[PartitionGroup], o2: Option[PartitionGroup]): Boolean =
if (o1 == None) false else if (o2 == None) true else compare(o1.get, o2.get)
@@ -167,47 +171,43 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack:
// hash used for the first maxPartitions (to avoid duplicates)
val initialHash = mutable.Set[Partition]()
- // determines the tradeoff between load-balancing the partitions sizes and their locality
- // e.g. balanceSlack=0.10 means that it allows up to 10% imbalance in favor of locality
- val slack = (balanceSlack * prev.partitions.length).toInt
-
var noLocality = true // if true if no preferredLocations exists for parent RDD
// gets the *current* preferred locations from the DAGScheduler (as opposed to the static ones)
- def currPrefLocs(part: Partition): Seq[String] = {
+ def currPrefLocs(part: Partition, prev: RDD[_]): Seq[String] = {
prev.context.getPreferredLocs(prev, part.index).map(tl => tl.host)
}
- // this class just keeps iterating and rotating infinitely over the partitions of the RDD
- // next() returns the next preferred machine that a partition is replicated on
- // the rotator first goes through the first replica copy of each partition, then second, third
- // the iterators return type is a tuple: (replicaString, partition)
- class LocationIterator(prev: RDD[_]) extends Iterator[(String, Partition)] {
-
- var it: Iterator[(String, Partition)] = resetIterator()
-
- override val isEmpty = !it.hasNext
-
- // initializes/resets to start iterating from the beginning
- def resetIterator(): Iterator[(String, Partition)] = {
- val iterators = (0 to 2).map( x =>
- prev.partitions.iterator.flatMap(p => {
- if (currPrefLocs(p).size > x) Some((currPrefLocs(p)(x), p)) else None
- } )
+ class PartitionLocations(prev: RDD[_]) {
+
+ // contains all the partitions from the previous RDD that don't have preferred locations
+ val partsWithoutLocs = ArrayBuffer[Partition]()
+ // contains all the partitions from the previous RDD that have preferred locations
+ val partsWithLocs = ArrayBuffer[(String, Partition)]()
+
+ getAllPrefLocs(prev)
+
+ // gets all the preferred locations of the previous RDD and splits them into partitions
+ // with preferred locations and ones without
+ def getAllPrefLocs(prev: RDD[_]): Unit = {
+ val tmpPartsWithLocs = mutable.LinkedHashMap[Partition, Seq[String]]()
+ // first get the locations for each partition, only do this once since it can be expensive
+ prev.partitions.foreach(p => {
+ val locs = currPrefLocs(p, prev)
+ if (locs.nonEmpty) {
+ tmpPartsWithLocs.put(p, locs)
+ } else {
+ partsWithoutLocs += p
+ }
+ }
)
- iterators.reduceLeft((x, y) => x ++ y)
- }
-
- // hasNext() is false iff there are no preferredLocations for any of the partitions of the RDD
- override def hasNext: Boolean = { !isEmpty }
-
- // return the next preferredLocation of some partition of the RDD
- override def next(): (String, Partition) = {
- if (it.hasNext) {
- it.next()
- } else {
- it = resetIterator() // ran out of preferred locations, reset and rotate to the beginning
- it.next()
+ // convert it into an array of host to partition
+ for (x <- 0 to 2) {
+ tmpPartsWithLocs.foreach { parts =>
+ val p = parts._1
+ val locs = parts._2
+ if (locs.size > x) partsWithLocs += ((locs(x), p))
+ }
}
}
}
@@ -215,8 +215,9 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack:
/**
* Sorts and gets the least element of the list associated with key in groupHash
* The returned PartitionGroup is the least loaded of all groups that represent the machine "key"
+ *
* @param key string representing a partitioned group on preferred machine key
- * @return Option of PartitionGroup that has least elements for key
+ * @return Option of [[PartitionGroup]] that has least elements for key
*/
def getLeastGroupHash(key: String): Option[PartitionGroup] = {
groupHash.get(key).map(_.sortWith(compare).head)
@@ -224,78 +225,91 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack:
def addPartToPGroup(part: Partition, pgroup: PartitionGroup): Boolean = {
if (!initialHash.contains(part)) {
- pgroup.arr += part // already assign this element
+ pgroup.partitions += part // already assign this element
initialHash += part // needed to avoid assigning partitions to multiple buckets
true
} else { false }
}
/**
- * Initializes targetLen partition groups and assigns a preferredLocation
- * This uses coupon collector to estimate how many preferredLocations it must rotate through
- * until it has seen most of the preferred locations (2 * n log(n))
+ * Initializes targetLen partition groups. If there are preferred locations, each group
+ * is assigned a preferredLocation. This uses coupon collector to estimate how many
+ * preferredLocations it must rotate through until it has seen most of the preferred
+ * locations (2 * n log(n))
* @param targetLen
*/
- def setupGroups(targetLen: Int) {
- val rotIt = new LocationIterator(prev)
-
+ def setupGroups(targetLen: Int, partitionLocs: PartitionLocations) {
// deal with empty case, just create targetLen partition groups with no preferred location
- if (!rotIt.hasNext) {
- (1 to targetLen).foreach(x => groupArr += PartitionGroup())
+ if (partitionLocs.partsWithLocs.isEmpty) {
+ (1 to targetLen).foreach(x => groupArr += new PartitionGroup())
return
}
noLocality = false
-
// number of iterations needed to be certain that we've seen most preferred locations
val expectedCoupons2 = 2 * (math.log(targetLen)*targetLen + targetLen + 0.5).toInt
var numCreated = 0
var tries = 0
// rotate through until either targetLen unique/distinct preferred locations have been created
- // OR we've rotated expectedCoupons2, in which case we have likely seen all preferred locations,
- // i.e. likely targetLen >> number of preferred locations (more buckets than there are machines)
- while (numCreated < targetLen && tries < expectedCoupons2) {
+ // OR (we have went through either all partitions OR we've rotated expectedCoupons2 - in
+ // which case we have likely seen all preferred locations)
+ val numPartsToLookAt = math.min(expectedCoupons2, partitionLocs.partsWithLocs.length)
+ while (numCreated < targetLen && tries < numPartsToLookAt) {
+ val (nxt_replica, nxt_part) = partitionLocs.partsWithLocs(tries)
tries += 1
- val (nxt_replica, nxt_part) = rotIt.next()
if (!groupHash.contains(nxt_replica)) {
- val pgroup = PartitionGroup(nxt_replica)
+ val pgroup = new PartitionGroup(Some(nxt_replica))
groupArr += pgroup
addPartToPGroup(nxt_part, pgroup)
groupHash.put(nxt_replica, ArrayBuffer(pgroup)) // list in case we have multiple
numCreated += 1
}
}
-
- while (numCreated < targetLen) { // if we don't have enough partition groups, create duplicates
- var (nxt_replica, nxt_part) = rotIt.next()
- val pgroup = PartitionGroup(nxt_replica)
+ tries = 0
+ // if we don't have enough partition groups, create duplicates
+ while (numCreated < targetLen) {
+ var (nxt_replica, nxt_part) = partitionLocs.partsWithLocs(tries)
+ tries += 1
+ val pgroup = new PartitionGroup(Some(nxt_replica))
groupArr += pgroup
groupHash.getOrElseUpdate(nxt_replica, ArrayBuffer()) += pgroup
- var tries = 0
- while (!addPartToPGroup(nxt_part, pgroup) && tries < targetLen) { // ensure at least one part
- nxt_part = rotIt.next()._2
- tries += 1
- }
+ addPartToPGroup(nxt_part, pgroup)
numCreated += 1
+ if (tries >= partitionLocs.partsWithLocs.length) tries = 0
}
-
}
/**
* Takes a parent RDD partition and decides which of the partition groups to put it in
* Takes locality into account, but also uses power of 2 choices to load balance
- * It strikes a balance between the two use the balanceSlack variable
+ * It strikes a balance between the two using the balanceSlack variable
* @param p partition (ball to be thrown)
+ * @param balanceSlack determines the trade-off between load-balancing the partitions sizes and
+ * their locality. e.g., balanceSlack=0.10 means that it allows up to 10%
+ * imbalance in favor of locality
* @return partition group (bin to be put in)
*/
- def pickBin(p: Partition): PartitionGroup = {
- val pref = currPrefLocs(p).map(getLeastGroupHash(_)).sortWith(compare) // least loaded pref locs
+ def pickBin(
+ p: Partition,
+ prev: RDD[_],
+ balanceSlack: Double,
+ partitionLocs: PartitionLocations): PartitionGroup = {
+ val slack = (balanceSlack * prev.partitions.length).toInt
+ // least loaded pref locs
+ val pref = currPrefLocs(p, prev).map(getLeastGroupHash(_)).sortWith(compare)
val prefPart = if (pref == Nil) None else pref.head
val r1 = rnd.nextInt(groupArr.size)
val r2 = rnd.nextInt(groupArr.size)
- val minPowerOfTwo = if (groupArr(r1).size < groupArr(r2).size) groupArr(r1) else groupArr(r2)
+ val minPowerOfTwo = {
+ if (groupArr(r1).numPartitions < groupArr(r2).numPartitions) {
+ groupArr(r1)
+ }
+ else {
+ groupArr(r2)
+ }
+ }
if (prefPart.isEmpty) {
// if no preferred locations, just use basic power of two
return minPowerOfTwo
@@ -303,55 +317,82 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack:
val prefPartActual = prefPart.get
- if (minPowerOfTwo.size + slack <= prefPartActual.size) { // more imbalance than the slack allows
+ // more imbalance than the slack allows
+ if (minPowerOfTwo.numPartitions + slack <= prefPartActual.numPartitions) {
minPowerOfTwo // prefer balance over locality
} else {
prefPartActual // prefer locality over balance
}
}
- def throwBalls() {
+ def throwBalls(
+ maxPartitions: Int,
+ prev: RDD[_],
+ balanceSlack: Double, partitionLocs: PartitionLocations) {
if (noLocality) { // no preferredLocations in parent RDD, no randomization needed
if (maxPartitions > groupArr.size) { // just return prev.partitions
for ((p, i) <- prev.partitions.zipWithIndex) {
- groupArr(i).arr += p
+ groupArr(i).partitions += p
}
} else { // no locality available, then simply split partitions based on positions in array
for (i <- 0 until maxPartitions) {
val rangeStart = ((i.toLong * prev.partitions.length) / maxPartitions).toInt
val rangeEnd = (((i.toLong + 1) * prev.partitions.length) / maxPartitions).toInt
- (rangeStart until rangeEnd).foreach{ j => groupArr(i).arr += prev.partitions(j) }
+ (rangeStart until rangeEnd).foreach{ j => groupArr(i).partitions += prev.partitions(j) }
}
}
} else {
+ // It is possible to have unionRDD where one rdd has preferred locations and another rdd
+ // that doesn't. To make sure we end up with the requested number of partitions,
+ // make sure to put a partition in every group.
+
+ // if we don't have a partition assigned to every group first try to fill them
+ // with the partitions with preferred locations
+ val partIter = partitionLocs.partsWithLocs.iterator
+ groupArr.filter(pg => pg.numPartitions == 0).foreach { pg =>
+ while (partIter.hasNext && pg.numPartitions == 0) {
+ var (nxt_replica, nxt_part) = partIter.next()
+ if (!initialHash.contains(nxt_part)) {
+ pg.partitions += nxt_part
+ initialHash += nxt_part
+ }
+ }
+ }
+
+ // if we didn't get one partitions per group from partitions with preferred locations
+ // use partitions without preferred locations
+ val partNoLocIter = partitionLocs.partsWithoutLocs.iterator
+ groupArr.filter(pg => pg.numPartitions == 0).foreach { pg =>
+ while (partNoLocIter.hasNext && pg.numPartitions == 0) {
+ var nxt_part = partNoLocIter.next()
+ if (!initialHash.contains(nxt_part)) {
+ pg.partitions += nxt_part
+ initialHash += nxt_part
+ }
+ }
+ }
+
+ // finally pick bin for the rest
for (p <- prev.partitions if (!initialHash.contains(p))) { // throw every partition into group
- pickBin(p).arr += p
+ pickBin(p, prev, balanceSlack, partitionLocs).partitions += p
}
}
}
- def getPartitions: Array[PartitionGroup] = groupArr.filter( pg => pg.size > 0).toArray
+ def getPartitions: Array[PartitionGroup] = groupArr.filter( pg => pg.numPartitions > 0).toArray
/**
* Runs the packing algorithm and returns an array of PartitionGroups that if possible are
* load balanced and grouped by locality
- * @return array of partition groups
+ *
+ * @return array of partition groups
*/
- def run(): Array[PartitionGroup] = {
- setupGroups(math.min(prev.partitions.length, maxPartitions)) // setup the groups (bins)
- throwBalls() // assign partitions (balls) to each group (bins)
+ def coalesce(maxPartitions: Int, prev: RDD[_]): Array[PartitionGroup] = {
+ val partitionLocs = new PartitionLocations(prev)
+ // setup the groups (bins)
+ setupGroups(math.min(prev.partitions.length, maxPartitions), partitionLocs)
+ // assign partitions (balls) to each group (bins)
+ throwBalls(maxPartitions, prev, balanceSlack, partitionLocs)
getPartitions
}
}
-
-private case class PartitionGroup(prefLoc: Option[String] = None) {
- var arr = mutable.ArrayBuffer[Partition]()
- def size: Int = arr.size
-}
-
-private object PartitionGroup {
- def apply(prefLoc: String): PartitionGroup = {
- require(prefLoc != "", "Preferred location must not be empty")
- PartitionGroup(Some(prefLoc))
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
index 7fbaadcea3a3..14331dfd0c98 100644
--- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
@@ -17,8 +17,9 @@
package org.apache.spark.rdd
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.{TaskContext, Logging}
+import org.apache.spark.annotation.Since
+import org.apache.spark.TaskContext
+import org.apache.spark.internal.Logging
import org.apache.spark.partial.BoundedDouble
import org.apache.spark.partial.MeanEvaluator
import org.apache.spark.partial.PartialResult
@@ -47,12 +48,12 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
stats().mean
}
- /** Compute the variance of this RDD's elements. */
+ /** Compute the population variance of this RDD's elements. */
def variance(): Double = self.withScope {
stats().variance
}
- /** Compute the standard deviation of this RDD's elements. */
+ /** Compute the population standard deviation of this RDD's elements. */
def stdev(): Double = self.withScope {
stats().stdev
}
@@ -73,6 +74,22 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
stats().sampleVariance
}
+ /**
+ * Compute the population standard deviation of this RDD's elements.
+ */
+ @Since("2.1.0")
+ def popStdev(): Double = self.withScope {
+ stats().popStdev
+ }
+
+ /**
+ * Compute the population variance of this RDD's elements.
+ */
+ @Since("2.1.0")
+ def popVariance(): Double = self.withScope {
+ stats().popVariance
+ }
+
/**
* Approximate operation to return the mean within a timeout.
*/
@@ -103,7 +120,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
* If the RDD contains infinity, NaN throws an exception
* If the elements in RDD do not vary (max == min) always returns a single bucket.
*/
- def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = self.withScope {
+ def histogram(bucketCount: Int): (Array[Double], Array[Long]) = self.withScope {
// Scala's built-in range has issues. See #SI-8782
def customRange(min: Double, max: Double, steps: Int): IndexedSeq[Double] = {
val span = max - min
@@ -112,7 +129,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
// Compute the minimum and the maximum
val (max: Double, min: Double) = self.mapPartitions { items =>
Iterator(items.foldRight(Double.NegativeInfinity,
- Double.PositiveInfinity)((e: Double, x: Pair[Double, Double]) =>
+ Double.PositiveInfinity)((e: Double, x: (Double, Double)) =>
(x._1.max(e), x._2.min(e))))
}.reduce { (maxmin1, maxmin2) =>
(maxmin1._1.max(maxmin2._1), maxmin1._2.min(maxmin2._2))
@@ -135,14 +152,14 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
/**
* Compute a histogram using the provided buckets. The buckets are all open
- * to the right except for the last which is closed
+ * to the right except for the last which is closed.
* e.g. for the array
* [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50]
- * e.g 1<=x<10 , 10<=x<20, 20<=x<=50
+ * e.g {@code <=x<10, 10<=x<20, 20<=x<=50}
* And on the input of 1 and 50 we would have a histogram of 1, 0, 1
*
- * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched
- * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets
+ * @note If your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched
+ * from an O(log n) insertion to O(1) per element. (where n = # buckets) if you set evenBuckets
* to true.
* buckets must be sorted and not contain any duplicates.
* buckets array must be at least two elements
@@ -166,8 +183,8 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
val counters = new Array[Long](buckets.length - 1)
while (iter.hasNext) {
bucketFunction(iter.next()) match {
- case Some(x: Int) => {counters(x) += 1}
- case _ => {}
+ case Some(x: Int) => counters(x) += 1
+ case _ => // No-Op
}
}
Iterator(counters)
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index d841f05ec52c..23b344230e49 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -17,48 +17,41 @@
package org.apache.spark.rdd
+import java.io.IOException
import java.text.SimpleDateFormat
-import java.util.Date
-import java.io.EOFException
+import java.util.{Date, Locale}
import scala.collection.immutable.Map
import scala.reflect.ClassTag
-import scala.collection.mutable.ListBuffer
import org.apache.hadoop.conf.{Configurable, Configuration}
-import org.apache.hadoop.mapred.FileSplit
-import org.apache.hadoop.mapred.InputFormat
-import org.apache.hadoop.mapred.InputSplit
-import org.apache.hadoop.mapred.JobConf
-import org.apache.hadoop.mapred.RecordReader
-import org.apache.hadoop.mapred.Reporter
-import org.apache.hadoop.mapred.JobID
-import org.apache.hadoop.mapred.TaskAttemptID
-import org.apache.hadoop.mapred.TaskID
+import org.apache.hadoop.mapred._
import org.apache.hadoop.mapred.lib.CombineFileSplit
+import org.apache.hadoop.mapreduce.TaskType
import org.apache.hadoop.util.ReflectionUtils
import org.apache.spark._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.executor.DataReadMethod
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.IGNORE_CORRUPT_FILES
import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD
-import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, NextIterator, Utils}
-import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation}
+import org.apache.spark.scheduler.{HDFSCacheTaskLocation, HostTaskLocation}
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.{NextIterator, SerializableConfiguration, ShutdownHookManager}
/**
* A Spark split class that wraps around a Hadoop InputSplit.
*/
-private[spark] class HadoopPartition(rddId: Int, idx: Int, s: InputSplit)
+private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: InputSplit)
extends Partition {
val inputSplit = new SerializableWritable[InputSplit](s)
- override def hashCode(): Int = 41 * (41 + rddId) + idx
+ override def hashCode(): Int = 31 * (31 + rddId) + index
- override val index: Int = idx
+ override def equals(other: Any): Boolean = super.equals(other)
/**
* Get any environment variables that should be added to the users environment when running pipes
@@ -68,7 +61,7 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, s: InputSplit)
val envVars: Map[String, String] = if (inputSplit.value.isInstanceOf[FileSplit]) {
val is: FileSplit = inputSplit.value.asInstanceOf[FileSplit]
// map_input_file is deprecated in favor of mapreduce_map_input_file but set both
- // since its not removed yet
+ // since it's not removed yet
Map("map_input_file" -> is.getPath().toString(),
"mapreduce_map_input_file" -> is.getPath().toString())
} else {
@@ -83,19 +76,19 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, s: InputSplit)
* An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS,
* sources in HBase, or S3), using the older MapReduce API (`org.apache.hadoop.mapred`).
*
- * Note: Instantiating this class directly is not recommended, please use
- * [[org.apache.spark.SparkContext.hadoopRDD()]]
- *
* @param sc The SparkContext to associate the RDD with.
* @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed
- * variabe references an instance of JobConf, then that JobConf will be used for the Hadoop job.
- * Otherwise, a new JobConf will be created on each slave using the enclosed Configuration.
+ * variable references an instance of JobConf, then that JobConf will be used for the Hadoop job.
+ * Otherwise, a new JobConf will be created on each slave using the enclosed Configuration.
* @param initLocalJobConfFuncOpt Optional closure used to initialize any JobConf that HadoopRDD
* creates.
* @param inputFormatClass Storage format of the data to be read.
* @param keyClass Class of the key associated with the inputFormatClass.
* @param valueClass Class of the value associated with the inputFormatClass.
* @param minPartitions Minimum number of HadoopRDD partitions (Hadoop Splits) to generate.
+ *
+ * @note Instantiating this class directly is not recommended, please use
+ * `org.apache.spark.SparkContext.hadoopRDD()`
*/
@DeveloperApi
class HadoopRDD[K, V](
@@ -123,22 +116,24 @@ class HadoopRDD[K, V](
sc,
sc.broadcast(new SerializableConfiguration(conf))
.asInstanceOf[Broadcast[SerializableConfiguration]],
- None /* initLocalJobConfFuncOpt */,
+ initLocalJobConfFuncOpt = None,
inputFormatClass,
keyClass,
valueClass,
minPartitions)
}
- protected val jobConfCacheKey = "rdd_%d_job_conf".format(id)
+ protected val jobConfCacheKey: String = "rdd_%d_job_conf".format(id)
- protected val inputFormatCacheKey = "rdd_%d_input_format".format(id)
+ protected val inputFormatCacheKey: String = "rdd_%d_input_format".format(id)
// used to build JobTracker ID
private val createTime = new Date()
private val shouldCloneJobConf = sparkContext.conf.getBoolean("spark.hadoop.cloneConf", false)
+ private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES)
+
// Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads.
protected def getJobConf(): JobConf = {
val conf: Configuration = broadcastedConf.value.value
@@ -154,7 +149,7 @@ class HadoopRDD[K, V](
logDebug("Cloning Hadoop Configuration")
val newJobConf = new JobConf(conf)
if (!conf.isInstanceOf[JobConf]) {
- initLocalJobConfFuncOpt.map(f => f(newJobConf))
+ initLocalJobConfFuncOpt.foreach(f => f(newJobConf))
}
newJobConf
}
@@ -162,20 +157,25 @@ class HadoopRDD[K, V](
if (conf.isInstanceOf[JobConf]) {
logDebug("Re-using user-broadcasted JobConf")
conf.asInstanceOf[JobConf]
- } else if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) {
- logDebug("Re-using cached JobConf")
- HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf]
} else {
- // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the
- // local process. The local cache is accessed through HadoopRDD.putCachedMetadata().
- // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects.
- // Synchronize to prevent ConcurrentModificationException (SPARK-1097, HADOOP-10456).
- HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized {
- logDebug("Creating new JobConf and caching it for later re-use")
- val newJobConf = new JobConf(conf)
- initLocalJobConfFuncOpt.map(f => f(newJobConf))
- HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf)
- newJobConf
+ Option(HadoopRDD.getCachedMetadata(jobConfCacheKey))
+ .map { conf =>
+ logDebug("Re-using cached JobConf")
+ conf.asInstanceOf[JobConf]
+ }
+ .getOrElse {
+ // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in
+ // the local process. The local cache is accessed through HadoopRDD.putCachedMetadata().
+ // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary
+ // objects. Synchronize to prevent ConcurrentModificationException (SPARK-1097,
+ // HADOOP-10456).
+ HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized {
+ logDebug("Creating new JobConf and caching it for later re-use")
+ val newJobConf = new JobConf(conf)
+ initLocalJobConfFuncOpt.foreach(f => f(newJobConf))
+ HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf)
+ newJobConf
+ }
}
}
}
@@ -184,8 +184,9 @@ class HadoopRDD[K, V](
protected def getInputFormat(conf: JobConf): InputFormat[K, V] = {
val newInputFormat = ReflectionUtils.newInstance(inputFormatClass.asInstanceOf[Class[_]], conf)
.asInstanceOf[InputFormat[K, V]]
- if (newInputFormat.isInstanceOf[Configurable]) {
- newInputFormat.asInstanceOf[Configurable].setConf(conf)
+ newInputFormat match {
+ case c: Configurable => c.setConf(conf)
+ case _ =>
}
newInputFormat
}
@@ -195,9 +196,6 @@ class HadoopRDD[K, V](
// add the credentials here as this can be called before SparkContext initialized
SparkHadoopUtil.get.addCredentials(jobConf)
val inputFormat = getInputFormat(jobConf)
- if (inputFormat.isInstanceOf[Configurable]) {
- inputFormat.asInstanceOf[Configurable].setConf(jobConf)
- }
val inputSplits = inputFormat.getSplits(jobConf, minPartitions)
val array = new Array[Partition](inputSplits.size)
for (i <- 0 until inputSplits.size) {
@@ -209,53 +207,85 @@ class HadoopRDD[K, V](
override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = {
val iter = new NextIterator[(K, V)] {
- val split = theSplit.asInstanceOf[HadoopPartition]
+ private val split = theSplit.asInstanceOf[HadoopPartition]
logInfo("Input split: " + split.inputSplit)
- val jobConf = getJobConf()
+ private val jobConf = getJobConf()
- val inputMetrics = context.taskMetrics.getInputMetricsForReadMethod(DataReadMethod.Hadoop)
+ private val inputMetrics = context.taskMetrics().inputMetrics
+ private val existingBytesRead = inputMetrics.bytesRead
+
+ // Sets InputFileBlockHolder for the file block's information
+ split.inputSplit.value match {
+ case fs: FileSplit =>
+ InputFileBlockHolder.set(fs.getPath.toString, fs.getStart, fs.getLength)
+ case _ =>
+ InputFileBlockHolder.unset()
+ }
// Find a function that will return the FileSystem bytes read by this thread. Do this before
// creating RecordReader, because RecordReader's constructor might read some bytes
- val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
- split.inputSplit.value match {
- case _: FileSplit | _: CombineFileSplit =>
- SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
- case _ => None
+ private val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match {
+ case _: FileSplit | _: CombineFileSplit =>
+ Some(SparkHadoopUtil.get.getFSBytesReadOnThreadCallback())
+ case _ => None
+ }
+
+ // We get our input bytes from thread-local Hadoop FileSystem statistics.
+ // If we do a coalesce, however, we are likely to compute multiple partitions in the same
+ // task and in the same thread, in which case we need to avoid override values written by
+ // previous partitions (SPARK-13071).
+ private def updateBytesRead(): Unit = {
+ getBytesReadCallback.foreach { getBytesRead =>
+ inputMetrics.setBytesRead(existingBytesRead + getBytesRead())
}
}
- inputMetrics.setBytesReadCallback(bytesReadCallback)
- var reader: RecordReader[K, V] = null
- val inputFormat = getInputFormat(jobConf)
- HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime),
+ private var reader: RecordReader[K, V] = null
+ private val inputFormat = getInputFormat(jobConf)
+ HadoopRDD.addLocalConfiguration(
+ new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(createTime),
context.stageId, theSplit.index, context.attemptNumber, jobConf)
- reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
+ reader =
+ try {
+ inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
+ } catch {
+ case e: IOException if ignoreCorruptFiles =>
+ logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e)
+ finished = true
+ null
+ }
// Register an on-task-completion callback to close the input stream.
- context.addTaskCompletionListener{ context => closeIfNeeded() }
- val key: K = reader.createKey()
- val value: V = reader.createValue()
+ context.addTaskCompletionListener { context =>
+ // Update the bytes read before closing is to make sure lingering bytesRead statistics in
+ // this thread get correctly added.
+ updateBytesRead()
+ closeIfNeeded()
+ }
+
+ private val key: K = if (reader == null) null.asInstanceOf[K] else reader.createKey()
+ private val value: V = if (reader == null) null.asInstanceOf[V] else reader.createValue()
override def getNext(): (K, V) = {
try {
finished = !reader.next(key, value)
} catch {
- case eof: EOFException =>
+ case e: IOException if ignoreCorruptFiles =>
+ logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e)
finished = true
}
if (!finished) {
inputMetrics.incRecordsRead(1)
}
+ if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) {
+ updateBytesRead()
+ }
(key, value)
}
- override def close() {
+ override def close(): Unit = {
if (reader != null) {
- // Close the reader and release it. Note: it's very important that we don't close the
- // reader more than once, since that exposes us to MAPREDUCE-5918 when running against
- // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic
- // corruption issues when reading compressed input.
+ InputFileBlockHolder.unset()
try {
reader.close()
} catch {
@@ -266,8 +296,8 @@ class HadoopRDD[K, V](
} finally {
reader = null
}
- if (bytesReadCallback.isDefined) {
- inputMetrics.updateBytesRead()
+ if (getBytesReadCallback.isDefined) {
+ updateBytesRead()
} else if (split.inputSplit.value.isInstanceOf[FileSplit] ||
split.inputSplit.value.isInstanceOf[CombineFileSplit]) {
// If we can't get the bytes read from the FS stats, fall back to the split size,
@@ -295,18 +325,10 @@ class HadoopRDD[K, V](
override def getPreferredLocations(split: Partition): Seq[String] = {
val hsplit = split.asInstanceOf[HadoopPartition].inputSplit.value
- val locs: Option[Seq[String]] = HadoopRDD.SPLIT_INFO_REFLECTIONS match {
- case Some(c) =>
- try {
- val lsplit = c.inputSplitWithLocationInfo.cast(hsplit)
- val infos = c.getLocationInfo.invoke(lsplit).asInstanceOf[Array[AnyRef]]
- Some(HadoopRDD.convertSplitLocationInfo(infos))
- } catch {
- case e: Exception =>
- logDebug("Failed to use InputSplitWithLocations.", e)
- None
- }
- case None => None
+ val locs = hsplit match {
+ case lsplit: InputSplitWithLocationInfo =>
+ HadoopRDD.convertSplitLocationInfo(lsplit.getLocationInfo)
+ case _ => None
}
locs.getOrElse(hsplit.getLocations.filter(_ != "localhost"))
}
@@ -317,7 +339,7 @@ class HadoopRDD[K, V](
override def persist(storageLevel: StorageLevel): this.type = {
if (storageLevel.deserialized) {
- logWarning("Caching NewHadoopRDDs as deserialized objects usually leads to undesired" +
+ logWarning("Caching HadoopRDDs as deserialized objects usually leads to undesired" +
" behavior because Hadoop's RecordReader reuses the same Writable object for all records." +
" Use a map transformation to make copies of the records.")
}
@@ -343,8 +365,6 @@ private[spark] object HadoopRDD extends Logging {
*/
def getCachedMetadata(key: String): Any = SparkEnv.get.hadoopJobMetadata.get(key)
- def containsCachedMetadata(key: String): Boolean = SparkEnv.get.hadoopJobMetadata.containsKey(key)
-
private def putCachedMetadata(key: String, value: Any): Unit =
SparkEnv.get.hadoopJobMetadata.put(key, value)
@@ -352,13 +372,13 @@ private[spark] object HadoopRDD extends Logging {
def addLocalConfiguration(jobTrackerId: String, jobId: Int, splitId: Int, attemptId: Int,
conf: JobConf) {
val jobID = new JobID(jobTrackerId, jobId)
- val taId = new TaskAttemptID(new TaskID(jobID, true, splitId), attemptId)
+ val taId = new TaskAttemptID(new TaskID(jobID, TaskType.MAP, splitId), attemptId)
- conf.set("mapred.tip.id", taId.getTaskID.toString)
- conf.set("mapred.task.id", taId.toString)
- conf.setBoolean("mapred.task.is.map", true)
- conf.setInt("mapred.task.partition", splitId)
- conf.set("mapred.job.id", jobID.toString)
+ conf.set("mapreduce.task.id", taId.getTaskID.toString)
+ conf.set("mapreduce.task.attempt.id", taId.toString)
+ conf.setBoolean("mapreduce.task.ismap", true)
+ conf.setInt("mapreduce.task.partition", splitId)
+ conf.set("mapreduce.job.id", jobID.toString)
}
/**
@@ -382,41 +402,20 @@ private[spark] object HadoopRDD extends Logging {
}
}
- private[spark] class SplitInfoReflections {
- val inputSplitWithLocationInfo =
- Utils.classForName("org.apache.hadoop.mapred.InputSplitWithLocationInfo")
- val getLocationInfo = inputSplitWithLocationInfo.getMethod("getLocationInfo")
- val newInputSplit = Utils.classForName("org.apache.hadoop.mapreduce.InputSplit")
- val newGetLocationInfo = newInputSplit.getMethod("getLocationInfo")
- val splitLocationInfo = Utils.classForName("org.apache.hadoop.mapred.SplitLocationInfo")
- val isInMemory = splitLocationInfo.getMethod("isInMemory")
- val getLocation = splitLocationInfo.getMethod("getLocation")
- }
-
- private[spark] val SPLIT_INFO_REFLECTIONS: Option[SplitInfoReflections] = try {
- Some(new SplitInfoReflections)
- } catch {
- case e: Exception =>
- logDebug("SplitLocationInfo and other new Hadoop classes are " +
- "unavailable. Using the older Hadoop location info code.", e)
- None
- }
-
- private[spark] def convertSplitLocationInfo(infos: Array[AnyRef]): Seq[String] = {
- val out = ListBuffer[String]()
- infos.foreach { loc => {
- val locationStr = HadoopRDD.SPLIT_INFO_REFLECTIONS.get.
- getLocation.invoke(loc).asInstanceOf[String]
+ private[spark] def convertSplitLocationInfo(
+ infos: Array[SplitLocationInfo]): Option[Seq[String]] = {
+ Option(infos).map(_.flatMap { loc =>
+ val locationStr = loc.getLocation
if (locationStr != "localhost") {
- if (HadoopRDD.SPLIT_INFO_REFLECTIONS.get.isInMemory.
- invoke(loc).asInstanceOf[Boolean]) {
- logDebug("Partition " + locationStr + " is cached by Hadoop.")
- out += new HDFSCacheTaskLocation(locationStr).toString
+ if (loc.isInMemory) {
+ logDebug(s"Partition $locationStr is cached by Hadoop.")
+ Some(HDFSCacheTaskLocation(locationStr).toString)
} else {
- out += new HostTaskLocation(locationStr).toString
+ Some(HostTaskLocation(locationStr).toString)
}
+ } else {
+ None
}
- }}
- out.seq
+ })
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala
new file mode 100644
index 000000000000..ff2f58d81142
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * This holds file names of the current Spark task. This is used in HadoopRDD,
+ * FileScanRDD, NewHadoopRDD and InputFileName function in Spark SQL.
+ */
+private[spark] object InputFileBlockHolder {
+ /**
+ * A wrapper around some input file information.
+ *
+ * @param filePath path of the file read, or empty string if not available.
+ * @param startOffset starting offset, in bytes, or -1 if not available.
+ * @param length size of the block, in bytes, or -1 if not available.
+ */
+ private class FileBlock(val filePath: UTF8String, val startOffset: Long, val length: Long) {
+ def this() {
+ this(UTF8String.fromString(""), -1, -1)
+ }
+ }
+
+ /**
+ * The thread variable for the name of the current file being read. This is used by
+ * the InputFileName function in Spark SQL.
+ */
+ private[this] val inputBlock: InheritableThreadLocal[FileBlock] =
+ new InheritableThreadLocal[FileBlock] {
+ override protected def initialValue(): FileBlock = new FileBlock
+ }
+
+ /**
+ * Returns the holding file name or empty string if it is unknown.
+ */
+ def getInputFilePath: UTF8String = inputBlock.get().filePath
+
+ /**
+ * Returns the starting offset of the block currently being read, or -1 if it is unknown.
+ */
+ def getStartOffset: Long = inputBlock.get().startOffset
+
+ /**
+ * Returns the length of the block being read, or -1 if it is unknown.
+ */
+ def getLength: Long = inputBlock.get().length
+
+ /**
+ * Sets the thread-local input block.
+ */
+ def set(filePath: String, startOffset: Long, length: Long): Unit = {
+ require(filePath != null, "filePath cannot be null")
+ require(startOffset >= 0, s"startOffset ($startOffset) cannot be negative")
+ require(length >= 0, s"length ($length) cannot be negative")
+ inputBlock.set(new FileBlock(UTF8String.fromString(filePath), startOffset, length))
+ }
+
+ /**
+ * Clears the input file block to default value.
+ */
+ def unset(): Unit = inputBlock.remove()
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
index 0c28f045e46e..aab46b8954bf 100644
--- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
@@ -17,15 +17,16 @@
package org.apache.spark.rdd
-import java.sql.{PreparedStatement, Connection, ResultSet}
+import java.sql.{Connection, ResultSet}
import scala.reflect.ClassTag
+import org.apache.spark.{Partition, SparkContext, TaskContext}
+import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.api.java.function.{Function => JFunction}
-import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
+import org.apache.spark.internal.Logging
import org.apache.spark.util.NextIterator
-import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition {
override def index: Int = idx
@@ -33,14 +34,17 @@ private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) e
// TODO: Expose a jdbcRDD function in SparkContext and mark this as semi-private
/**
- * An RDD that executes an SQL query on a JDBC connection and reads results.
+ * An RDD that executes a SQL query on a JDBC connection and reads results.
* For usage example, see test case JdbcRDDSuite.
*
* @param getConnection a function that returns an open Connection.
* The RDD takes care of closing the connection.
* @param sql the text of the query.
* The query must contain two ? placeholders for parameters used to partition the results.
- * E.g. "select title, author from books where ? <= id and id <= ?"
+ * For example,
+ * {{{
+ * select title, author from books where ? <= id and id <= ?
+ * }}}
* @param lowerBound the minimum value of the first placeholder
* @param upperBound the maximum value of the second placeholder
* The lower and upper bounds are inclusive.
@@ -64,11 +68,11 @@ class JdbcRDD[T: ClassTag](
override def getPartitions: Array[Partition] = {
// bounds are inclusive, hence the + 1 here and - 1 on end
val length = BigInt(1) + upperBound - lowerBound
- (0 until numPartitions).map(i => {
+ (0 until numPartitions).map { i =>
val start = lowerBound + ((i * length) / numPartitions)
val end = lowerBound + (((i + 1) * length) / numPartitions) - 1
new JdbcPartition(i, start.toLong, end.toLong)
- }).toArray
+ }.toArray
}
override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new NextIterator[T]
@@ -78,14 +82,20 @@ class JdbcRDD[T: ClassTag](
val conn = getConnection()
val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
- // setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force streaming results,
- // rather than pulling entire resultset into memory.
- // see http://dev.mysql.com/doc/refman/5.0/en/connector-j-reference-implementation-notes.html
- if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) {
+ val url = conn.getMetaData.getURL
+ if (url.startsWith("jdbc:mysql:")) {
+ // setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force
+ // streaming results, rather than pulling entire resultset into memory.
+ // See the below URL
+ // dev.mysql.com/doc/connector-j/5.1/en/connector-j-reference-implementation-notes.html
+
stmt.setFetchSize(Integer.MIN_VALUE)
- logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ")
+ } else {
+ stmt.setFetchSize(100)
}
+ logInfo(s"statement fetch size set to: ${stmt.getFetchSize}")
+
stmt.setLong(1, part.lower)
stmt.setLong(2, part.upper)
val rs = stmt.executeQuery()
@@ -137,14 +147,17 @@ object JdbcRDD {
}
/**
- * Create an RDD that executes an SQL query on a JDBC connection and reads results.
+ * Create an RDD that executes a SQL query on a JDBC connection and reads results.
* For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
*
* @param connectionFactory a factory that returns an open Connection.
* The RDD takes care of closing the connection.
* @param sql the text of the query.
* The query must contain two ? placeholders for parameters used to partition the results.
- * E.g. "select title, author from books where ? <= id and id <= ?"
+ * For example,
+ * {{{
+ * select title, author from books where ? <= id and id <= ?
+ * }}}
* @param lowerBound the minimum value of the first placeholder
* @param upperBound the maximum value of the second placeholder
* The lower and upper bounds are inclusive.
@@ -177,14 +190,17 @@ object JdbcRDD {
}
/**
- * Create an RDD that executes an SQL query on a JDBC connection and reads results. Each row is
+ * Create an RDD that executes a SQL query on a JDBC connection and reads results. Each row is
* converted into a `Object` array. For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
*
* @param connectionFactory a factory that returns an open Connection.
* The RDD takes care of closing the connection.
* @param sql the text of the query.
* The query must contain two ? placeholders for parameters used to partition the results.
- * E.g. "select title, author from books where ? <= id and id <= ?"
+ * For example,
+ * {{{
+ * select title, author from books where ? <= id and id <= ?
+ * }}}
* @param lowerBound the minimum value of the first placeholder
* @param upperBound the maximum value of the second placeholder
* The lower and upper bounds are inclusive.
diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala
index bfe19195fcd3..503aa0dffc9f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala
@@ -19,7 +19,7 @@ package org.apache.spark.rdd
import scala.reflect.ClassTag
-import org.apache.spark.{Partition, SparkContext, SparkEnv, SparkException, TaskContext}
+import org.apache.spark.{Partition, SparkContext, SparkException, TaskContext}
import org.apache.spark.storage.RDDBlockId
/**
@@ -41,7 +41,7 @@ private[spark] class LocalCheckpointRDD[T: ClassTag](
extends CheckpointRDD[T](sc) {
def this(rdd: RDD[T]) {
- this(rdd.context, rdd.id, rdd.partitions.size)
+ this(rdd.context, rdd.id, rdd.partitions.length)
}
protected override def getPartitions: Array[Partition] = {
diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala
index c115e0ff74d3..56f53714cbe3 100644
--- a/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala
@@ -19,7 +19,8 @@ package org.apache.spark.rdd
import scala.reflect.ClassTag
-import org.apache.spark.{Logging, SparkEnv, SparkException, TaskContext}
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.internal.Logging
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
import org.apache.spark.util.Utils
@@ -72,12 +73,6 @@ private[spark] object LocalRDDCheckpointData {
* This method is idempotent.
*/
def transformStorageLevel(level: StorageLevel): StorageLevel = {
- // If this RDD is to be cached off-heap, fail fast since we cannot provide any
- // correctness guarantees about subsequent computations after the first one
- if (level.useOffHeap) {
- throw new SparkException("Local checkpointing is not compatible with off-heap caching.")
- }
-
StorageLevel(useDisk = true, level.useMemory, level.deserialized, level.replication)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
index 4312d3a41775..15128f0913af 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
@@ -23,11 +23,22 @@ import org.apache.spark.{Partition, TaskContext}
/**
* An RDD that applies the provided function to every partition of the parent RDD.
+ *
+ * @param prev the parent RDD.
+ * @param f The function used to map a tuple of (TaskContext, partition index, input iterator) to
+ * an output iterator.
+ * @param preservesPartitioning Whether the input function preserves the partitioner, which should
+ * be `false` unless `prev` is a pair RDD and the input function
+ * doesn't modify the keys.
+ * @param isOrderSensitive whether or not the function is order-sensitive. If it's order
+ * sensitive, it may return totally different result when the input order
+ * is changed. Mostly stateful functions are order-sensitive.
*/
private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
- prev: RDD[T],
+ var prev: RDD[T],
f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator)
- preservesPartitioning: Boolean = false)
+ preservesPartitioning: Boolean = false,
+ isOrderSensitive: Boolean = false)
extends RDD[U](prev) {
override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
@@ -36,4 +47,17 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
override def compute(split: Partition, context: TaskContext): Iterator[U] =
f(context, split.index, firstParent[T].iterator(split, context))
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ prev = null
+ }
+
+ override protected def getOutputDeterministicLevel = {
+ if (isOrderSensitive && prev.outputDeterministicLevel == DeterministicLevel.UNORDERED) {
+ DeterministicLevel.INDETERMINATE
+ } else {
+ super.getOutputDeterministicLevel
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 9c4b70844bdb..482875e6c1ac 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -17,25 +17,27 @@
package org.apache.spark.rdd
+import java.io.IOException
import java.text.SimpleDateFormat
-import java.util.Date
+import java.util.{Date, Locale}
import scala.reflect.ClassTag
import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
+import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit}
+import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl}
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.input.WholeTextFileInputFormat
import org.apache.spark._
-import org.apache.spark.executor.DataReadMethod
-import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
-import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD
-import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils}
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.IGNORE_CORRUPT_FILES
+import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager}
private[spark] class NewHadoopPartition(
rddId: Int,
@@ -44,7 +46,10 @@ private[spark] class NewHadoopPartition(
extends Partition {
val serializableHadoopSplit = new SerializableWritable(rawSplit)
- override def hashCode(): Int = 41 * (41 + rddId) + index
+
+ override def hashCode(): Int = 31 * (31 + rddId) + index
+
+ override def equals(other: Any): Boolean = super.equals(other)
}
/**
@@ -52,14 +57,13 @@ private[spark] class NewHadoopPartition(
* An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS,
* sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`).
*
- * Note: Instantiating this class directly is not recommended, please use
- * [[org.apache.spark.SparkContext.newAPIHadoopRDD()]]
- *
* @param sc The SparkContext to associate the RDD with.
* @param inputFormatClass Storage format of the data to be read.
* @param keyClass Class of the key associated with the inputFormatClass.
* @param valueClass Class of the value associated with the inputFormatClass.
- * @param conf The Hadoop configuration.
+ *
+ * @note Instantiating this class directly is not recommended, please use
+ * `org.apache.spark.SparkContext.newAPIHadoopRDD()`
*/
@DeveloperApi
class NewHadoopRDD[K, V](
@@ -68,16 +72,14 @@ class NewHadoopRDD[K, V](
keyClass: Class[K],
valueClass: Class[V],
@transient private val _conf: Configuration)
- extends RDD[(K, V)](sc, Nil)
- with SparkHadoopMapReduceUtil
- with Logging {
+ extends RDD[(K, V)](sc, Nil) with Logging {
// A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
private val confBroadcast = sc.broadcast(new SerializableConfiguration(_conf))
// private val serializableConf = new SerializableWritable(_conf)
private val jobTrackerId: String = {
- val formatter = new SimpleDateFormat("yyyyMMddHHmm")
+ val formatter = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US)
formatter.format(new Date())
}
@@ -85,6 +87,8 @@ class NewHadoopRDD[K, V](
private val shouldCloneJobConf = sparkContext.conf.getBoolean("spark.hadoop.cloneConf", false)
+ private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES)
+
def getConf: Configuration = {
val conf: Configuration = confBroadcast.value.value
if (shouldCloneJobConf) {
@@ -97,7 +101,13 @@ class NewHadoopRDD[K, V](
// issues, this cloning is disabled by default.
NewHadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized {
logDebug("Cloning Hadoop Configuration")
- new Configuration(conf)
+ // The Configuration passed in is actually a JobConf and possibly contains credentials.
+ // To keep those credentials properly we have to create a new JobConf not a Configuration.
+ if (conf.isInstanceOf[JobConf]) {
+ new JobConf(conf)
+ } else {
+ new Configuration(conf)
+ }
}
} else {
conf
@@ -111,7 +121,7 @@ class NewHadoopRDD[K, V](
configurable.setConf(_conf)
case _ =>
}
- val jobContext = newJobContext(_conf, jobId)
+ val jobContext = new JobContextImpl(_conf, jobId)
val rawSplits = inputFormat.getSplits(jobContext).toArray
val result = new Array[Partition](rawSplits.size)
for (i <- 0 until rawSplits.size) {
@@ -122,45 +132,86 @@ class NewHadoopRDD[K, V](
override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = {
val iter = new Iterator[(K, V)] {
- val split = theSplit.asInstanceOf[NewHadoopPartition]
+ private val split = theSplit.asInstanceOf[NewHadoopPartition]
logInfo("Input split: " + split.serializableHadoopSplit)
- val conf = getConf
+ private val conf = getConf
- val inputMetrics = context.taskMetrics
- .getInputMetricsForReadMethod(DataReadMethod.Hadoop)
+ private val inputMetrics = context.taskMetrics().inputMetrics
+ private val existingBytesRead = inputMetrics.bytesRead
+
+ // Sets InputFileBlockHolder for the file block's information
+ split.serializableHadoopSplit.value match {
+ case fs: FileSplit =>
+ InputFileBlockHolder.set(fs.getPath.toString, fs.getStart, fs.getLength)
+ case _ =>
+ InputFileBlockHolder.unset()
+ }
// Find a function that will return the FileSystem bytes read by this thread. Do this before
// creating RecordReader, because RecordReader's constructor might read some bytes
- val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
+ private val getBytesReadCallback: Option[() => Long] =
split.serializableHadoopSplit.value match {
case _: FileSplit | _: CombineFileSplit =>
- SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
+ Some(SparkHadoopUtil.get.getFSBytesReadOnThreadCallback())
case _ => None
}
+
+ // We get our input bytes from thread-local Hadoop FileSystem statistics.
+ // If we do a coalesce, however, we are likely to compute multiple partitions in the same
+ // task and in the same thread, in which case we need to avoid override values written by
+ // previous partitions (SPARK-13071).
+ private def updateBytesRead(): Unit = {
+ getBytesReadCallback.foreach { getBytesRead =>
+ inputMetrics.setBytesRead(existingBytesRead + getBytesRead())
+ }
}
- inputMetrics.setBytesReadCallback(bytesReadCallback)
- val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0)
- val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
- val format = inputFormatClass.newInstance
+ private val format = inputFormatClass.newInstance
format match {
case configurable: Configurable =>
configurable.setConf(conf)
case _ =>
}
- private var reader = format.createRecordReader(
- split.serializableHadoopSplit.value, hadoopAttemptContext)
- reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
+ private val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0)
+ private val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+ private var finished = false
+ private var reader =
+ try {
+ val _reader = format.createRecordReader(
+ split.serializableHadoopSplit.value, hadoopAttemptContext)
+ _reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
+ _reader
+ } catch {
+ case e: IOException if ignoreCorruptFiles =>
+ logWarning(
+ s"Skipped the rest content in the corrupted file: ${split.serializableHadoopSplit}",
+ e)
+ finished = true
+ null
+ }
// Register an on-task-completion callback to close the input stream.
- context.addTaskCompletionListener(context => close())
- var havePair = false
- var finished = false
- var recordsSinceMetricsUpdate = 0
+ context.addTaskCompletionListener { context =>
+ // Update the bytesRead before closing is to make sure lingering bytesRead statistics in
+ // this thread get correctly added.
+ updateBytesRead()
+ close()
+ }
+
+ private var havePair = false
+ private var recordsSinceMetricsUpdate = 0
override def hasNext: Boolean = {
if (!finished && !havePair) {
- finished = !reader.nextKeyValue
+ try {
+ finished = !reader.nextKeyValue
+ } catch {
+ case e: IOException if ignoreCorruptFiles =>
+ logWarning(
+ s"Skipped the rest content in the corrupted file: ${split.serializableHadoopSplit}",
+ e)
+ finished = true
+ }
if (finished) {
// Close and release the reader here; close() will also be called when the task
// completes, but for tasks that read from many files, it helps to release the
@@ -180,15 +231,15 @@ class NewHadoopRDD[K, V](
if (!finished) {
inputMetrics.incRecordsRead(1)
}
+ if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) {
+ updateBytesRead()
+ }
(reader.getCurrentKey, reader.getCurrentValue)
}
- private def close() {
+ private def close(): Unit = {
if (reader != null) {
- // Close the reader and release it. Note: it's very important that we don't close the
- // reader more than once, since that exposes us to MAPREDUCE-5918 when running against
- // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic
- // corruption issues when reading compressed input.
+ InputFileBlockHolder.unset()
try {
reader.close()
} catch {
@@ -199,8 +250,8 @@ class NewHadoopRDD[K, V](
} finally {
reader = null
}
- if (bytesReadCallback.isDefined) {
- inputMetrics.updateBytesRead()
+ if (getBytesReadCallback.isDefined) {
+ updateBytesRead()
} else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] ||
split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) {
// If we can't get the bytes read from the FS stats, fall back to the split size,
@@ -228,18 +279,7 @@ class NewHadoopRDD[K, V](
override def getPreferredLocations(hsplit: Partition): Seq[String] = {
val split = hsplit.asInstanceOf[NewHadoopPartition].serializableHadoopSplit.value
- val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match {
- case Some(c) =>
- try {
- val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]]
- Some(HadoopRDD.convertSplitLocationInfo(infos))
- } catch {
- case e : Exception =>
- logDebug("Failed to use InputSplit#getLocationInfo.", e)
- None
- }
- case None => None
- }
+ val locs = HadoopRDD.convertSplitLocationInfo(split.getLocationInfo)
locs.getOrElse(split.getLocations.filter(_ != "localhost"))
}
@@ -282,32 +322,3 @@ private[spark] object NewHadoopRDD {
}
}
}
-
-private[spark] class WholeTextFileRDD(
- sc : SparkContext,
- inputFormatClass: Class[_ <: WholeTextFileInputFormat],
- keyClass: Class[String],
- valueClass: Class[String],
- conf: Configuration,
- minPartitions: Int)
- extends NewHadoopRDD[String, String](sc, inputFormatClass, keyClass, valueClass, conf) {
-
- override def getPartitions: Array[Partition] = {
- val inputFormat = inputFormatClass.newInstance
- val conf = getConf
- inputFormat match {
- case configurable: Configurable =>
- configurable.setConf(conf)
- case _ =>
- }
- val jobContext = newJobContext(conf, jobId)
- inputFormat.setMinPartitions(jobContext, minPartitions)
- val rawSplits = inputFormat.getSplits(jobContext).toArray
- val result = new Array[Partition](rawSplits.size)
- for (i <- 0 until rawSplits.size) {
- result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable])
- }
- result
- }
-}
-
diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
index d71bb6300090..a5992022d083 100644
--- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
@@ -19,8 +19,9 @@ package org.apache.spark.rdd
import scala.reflect.ClassTag
-import org.apache.spark.{Logging, Partitioner, RangePartitioner}
+import org.apache.spark.{Partitioner, RangePartitioner}
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.internal.Logging
/**
* Extra functions available on RDDs of (key, value) pairs where the key is sortable through
@@ -45,8 +46,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
V: ClassTag,
P <: Product2[K, V] : ClassTag] @DeveloperApi() (
self: RDD[P])
- extends Logging with Serializable
-{
+ extends Logging with Serializable {
private val ordering = implicitly[Ordering[K]]
/**
@@ -76,7 +76,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
}
/**
- * Returns an RDD containing only the elements in the the inclusive range `lower` to `upper`.
+ * Returns an RDD containing only the elements in the inclusive range `lower` to `upper`.
* If the RDD has been partitioned using a `RangePartitioner`, then this operation can be
* performed efficiently by only scanning the partitions that might contain matching elements.
* Otherwise, a standard `filter` is applied to all partitions.
@@ -86,12 +86,11 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
def inRange(k: K): Boolean = ordering.gteq(k, lower) && ordering.lteq(k, upper)
val rddToFilter: RDD[P] = self.partitioner match {
- case Some(rp: RangePartitioner[K, V]) => {
+ case Some(rp: RangePartitioner[K, V]) =>
val partitionIndicies = (rp.getPartition(lower), rp.getPartition(upper)) match {
case (l, u) => Math.min(l, u) to Math.max(l, u)
}
PartitionPruningRDD.create(self, partitionIndicies.contains)
- }
case _ =>
self
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index c6181902ace6..58762cc0838c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -18,33 +18,31 @@
package org.apache.spark.rdd
import java.nio.ByteBuffer
-import java.text.SimpleDateFormat
-import java.util.{Date, HashMap => JHashMap}
+import java.util.{HashMap => JHashMap}
-import scala.collection.{Map, mutable}
+import scala.collection.{mutable, Map}
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
-import scala.util.DynamicVariable
import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus
-import org.apache.hadoop.conf.{Configurable, Configuration}
+import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat}
-import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat,
- RecordWriter => NewRecordWriter}
+import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat}
import org.apache.spark._
import org.apache.spark.Partitioner.defaultPartitioner
import org.apache.spark.annotation.Experimental
import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.executor.{DataWriteMethod, OutputMetrics}
-import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
+import org.apache.spark.internal.io.{SparkHadoopMapReduceWriter, SparkHadoopWriter,
+ SparkHadoopWriterUtils}
+import org.apache.spark.internal.Logging
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.serializer.Serializer
-import org.apache.spark.util.{SerializableConfiguration, Utils}
+import org.apache.spark.util.Utils
import org.apache.spark.util.collection.CompactBuffer
import org.apache.spark.util.random.StratifiedSamplingUtils
@@ -53,24 +51,24 @@ import org.apache.spark.util.random.StratifiedSamplingUtils
*/
class PairRDDFunctions[K, V](self: RDD[(K, V)])
(implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null)
- extends Logging
- with SparkHadoopMapReduceUtil
- with Serializable
-{
+ extends Logging with Serializable {
/**
* :: Experimental ::
* Generic function to combine the elements for each key using a custom set of aggregation
* functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C
- * Note that V and C can be different -- for example, one might group an RDD of type
- * (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions:
*
- * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list)
- * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list)
- * - `mergeCombiners`, to combine two C's into a single one.
+ * Users provide three functions:
+ *
+ * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list)
+ * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list)
+ * - `mergeCombiners`, to combine two C's into a single one.
*
* In addition, users can control the partitioning of the output RDD, and whether to perform
* map-side aggregation (if a mapper can produce multiple items with the same key).
+ *
+ * @note V and C can be different -- for example, one might group an RDD of type
+ * (Int, Int) into an RDD of type (Int, Seq[Int]).
*/
@Experimental
def combineByKeyWithClassTag[C](
@@ -86,7 +84,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
throw new SparkException("Cannot use map-side combining with array keys.")
}
if (partitioner.isInstanceOf[HashPartitioner]) {
- throw new SparkException("Default partitioner cannot partition array keys.")
+ throw new SparkException("HashPartitioner cannot partition array keys.")
}
}
val aggregator = new Aggregator[K, V, C](
@@ -111,7 +109,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* functions. This method is here for backward compatibility. It does not provide combiner
* classtag information to the shuffle.
*
- * @see [[combineByKeyWithClassTag]]
+ * @see `combineByKeyWithClassTag`
*/
def combineByKey[C](
createCombiner: V => C,
@@ -129,7 +127,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* This method is here for backward compatibility. It does not provide combiner
* classtag information to the shuffle.
*
- * @see [[combineByKeyWithClassTag]]
+ * @see `combineByKeyWithClassTag`
*/
def combineByKey[C](
createCombiner: V => C,
@@ -304,27 +302,27 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
}
/**
- * Merge the values for each key using an associative reduce function. This will also perform
- * the merging locally on each mapper before sending results to a reducer, similarly to a
- * "combiner" in MapReduce.
+ * Merge the values for each key using an associative and commutative reduce function. This will
+ * also perform the merging locally on each mapper before sending results to a reducer, similarly
+ * to a "combiner" in MapReduce.
*/
def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = self.withScope {
combineByKeyWithClassTag[V]((v: V) => v, func, func, partitioner)
}
/**
- * Merge the values for each key using an associative reduce function. This will also perform
- * the merging locally on each mapper before sending results to a reducer, similarly to a
- * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions.
+ * Merge the values for each key using an associative and commutative reduce function. This will
+ * also perform the merging locally on each mapper before sending results to a reducer, similarly
+ * to a "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions.
*/
def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)] = self.withScope {
reduceByKey(new HashPartitioner(numPartitions), func)
}
/**
- * Merge the values for each key using an associative reduce function. This will also perform
- * the merging locally on each mapper before sending results to a reducer, similarly to a
- * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/
+ * Merge the values for each key using an associative and commutative reduce function. This will
+ * also perform the merging locally on each mapper before sending results to a reducer, similarly
+ * to a "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/
* parallelism level.
*/
def reduceByKey(func: (V, V) => V): RDD[(K, V)] = self.withScope {
@@ -332,9 +330,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
}
/**
- * Merge the values for each key using an associative reduce function, but return the results
- * immediately to the master as a Map. This will also perform the merging locally on each mapper
- * before sending results to a reducer, similarly to a "combiner" in MapReduce.
+ * Merge the values for each key using an associative and commutative reduce function, but return
+ * the results immediately to the master as a Map. This will also perform the merging locally on
+ * each mapper before sending results to a reducer, similarly to a "combiner" in MapReduce.
*/
def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = self.withScope {
val cleanedF = self.sparkContext.clean(func)
@@ -363,16 +361,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
self.mapPartitions(reducePartition).reduce(mergeMaps).asScala
}
- /** Alias for reduceByKeyLocally */
- @deprecated("Use reduceByKeyLocally", "1.0.0")
- def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = self.withScope {
- reduceByKeyLocally(func)
- }
-
/**
* Count the number of elements for each key, collecting the results to a local Map.
*
- * Note that this method should only be used if the resulting map is expected to be small, as
+ * @note This method should only be used if the resulting map is expected to be small, as
* the whole thing is loaded into the driver's memory.
* To handle very large results, consider using rdd.mapValues(_ => 1L).reduceByKey(_ + _), which
* returns an RDD[T, Long] instead of a map.
@@ -384,6 +376,16 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
/**
* Approximate version of countByKey that can return a partial result if it does
* not finish within a timeout.
+ *
+ * The confidence is the probability that the error bounds of the result will
+ * contain the true value. That is, if countApprox were called repeatedly
+ * with confidence 0.9, we would expect 90% of the results to contain the
+ * true count. The confidence must be in the range [0,1] or an exception will
+ * be thrown.
+ *
+ * @param timeout maximum time to wait for the job, in milliseconds
+ * @param confidence the desired statistical confidence in the result
+ * @return a potentially incomplete result, with error bounds
*/
def countByKeyApprox(timeout: Long, confidence: Double = 0.95)
: PartialResult[Map[K, BoundedDouble]] = self.withScope {
@@ -397,9 +399,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available
* here.
*
- * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero `sp > p`
- * would trigger sparse representation of registers, which may reduce the memory consumption
- * and increase accuracy when the cardinality is small.
+ * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero (`sp` is
+ * greater than `p`) would trigger sparse representation of registers, which may reduce the
+ * memory consumption and increase accuracy when the cardinality is small.
*
* @param p The precision value for the normal set.
* `p` must be a value between 4 and `sp` if `sp` is not zero (32 max).
@@ -489,12 +491,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* The ordering of elements within each group is not guaranteed, and may even differ
* each time the resulting RDD is evaluated.
*
- * Note: This operation may be very expensive. If you are grouping in order to perform an
- * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]]
- * or [[PairRDDFunctions.reduceByKey]] will provide much better performance.
+ * @note This operation may be very expensive. If you are grouping in order to perform an
+ * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey`
+ * or `PairRDDFunctions.reduceByKey` will provide much better performance.
*
- * Note: As currently implemented, groupByKey must be able to hold all the key-value pairs for any
- * key in memory. If a key has too many values, it can result in an [[OutOfMemoryError]].
+ * @note As currently implemented, groupByKey must be able to hold all the key-value pairs for any
+ * key in memory. If a key has too many values, it can result in an `OutOfMemoryError`.
*/
def groupByKey(partitioner: Partitioner): RDD[(K, Iterable[V])] = self.withScope {
// groupByKey shouldn't use map side combine because map side combine does not
@@ -513,12 +515,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* resulting RDD with into `numPartitions` partitions. The ordering of elements within
* each group is not guaranteed, and may even differ each time the resulting RDD is evaluated.
*
- * Note: This operation may be very expensive. If you are grouping in order to perform an
- * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]]
- * or [[PairRDDFunctions.reduceByKey]] will provide much better performance.
+ * @note This operation may be very expensive. If you are grouping in order to perform an
+ * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey`
+ * or `PairRDDFunctions.reduceByKey` will provide much better performance.
*
- * Note: As currently implemented, groupByKey must be able to hold all the key-value pairs for any
- * key in memory. If a key has too many values, it can result in an [[OutOfMemoryError]].
+ * @note As currently implemented, groupByKey must be able to hold all the key-value pairs for any
+ * key in memory. If a key has too many values, it can result in an `OutOfMemoryError`.
*/
def groupByKey(numPartitions: Int): RDD[(K, Iterable[V])] = self.withScope {
groupByKey(new HashPartitioner(numPartitions))
@@ -529,7 +531,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
*/
def partitionBy(partitioner: Partitioner): RDD[(K, V)] = self.withScope {
if (keyClass.isArray && partitioner.isInstanceOf[HashPartitioner]) {
- throw new SparkException("Default partitioner cannot partition array keys.")
+ throw new SparkException("HashPartitioner cannot partition array keys.")
}
if (self.partitioner == Some(partitioner)) {
self
@@ -606,7 +608,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* existing partitioner/parallelism level. This method is here for backward compatibility. It
* does not provide combiner classtag information to the shuffle.
*
- * @see [[combineByKeyWithClassTag]]
+ * @see `combineByKeyWithClassTag`
*/
def combineByKey[C](
createCombiner: V => C,
@@ -634,9 +636,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* within each group is not guaranteed, and may even differ each time the resulting RDD is
* evaluated.
*
- * Note: This operation may be very expensive. If you are grouping in order to perform an
- * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]]
- * or [[PairRDDFunctions.reduceByKey]] will provide much better performance.
+ * @note This operation may be very expensive. If you are grouping in order to perform an
+ * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey`
+ * or `PairRDDFunctions.reduceByKey` will provide much better performance.
*/
def groupByKey(): RDD[(K, Iterable[V])] = self.withScope {
groupByKey(defaultPartitioner(self))
@@ -736,6 +738,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
*
* Warning: this doesn't return a multimap (so if you have multiple values to the same key, only
* one value per key is preserved in the map returned)
+ *
+ * @note this method should only be used if the resulting data is expected to be small, as
+ * all the data is loaded into the driver's memory.
*/
def collectAsMap(): Map[K, V] = self.withScope {
val data = self.collect()
@@ -780,7 +785,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
partitioner: Partitioner)
: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = self.withScope {
if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) {
- throw new SparkException("Default partitioner cannot partition array keys.")
+ throw new SparkException("HashPartitioner cannot partition array keys.")
}
val cg = new CoGroupedRDD[K](Seq(self, other1, other2, other3), partitioner)
cg.mapValues { case Array(vs, w1s, w2s, w3s) =>
@@ -798,7 +803,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner)
: RDD[(K, (Iterable[V], Iterable[W]))] = self.withScope {
if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) {
- throw new SparkException("Default partitioner cannot partition array keys.")
+ throw new SparkException("HashPartitioner cannot partition array keys.")
}
val cg = new CoGroupedRDD[K](Seq(self, other), partitioner)
cg.mapValues { case Array(vs, w1s) =>
@@ -813,7 +818,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner)
: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = self.withScope {
if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) {
- throw new SparkException("Default partitioner cannot partition array keys.")
+ throw new SparkException("HashPartitioner cannot partition array keys.")
}
val cg = new CoGroupedRDD[K](Seq(self, other1, other2), partitioner)
cg.mapValues { case Array(vs, w1s, w2s) =>
@@ -903,20 +908,24 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* Return an RDD with the pairs from `this` whose keys are not in `other`.
*
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
- * RDD will be <= us.
+ * RDD will be less than or equal to us.
*/
def subtractByKey[W: ClassTag](other: RDD[(K, W)]): RDD[(K, V)] = self.withScope {
subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.length)))
}
- /** Return an RDD with the pairs from `this` whose keys are not in `other`. */
+ /**
+ * Return an RDD with the pairs from `this` whose keys are not in `other`.
+ */
def subtractByKey[W: ClassTag](
other: RDD[(K, W)],
numPartitions: Int): RDD[(K, V)] = self.withScope {
subtractByKey(other, new HashPartitioner(numPartitions))
}
- /** Return an RDD with the pairs from `this` whose keys are not in `other`. */
+ /**
+ * Return an RDD with the pairs from `this` whose keys are not in `other`.
+ */
def subtractByKey[W: ClassTag](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] = self.withScope {
new SubtractedRDD[K, V, W](self, other, p)
}
@@ -985,12 +994,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
conf: Configuration = self.context.hadoopConfiguration): Unit = self.withScope {
// Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038).
val hadoopConf = conf
- val job = new NewAPIHadoopJob(hadoopConf)
+ val job = NewAPIHadoopJob.getInstance(hadoopConf)
job.setOutputKeyClass(keyClass)
job.setOutputValueClass(valueClass)
job.setOutputFormatClass(outputFormatClass)
- val jobConfiguration = SparkHadoopUtil.get.getConfigurationFromJobContext(job)
- jobConfiguration.set("mapred.output.dir", path)
+ val jobConfiguration = job.getConfiguration
+ jobConfiguration.set("mapreduce.output.fileoutputformat.outputdir", path)
saveAsNewAPIHadoopDataset(jobConfiguration)
}
@@ -1012,7 +1021,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
* supporting the key and value types K and V in this RDD.
*
- * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do
+ * @note We should make sure our tasks are idempotent when speculation is enabled, i.e. do
* not use output committer that writes data directly.
* There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad
* result of using direct output committer with speculation enabled.
@@ -1031,10 +1040,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
conf.setOutputFormat(outputFormatClass)
for (c <- codec) {
hadoopConf.setCompressMapOutput(true)
- hadoopConf.set("mapred.output.compress", "true")
+ hadoopConf.set("mapreduce.output.fileoutputformat.compress", "true")
hadoopConf.setMapOutputCompressorClass(c)
- hadoopConf.set("mapred.output.compression.codec", c.getCanonicalName)
- hadoopConf.set("mapred.output.compression.type", CompressionType.BLOCK.toString)
+ hadoopConf.set("mapreduce.output.fileoutputformat.compress.codec", c.getCanonicalName)
+ hadoopConf.set("mapreduce.output.fileoutputformat.compress.type",
+ CompressionType.BLOCK.toString)
}
// Use configured output committer if already set
@@ -1050,13 +1060,13 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val warningMessage =
s"$outputCommitterClass may be an output committer that writes data directly to " +
"the final location. Because speculation is enabled, this output committer may " +
- "cause data loss (see the case in SPARK-10063). If possible, please use a output " +
+ "cause data loss (see the case in SPARK-10063). If possible, please use an output " +
"committer that does not have this behavior (e.g. FileOutputCommitter)."
logWarning(warningMessage)
}
FileOutputFormat.setOutputPath(hadoopConf,
- SparkHadoopWriter.createPathFromString(path, hadoopConf))
+ SparkHadoopWriterUtils.createPathFromString(path, hadoopConf))
saveAsHadoopDataset(hadoopConf)
}
@@ -1066,85 +1076,15 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* output paths required (e.g. a table name to write to) in the same way as it would be
* configured for a Hadoop MapReduce job.
*
- * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do
+ * @note We should make sure our tasks are idempotent when speculation is enabled, i.e. do
* not use output committer that writes data directly.
* There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad
* result of using direct output committer with speculation enabled.
*/
def saveAsNewAPIHadoopDataset(conf: Configuration): Unit = self.withScope {
- // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038).
- val hadoopConf = conf
- val job = new NewAPIHadoopJob(hadoopConf)
- val formatter = new SimpleDateFormat("yyyyMMddHHmm")
- val jobtrackerID = formatter.format(new Date())
- val stageId = self.id
- val jobConfiguration = SparkHadoopUtil.get.getConfigurationFromJobContext(job)
- val wrappedConf = new SerializableConfiguration(jobConfiguration)
- val outfmt = job.getOutputFormatClass
- val jobFormat = outfmt.newInstance
-
- if (isOutputSpecValidationEnabled) {
- // FileOutputFormat ignores the filesystem parameter
- jobFormat.checkOutputSpecs(job)
- }
-
- val writeShard = (context: TaskContext, iter: Iterator[(K, V)]) => {
- val config = wrappedConf.value
- /* "reduce task" */
- val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
- context.attemptNumber)
- val hadoopContext = newTaskAttemptContext(config, attemptId)
- val format = outfmt.newInstance
- format match {
- case c: Configurable => c.setConf(config)
- case _ => ()
- }
- val committer = format.getOutputCommitter(hadoopContext)
- committer.setupTask(hadoopContext)
-
- val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context)
-
- val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K, V]]
- require(writer != null, "Unable to obtain RecordWriter")
- var recordsWritten = 0L
- Utils.tryWithSafeFinally {
- while (iter.hasNext) {
- val pair = iter.next()
- writer.write(pair._1, pair._2)
-
- // Update bytes written metric every few records
- maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten)
- recordsWritten += 1
- }
- } {
- writer.close(hadoopContext)
- }
- committer.commitTask(hadoopContext)
- bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) }
- outputMetrics.setRecordsWritten(recordsWritten)
- 1
- } : Int
-
- val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0)
- val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId)
- val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
-
- // When speculation is on and output committer class name contains "Direct", we should warn
- // users that they may loss data if they are using a direct output committer.
- val speculationEnabled = self.conf.getBoolean("spark.speculation", false)
- val outputCommitterClass = jobCommitter.getClass.getSimpleName
- if (speculationEnabled && outputCommitterClass.contains("Direct")) {
- val warningMessage =
- s"$outputCommitterClass may be an output committer that writes data directly to " +
- "the final location. Because speculation is enabled, this output committer may " +
- "cause data loss (see the case in SPARK-10063). If possible, please use a output " +
- "committer that does not have this behavior (e.g. FileOutputCommitter)."
- logWarning(warningMessage)
- }
-
- jobCommitter.setupJob(jobTaskContext)
- self.context.runJob(self, writeShard)
- jobCommitter.commitJob(jobTaskContext)
+ SparkHadoopMapReduceWriter.write(
+ rdd = self,
+ hadoopConf = conf)
}
/**
@@ -1173,7 +1113,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " +
valueClass.getSimpleName + ")")
- if (isOutputSpecValidationEnabled) {
+ if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(self.conf)) {
// FileOutputFormat ignores the filesystem parameter
val ignoredFs = FileSystem.get(hadoopConf)
hadoopConf.getOutputFormat.checkOutputSpecs(ignoredFs, hadoopConf)
@@ -1187,26 +1127,24 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt
- val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context)
+ val (outputMetrics, callback) = SparkHadoopWriterUtils.initHadoopOutputMetrics(context)
writer.setup(context.stageId, context.partitionId, taskAttemptId)
writer.open()
var recordsWritten = 0L
- Utils.tryWithSafeFinally {
+ Utils.tryWithSafeFinallyAndFailureCallbacks {
while (iter.hasNext) {
val record = iter.next()
writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
// Update bytes written metric every few records
- maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten)
+ SparkHadoopWriterUtils.maybeUpdateOutputMetrics(outputMetrics, callback, recordsWritten)
recordsWritten += 1
}
- } {
- writer.close()
- }
+ }(finallyBlock = writer.close())
writer.commit()
- bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) }
+ outputMetrics.setBytesWritten(callback())
outputMetrics.setRecordsWritten(recordsWritten)
}
@@ -1214,23 +1152,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
writer.commitJob()
}
- private def initHadoopOutputMetrics(context: TaskContext): (OutputMetrics, Option[() => Long]) = {
- val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback()
- val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop)
- if (bytesWrittenCallback.isDefined) {
- context.taskMetrics.outputMetrics = Some(outputMetrics)
- }
- (outputMetrics, bytesWrittenCallback)
- }
-
- private def maybeUpdateOutputMetrics(bytesWrittenCallback: Option[() => Long],
- outputMetrics: OutputMetrics, recordsWritten: Long): Unit = {
- if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0) {
- bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) }
- outputMetrics.setRecordsWritten(recordsWritten)
- }
- }
-
/**
* Return an RDD with the keys of each tuple.
*/
@@ -1246,22 +1167,4 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
private[spark] def valueClass: Class[_] = vt.runtimeClass
private[spark] def keyOrdering: Option[Ordering[K]] = Option(ord)
-
- // Note: this needs to be a function instead of a 'val' so that the disableOutputSpecValidation
- // setting can take effect:
- private def isOutputSpecValidationEnabled: Boolean = {
- val validationDisabled = PairRDDFunctions.disableOutputSpecValidation.value
- val enabledInConf = self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true)
- enabledInConf && !validationDisabled
- }
-}
-
-private[spark] object PairRDDFunctions {
- val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256
-
- /**
- * Allows for the `spark.hadoop.validateOutputSpecs` checks to be disabled on a case-by-case
- * basis; see SPARK-4835 for more details.
- */
- val disableOutputSpecValidation: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false)
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
index 582fa93afe34..9f8019b80a4d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
@@ -32,8 +32,8 @@ import org.apache.spark.util.Utils
private[spark] class ParallelCollectionPartition[T: ClassTag](
var rddId: Long,
var slice: Int,
- var values: Seq[T])
- extends Partition with Serializable {
+ var values: Seq[T]
+ ) extends Partition with Serializable {
def iterator: Iterator[T] = values.iterator
@@ -116,20 +116,20 @@ private object ParallelCollectionRDD {
*/
def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
if (numSlices < 1) {
- throw new IllegalArgumentException("Positive number of slices required")
+ throw new IllegalArgumentException("Positive number of partitions required")
}
// Sequences need to be sliced at the same set of index positions for operations
// like RDD.zip() to behave as expected
def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = {
- (0 until numSlices).iterator.map(i => {
+ (0 until numSlices).iterator.map { i =>
val start = ((i * length) / numSlices).toInt
val end = (((i + 1) * length) / numSlices).toInt
(start, end)
- })
+ }
}
seq match {
- case r: Range => {
- positions(r.length, numSlices).zipWithIndex.map({ case ((start, end), index) =>
+ case r: Range =>
+ positions(r.length, numSlices).zipWithIndex.map { case ((start, end), index) =>
// If the range is inclusive, use inclusive range for the last slice
if (r.isInclusive && index == numSlices - 1) {
new Range.Inclusive(r.start + start * r.step, r.end, r.step)
@@ -137,9 +137,8 @@ private object ParallelCollectionRDD {
else {
new Range(r.start + start * r.step, r.start + end * r.step, r.step)
}
- }).toSeq.asInstanceOf[Seq[Seq[T]]]
- }
- case nr: NumericRange[_] => {
+ }.toSeq.asInstanceOf[Seq[Seq[T]]]
+ case nr: NumericRange[_] =>
// For ranges of Long, Double, BigInteger, etc
val slices = new ArrayBuffer[Seq[T]](numSlices)
var r = nr
@@ -149,14 +148,11 @@ private object ParallelCollectionRDD {
r = r.drop(sliceSize)
}
slices
- }
- case _ => {
+ case _ =>
val array = seq.toArray // To prevent O(n^2) operations for List etc
- positions(array.length, numSlices).map({
- case (start, end) =>
+ positions(array.length, numSlices).map { case (start, end) =>
array.slice(start, end).toSeq
- }).toSeq
- }
+ }.toSeq
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
index d6a37e8cc5da..ce75a16031a3 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
@@ -48,7 +48,7 @@ private[spark] class PruneDependency[T](rdd: RDD[T], partitionFilterFunc: Int =>
/**
* :: DeveloperApi ::
- * A RDD used to prune RDD partitions/partitions so we can avoid launching tasks on
+ * An RDD used to prune RDD partitions/partitions so we can avoid launching tasks on
* all partitions. An example use case: If we know the RDD is partitioned by range,
* and the execution DAG has a filter on the key, we can avoid launching tasks
* on partitions that don't have the range covering the key.
@@ -65,7 +65,7 @@ class PartitionPruningRDD[T: ClassTag](
}
override protected def getPartitions: Array[Partition] =
- getDependencies.head.asInstanceOf[PruneDependency[T]].partitions
+ dependencies.head.asInstanceOf[PruneDependency[T]].partitions
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
index 9e3880714a79..d744d6759254 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
@@ -31,12 +31,13 @@ import org.apache.spark.util.Utils
private[spark]
class PartitionerAwareUnionRDDPartition(
@transient val rdds: Seq[RDD[_]],
- val idx: Int
+ override val index: Int
) extends Partition {
- var parents = rdds.map(_.partitions(idx)).toArray
+ var parents = rdds.map(_.partitions(index)).toArray
- override val index = idx
- override def hashCode(): Int = idx
+ override def hashCode(): Int = index
+
+ override def equals(other: Any): Boolean = super.equals(other)
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
@@ -59,7 +60,7 @@ class PartitionerAwareUnionRDD[T: ClassTag](
sc: SparkContext,
var rdds: Seq[RDD[T]]
) extends RDD[T](sc, rdds.map(x => new OneToOneDependency(x))) {
- require(rdds.length > 0)
+ require(rdds.nonEmpty)
require(rdds.forall(_.partitioner.isDefined))
require(rdds.flatMap(_.partitioner).toSet.size == 1,
"Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner))
@@ -68,9 +69,9 @@ class PartitionerAwareUnionRDD[T: ClassTag](
override def getPartitions: Array[Partition] = {
val numPartitions = partitioner.get.numPartitions
- (0 until numPartitions).map(index => {
+ (0 until numPartitions).map { index =>
new PartitionerAwareUnionRDDPartition(rdds, index)
- }).toArray
+ }.toArray
}
// Get the location where most of the partitions of parent RDDs are located
@@ -78,11 +79,10 @@ class PartitionerAwareUnionRDD[T: ClassTag](
logDebug("Finding preferred location for " + this + ", partition " + s.index)
val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents
val locations = rdds.zip(parentPartitions).flatMap {
- case (rdd, part) => {
+ case (rdd, part) =>
val parentLocations = currPrefLocs(rdd, part)
logDebug("Location of " + rdd + " partition " + part.index + " = " + parentLocations)
parentLocations
- }
}
val location = if (locations.isEmpty) {
None
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
index 3b1acacf409b..6a89ea878646 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
@@ -32,7 +32,7 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
}
/**
- * A RDD sampled from its parent RDD partition-wise. For each partition of the parent RDD,
+ * An RDD sampled from its parent RDD partition-wise. For each partition of the parent RDD,
* a user-specified [[org.apache.spark.util.random.RandomSampler]] instance is used to obtain
* a random sample of the records in the partition. The random seeds assigned to the samplers
* are guaranteed to have different values.
diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
index afbe566b7656..02b28b72fb0e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
@@ -17,11 +17,14 @@
package org.apache.spark.rdd
+import java.io.BufferedWriter
import java.io.File
import java.io.FilenameFilter
import java.io.IOException
+import java.io.OutputStreamWriter
import java.io.PrintWriter
import java.util.StringTokenizer
+import java.util.concurrent.atomic.AtomicReference
import scala.collection.JavaConverters._
import scala.collection.Map
@@ -43,22 +46,11 @@ private[spark] class PipedRDD[T: ClassTag](
envVars: Map[String, String],
printPipeContext: (String => Unit) => Unit,
printRDDElement: (T, String => Unit) => Unit,
- separateWorkingDir: Boolean)
+ separateWorkingDir: Boolean,
+ bufferSize: Int,
+ encoding: String)
extends RDD[String](prev) {
- // Similar to Runtime.exec(), if we are given a single string, split it into words
- // using a standard StringTokenizer (i.e. by spaces)
- def this(
- prev: RDD[T],
- command: String,
- envVars: Map[String, String] = Map(),
- printPipeContext: (String => Unit) => Unit = null,
- printRDDElement: (T, String => Unit) => Unit = null,
- separateWorkingDir: Boolean = false) =
- this(prev, PipedRDD.tokenize(command), envVars, printPipeContext, printRDDElement,
- separateWorkingDir)
-
-
override def getPartitions: Array[Partition] = firstParent[T].partitions
/**
@@ -118,63 +110,99 @@ private[spark] class PipedRDD[T: ClassTag](
val proc = pb.start()
val env = SparkEnv.get
+ val childThreadException = new AtomicReference[Throwable](null)
// Start a thread to print the process's stderr to ours
- new Thread("stderr reader for " + command) {
- override def run() {
- for (line <- Source.fromInputStream(proc.getErrorStream).getLines) {
- // scalastyle:off println
- System.err.println(line)
- // scalastyle:on println
+ new Thread(s"stderr reader for $command") {
+ override def run(): Unit = {
+ val err = proc.getErrorStream
+ try {
+ for (line <- Source.fromInputStream(err)(encoding).getLines) {
+ // scalastyle:off println
+ System.err.println(line)
+ // scalastyle:on println
+ }
+ } catch {
+ case t: Throwable => childThreadException.set(t)
+ } finally {
+ err.close()
}
}
}.start()
// Start a thread to feed the process input from our parent's iterator
- new Thread("stdin writer for " + command) {
- override def run() {
+ new Thread(s"stdin writer for $command") {
+ override def run(): Unit = {
TaskContext.setTaskContext(context)
- val out = new PrintWriter(proc.getOutputStream)
-
- // scalastyle:off println
- // input the pipe context firstly
- if (printPipeContext != null) {
- printPipeContext(out.println(_))
- }
- for (elem <- firstParent[T].iterator(split, context)) {
- if (printRDDElement != null) {
- printRDDElement(elem, out.println(_))
- } else {
- out.println(elem)
+ val out = new PrintWriter(new BufferedWriter(
+ new OutputStreamWriter(proc.getOutputStream, encoding), bufferSize))
+ try {
+ // scalastyle:off println
+ // input the pipe context firstly
+ if (printPipeContext != null) {
+ printPipeContext(out.println)
}
+ for (elem <- firstParent[T].iterator(split, context)) {
+ if (printRDDElement != null) {
+ printRDDElement(elem, out.println)
+ } else {
+ out.println(elem)
+ }
+ }
+ // scalastyle:on println
+ } catch {
+ case t: Throwable => childThreadException.set(t)
+ } finally {
+ out.close()
}
- // scalastyle:on println
- out.close()
}
}.start()
// Return an iterator that read lines from the process's stdout
- val lines = Source.fromInputStream(proc.getInputStream).getLines()
+ val lines = Source.fromInputStream(proc.getInputStream)(encoding).getLines
new Iterator[String] {
- def next(): String = lines.next()
- def hasNext: Boolean = {
- if (lines.hasNext) {
+ def next(): String = {
+ if (!hasNext()) {
+ throw new NoSuchElementException()
+ }
+ lines.next()
+ }
+
+ def hasNext(): Boolean = {
+ val result = if (lines.hasNext) {
true
} else {
val exitStatus = proc.waitFor()
+ cleanup()
if (exitStatus != 0) {
- throw new Exception("Subprocess exited with status " + exitStatus)
+ throw new IllegalStateException(s"Subprocess exited with status $exitStatus. " +
+ s"Command ran: " + command.mkString(" "))
}
+ false
+ }
+ propagateChildException()
+ result
+ }
- // cleanup task working directory if used
- if (workInTaskDirectory) {
- scala.util.control.Exception.ignoring(classOf[IOException]) {
- Utils.deleteRecursively(new File(taskDirectory))
- }
- logDebug("Removed task working directory " + taskDirectory)
+ private def cleanup(): Unit = {
+ // cleanup task working directory if used
+ if (workInTaskDirectory) {
+ scala.util.control.Exception.ignoring(classOf[IOException]) {
+ Utils.deleteRecursively(new File(taskDirectory))
}
+ logDebug(s"Removed task working directory $taskDirectory")
+ }
+ }
- false
+ private def propagateChildException(): Unit = {
+ val t = childThreadException.get()
+ if (t != null) {
+ val commandRan = command.mkString(" ")
+ logError(s"Caught exception while running pipe() operator. Command ran: $commandRan. " +
+ s"Exception: ${t.getMessage}")
+ proc.destroy()
+ cleanup()
+ throw t
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 800ef53cbef0..4ff0f83263db 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -21,6 +21,7 @@ import java.util.Random
import scala.collection.{mutable, Map}
import scala.collection.mutable.ArrayBuffer
+import scala.io.Codec
import scala.language.implicitConversions
import scala.reflect.{classTag, ClassTag}
@@ -31,16 +32,17 @@ import org.apache.hadoop.mapred.TextOutputFormat
import org.apache.spark._
import org.apache.spark.Partitioner._
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.internal.Logging
import org.apache.spark.partial.BoundedDouble
import org.apache.spark.partial.CountEvaluator
import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.{RDDBlockId, StorageLevel}
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
-import org.apache.spark.util.collection.OpenHashMap
-import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, BernoulliCellSampler,
+import org.apache.spark.util.collection.{OpenHashMap, Utils => collectionUtils}
+import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler,
SamplingUtils}
/**
@@ -68,8 +70,8 @@ import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, Bernoulli
* All of the scheduling and execution in Spark is done based on these methods, allowing each RDD
* to implement its own way of computing itself. Indeed, users can implement custom RDDs (e.g. for
* reading data from a new storage system) by overriding these functions. Please refer to the
- * [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details
- * on RDD internals.
+ * Spark paper
+ * for more details on RDD internals.
*/
abstract class RDD[T: ClassTag](
@transient private var _sc: SparkContext,
@@ -85,17 +87,21 @@ abstract class RDD[T: ClassTag](
private def sc: SparkContext = {
if (_sc == null) {
throw new SparkException(
- "RDD transformations and actions can only be invoked by the driver, not inside of other " +
- "transformations; for example, rdd1.map(x => rdd2.values.count() * x) is invalid because " +
- "the values transformation and count action cannot be performed inside of the rdd1.map " +
- "transformation. For more information, see SPARK-5063.")
+ "This RDD lacks a SparkContext. It could happen in the following cases: \n(1) RDD " +
+ "transformations and actions are NOT invoked by the driver, but inside of other " +
+ "transformations; for example, rdd1.map(x => rdd2.values.count() * x) is invalid " +
+ "because the values transformation and count action cannot be performed inside of the " +
+ "rdd1.map transformation. For more information, see SPARK-5063.\n(2) When a Spark " +
+ "Streaming job recovers from checkpoint, this exception will be hit if a reference to " +
+ "an RDD not defined by the streaming job is used in DStream operations. For more " +
+ "information, See SPARK-13758.")
}
_sc
}
/** Construct an RDD with just a one-to-one dependency on one parent */
def this(@transient oneParent: RDD[_]) =
- this(oneParent.context , List(new OneToOneDependency(oneParent)))
+ this(oneParent.context, List(new OneToOneDependency(oneParent)))
private[spark] def conf = sc.conf
// =======================================================================
@@ -112,6 +118,9 @@ abstract class RDD[T: ClassTag](
/**
* Implemented by subclasses to return the set of partitions in this RDD. This method will only
* be called once, so it is safe to implement a time-consuming computation in it.
+ *
+ * The partitions in this array must satisfy the following property:
+ * `rdd.partitions.zipWithIndex.forall { case (partition, index) => partition.index == index }`
*/
protected def getPartitions: Array[Partition]
@@ -186,10 +195,14 @@ abstract class RDD[T: ClassTag](
}
}
- /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
+ /**
+ * Persist this RDD with the default storage level (`MEMORY_ONLY`).
+ */
def persist(): this.type = persist(StorageLevel.MEMORY_ONLY)
- /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
+ /**
+ * Persist this RDD with the default storage level (`MEMORY_ONLY`).
+ */
def cache(): this.type = persist()
/**
@@ -237,11 +250,21 @@ abstract class RDD[T: ClassTag](
checkpointRDD.map(_.partitions).getOrElse {
if (partitions_ == null) {
partitions_ = getPartitions
+ partitions_.zipWithIndex.foreach { case (partition, index) =>
+ require(partition.index == index,
+ s"partitions($index).partition == ${partition.index}, but it should equal $index")
+ }
}
partitions_
}
}
+ /**
+ * Returns the number of partitions of this RDD.
+ */
+ @Since("1.6.0")
+ final def getNumPartitions: Int = partitions.length
+
/**
* Get the preferred locations of a partition, taking into account whether the
* RDD is checkpointed.
@@ -259,7 +282,7 @@ abstract class RDD[T: ClassTag](
*/
final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
- SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
+ getOrCompute(split, context)
} else {
computeOrReadCheckpoint(split, context)
}
@@ -301,6 +324,35 @@ abstract class RDD[T: ClassTag](
}
}
+ /**
+ * Gets or computes an RDD partition. Used by RDD.iterator() when an RDD is cached.
+ */
+ private[spark] def getOrCompute(partition: Partition, context: TaskContext): Iterator[T] = {
+ val blockId = RDDBlockId(id, partition.index)
+ var readCachedBlock = true
+ // This method is called on executors, so we need call SparkEnv.get instead of sc.env.
+ SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, elementClassTag, () => {
+ readCachedBlock = false
+ computeOrReadCheckpoint(partition, context)
+ }) match {
+ case Left(blockResult) =>
+ if (readCachedBlock) {
+ val existingMetrics = context.taskMetrics().inputMetrics
+ existingMetrics.incBytesRead(blockResult.bytes)
+ new InterruptibleIterator[T](context, blockResult.data.asInstanceOf[Iterator[T]]) {
+ override def next(): T = {
+ existingMetrics.incRecordsRead(1)
+ delegate.next()
+ }
+ }
+ } else {
+ new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]])
+ }
+ case Right(iter) =>
+ new InterruptibleIterator(context, iter.asInstanceOf[Iterator[T]])
+ }
+ }
+
/**
* Execute a block of code in a scope such that all new RDDs created in this body will
* be part of the same scope. For more detail, see {{org.apache.spark.rdd.RDDOperationScope}}.
@@ -361,6 +413,8 @@ abstract class RDD[T: ClassTag](
*
* If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
* which can avoid performing a shuffle.
+ *
+ * TODO Fix the Shuffle+Repartition data loss issue described in SPARK-23207.
*/
def repartition(numPartitions: Int)(implicit ord: Ordering[T] = null): RDD[T] = withScope {
coalesce(numPartitions, shuffle = true)
@@ -371,7 +425,8 @@ abstract class RDD[T: ClassTag](
*
* This results in a narrow dependency, e.g. if you go from 1000 partitions
* to 100 partitions, there will not be a shuffle, instead each of the 100
- * new partitions will claim 10 of the current partitions.
+ * new partitions will claim 10 of the current partitions. If a larger number
+ * of partitions is requested, it will stay at the current number of partitions.
*
* However, if you're doing a drastic coalesce, e.g. to numPartitions = 1,
* this may result in your computation taking place on fewer nodes than
@@ -380,14 +435,18 @@ abstract class RDD[T: ClassTag](
* current upstream partitions will be executed in parallel (per whatever
* the current partitioning is).
*
- * Note: With shuffle = true, you can actually coalesce to a larger number
+ * @note With shuffle = true, you can actually coalesce to a larger number
* of partitions. This is useful if you have a small number of partitions,
* say 100, potentially with a few partitions being abnormally large. Calling
* coalesce(1000, shuffle = true) will result in 1000 partitions with the
- * data distributed using a hash partitioner.
+ * data distributed using a hash partitioner. The optional partition coalescer
+ * passed in must be serializable.
*/
- def coalesce(numPartitions: Int, shuffle: Boolean = false)(implicit ord: Ordering[T] = null)
+ def coalesce(numPartitions: Int, shuffle: Boolean = false,
+ partitionCoalescer: Option[PartitionCoalescer] = Option.empty)
+ (implicit ord: Ordering[T] = null)
: RDD[T] = withScope {
+ require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.")
if (shuffle) {
/** Distributes elements evenly across output partitions, starting from a random partition. */
val distributePartition = (index: Int, items: Iterator[T]) => {
@@ -402,11 +461,13 @@ abstract class RDD[T: ClassTag](
// include a shuffle step so that our upstream tasks are still distributed
new CoalescedRDD(
- new ShuffledRDD[Int, T, T](mapPartitionsWithIndex(distributePartition),
- new HashPartitioner(numPartitions)),
- numPartitions).values
+ new ShuffledRDD[Int, T, T](
+ mapPartitionsWithIndexInternal(distributePartition, isOrderSensitive = true),
+ new HashPartitioner(numPartitions)),
+ numPartitions,
+ partitionCoalescer).values
} else {
- new CoalescedRDD(this, numPartitions)
+ new CoalescedRDD(this, numPartitions, partitionCoalescer)
}
}
@@ -416,18 +477,27 @@ abstract class RDD[T: ClassTag](
* @param withReplacement can elements be sampled multiple times (replaced when sampled out)
* @param fraction expected size of the sample as a fraction of this RDD's size
* without replacement: probability that each element is chosen; fraction must be [0, 1]
- * with replacement: expected number of times each element is chosen; fraction must be >= 0
+ * with replacement: expected number of times each element is chosen; fraction must be greater
+ * than or equal to 0
* @param seed seed for the random number generator
+ *
+ * @note This is NOT guaranteed to provide exactly the fraction of the count
+ * of the given [[RDD]].
*/
def sample(
withReplacement: Boolean,
fraction: Double,
- seed: Long = Utils.random.nextLong): RDD[T] = withScope {
- require(fraction >= 0.0, "Negative fraction value: " + fraction)
- if (withReplacement) {
- new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), true, seed)
- } else {
- new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), true, seed)
+ seed: Long = Utils.random.nextLong): RDD[T] = {
+ require(fraction >= 0,
+ s"Fraction must be nonnegative, but got ${fraction}")
+
+ withScope {
+ require(fraction >= 0.0, "Negative fraction value: " + fraction)
+ if (withReplacement) {
+ new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), true, seed)
+ } else {
+ new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), true, seed)
+ }
}
}
@@ -441,14 +511,22 @@ abstract class RDD[T: ClassTag](
*/
def randomSplit(
weights: Array[Double],
- seed: Long = Utils.random.nextLong): Array[RDD[T]] = withScope {
- val sum = weights.sum
- val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
- normalizedCumWeights.sliding(2).map { x =>
- randomSampleWithRange(x(0), x(1), seed)
- }.toArray
+ seed: Long = Utils.random.nextLong): Array[RDD[T]] = {
+ require(weights.forall(_ >= 0),
+ s"Weights must be nonnegative, but got ${weights.mkString("[", ",", "]")}")
+ require(weights.sum > 0,
+ s"Sum of weights must be positive, but got ${weights.mkString("[", ",", "]")}")
+
+ withScope {
+ val sum = weights.sum
+ val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
+ normalizedCumWeights.sliding(2).map { x =>
+ randomSampleWithRange(x(0), x(1), seed)
+ }.toArray
+ }
}
+
/**
* Internal method exposed for Random Splits in DataFrames. Samples an RDD given a probability
* range.
@@ -472,6 +550,9 @@ abstract class RDD[T: ClassTag](
* @param num size of the returned sample
* @param seed seed for the random number generator
* @return sample of specified size in an array
+ *
+ * @note this method should only be used if the resulting array is expected to be small, as
+ * all the data is loaded into the driver's memory.
*/
def takeSample(
withReplacement: Boolean,
@@ -518,11 +599,7 @@ abstract class RDD[T: ClassTag](
* times (use `.distinct()` to eliminate them).
*/
def union(other: RDD[T]): RDD[T] = withScope {
- if (partitioner.isDefined && other.partitioner == partitioner) {
- new PartitionerAwareUnionRDD(sc, Array(this, other))
- } else {
- new UnionRDD(sc, Array(this, other))
- }
+ sc.union(this, other)
}
/**
@@ -550,7 +627,7 @@ abstract class RDD[T: ClassTag](
* Return the intersection of this RDD and another one. The output will not contain any duplicate
* elements, even if the input RDDs did.
*
- * Note that this method performs a shuffle internally.
+ * @note This method performs a shuffle internally.
*/
def intersection(other: RDD[T]): RDD[T] = withScope {
this.map(v => (v, null)).cogroup(other.map(v => (v, null)))
@@ -562,7 +639,7 @@ abstract class RDD[T: ClassTag](
* Return the intersection of this RDD and another one. The output will not contain any duplicate
* elements, even if the input RDDs did.
*
- * Note that this method performs a shuffle internally.
+ * @note This method performs a shuffle internally.
*
* @param partitioner Partitioner to use for the resulting RDD
*/
@@ -578,7 +655,7 @@ abstract class RDD[T: ClassTag](
* Return the intersection of this RDD and another one. The output will not contain any duplicate
* elements, even if the input RDDs did. Performs a hash partition across the cluster
*
- * Note that this method performs a shuffle internally.
+ * @note This method performs a shuffle internally.
*
* @param numPartitions How many partitions to use in the resulting RDD
*/
@@ -606,9 +683,9 @@ abstract class RDD[T: ClassTag](
* mapping to that key. The ordering of elements within each group is not guaranteed, and
* may even differ each time the resulting RDD is evaluated.
*
- * Note: This operation may be very expensive. If you are grouping in order to perform an
- * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]]
- * or [[PairRDDFunctions.reduceByKey]] will provide much better performance.
+ * @note This operation may be very expensive. If you are grouping in order to perform an
+ * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey`
+ * or `PairRDDFunctions.reduceByKey` will provide much better performance.
*/
def groupBy[K](f: T => K)(implicit kt: ClassTag[K]): RDD[(K, Iterable[T])] = withScope {
groupBy[K](f, defaultPartitioner(this))
@@ -619,9 +696,9 @@ abstract class RDD[T: ClassTag](
* mapping to that key. The ordering of elements within each group is not guaranteed, and
* may even differ each time the resulting RDD is evaluated.
*
- * Note: This operation may be very expensive. If you are grouping in order to perform an
- * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]]
- * or [[PairRDDFunctions.reduceByKey]] will provide much better performance.
+ * @note This operation may be very expensive. If you are grouping in order to perform an
+ * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey`
+ * or `PairRDDFunctions.reduceByKey` will provide much better performance.
*/
def groupBy[K](
f: T => K,
@@ -634,9 +711,9 @@ abstract class RDD[T: ClassTag](
* mapping to that key. The ordering of elements within each group is not guaranteed, and
* may even differ each time the resulting RDD is evaluated.
*
- * Note: This operation may be very expensive. If you are grouping in order to perform an
- * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]]
- * or [[PairRDDFunctions.reduceByKey]] will provide much better performance.
+ * @note This operation may be very expensive. If you are grouping in order to perform an
+ * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey`
+ * or `PairRDDFunctions.reduceByKey` will provide much better performance.
*/
def groupBy[K](f: T => K, p: Partitioner)(implicit kt: ClassTag[K], ord: Ordering[K] = null)
: RDD[(K, Iterable[T])] = withScope {
@@ -648,18 +725,28 @@ abstract class RDD[T: ClassTag](
* Return an RDD created by piping elements to a forked external process.
*/
def pipe(command: String): RDD[String] = withScope {
- new PipedRDD(this, command)
+ // Similar to Runtime.exec(), if we are given a single string, split it into words
+ // using a standard StringTokenizer (i.e. by spaces)
+ pipe(PipedRDD.tokenize(command))
}
/**
* Return an RDD created by piping elements to a forked external process.
*/
def pipe(command: String, env: Map[String, String]): RDD[String] = withScope {
- new PipedRDD(this, command, env)
+ // Similar to Runtime.exec(), if we are given a single string, split it into words
+ // using a standard StringTokenizer (i.e. by spaces)
+ pipe(PipedRDD.tokenize(command), env)
}
/**
- * Return an RDD created by piping elements to a forked external process.
+ * Return an RDD created by piping elements to a forked external process. The resulting RDD
+ * is computed by executing the given process once per partition. All elements
+ * of each input partition are written to a process's stdin as lines of input separated
+ * by a newline. The resulting partition consists of the process's stdout output, with
+ * each line of stdout resulting in one element of the output partition. A process is invoked
+ * even for empty partitions.
+ *
* The print behavior can be customized by providing two functions.
*
* @param command command to run in forked process.
@@ -672,9 +759,14 @@ abstract class RDD[T: ClassTag](
* print line function (like out.println()) as the 2nd parameter.
* An example of pipe the RDD data of groupBy() in a streaming way,
* instead of constructing a huge String to concat all the elements:
- * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) =
- * for (e <- record._2){f(e)}
+ * {{{
+ * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) =
+ * for (e <- record._2) {f(e)}
+ * }}}
* @param separateWorkingDir Use separate working directories for each task.
+ * @param bufferSize Buffer size for the stdin writer for the piped process.
+ * @param encoding Char encoding used for interacting (via stdin, stdout and stderr) with
+ * the piped process
* @return the result RDD
*/
def pipe(
@@ -682,11 +774,15 @@ abstract class RDD[T: ClassTag](
env: Map[String, String] = Map(),
printPipeContext: (String => Unit) => Unit = null,
printRDDElement: (T, String => Unit) => Unit = null,
- separateWorkingDir: Boolean = false): RDD[String] = withScope {
+ separateWorkingDir: Boolean = false,
+ bufferSize: Int = 8192,
+ encoding: String = Codec.defaultCharsetCodec.name): RDD[String] = withScope {
new PipedRDD(this, command, env,
if (printPipeContext ne null) sc.clean(printPipeContext) else null,
if (printRDDElement ne null) sc.clean(printRDDElement) else null,
- separateWorkingDir)
+ separateWorkingDir,
+ bufferSize,
+ encoding)
}
/**
@@ -706,113 +802,55 @@ abstract class RDD[T: ClassTag](
}
/**
- * Return a new RDD by applying a function to each partition of this RDD, while tracking the index
- * of the original partition.
+ * [performance] Spark's internal mapPartitionsWithIndex method that skips closure cleaning.
+ * It is a performance API to be used carefully only if we are sure that the RDD elements are
+ * serializable and don't require closure cleaning.
*
- * `preservesPartitioning` indicates whether the input function preserves the partitioner, which
- * should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
- */
- def mapPartitionsWithIndex[U: ClassTag](
+ * @param preservesPartitioning indicates whether the input function preserves the partitioner,
+ * which should be `false` unless this is a pair RDD and the input
+ * function doesn't modify the keys.
+ * @param isOrderSensitive whether or not the function is order-sensitive. If it's order
+ * sensitive, it may return totally different result when the input order
+ * is changed. Mostly stateful functions are order-sensitive.
+ */
+ private[spark] def mapPartitionsWithIndexInternal[U: ClassTag](
f: (Int, Iterator[T]) => Iterator[U],
- preservesPartitioning: Boolean = false): RDD[U] = withScope {
- val cleanedF = sc.clean(f)
+ preservesPartitioning: Boolean = false,
+ isOrderSensitive: Boolean = false): RDD[U] = withScope {
new MapPartitionsRDD(
this,
- (context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(index, iter),
- preservesPartitioning)
+ (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter),
+ preservesPartitioning = preservesPartitioning,
+ isOrderSensitive = isOrderSensitive)
}
/**
- * :: DeveloperApi ::
- * Return a new RDD by applying a function to each partition of this RDD. This is a variant of
- * mapPartitions that also passes the TaskContext into the closure.
- *
- * `preservesPartitioning` indicates whether the input function preserves the partitioner, which
- * should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
+ * [performance] Spark's internal mapPartitions method that skips closure cleaning.
*/
- @DeveloperApi
- @deprecated("use TaskContext.get", "1.2.0")
- def mapPartitionsWithContext[U: ClassTag](
- f: (TaskContext, Iterator[T]) => Iterator[U],
+ private[spark] def mapPartitionsInternal[U: ClassTag](
+ f: Iterator[T] => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = withScope {
- val cleanF = sc.clean(f)
- val func = (context: TaskContext, index: Int, iter: Iterator[T]) => cleanF(context, iter)
- new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
+ new MapPartitionsRDD(
+ this,
+ (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter),
+ preservesPartitioning)
}
/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
+ *
+ * `preservesPartitioning` indicates whether the input function preserves the partitioner, which
+ * should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
*/
- @deprecated("use mapPartitionsWithIndex", "0.7.0")
- def mapPartitionsWithSplit[U: ClassTag](
+ def mapPartitionsWithIndex[U: ClassTag](
f: (Int, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = withScope {
- mapPartitionsWithIndex(f, preservesPartitioning)
- }
-
- /**
- * Maps f over this RDD, where f takes an additional parameter of type A. This
- * additional parameter is produced by constructA, which is called in each
- * partition with the index of that partition.
- */
- @deprecated("use mapPartitionsWithIndex", "1.0.0")
- def mapWith[A, U: ClassTag]
- (constructA: Int => A, preservesPartitioning: Boolean = false)
- (f: (T, A) => U): RDD[U] = withScope {
- val cleanF = sc.clean(f)
- val cleanA = sc.clean(constructA)
- mapPartitionsWithIndex((index, iter) => {
- val a = cleanA(index)
- iter.map(t => cleanF(t, a))
- }, preservesPartitioning)
- }
-
- /**
- * FlatMaps f over this RDD, where f takes an additional parameter of type A. This
- * additional parameter is produced by constructA, which is called in each
- * partition with the index of that partition.
- */
- @deprecated("use mapPartitionsWithIndex and flatMap", "1.0.0")
- def flatMapWith[A, U: ClassTag]
- (constructA: Int => A, preservesPartitioning: Boolean = false)
- (f: (T, A) => Seq[U]): RDD[U] = withScope {
- val cleanF = sc.clean(f)
- val cleanA = sc.clean(constructA)
- mapPartitionsWithIndex((index, iter) => {
- val a = cleanA(index)
- iter.flatMap(t => cleanF(t, a))
- }, preservesPartitioning)
- }
-
- /**
- * Applies f to each element of this RDD, where f takes an additional parameter of type A.
- * This additional parameter is produced by constructA, which is called in each
- * partition with the index of that partition.
- */
- @deprecated("use mapPartitionsWithIndex and foreach", "1.0.0")
- def foreachWith[A](constructA: Int => A)(f: (T, A) => Unit): Unit = withScope {
- val cleanF = sc.clean(f)
- val cleanA = sc.clean(constructA)
- mapPartitionsWithIndex { (index, iter) =>
- val a = cleanA(index)
- iter.map(t => {cleanF(t, a); t})
- }
- }
-
- /**
- * Filters this RDD with p, where p takes an additional parameter of type A. This
- * additional parameter is produced by constructA, which is called in each
- * partition with the index of that partition.
- */
- @deprecated("use mapPartitionsWithIndex and filter", "1.0.0")
- def filterWith[A](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = withScope {
- val cleanP = sc.clean(p)
- val cleanA = sc.clean(constructA)
- mapPartitionsWithIndex((index, iter) => {
- val a = cleanA(index)
- iter.filter(t => cleanP(t, a))
- }, preservesPartitioning = true)
+ val cleanedF = sc.clean(f)
+ new MapPartitionsRDD(
+ this,
+ (context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(index, iter),
+ preservesPartitioning)
}
/**
@@ -898,6 +936,9 @@ abstract class RDD[T: ClassTag](
/**
* Return an array that contains all of the elements in this RDD.
+ *
+ * @note This method should only be used if the resulting array is expected to be small, as
+ * all the data is loaded into the driver's memory.
*/
def collect(): Array[T] = withScope {
val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
@@ -909,7 +950,7 @@ abstract class RDD[T: ClassTag](
*
* The iterator will consume as much memory as the largest partition in this RDD.
*
- * Note: this results in multiple Spark jobs, and if the input RDD is the result
+ * @note This results in multiple Spark jobs, and if the input RDD is the result
* of a wide transformation (e.g. join with different partitioners), to avoid
* recomputing the input RDD should be cached first.
*/
@@ -920,14 +961,6 @@ abstract class RDD[T: ClassTag](
(0 until partitions.length).iterator.flatMap(i => collectPartition(i))
}
- /**
- * Return an array that contains all of the elements in this RDD.
- */
- @deprecated("use collect", "1.0.0")
- def toArray(): Array[T] = withScope {
- collect()
- }
-
/**
* Return an RDD that contains all matching values by applying `f`.
*/
@@ -1037,7 +1070,7 @@ abstract class RDD[T: ClassTag](
/**
* Aggregate the elements of each partition, and then the results for all the partitions, using a
- * given associative and commutative function and a neutral "zero value". The function
+ * given associative function and a neutral "zero value". The function
* op(t1, t2) is allowed to modify t1 and return it as its result value to avoid object
* allocation; however, it should not modify t2.
*
@@ -1047,6 +1080,13 @@ abstract class RDD[T: ClassTag](
* apply the fold to each element sequentially in some defined ordering. For functions
* that are not commutative, the result may differ from that of a fold applied to a
* non-distributed collection.
+ *
+ * @param zeroValue the initial value for the accumulated result of each partition for the `op`
+ * operator, and also the initial value for the combine results from different
+ * partitions for the `op` operator - this will typically be the neutral
+ * element (e.g. `Nil` for list concatenation or `0` for summation)
+ * @param op an operator used to both accumulate results within a partition and combine results
+ * from different partitions
*/
def fold(zeroValue: T)(op: (T, T) => T): T = withScope {
// Clone the zero value since we will also be serializing it as part of tasks
@@ -1065,6 +1105,13 @@ abstract class RDD[T: ClassTag](
* and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are
* allowed to modify and return their first argument instead of creating a new U to avoid memory
* allocation.
+ *
+ * @param zeroValue the initial value for the accumulated result of each partition for the
+ * `seqOp` operator, and also the initial value for the combine results from
+ * different partitions for the `combOp` operator - this will typically be the
+ * neutral element (e.g. `Nil` for list concatenation or `0` for summation)
+ * @param seqOp an operator used to accumulate results within a partition
+ * @param combOp an associative operator used to combine results from different partitions
*/
def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = withScope {
// Clone the zero value since we will also be serializing it as part of tasks
@@ -1121,10 +1168,21 @@ abstract class RDD[T: ClassTag](
/**
* Approximate version of count() that returns a potentially incomplete result
* within a timeout, even if not all tasks have finished.
+ *
+ * The confidence is the probability that the error bounds of the result will
+ * contain the true value. That is, if countApprox were called repeatedly
+ * with confidence 0.9, we would expect 90% of the results to contain the
+ * true count. The confidence must be in the range [0,1] or an exception will
+ * be thrown.
+ *
+ * @param timeout maximum time to wait for the job, in milliseconds
+ * @param confidence the desired statistical confidence in the result
+ * @return a potentially incomplete result, with error bounds
*/
def countApprox(
timeout: Long,
confidence: Double = 0.95): PartialResult[BoundedDouble] = withScope {
+ require(0.0 <= confidence && confidence <= 1.0, s"confidence ($confidence) must be in [0,1]")
val countElements: (TaskContext, Iterator[T]) => Long = { (ctx, iter) =>
var result = 0L
while (iter.hasNext) {
@@ -1140,10 +1198,15 @@ abstract class RDD[T: ClassTag](
/**
* Return the count of each unique value in this RDD as a local map of (value, count) pairs.
*
- * Note that this method should only be used if the resulting map is expected to be small, as
+ * @note This method should only be used if the resulting map is expected to be small, as
* the whole thing is loaded into the driver's memory.
- * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which
- * returns an RDD[T, Long] instead of a map.
+ * To handle very large results, consider using
+ *
+ * {{{
+ * rdd.map(x => (x, 1L)).reduceByKey(_ + _)
+ * }}}
+ *
+ * , which returns an RDD[T, Long] instead of a map.
*/
def countByValue()(implicit ord: Ordering[T] = null): Map[T, Long] = withScope {
map(value => (value, null)).countByKey()
@@ -1151,10 +1214,15 @@ abstract class RDD[T: ClassTag](
/**
* Approximate version of countByValue().
+ *
+ * @param timeout maximum time to wait for the job, in milliseconds
+ * @param confidence the desired statistical confidence in the result
+ * @return a potentially incomplete result, with error bounds
*/
def countByValueApprox(timeout: Long, confidence: Double = 0.95)
(implicit ord: Ordering[T] = null)
: PartialResult[Map[T, BoundedDouble]] = withScope {
+ require(0.0 <= confidence && confidence <= 1.0, s"confidence ($confidence) must be in [0,1]")
if (elementClassTag.runtimeClass.isArray) {
throw new SparkException("countByValueApprox() does not support arrays")
}
@@ -1176,9 +1244,9 @@ abstract class RDD[T: ClassTag](
* Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available
* here.
*
- * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero `sp > p`
- * would trigger sparse representation of registers, which may reduce the memory consumption
- * and increase accuracy when the cardinality is small.
+ * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero (`sp` is greater
+ * than `p`) would trigger sparse representation of registers, which may reduce the memory
+ * consumption and increase accuracy when the cardinality is small.
*
* @param p The precision value for the normal set.
* `p` must be a value between 4 and `sp` if `sp` is not zero (32 max).
@@ -1225,7 +1293,7 @@ abstract class RDD[T: ClassTag](
* This is similar to Scala's zipWithIndex but it uses Long instead of Int as the index type.
* This method needs to trigger a spark job when this RDD contains more than one partitions.
*
- * Note that some RDDs, such as those returned by groupBy(), do not guarantee order of
+ * @note Some RDDs, such as those returned by groupBy(), do not guarantee order of
* elements in a partition. The index assigned to each element is therefore not guaranteed,
* and may even change if the RDD is reevaluated. If a fixed ordering is required to guarantee
* the same index assignments, you should sort the RDD with sortByKey() or save it to a file.
@@ -1239,7 +1307,7 @@ abstract class RDD[T: ClassTag](
* 2*n+k, ..., where n is the number of partitions. So there may exist gaps, but this method
* won't trigger a spark job, which is different from [[org.apache.spark.rdd.RDD#zipWithIndex]].
*
- * Note that some RDDs, such as those returned by groupBy(), do not guarantee order of
+ * @note Some RDDs, such as those returned by groupBy(), do not guarantee order of
* elements in a partition. The unique ID assigned to each element is therefore not guaranteed,
* and may even change if the RDD is reevaluated. If a fixed ordering is required to guarantee
* the same index assignments, you should sort the RDD with sortByKey() or save it to a file.
@@ -1247,7 +1315,7 @@ abstract class RDD[T: ClassTag](
def zipWithUniqueId(): RDD[(T, Long)] = withScope {
val n = this.partitions.length.toLong
this.mapPartitionsWithIndex { case (k, iter) =>
- iter.zipWithIndex.map { case (item, i) =>
+ Utils.getIteratorZipWithIndex(iter, 0L).map { case (item, i) =>
(item, i * n + k)
}
}
@@ -1258,10 +1326,14 @@ abstract class RDD[T: ClassTag](
* results from that partition to estimate the number of additional partitions needed to satisfy
* the limit.
*
- * @note due to complications in the internal implementation, this method will raise
+ * @note This method should only be used if the resulting array is expected to be small, as
+ * all the data is loaded into the driver's memory.
+ *
+ * @note Due to complications in the internal implementation, this method will raise
* an exception if called on an RDD of `Nothing` or `Null`.
*/
def take(num: Int): Array[T] = withScope {
+ val scaleUpFactor = Math.max(conf.getInt("spark.rdd.limit.scaleUpFactor", 4), 2)
if (num == 0) {
new Array[T](0)
} else {
@@ -1271,26 +1343,26 @@ abstract class RDD[T: ClassTag](
while (buf.size < num && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
- var numPartsToTry = 1
+ var numPartsToTry = 1L
if (partsScanned > 0) {
// If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate
// it by 50%. We also cap the estimation in the end.
- if (buf.size == 0) {
- numPartsToTry = partsScanned * 4
+ if (buf.isEmpty) {
+ numPartsToTry = partsScanned * scaleUpFactor
} else {
// the left side of max is >=1 whenever partsScanned >= 2
numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1)
- numPartsToTry = Math.min(numPartsToTry, partsScanned * 4)
+ numPartsToTry = Math.min(numPartsToTry, partsScanned * scaleUpFactor)
}
}
val left = num - buf.size
- val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
+ val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p)
res.foreach(buf ++= _.take(num - buf.size))
- partsScanned += numPartsToTry
+ partsScanned += p.size
}
buf.toArray
@@ -1309,7 +1381,8 @@ abstract class RDD[T: ClassTag](
/**
* Returns the top k (largest) elements from this RDD as defined by the specified
- * implicit Ordering[T]. This does the opposite of [[takeOrdered]]. For example:
+ * implicit Ordering[T] and maintains the ordering. This does the opposite of
+ * [[takeOrdered]]. For example:
* {{{
* sc.parallelize(Seq(10, 4, 2, 12, 3)).top(1)
* // returns Array(12)
@@ -1318,6 +1391,9 @@ abstract class RDD[T: ClassTag](
* // returns Array(6, 5)
* }}}
*
+ * @note This method should only be used if the resulting array is expected to be small, as
+ * all the data is loaded into the driver's memory.
+ *
* @param num k, the number of top elements to return
* @param ord the implicit ordering for T
* @return an array of top elements
@@ -1338,6 +1414,9 @@ abstract class RDD[T: ClassTag](
* // returns Array(2, 3)
* }}}
*
+ * @note This method should only be used if the resulting array is expected to be small, as
+ * all the data is loaded into the driver's memory.
+ *
* @param num k, the number of elements to return
* @param ord the implicit ordering for T
* @return an array of top elements
@@ -1349,7 +1428,7 @@ abstract class RDD[T: ClassTag](
val mapRDDs = mapPartitions { items =>
// Priority keeps the largest elements, so let's reverse the ordering.
val queue = new BoundedPriorityQueue[T](num)(ord.reverse)
- queue ++= util.collection.Utils.takeOrdered(items, num)(ord)
+ queue ++= collectionUtils.takeOrdered(items, num)(ord)
Iterator.single(queue)
}
if (mapRDDs.partitions.length == 0) {
@@ -1380,7 +1459,7 @@ abstract class RDD[T: ClassTag](
}
/**
- * @note due to complications in the internal implementation, this method will raise an
+ * @note Due to complications in the internal implementation, this method will raise an
* exception if called on an RDD of `Nothing` or `Null`. This may be come up in practice
* because, for example, the type of `parallelize(Seq())` is `RDD[Nothing]`.
* (`parallelize(Seq())` should be avoided anyway in favor of `parallelize(Seq[T]())`.)
@@ -1540,14 +1619,15 @@ abstract class RDD[T: ClassTag](
/**
* Return whether this RDD is checkpointed and materialized, either reliably or locally.
*/
- def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed)
+ def isCheckpointed: Boolean = isCheckpointedAndMaterialized
/**
* Return whether this RDD is checkpointed and materialized, either reliably or locally.
* This is introduced as an alias for `isCheckpointed` to clarify the semantics of the
* return value. Exposed for testing.
*/
- private[spark] def isCheckpointedAndMaterialized: Boolean = isCheckpointed
+ private[spark] def isCheckpointedAndMaterialized: Boolean =
+ checkpointData.exists(_.isCheckpointed)
/**
* Return whether this RDD is marked for local checkpointing.
@@ -1560,6 +1640,16 @@ abstract class RDD[T: ClassTag](
}
}
+ /**
+ * Return whether this RDD is reliably checkpointed and materialized.
+ */
+ private[rdd] def isReliablyCheckpointed: Boolean = {
+ checkpointData match {
+ case Some(reliable: ReliableRDDCheckpointData[_]) if reliable.isCheckpointed => true
+ case _ => false
+ }
+ }
+
/**
* Gets the name of the directory to which this RDD was checkpointed.
* This is not defined if the RDD is checkpointed locally.
@@ -1597,6 +1687,15 @@ abstract class RDD[T: ClassTag](
private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None
+ // Whether to checkpoint all ancestor RDDs that are marked for checkpointing. By default,
+ // we stop as soon as we find the first such RDD, an optimization that allows us to write
+ // less data but is not safe for all workloads. E.g. in streaming we may checkpoint both
+ // an RDD and its parent in every batch, in which case the parent may never be checkpointed
+ // and its lineage never truncated, leading to OOMs in the long run (SPARK-6847).
+ private val checkpointAllMarkedAncestors =
+ Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS))
+ .map(_.toBoolean).getOrElse(false)
+
/** Returns the first parent RDD */
protected[spark] def firstParent[U: ClassTag]: RDD[U] = {
dependencies.head.rdd.asInstanceOf[RDD[U]]
@@ -1640,6 +1739,13 @@ abstract class RDD[T: ClassTag](
if (!doCheckpointCalled) {
doCheckpointCalled = true
if (checkpointData.isDefined) {
+ if (checkpointAllMarkedAncestors) {
+ // TODO We can collect all the RDDs that needs to be checkpointed, and then checkpoint
+ // them in parallel.
+ // Checkpoint parents first because our lineage will be truncated after we
+ // checkpoint ourselves
+ dependencies.foreach(_.rdd.doCheckpoint())
+ }
checkpointData.get.checkpoint()
} else {
dependencies.foreach(_.rdd.doCheckpoint())
@@ -1660,7 +1766,7 @@ abstract class RDD[T: ClassTag](
/**
* Clears the dependencies of this RDD. This method must ensure that all references
- * to the original parent RDDs is removed to enable the parent RDDs to be garbage
+ * to the original parent RDDs are removed to enable the parent RDDs to be garbage
* collected. Subclasses of RDD may override this method for implementing their own cleaning
* logic. See [[org.apache.spark.rdd.UnionRDD]] for an example.
*/
@@ -1748,6 +1854,63 @@ abstract class RDD[T: ClassTag](
def toJavaRDD() : JavaRDD[T] = {
new JavaRDD(this)(elementClassTag)
}
+
+ /**
+ * Returns the deterministic level of this RDD's output. Please refer to [[DeterministicLevel]]
+ * for the definition.
+ *
+ * By default, an reliably checkpointed RDD, or RDD without parents(root RDD) is DETERMINATE. For
+ * RDDs with parents, we will generate a deterministic level candidate per parent according to
+ * the dependency. The deterministic level of the current RDD is the deterministic level
+ * candidate that is deterministic least. Please override [[getOutputDeterministicLevel]] to
+ * provide custom logic of calculating output deterministic level.
+ */
+ // TODO: make it public so users can set deterministic level to their custom RDDs.
+ // TODO: this can be per-partition. e.g. UnionRDD can have different deterministic level for
+ // different partitions.
+ private[spark] final lazy val outputDeterministicLevel: DeterministicLevel.Value = {
+ if (isReliablyCheckpointed) {
+ DeterministicLevel.DETERMINATE
+ } else {
+ getOutputDeterministicLevel
+ }
+ }
+
+ @DeveloperApi
+ protected def getOutputDeterministicLevel: DeterministicLevel.Value = {
+ val deterministicLevelCandidates = dependencies.map {
+ // The shuffle is not really happening, treat it like narrow dependency and assume the output
+ // deterministic level of current RDD is same as parent.
+ case dep: ShuffleDependency[_, _, _] if dep.rdd.partitioner.exists(_ == dep.partitioner) =>
+ dep.rdd.outputDeterministicLevel
+
+ case dep: ShuffleDependency[_, _, _] =>
+ if (dep.rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE) {
+ // If map output was indeterminate, shuffle output will be indeterminate as well
+ DeterministicLevel.INDETERMINATE
+ } else if (dep.keyOrdering.isDefined && dep.aggregator.isDefined) {
+ // if aggregator specified (and so unique keys) and key ordering specified - then
+ // consistent ordering.
+ DeterministicLevel.DETERMINATE
+ } else {
+ // In Spark, the reducer fetches multiple remote shuffle blocks at the same time, and
+ // the arrival order of these shuffle blocks are totally random. Even if the parent map
+ // RDD is DETERMINATE, the reduce RDD is always UNORDERED.
+ DeterministicLevel.UNORDERED
+ }
+
+ // For narrow dependency, assume the output deterministic level of current RDD is same as
+ // parent.
+ case dep => dep.rdd.outputDeterministicLevel
+ }
+
+ if (deterministicLevelCandidates.isEmpty) {
+ // By default we assume the root RDD is determinate.
+ DeterministicLevel.DETERMINATE
+ } else {
+ deterministicLevelCandidates.maxBy(_.id)
+ }
+ }
}
@@ -1755,10 +1918,13 @@ abstract class RDD[T: ClassTag](
* Defines implicit functions that provide extra functionalities on RDDs of specific types.
*
* For example, [[RDD.rddToPairRDDFunctions]] converts an RDD into a [[PairRDDFunctions]] for
- * key-value-pair RDDs, and enabling extra functionalities such as [[PairRDDFunctions.reduceByKey]].
+ * key-value-pair RDDs, and enabling extra functionalities such as `PairRDDFunctions.reduceByKey`.
*/
object RDD {
+ private[spark] val CHECKPOINT_ALL_MARKED_ANCESTORS =
+ "spark.checkpoint.checkpointAllMarkedAncestors"
+
// The following implicit functions were in SparkContext before 1.3 and users had to
// `import SparkContext._` to enable them. Now we move them here to make the compiler find
// them automatically. However, we still keep the old functions in SparkContext for backward
@@ -1798,3 +1964,18 @@ object RDD {
new DoubleRDDFunctions(rdd.map(x => num.toDouble(x)))
}
}
+
+/**
+ * The deterministic level of RDD's output (i.e. what `RDD#compute` returns). This explains how
+ * the output will diff when Spark reruns the tasks for the RDD. There are 3 deterministic levels:
+ * 1. DETERMINATE: The RDD output is always the same data set in the same order after a rerun.
+ * 2. UNORDERED: The RDD output is always the same data set but the order can be different
+ * after a rerun.
+ * 3. INDETERMINATE. The RDD output can be different after a rerun.
+ *
+ * Note that, the output of an RDD usually relies on the parent RDDs. When the parent RDD's output
+ * is INDETERMINATE, it's very likely the RDD's output is also INDETERMINATE.
+ */
+private[spark] object DeterministicLevel extends Enumeration {
+ val DETERMINATE, UNORDERED, INDETERMINATE = Value
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
index 429514b4f6be..6c552d4d1251 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
@@ -23,7 +23,8 @@ import org.apache.spark.Partition
/**
* Enumeration to manage state transitions of an RDD through checkpointing
- * [ Initialized --> checkpointing in progress --> checkpointed ].
+ *
+ * [ Initialized --{@literal >} checkpointing in progress --{@literal >} checkpointed ]
*/
private[spark] object CheckpointState extends Enumeration {
type CheckpointState = Value
@@ -32,7 +33,7 @@ private[spark] object CheckpointState extends Enumeration {
/**
* This class contains all the information related to RDD checkpointing. Each instance of this
- * class is associated with a RDD. It manages process of checkpointing of the associated RDD,
+ * class is associated with an RDD. It manages process of checkpointing of the associated RDD,
* as well as, manages the post-checkpoint state by providing the updated partitions,
* iterator and preferred locations of the checkpointed RDD.
*/
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala
index 540cbd688b63..53d69ba26811 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala
@@ -25,7 +25,8 @@ import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import com.google.common.base.Objects
-import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.SparkContext
+import org.apache.spark.internal.Logging
/**
* A general, named code block representing an operation that instantiates RDDs.
diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
index a69be6a068bb..37c67cee55f9 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
@@ -17,15 +17,19 @@
package org.apache.spark.rdd
-import java.io.IOException
+import java.io.{FileNotFoundException, IOException}
+import java.util.concurrent.TimeUnit
import scala.reflect.ClassTag
+import scala.util.control.NonFatal
import org.apache.hadoop.fs.Path
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.CHECKPOINT_COMPRESS
+import org.apache.spark.io.CompressionCodec
import org.apache.spark.util.{SerializableConfiguration, Utils}
/**
@@ -33,8 +37,9 @@ import org.apache.spark.util.{SerializableConfiguration, Utils}
*/
private[spark] class ReliableCheckpointRDD[T: ClassTag](
sc: SparkContext,
- val checkpointPath: String)
- extends CheckpointRDD[T](sc) {
+ val checkpointPath: String,
+ _partitioner: Option[Partitioner] = None
+ ) extends CheckpointRDD[T](sc) {
@transient private val hadoopConf = sc.hadoopConfiguration
@transient private val cpath = new Path(checkpointPath)
@@ -47,7 +52,13 @@ private[spark] class ReliableCheckpointRDD[T: ClassTag](
/**
* Return the path of the checkpoint directory this RDD reads data from.
*/
- override def getCheckpointFile: Option[String] = Some(checkpointPath)
+ override val getCheckpointFile: Option[String] = Some(checkpointPath)
+
+ override val partitioner: Option[Partitioner] = {
+ _partitioner.orElse {
+ ReliableCheckpointRDD.readCheckpointedPartitionerFile(context, checkpointPath)
+ }
+ }
/**
* Return partitions described by the files in the checkpoint directory.
@@ -61,10 +72,10 @@ private[spark] class ReliableCheckpointRDD[T: ClassTag](
val inputFiles = fs.listStatus(cpath)
.map(_.getPath)
.filter(_.getName.startsWith("part-"))
- .sortBy(_.toString)
+ .sortBy(_.getName.stripPrefix("part-").toInt)
// Fail fast if input files are invalid
inputFiles.zipWithIndex.foreach { case (path, i) =>
- if (!path.toString.endsWith(ReliableCheckpointRDD.checkpointFileName(i))) {
+ if (path.getName != ReliableCheckpointRDD.checkpointFileName(i)) {
throw new SparkException(s"Invalid checkpoint file: $path")
}
}
@@ -100,10 +111,57 @@ private[spark] object ReliableCheckpointRDD extends Logging {
"part-%05d".format(partitionIndex)
}
+ private def checkpointPartitionerFileName(): String = {
+ "_partitioner"
+ }
+
+ /**
+ * Write RDD to checkpoint files and return a ReliableCheckpointRDD representing the RDD.
+ */
+ def writeRDDToCheckpointDirectory[T: ClassTag](
+ originalRDD: RDD[T],
+ checkpointDir: String,
+ blockSize: Int = -1): ReliableCheckpointRDD[T] = {
+ val checkpointStartTimeNs = System.nanoTime()
+
+ val sc = originalRDD.sparkContext
+
+ // Create the output path for the checkpoint
+ val checkpointDirPath = new Path(checkpointDir)
+ val fs = checkpointDirPath.getFileSystem(sc.hadoopConfiguration)
+ if (!fs.mkdirs(checkpointDirPath)) {
+ throw new SparkException(s"Failed to create checkpoint path $checkpointDirPath")
+ }
+
+ // Save to file, and reload it as an RDD
+ val broadcastedConf = sc.broadcast(
+ new SerializableConfiguration(sc.hadoopConfiguration))
+ // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
+ sc.runJob(originalRDD,
+ writePartitionToCheckpointFile[T](checkpointDirPath.toString, broadcastedConf) _)
+
+ if (originalRDD.partitioner.nonEmpty) {
+ writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath)
+ }
+
+ val checkpointDurationMs =
+ TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - checkpointStartTimeNs)
+ logInfo(s"Checkpointing took $checkpointDurationMs ms.")
+
+ val newRDD = new ReliableCheckpointRDD[T](
+ sc, checkpointDirPath.toString, originalRDD.partitioner)
+ if (newRDD.partitions.length != originalRDD.partitions.length) {
+ throw new SparkException(
+ s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " +
+ s"number of partitions from original RDD $originalRDD(${originalRDD.partitions.length})")
+ }
+ newRDD
+ }
+
/**
- * Write this partition's values to a checkpoint file.
+ * Write an RDD partition's data to a checkpoint file.
*/
- def writeCheckpointFile[T: ClassTag](
+ def writePartitionToCheckpointFile[T: ClassTag](
path: String,
broadcastedConf: Broadcast[SerializableConfiguration],
blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
@@ -116,16 +174,19 @@ private[spark] object ReliableCheckpointRDD extends Logging {
val tempOutputPath =
new Path(outputDir, s".$finalOutputName-attempt-${ctx.attemptNumber()}")
- if (fs.exists(tempOutputPath)) {
- throw new IOException(s"Checkpoint failed: temporary path $tempOutputPath already exists")
- }
val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
val fileOutputStream = if (blockSize < 0) {
- fs.create(tempOutputPath, false, bufferSize)
+ val fileStream = fs.create(tempOutputPath, false, bufferSize)
+ if (env.conf.get(CHECKPOINT_COMPRESS)) {
+ CompressionCodec.createCodec(env.conf).compressedOutputStream(fileStream)
+ } else {
+ fileStream
+ }
} else {
// This is mainly for testing purpose
- fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize)
+ fs.create(tempOutputPath, false, bufferSize,
+ fs.getDefaultReplication(fs.getWorkingDirectory), blockSize)
}
val serializer = env.serializer.newInstance()
val serializeStream = serializer.serializeStream(fileOutputStream)
@@ -151,6 +212,70 @@ private[spark] object ReliableCheckpointRDD extends Logging {
}
}
+ /**
+ * Write a partitioner to the given RDD checkpoint directory. This is done on a best-effort
+ * basis; any exception while writing the partitioner is caught, logged and ignored.
+ */
+ private def writePartitionerToCheckpointDir(
+ sc: SparkContext, partitioner: Partitioner, checkpointDirPath: Path): Unit = {
+ try {
+ val partitionerFilePath = new Path(checkpointDirPath, checkpointPartitionerFileName)
+ val bufferSize = sc.conf.getInt("spark.buffer.size", 65536)
+ val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration)
+ val fileOutputStream = fs.create(partitionerFilePath, false, bufferSize)
+ val serializer = SparkEnv.get.serializer.newInstance()
+ val serializeStream = serializer.serializeStream(fileOutputStream)
+ Utils.tryWithSafeFinally {
+ serializeStream.writeObject(partitioner)
+ } {
+ serializeStream.close()
+ }
+ logDebug(s"Written partitioner to $partitionerFilePath")
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Error writing partitioner $partitioner to $checkpointDirPath")
+ }
+ }
+
+
+ /**
+ * Read a partitioner from the given RDD checkpoint directory, if it exists.
+ * This is done on a best-effort basis; any exception while reading the partitioner is
+ * caught, logged and ignored.
+ */
+ private def readCheckpointedPartitionerFile(
+ sc: SparkContext,
+ checkpointDirPath: String): Option[Partitioner] = {
+ try {
+ val bufferSize = sc.conf.getInt("spark.buffer.size", 65536)
+ val partitionerFilePath = new Path(checkpointDirPath, checkpointPartitionerFileName)
+ val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration)
+ val fileInputStream = fs.open(partitionerFilePath, bufferSize)
+ val serializer = SparkEnv.get.serializer.newInstance()
+ val partitioner = Utils.tryWithSafeFinally {
+ val deserializeStream = serializer.deserializeStream(fileInputStream)
+ Utils.tryWithSafeFinally {
+ deserializeStream.readObject[Partitioner]
+ } {
+ deserializeStream.close()
+ }
+ } {
+ fileInputStream.close()
+ }
+
+ logDebug(s"Read partitioner from $partitionerFilePath")
+ Some(partitioner)
+ } catch {
+ case e: FileNotFoundException =>
+ logDebug("No partitioner file", e)
+ None
+ case NonFatal(e) =>
+ logWarning(s"Error reading partitioner from $checkpointDirPath, " +
+ s"partitioner will not be recovered which may lead to performance loss", e)
+ None
+ }
+ }
+
/**
* Read the content of the specified checkpoint file.
*/
@@ -161,7 +286,14 @@ private[spark] object ReliableCheckpointRDD extends Logging {
val env = SparkEnv.get
val fs = path.getFileSystem(broadcastedConf.value.value)
val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
- val fileInputStream = fs.open(path, bufferSize)
+ val fileInputStream = {
+ val fileStream = fs.open(path, bufferSize)
+ if (env.conf.get(CHECKPOINT_COMPRESS)) {
+ CompressionCodec.createCodec(env.conf).compressedInputStream(fileStream)
+ } else {
+ fileStream
+ }
+ }
val serializer = env.serializer.newInstance()
val deserializeStream = serializer.deserializeStream(fileInputStream)
diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala
index 91cad6662e4d..b6d723c68279 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala
@@ -22,7 +22,7 @@ import scala.reflect.ClassTag
import org.apache.hadoop.fs.Path
import org.apache.spark._
-import org.apache.spark.util.SerializableConfiguration
+import org.apache.spark.internal.Logging
/**
* An implementation of checkpointing that writes the RDD data to reliable storage.
@@ -55,25 +55,7 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v
* This is called immediately after the first action invoked on this RDD has completed.
*/
protected override def doCheckpoint(): CheckpointRDD[T] = {
-
- // Create the output path for the checkpoint
- val path = new Path(cpDir)
- val fs = path.getFileSystem(rdd.context.hadoopConfiguration)
- if (!fs.mkdirs(path)) {
- throw new SparkException(s"Failed to create checkpoint path $cpDir")
- }
-
- // Save to file, and reload it as an RDD
- val broadcastedConf = rdd.context.broadcast(
- new SerializableConfiguration(rdd.context.hadoopConfiguration))
- // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
- rdd.context.runJob(rdd, ReliableCheckpointRDD.writeCheckpointFile[T](cpDir, broadcastedConf) _)
- val newRDD = new ReliableCheckpointRDD[T](rdd.context, cpDir)
- if (newRDD.partitions.length != rdd.partitions.length) {
- throw new SparkException(
- s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " +
- s"number of partitions from original RDD $rdd(${rdd.partitions.length})")
- }
+ val newRDD = ReliableCheckpointRDD.writeRDDToCheckpointDirectory(rdd, cpDir)
// Optionally clean our checkpoint files if the reference is out of scope
if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) {
@@ -83,7 +65,6 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v
}
logInfo(s"Done checkpointing RDD ${rdd.id} to $cpDir, new parent is RDD ${newRDD.id}")
-
newRDD
}
@@ -99,12 +80,7 @@ private[spark] object ReliableRDDCheckpointData extends Logging {
/** Clean up the files associated with the checkpoint data for this RDD. */
def cleanCheckpoint(sc: SparkContext, rddId: Int): Unit = {
checkpointPath(sc, rddId).foreach { path =>
- val fs = path.getFileSystem(sc.hadoopConfiguration)
- if (fs.exists(path)) {
- if (!fs.delete(path, true)) {
- logWarning(s"Error deleting ${path.toString()}")
- }
- }
+ path.getFileSystem(sc.hadoopConfiguration).delete(path, true)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
deleted file mode 100644
index 9e8cee5331cf..000000000000
--- a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
+++ /dev/null
@@ -1,71 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import java.util.Random
-
-import scala.reflect.ClassTag
-
-import org.apache.commons.math3.distribution.PoissonDistribution
-
-import org.apache.spark.{Partition, TaskContext}
-
-@deprecated("Replaced by PartitionwiseSampledRDDPartition", "1.0.0")
-private[spark]
-class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable {
- override val index: Int = prev.index
-}
-
-@deprecated("Replaced by PartitionwiseSampledRDD", "1.0.0")
-private[spark] class SampledRDD[T: ClassTag](
- prev: RDD[T],
- withReplacement: Boolean,
- frac: Double,
- seed: Int)
- extends RDD[T](prev) {
-
- override def getPartitions: Array[Partition] = {
- val rg = new Random(seed)
- firstParent[T].partitions.map(x => new SampledRDDPartition(x, rg.nextInt))
- }
-
- override def getPreferredLocations(split: Partition): Seq[String] =
- firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDPartition].prev)
-
- override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = {
- val split = splitIn.asInstanceOf[SampledRDDPartition]
- if (withReplacement) {
- // For large datasets, the expected number of occurrences of each element in a sample with
- // replacement is Poisson(frac). We use that to get a count for each element.
- val poisson = new PoissonDistribution(frac)
- poisson.reseedRandomGenerator(split.seed)
-
- firstParent[T].iterator(split.prev, context).flatMap { element =>
- val count = poisson.sample()
- if (count == 0) {
- Iterator.empty // Avoid object allocation when we return 0 items, which is quite often
- } else {
- Iterator.fill(count)(element)
- }
- }
- } else { // Sampling without replacement
- val rand = new Random(split.seed)
- firstParent[T].iterator(split.prev, context).filter(x => (rand.nextDouble <= frac))
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala
index 4b5f15dd06b8..86a332790fb0 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala
@@ -16,20 +16,21 @@
*/
package org.apache.spark.rdd
-import scala.reflect.{ClassTag, classTag}
+import scala.reflect.{classTag, ClassTag}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.SequenceFileOutputFormat
-import org.apache.spark.Logging
+import org.apache.spark.internal.Logging
/**
* Extra functions available on RDDs of (key, value) pairs to create a Hadoop SequenceFile,
- * through an implicit conversion. Note that this can't be part of PairRDDFunctions because
- * we need more implicit parameters to convert our keys and values to Writable.
+ * through an implicit conversion.
*
+ * @note This can't be part of PairRDDFunctions because we need more implicit parameters to
+ * convert our keys and values to Writable.
*/
class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag](
self: RDD[(K, V)],
@@ -38,11 +39,6 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag
extends Logging
with Serializable {
- @deprecated("It's used to provide backward compatibility for pre 1.3.0.", "1.3.0")
- def this(self: RDD[(K, V)]) {
- this(self, null, null)
- }
-
private val keyWritableClass =
if (_keyWritableClass == null) {
// pre 1.3.0, we need to use Reflection to get the Writable class
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index a013c3f66a3a..26eaa9aa3d03 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -25,7 +25,6 @@ import org.apache.spark.serializer.Serializer
private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
override val index: Int = idx
- override def hashCode(): Int = idx
}
/**
@@ -44,7 +43,7 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](
part: Partitioner)
extends RDD[(K, C)](prev.context, Nil) {
- private var serializer: Option[Serializer] = None
+ private var userSpecifiedSerializer: Option[Serializer] = None
private var keyOrdering: Option[Ordering[K]] = None
@@ -54,7 +53,7 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](
/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C] = {
- this.serializer = Option(serializer)
+ this.userSpecifiedSerializer = Option(serializer)
this
}
@@ -77,6 +76,14 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](
}
override def getDependencies: Seq[Dependency[_]] = {
+ val serializer = userSpecifiedSerializer.getOrElse {
+ val serializerManager = SparkEnv.get.serializerManager
+ if (mapSideCombine) {
+ serializerManager.getSerializer(implicitly[ClassTag[K]], implicitly[ClassTag[C]])
+ } else {
+ serializerManager.getSerializer(implicitly[ClassTag[K]], implicitly[ClassTag[V]])
+ }
+ }
List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine))
}
@@ -86,7 +93,7 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](
Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i))
}
- override def getPreferredLocations(partition: Partition): Seq[String] = {
+ override protected def getPreferredLocations(partition: Partition): Seq[String] = {
val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
tracker.getPreferredLocationsForShuffle(dep, partition.index)
diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala
deleted file mode 100644
index 264dae7f3908..000000000000
--- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala
+++ /dev/null
@@ -1,290 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import java.text.SimpleDateFormat
-import java.util.Date
-
-import scala.reflect.ClassTag
-
-import org.apache.hadoop.conf.{Configurable, Configuration}
-import org.apache.hadoop.io.Writable
-import org.apache.hadoop.mapreduce._
-import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit}
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.executor.DataReadMethod
-import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
-import org.apache.spark.unsafe.types.UTF8String
-import org.apache.spark.{Partition => SparkPartition, _}
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils}
-
-
-private[spark] class SqlNewHadoopPartition(
- rddId: Int,
- val index: Int,
- rawSplit: InputSplit with Writable)
- extends SparkPartition {
-
- val serializableHadoopSplit = new SerializableWritable(rawSplit)
-
- override def hashCode(): Int = 41 * (41 + rddId) + index
-}
-
-/**
- * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS,
- * sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`).
- * It is based on [[org.apache.spark.rdd.NewHadoopRDD]]. It has three additions.
- * 1. A shared broadcast Hadoop Configuration.
- * 2. An optional closure `initDriverSideJobFuncOpt` that set configurations at the driver side
- * to the shared Hadoop Configuration.
- * 3. An optional closure `initLocalJobFuncOpt` that set configurations at both the driver side
- * and the executor side to the shared Hadoop Configuration.
- *
- * Note: This is RDD is basically a cloned version of [[org.apache.spark.rdd.NewHadoopRDD]] with
- * changes based on [[org.apache.spark.rdd.HadoopRDD]].
- */
-private[spark] class SqlNewHadoopRDD[V: ClassTag](
- sc : SparkContext,
- broadcastedConf: Broadcast[SerializableConfiguration],
- @transient private val initDriverSideJobFuncOpt: Option[Job => Unit],
- initLocalJobFuncOpt: Option[Job => Unit],
- inputFormatClass: Class[_ <: InputFormat[Void, V]],
- valueClass: Class[V])
- extends RDD[V](sc, Nil)
- with SparkHadoopMapReduceUtil
- with Logging {
-
- protected def getJob(): Job = {
- val conf: Configuration = broadcastedConf.value.value
- // "new Job" will make a copy of the conf. Then, it is
- // safe to mutate conf properties with initLocalJobFuncOpt
- // and initDriverSideJobFuncOpt.
- val newJob = new Job(conf)
- initLocalJobFuncOpt.map(f => f(newJob))
- newJob
- }
-
- def getConf(isDriverSide: Boolean): Configuration = {
- val job = getJob()
- if (isDriverSide) {
- initDriverSideJobFuncOpt.map(f => f(job))
- }
- SparkHadoopUtil.get.getConfigurationFromJobContext(job)
- }
-
- private val jobTrackerId: String = {
- val formatter = new SimpleDateFormat("yyyyMMddHHmm")
- formatter.format(new Date())
- }
-
- @transient protected val jobId = new JobID(jobTrackerId, id)
-
- override def getPartitions: Array[SparkPartition] = {
- val conf = getConf(isDriverSide = true)
- val inputFormat = inputFormatClass.newInstance
- inputFormat match {
- case configurable: Configurable =>
- configurable.setConf(conf)
- case _ =>
- }
- val jobContext = newJobContext(conf, jobId)
- val rawSplits = inputFormat.getSplits(jobContext).toArray
- val result = new Array[SparkPartition](rawSplits.size)
- for (i <- 0 until rawSplits.size) {
- result(i) =
- new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable])
- }
- result
- }
-
- override def compute(
- theSplit: SparkPartition,
- context: TaskContext): Iterator[V] = {
- val iter = new Iterator[V] {
- val split = theSplit.asInstanceOf[SqlNewHadoopPartition]
- logInfo("Input split: " + split.serializableHadoopSplit)
- val conf = getConf(isDriverSide = false)
-
- val inputMetrics = context.taskMetrics
- .getInputMetricsForReadMethod(DataReadMethod.Hadoop)
-
- // Sets the thread local variable for the file's name
- split.serializableHadoopSplit.value match {
- case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString)
- case _ => SqlNewHadoopRDD.unsetInputFileName()
- }
-
- // Find a function that will return the FileSystem bytes read by this thread. Do this before
- // creating RecordReader, because RecordReader's constructor might read some bytes
- val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
- split.serializableHadoopSplit.value match {
- case _: FileSplit | _: CombineFileSplit =>
- SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
- case _ => None
- }
- }
- inputMetrics.setBytesReadCallback(bytesReadCallback)
-
- val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0)
- val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
- val format = inputFormatClass.newInstance
- format match {
- case configurable: Configurable =>
- configurable.setConf(conf)
- case _ =>
- }
- private[this] var reader = format.createRecordReader(
- split.serializableHadoopSplit.value, hadoopAttemptContext)
- reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
-
- // Register an on-task-completion callback to close the input stream.
- context.addTaskCompletionListener(context => close())
-
- private[this] var havePair = false
- private[this] var finished = false
-
- override def hasNext: Boolean = {
- if (context.isInterrupted) {
- throw new TaskKilledException
- }
- if (!finished && !havePair) {
- finished = !reader.nextKeyValue
- if (finished) {
- // Close and release the reader here; close() will also be called when the task
- // completes, but for tasks that read from many files, it helps to release the
- // resources early.
- close()
- }
- havePair = !finished
- }
- !finished
- }
-
- override def next(): V = {
- if (!hasNext) {
- throw new java.util.NoSuchElementException("End of stream")
- }
- havePair = false
- if (!finished) {
- inputMetrics.incRecordsRead(1)
- }
- reader.getCurrentValue
- }
-
- private def close() {
- if (reader != null) {
- SqlNewHadoopRDD.unsetInputFileName()
- // Close the reader and release it. Note: it's very important that we don't close the
- // reader more than once, since that exposes us to MAPREDUCE-5918 when running against
- // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic
- // corruption issues when reading compressed input.
- try {
- reader.close()
- } catch {
- case e: Exception =>
- if (!ShutdownHookManager.inShutdown()) {
- logWarning("Exception in RecordReader.close()", e)
- }
- } finally {
- reader = null
- }
- if (bytesReadCallback.isDefined) {
- inputMetrics.updateBytesRead()
- } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] ||
- split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) {
- // If we can't get the bytes read from the FS stats, fall back to the split size,
- // which may be inaccurate.
- try {
- inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength)
- } catch {
- case e: java.io.IOException =>
- logWarning("Unable to get input size to set InputMetrics for task", e)
- }
- }
- }
- }
- }
- iter
- }
-
- override def getPreferredLocations(hsplit: SparkPartition): Seq[String] = {
- val split = hsplit.asInstanceOf[SqlNewHadoopPartition].serializableHadoopSplit.value
- val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match {
- case Some(c) =>
- try {
- val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]]
- Some(HadoopRDD.convertSplitLocationInfo(infos))
- } catch {
- case e : Exception =>
- logDebug("Failed to use InputSplit#getLocationInfo.", e)
- None
- }
- case None => None
- }
- locs.getOrElse(split.getLocations.filter(_ != "localhost"))
- }
-
- override def persist(storageLevel: StorageLevel): this.type = {
- if (storageLevel.deserialized) {
- logWarning("Caching NewHadoopRDDs as deserialized objects usually leads to undesired" +
- " behavior because Hadoop's RecordReader reuses the same Writable object for all records." +
- " Use a map transformation to make copies of the records.")
- }
- super.persist(storageLevel)
- }
-}
-
-private[spark] object SqlNewHadoopRDD {
-
- /**
- * The thread variable for the name of the current file being read. This is used by
- * the InputFileName function in Spark SQL.
- */
- private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] {
- override protected def initialValue(): UTF8String = UTF8String.fromString("")
- }
-
- def getInputFileName(): UTF8String = inputFileName.get()
-
- private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file))
-
- private[spark] def unsetInputFileName(): Unit = inputFileName.remove()
-
- /**
- * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to
- * the given function rather than the index of the partition.
- */
- private[spark] class NewHadoopMapPartitionsWithSplitRDD[U: ClassTag, T: ClassTag](
- prev: RDD[T],
- f: (InputSplit, Iterator[T]) => Iterator[U],
- preservesPartitioning: Boolean = false)
- extends RDD[U](prev) {
-
- override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
-
- override def getPartitions: Array[SparkPartition] = firstParent[T].partitions
-
- override def compute(split: SparkPartition, context: TaskContext): Iterator[U] = {
- val partition = split.asInstanceOf[SqlNewHadoopPartition]
- val inputSplit = partition.serializableHadoopSplit.value
- f(inputSplit, firstParent[T].iterator(split, context))
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
index 25ec685eff5a..a733eaa5d7e5 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
@@ -30,7 +30,6 @@ import org.apache.spark.Partitioner
import org.apache.spark.ShuffleDependency
import org.apache.spark.SparkEnv
import org.apache.spark.TaskContext
-import org.apache.spark.serializer.Serializer
/**
* An optimized version of cogroup for set difference/subtraction.
@@ -54,13 +53,6 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
part: Partitioner)
extends RDD[(K, V)](rdd1.context, Nil) {
- private var serializer: Option[Serializer] = None
-
- /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
- def setSerializer(serializer: Serializer): SubtractedRDD[K, V, W] = {
- this.serializer = Option(serializer)
- this
- }
override def getDependencies: Seq[Dependency[_]] = {
def rddDependency[T1: ClassTag, T2: ClassTag](rdd: RDD[_ <: Product2[T1, T2]])
@@ -70,7 +62,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
new OneToOneDependency(rdd)
} else {
logDebug("Adding shuffle dependency with " + rdd)
- new ShuffleDependency[T1, T2, Any](rdd, part, serializer)
+ new ShuffleDependency[T1, T2, Any](rdd, part)
}
}
Seq(rddDependency[K, V](rdd1), rddDependency[K, W](rdd2))
diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
index 66cf4369da2e..60e383afadf1 100644
--- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
@@ -20,6 +20,8 @@ package org.apache.spark.rdd
import java.io.{IOException, ObjectOutputStream}
import scala.collection.mutable.ArrayBuffer
+import scala.collection.parallel.ForkJoinTaskSupport
+import scala.concurrent.forkjoin.ForkJoinPool
import scala.reflect.ClassTag
import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
@@ -56,14 +58,30 @@ private[spark] class UnionPartition[T: ClassTag](
}
}
+object UnionRDD {
+ private[spark] lazy val partitionEvalTaskSupport =
+ new ForkJoinTaskSupport(new ForkJoinPool(8))
+}
+
@DeveloperApi
class UnionRDD[T: ClassTag](
sc: SparkContext,
var rdds: Seq[RDD[T]])
extends RDD[T](sc, Nil) { // Nil since we implement getDependencies
+ // visible for testing
+ private[spark] val isPartitionListingParallel: Boolean =
+ rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10)
+
override def getPartitions: Array[Partition] = {
- val array = new Array[Partition](rdds.map(_.partitions.length).sum)
+ val parRDDs = if (isPartitionListingParallel) {
+ val parArray = rdds.par
+ parArray.tasksupport = UnionRDD.partitionEvalTaskSupport
+ parArray
+ } else {
+ rdds
+ }
+ val array = new Array[Partition](parRDDs.map(_.partitions.length).seq.sum)
var pos = 0
for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
diff --git a/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala
new file mode 100644
index 000000000000..8e1baae796fc
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.hadoop.conf.{Configurable, Configuration}
+import org.apache.hadoop.io.{Text, Writable}
+import org.apache.hadoop.mapreduce.InputSplit
+import org.apache.hadoop.mapreduce.task.JobContextImpl
+
+import org.apache.spark.{Partition, SparkContext}
+import org.apache.spark.input.WholeTextFileInputFormat
+
+/**
+ * An RDD that reads a bunch of text files in, and each text file becomes one record.
+ */
+private[spark] class WholeTextFileRDD(
+ sc : SparkContext,
+ inputFormatClass: Class[_ <: WholeTextFileInputFormat],
+ keyClass: Class[Text],
+ valueClass: Class[Text],
+ conf: Configuration,
+ minPartitions: Int)
+ extends NewHadoopRDD[Text, Text](sc, inputFormatClass, keyClass, valueClass, conf) {
+
+ override def getPartitions: Array[Partition] = {
+ val inputFormat = inputFormatClass.newInstance
+ val conf = getConf
+ inputFormat match {
+ case configurable: Configurable =>
+ configurable.setConf(conf)
+ case _ =>
+ }
+ val jobContext = new JobContextImpl(conf, jobId)
+ inputFormat.setMinPartitions(jobContext, minPartitions)
+ val rawSplits = inputFormat.getSplits(jobContext).toArray
+ val result = new Array[Partition](rawSplits.size)
+ for (i <- 0 until rawSplits.size) {
+ result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable])
+ }
+ result
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
index 4333a679c8aa..3cb1231bd347 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
@@ -54,7 +54,8 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag](
override def getPartitions: Array[Partition] = {
val numParts = rdds.head.partitions.length
if (!rdds.forall(rdd => rdd.partitions.length == numParts)) {
- throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
+ throw new IllegalArgumentException(
+ s"Can't zip RDDs with unequal numbers of partitions: ${rdds.map(_.partitions.length)}")
}
Array.tabulate[Partition](numParts) { i =>
val prefs = rdds.map(rdd => rdd.preferredLocations(rdd.partitions(i)))
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala
index 32931d59acb1..8425b211d6ec 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala
@@ -29,7 +29,7 @@ class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long)
}
/**
- * Represents a RDD zipped with its element indices. The ordering is first based on the partition
+ * Represents an RDD zipped with its element indices. The ordering is first based on the partition
* index and then the ordering of items within each partition. So the first item in the first
* partition gets index 0, and the last item in the last partition receives the largest index.
*
@@ -43,7 +43,7 @@ class ZippedWithIndexRDD[T: ClassTag](prev: RDD[T]) extends RDD[(T, Long)](prev)
@transient private val startIndices: Array[Long] = {
val n = prev.partitions.length
if (n == 0) {
- Array[Long]()
+ Array.empty
} else if (n == 1) {
Array(0L)
} else {
@@ -64,8 +64,7 @@ class ZippedWithIndexRDD[T: ClassTag](prev: RDD[T]) extends RDD[(T, Long)](prev)
override def compute(splitIn: Partition, context: TaskContext): Iterator[(T, Long)] = {
val split = splitIn.asInstanceOf[ZippedWithIndexRDDPartition]
- firstParent[T].iterator(split.prev, context).zipWithIndex.map { x =>
- (x._1, split.startIndex + x._2)
- }
+ val parentIter = firstParent[T].iterator(split.prev, context)
+ Utils.getIteratorZipWithIndex(parentIter, split.startIndex)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala b/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala
new file mode 100644
index 000000000000..e00bc22aba44
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.collection.mutable
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.Partition
+
+/**
+ * ::DeveloperApi::
+ * A PartitionCoalescer defines how to coalesce the partitions of a given RDD.
+ */
+@DeveloperApi
+trait PartitionCoalescer {
+
+ /**
+ * Coalesce the partitions of the given RDD.
+ *
+ * @param maxPartitions the maximum number of partitions to have after coalescing
+ * @param parent the parent RDD whose partitions to coalesce
+ * @return an array of [[PartitionGroup]]s, where each element is itself an array of
+ * `Partition`s and represents a partition after coalescing is performed.
+ */
+ def coalesce(maxPartitions: Int, parent: RDD[_]): Array[PartitionGroup]
+}
+
+/**
+ * ::DeveloperApi::
+ * A group of `Partition`s
+ * @param prefLoc preferred location for the partition group
+ */
+@DeveloperApi
+class PartitionGroup(val prefLoc: Option[String] = None) {
+ val partitions = mutable.ArrayBuffer[Partition]()
+ def numPartitions: Int = partitions.size
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/package-info.java b/core/src/main/scala/org/apache/spark/rdd/package-info.java
index 176cc58179fb..d9aa9bebe56d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/package-info.java
+++ b/core/src/main/scala/org/apache/spark/rdd/package-info.java
@@ -18,4 +18,4 @@
/**
* Provides implementation's of various RDDs.
*/
-package org.apache.spark.rdd;
\ No newline at end of file
+package org.apache.spark.rdd;
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala
similarity index 96%
rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala
rename to core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala
index f31ed2aa90a6..ab72addb2466 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala
@@ -15,11 +15,12 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.impl
+package org.apache.spark.rdd.util
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.PeriodicCheckpointer
/**
@@ -74,7 +75,7 @@ import org.apache.spark.storage.StorageLevel
*
* TODO: Move this out of MLlib?
*/
-private[mllib] class PeriodicRDDCheckpointer[T](
+private[spark] class PeriodicRDDCheckpointer[T](
checkpointInterval: Int,
sc: SparkContext)
extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) {
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala
index f527ec86ab7b..117f51c5b8f2 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala
@@ -18,7 +18,7 @@
package org.apache.spark.rpc
/**
- * A callback that [[RpcEndpoint]] can use it to send back a message or failure. It's thread-safe
+ * A callback that [[RpcEndpoint]] can use to send back a message or failure. It's thread-safe
* and can be called in any thread.
*/
private[spark] trait RpcCallContext {
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
index 0ba95169529e..97eed540b8f5 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
@@ -35,7 +35,7 @@ private[spark] trait RpcEnvFactory {
*
* The life-cycle of an endpoint is:
*
- * constructor -> onStart -> receive* -> onStop
+ * {@code constructor -> onStart -> receive* -> onStop}
*
* Note: `receive` can be called concurrently. If you want `receive` to be thread-safe, please use
* [[ThreadSafeRpcEndpoint]]
@@ -63,16 +63,16 @@ private[spark] trait RpcEndpoint {
}
/**
- * Process messages from [[RpcEndpointRef.send]] or [[RpcCallContext.reply)]]. If receiving a
- * unmatched message, [[SparkException]] will be thrown and sent to `onError`.
+ * Process messages from `RpcEndpointRef.send` or `RpcCallContext.reply`. If receiving a
+ * unmatched message, `SparkException` will be thrown and sent to `onError`.
*/
def receive: PartialFunction[Any, Unit] = {
case _ => throw new SparkException(self + " does not implement 'receive'")
}
/**
- * Process messages from [[RpcEndpointRef.ask]]. If receiving a unmatched message,
- * [[SparkException]] will be thrown and sent to `onError`.
+ * Process messages from `RpcEndpointRef.ask`. If receiving a unmatched message,
+ * `SparkException` will be thrown and sent to `onError`.
*/
def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case _ => context.sendFailure(new SparkException(self + " won't reply anything"))
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala
similarity index 83%
rename from core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala
rename to core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala
index d2e94f943aba..fdbccc9e74c3 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala
@@ -15,10 +15,9 @@
* limitations under the License.
*/
-package org.apache.spark.rpc.netty
+package org.apache.spark.rpc
import org.apache.spark.SparkException
-import org.apache.spark.rpc.RpcAddress
/**
* An address identifier for an RPC endpoint.
@@ -26,10 +25,11 @@ import org.apache.spark.rpc.RpcAddress
* The `rpcAddress` may be null, in which case the endpoint is registered via a client-only
* connection and can only be reached via the client that sent the endpoint reference.
*
- * @param rpcAddress The socket address of the endpint.
+ * @param rpcAddress The socket address of the endpoint. It's `null` when this address pointing to
+ * an endpoint in a client `NettyRpcEnv`.
* @param name Name of the endpoint.
*/
-private[netty] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) {
+private[spark] case class RpcEndpointAddress(rpcAddress: RpcAddress, name: String) {
require(name != null, "RpcEndpoint name must be provided.")
@@ -44,7 +44,11 @@ private[netty] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val nam
}
}
-private[netty] object RpcEndpointAddress {
+private[spark] object RpcEndpointAddress {
+
+ def apply(host: String, port: Int, name: String): RpcEndpointAddress = {
+ new RpcEndpointAddress(host, port, name)
+ }
def apply(sparkUrl: String): RpcEndpointAddress = {
try {
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala
index 623da3e9c11b..4d39f144dd19 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala
@@ -20,8 +20,9 @@ package org.apache.spark.rpc
import scala.concurrent.Future
import scala.reflect.ClassTag
+import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.internal.Logging
import org.apache.spark.util.RpcUtils
-import org.apache.spark.{SparkException, Logging, SparkConf}
/**
* A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe.
@@ -62,25 +63,21 @@ private[spark] abstract class RpcEndpointRef(conf: SparkConf)
def ask[T: ClassTag](message: Any): Future[T] = ask(message, defaultAskTimeout)
/**
- * Send a message to the corresponding [[RpcEndpoint]] and get its result within a default
- * timeout, or throw a SparkException if this fails even after the default number of retries.
- * The default `timeout` will be used in every trial of calling `sendWithReply`. Because this
- * method retries, the message handling in the receiver side should be idempotent.
+ * Send a message to the corresponding [[RpcEndpoint.receiveAndReply]] and get its result within a
+ * default timeout, throw an exception if this fails.
*
* Note: this is a blocking action which may cost a lot of time, so don't call it in a message
* loop of [[RpcEndpoint]].
- *
+
* @param message the message to send
* @tparam T type of the reply message
* @return the reply message from the corresponding [[RpcEndpoint]]
*/
- def askWithRetry[T: ClassTag](message: Any): T = askWithRetry(message, defaultAskTimeout)
+ def askSync[T: ClassTag](message: Any): T = askSync(message, defaultAskTimeout)
/**
- * Send a message to the corresponding [[RpcEndpoint.receive]] and get its result within a
- * specified timeout, throw a SparkException if this fails even after the specified number of
- * retries. `timeout` will be used in every trial of calling `sendWithReply`. Because this method
- * retries, the message handling in the receiver side should be idempotent.
+ * Send a message to the corresponding [[RpcEndpoint.receiveAndReply]] and get its result within a
+ * specified timeout, throw an exception if this fails.
*
* Note: this is a blocking action which may cost a lot of time, so don't call it in a message
* loop of [[RpcEndpoint]].
@@ -90,33 +87,9 @@ private[spark] abstract class RpcEndpointRef(conf: SparkConf)
* @tparam T type of the reply message
* @return the reply message from the corresponding [[RpcEndpoint]]
*/
- def askWithRetry[T: ClassTag](message: Any, timeout: RpcTimeout): T = {
- // TODO: Consider removing multiple attempts
- var attempts = 0
- var lastException: Exception = null
- while (attempts < maxRetries) {
- attempts += 1
- try {
- val future = ask[T](message, timeout)
- val result = timeout.awaitResult(future)
- if (result == null) {
- throw new SparkException("RpcEndpoint returned null")
- }
- return result
- } catch {
- case ie: InterruptedException => throw ie
- case e: Exception =>
- lastException = e
- logWarning(s"Error sending message [message = $message] in $attempts attempts", e)
- }
-
- if (attempts < maxRetries) {
- Thread.sleep(retryWaitMs)
- }
- }
-
- throw new SparkException(
- s"Error sending message [message = $message]", lastException)
+ def askSync[T: ClassTag](message: Any, timeout: RpcTimeout): T = {
+ val future = ask[T](message, timeout)
+ timeout.awaitResult(future)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index a560fd10cdf7..530743c03640 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -17,10 +17,14 @@
package org.apache.spark.rpc
+import java.io.File
+import java.nio.channels.ReadableByteChannel
+
import scala.concurrent.Future
import org.apache.spark.{SecurityManager, SparkConf}
-import org.apache.spark.util.{RpcUtils, Utils}
+import org.apache.spark.rpc.netty.NettyRpcEnvFactory
+import org.apache.spark.util.RpcUtils
/**
@@ -29,15 +33,6 @@ import org.apache.spark.util.{RpcUtils, Utils}
*/
private[spark] object RpcEnv {
- private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = {
- val rpcEnvNames = Map(
- "akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory",
- "netty" -> "org.apache.spark.rpc.netty.NettyRpcEnvFactory")
- val rpcEnvName = conf.get("spark.rpc", "netty")
- val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName)
- Utils.classForName(rpcEnvFactoryClassName).newInstance().asInstanceOf[RpcEnvFactory]
- }
-
def create(
name: String,
host: String,
@@ -45,9 +40,20 @@ private[spark] object RpcEnv {
conf: SparkConf,
securityManager: SecurityManager,
clientMode: Boolean = false): RpcEnv = {
- // Using Reflection to create the RpcEnv to avoid to depend on Akka directly
- val config = RpcEnvConfig(conf, name, host, port, securityManager, clientMode)
- getRpcEnvFactory(conf).create(config)
+ create(name, host, host, port, conf, securityManager, clientMode)
+ }
+
+ def create(
+ name: String,
+ bindAddress: String,
+ advertiseAddress: String,
+ port: Int,
+ conf: SparkConf,
+ securityManager: SecurityManager,
+ clientMode: Boolean): RpcEnv = {
+ val config = RpcEnvConfig(conf, name, bindAddress, advertiseAddress, port, securityManager,
+ clientMode)
+ new NettyRpcEnvFactory().create(config)
}
}
@@ -95,12 +101,11 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
}
/**
- * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`.
+ * Retrieve the [[RpcEndpointRef]] represented by `address` and `endpointName`.
* This is a blocking action.
*/
- def setupEndpointRef(
- systemName: String, address: RpcAddress, endpointName: String): RpcEndpointRef = {
- setupEndpointRefByURI(uriOf(systemName, address, endpointName))
+ def setupEndpointRef(address: RpcAddress, endpointName: String): RpcEndpointRef = {
+ setupEndpointRefByURI(RpcEndpointAddress(address, endpointName).toString)
}
/**
@@ -121,24 +126,79 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
*/
def awaitTermination(): Unit
- /**
- * Create a URI used to create a [[RpcEndpointRef]]. Use this one to create the URI instead of
- * creating it manually because different [[RpcEnv]] may have different formats.
- */
- def uriOf(systemName: String, address: RpcAddress, endpointName: String): String
-
/**
* [[RpcEndpointRef]] cannot be deserialized without [[RpcEnv]]. So when deserializing any object
* that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method.
*/
def deserialize[T](deserializationAction: () => T): T
+
+ /**
+ * Return the instance of the file server used to serve files. This may be `null` if the
+ * RpcEnv is not operating in server mode.
+ */
+ def fileServer: RpcEnvFileServer
+
+ /**
+ * Open a channel to download a file from the given URI. If the URIs returned by the
+ * RpcEnvFileServer use the "spark" scheme, this method will be called by the Utils class to
+ * retrieve the files.
+ *
+ * @param uri URI with location of the file.
+ */
+ def openChannel(uri: String): ReadableByteChannel
}
+/**
+ * A server used by the RpcEnv to server files to other processes owned by the application.
+ *
+ * The file server can return URIs handled by common libraries (such as "http" or "hdfs"), or
+ * it can return "spark" URIs which will be handled by `RpcEnv#fetchFile`.
+ */
+private[spark] trait RpcEnvFileServer {
+
+ /**
+ * Adds a file to be served by this RpcEnv. This is used to serve files from the driver
+ * to executors when they're stored on the driver's local file system.
+ *
+ * @param file Local file to serve.
+ * @return A URI for the location of the file.
+ */
+ def addFile(file: File): String
+
+ /**
+ * Adds a jar to be served by this RpcEnv. Similar to `addFile` but for jars added using
+ * `SparkContext.addJar`.
+ *
+ * @param file Local file to serve.
+ * @return A URI for the location of the file.
+ */
+ def addJar(file: File): String
+
+ /**
+ * Adds a local directory to be served via this file server.
+ *
+ * @param baseUri Leading URI path (files can be retrieved by appending their relative
+ * path to this base URI). This cannot be "files" nor "jars".
+ * @param path Path to the local directory.
+ * @return URI for the root of the directory in the file server.
+ */
+ def addDirectory(baseUri: String, path: File): String
+
+ /** Validates and normalizes the base URI for directories. */
+ protected def validateDirectoryUri(baseUri: String): String = {
+ val fixedBaseUri = "/" + baseUri.stripPrefix("/").stripSuffix("/")
+ require(fixedBaseUri != "/files" && fixedBaseUri != "/jars",
+ "Directory URI cannot be /files nor /jars.")
+ fixedBaseUri
+ }
+
+}
private[spark] case class RpcEnvConfig(
conf: SparkConf,
name: String,
- host: String,
+ bindAddress: String,
+ advertiseAddress: String,
port: Int,
securityManager: SecurityManager,
clientMode: Boolean)
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnvStoppedException.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnvStoppedException.scala
new file mode 100644
index 000000000000..c296cc23f12b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnvStoppedException.scala
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.rpc
+
+private[rpc] class RpcEnvStoppedException()
+ extends IllegalStateException("RpcEnv already stopped.")
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala
index 285786ebf9f1..0557b7a3cc0b 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala
@@ -19,15 +19,14 @@ package org.apache.spark.rpc
import java.util.concurrent.TimeoutException
-import scala.concurrent.{Awaitable, Await}
+import scala.concurrent.Future
import scala.concurrent.duration._
import org.apache.spark.SparkConf
-import org.apache.spark.util.Utils
-
+import org.apache.spark.util.{ThreadUtils, Utils}
/**
- * An exception thrown if RpcTimeout modifies a [[TimeoutException]].
+ * An exception thrown if RpcTimeout modifies a `TimeoutException`.
*/
private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException)
extends TimeoutException(message) { initCause(cause) }
@@ -66,13 +65,14 @@ private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: S
/**
* Wait for the completed result and return it. If the result is not available within this
* timeout, throw a [[RpcTimeoutException]] to indicate which configuration controls the timeout.
- * @param awaitable the `Awaitable` to be awaited
- * @throws RpcTimeoutException if after waiting for the specified time `awaitable`
+ *
+ * @param future the `Future` to be awaited
+ * @throws RpcTimeoutException if after waiting for the specified time `future`
* is still not ready
*/
- def awaitResult[T](awaitable: Awaitable[T]): T = {
+ def awaitResult[T](future: Future[T]): T = {
try {
- Await.result(awaitable, duration)
+ ThreadUtils.awaitResult(future, duration)
} catch addMessageIfTimeout
}
}
@@ -83,6 +83,7 @@ private[spark] object RpcTimeout {
/**
* Lookup the timeout property in the configuration and create
* a RpcTimeout with the property key in the description.
+ *
* @param conf configuration properties containing the timeout
* @param timeoutProp property key for the timeout in seconds
* @throws NoSuchElementException if property is not set
@@ -96,6 +97,7 @@ private[spark] object RpcTimeout {
* Lookup the timeout property in the configuration and create
* a RpcTimeout with the property key in the description.
* Uses the given default value if property is not set
+ *
* @param conf configuration properties containing the timeout
* @param timeoutProp property key for the timeout in seconds
* @param defaultValue default timeout value in seconds if property not found
@@ -110,6 +112,7 @@ private[spark] object RpcTimeout {
* and create a RpcTimeout with the first set property key in the
* description.
* Uses the given default value if property is not set
+ *
* @param conf configuration properties containing the timeout
* @param timeoutPropList prioritized list of property keys for the timeout in seconds
* @param defaultValue default timeout value in seconds if no properties found
@@ -120,7 +123,7 @@ private[spark] object RpcTimeout {
// Find the first set property or use the default value with the first property
val itr = timeoutPropList.iterator
var foundProp: Option[(String, String)] = None
- while (itr.hasNext && foundProp.isEmpty){
+ while (itr.hasNext && foundProp.isEmpty) {
val propKey = itr.next()
conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) }
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
deleted file mode 100644
index 3fad595a0d0b..000000000000
--- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
+++ /dev/null
@@ -1,345 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rpc.akka
-
-import java.util.concurrent.ConcurrentHashMap
-
-import scala.concurrent.Future
-import scala.language.postfixOps
-import scala.reflect.ClassTag
-import scala.util.control.NonFatal
-
-import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Address}
-import akka.event.Logging.Error
-import akka.pattern.{ask => akkaAsk}
-import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent}
-import akka.serialization.JavaSerializer
-
-import org.apache.spark.{SparkException, Logging, SparkConf}
-import org.apache.spark.rpc._
-import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils}
-
-/**
- * A RpcEnv implementation based on Akka.
- *
- * TODO Once we remove all usages of Akka in other place, we can move this file to a new project and
- * remove Akka from the dependencies.
- */
-private[spark] class AkkaRpcEnv private[akka] (
- val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int)
- extends RpcEnv(conf) with Logging {
-
- private val defaultAddress: RpcAddress = {
- val address = actorSystem.asInstanceOf[ExtendedActorSystem].provider.getDefaultAddress
- // In some test case, ActorSystem doesn't bind to any address.
- // So just use some default value since they are only some unit tests
- RpcAddress(address.host.getOrElse("localhost"), address.port.getOrElse(boundPort))
- }
-
- override val address: RpcAddress = defaultAddress
-
- /**
- * A lookup table to search a [[RpcEndpointRef]] for a [[RpcEndpoint]]. We need it to make
- * [[RpcEndpoint.self]] work.
- */
- private val endpointToRef = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]()
-
- /**
- * Need this map to remove `RpcEndpoint` from `endpointToRef` via a `RpcEndpointRef`
- */
- private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]()
-
- private def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = {
- endpointToRef.put(endpoint, endpointRef)
- refToEndpoint.put(endpointRef, endpoint)
- }
-
- private def unregisterEndpoint(endpointRef: RpcEndpointRef): Unit = {
- val endpoint = refToEndpoint.remove(endpointRef)
- if (endpoint != null) {
- endpointToRef.remove(endpoint)
- }
- }
-
- /**
- * Retrieve the [[RpcEndpointRef]] of `endpoint`.
- */
- override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointToRef.get(endpoint)
-
- override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
- @volatile var endpointRef: AkkaRpcEndpointRef = null
- // Use defered function because the Actor needs to use `endpointRef`.
- // So `actorRef` should be created after assigning `endpointRef`.
- val actorRef = () => actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging {
-
- assert(endpointRef != null)
-
- override def preStart(): Unit = {
- // Listen for remote client network events
- context.system.eventStream.subscribe(self, classOf[AssociationEvent])
- safelyCall(endpoint) {
- endpoint.onStart()
- }
- }
-
- override def receiveWithLogging: Receive = {
- case AssociatedEvent(_, remoteAddress, _) =>
- safelyCall(endpoint) {
- endpoint.onConnected(akkaAddressToRpcAddress(remoteAddress))
- }
-
- case DisassociatedEvent(_, remoteAddress, _) =>
- safelyCall(endpoint) {
- endpoint.onDisconnected(akkaAddressToRpcAddress(remoteAddress))
- }
-
- case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) =>
- safelyCall(endpoint) {
- endpoint.onNetworkError(cause, akkaAddressToRpcAddress(remoteAddress))
- }
-
- case e: AssociationEvent =>
- // TODO ignore?
-
- case m: AkkaMessage =>
- logDebug(s"Received RPC message: $m")
- safelyCall(endpoint) {
- processMessage(endpoint, m, sender)
- }
-
- case AkkaFailure(e) =>
- safelyCall(endpoint) {
- throw e
- }
-
- case message: Any => {
- logWarning(s"Unknown message: $message")
- }
-
- }
-
- override def postStop(): Unit = {
- unregisterEndpoint(endpoint.self)
- safelyCall(endpoint) {
- endpoint.onStop()
- }
- }
-
- }), name = name)
- endpointRef = new AkkaRpcEndpointRef(defaultAddress, actorRef, conf, initInConstructor = false)
- registerEndpoint(endpoint, endpointRef)
- // Now actorRef can be created safely
- endpointRef.init()
- endpointRef
- }
-
- private def processMessage(endpoint: RpcEndpoint, m: AkkaMessage, _sender: ActorRef): Unit = {
- val message = m.message
- val needReply = m.needReply
- val pf: PartialFunction[Any, Unit] =
- if (needReply) {
- endpoint.receiveAndReply(new RpcCallContext {
- override def sendFailure(e: Throwable): Unit = {
- _sender ! AkkaFailure(e)
- }
-
- override def reply(response: Any): Unit = {
- _sender ! AkkaMessage(response, false)
- }
-
- // Use "lazy" because most of RpcEndpoints don't need "senderAddress"
- override lazy val senderAddress: RpcAddress =
- new AkkaRpcEndpointRef(defaultAddress, _sender, conf).address
- })
- } else {
- endpoint.receive
- }
- try {
- pf.applyOrElse[Any, Unit](message, { message =>
- throw new SparkException(s"Unmatched message $message from ${_sender}")
- })
- } catch {
- case NonFatal(e) =>
- _sender ! AkkaFailure(e)
- if (!needReply) {
- // If the sender does not require a reply, it may not handle the exception. So we rethrow
- // "e" to make sure it will be processed.
- throw e
- }
- }
- }
-
- /**
- * Run `action` safely to avoid to crash the thread. If any non-fatal exception happens, it will
- * call `endpoint.onError`. If `endpoint.onError` throws any non-fatal exception, just log it.
- */
- private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = {
- try {
- action
- } catch {
- case NonFatal(e) => {
- try {
- endpoint.onError(e)
- } catch {
- case NonFatal(e) => logError(s"Ignore error: ${e.getMessage}", e)
- }
- }
- }
- }
-
- private def akkaAddressToRpcAddress(address: Address): RpcAddress = {
- RpcAddress(address.host.getOrElse(defaultAddress.host),
- address.port.getOrElse(defaultAddress.port))
- }
-
- override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = {
- import actorSystem.dispatcher
- actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout.duration).
- map(new AkkaRpcEndpointRef(defaultAddress, _, conf)).
- // this is just in case there is a timeout from creating the future in resolveOne, we want the
- // exception to indicate the conf that determines the timeout
- recover(defaultLookupTimeout.addMessageIfTimeout)
- }
-
- override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = {
- AkkaUtils.address(
- AkkaUtils.protocol(actorSystem), systemName, address.host, address.port, endpointName)
- }
-
- override def shutdown(): Unit = {
- actorSystem.shutdown()
- }
-
- override def stop(endpoint: RpcEndpointRef): Unit = {
- require(endpoint.isInstanceOf[AkkaRpcEndpointRef])
- actorSystem.stop(endpoint.asInstanceOf[AkkaRpcEndpointRef].actorRef)
- }
-
- override def awaitTermination(): Unit = {
- actorSystem.awaitTermination()
- }
-
- override def toString: String = s"${getClass.getSimpleName}($actorSystem)"
-
- override def deserialize[T](deserializationAction: () => T): T = {
- JavaSerializer.currentSystem.withValue(actorSystem.asInstanceOf[ExtendedActorSystem]) {
- deserializationAction()
- }
- }
-}
-
-private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory {
-
- def create(config: RpcEnvConfig): RpcEnv = {
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
- config.name, config.host, config.port, config.conf, config.securityManager)
- actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor")
- new AkkaRpcEnv(actorSystem, config.conf, boundPort)
- }
-}
-
-/**
- * Monitor errors reported by Akka and log them.
- */
-private[akka] class ErrorMonitor extends Actor with ActorLogReceive with Logging {
-
- override def preStart(): Unit = {
- context.system.eventStream.subscribe(self, classOf[Error])
- }
-
- override def receiveWithLogging: Actor.Receive = {
- case Error(cause: Throwable, _, _, message: String) => logError(message, cause)
- }
-}
-
-private[akka] class AkkaRpcEndpointRef(
- @transient private val defaultAddress: RpcAddress,
- @transient private val _actorRef: () => ActorRef,
- conf: SparkConf,
- initInConstructor: Boolean)
- extends RpcEndpointRef(conf) with Logging {
-
- def this(
- defaultAddress: RpcAddress,
- _actorRef: ActorRef,
- conf: SparkConf) = {
- this(defaultAddress, () => _actorRef, conf, true)
- }
-
- lazy val actorRef = _actorRef()
-
- override lazy val address: RpcAddress = {
- val akkaAddress = actorRef.path.address
- RpcAddress(akkaAddress.host.getOrElse(defaultAddress.host),
- akkaAddress.port.getOrElse(defaultAddress.port))
- }
-
- override lazy val name: String = actorRef.path.name
-
- private[akka] def init(): Unit = {
- // Initialize the lazy vals
- actorRef
- address
- name
- }
-
- if (initInConstructor) {
- init()
- }
-
- override def send(message: Any): Unit = {
- actorRef ! AkkaMessage(message, false)
- }
-
- override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
- actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap {
- // The function will run in the calling thread, so it should be short and never block.
- case msg @ AkkaMessage(message, reply) =>
- if (reply) {
- logError(s"Receive $msg but the sender cannot reply")
- Future.failed(new SparkException(s"Receive $msg but the sender cannot reply"))
- } else {
- Future.successful(message)
- }
- case AkkaFailure(e) =>
- Future.failed(e)
- }(ThreadUtils.sameThread).mapTo[T].
- recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
- }
-
- override def toString: String = s"${getClass.getSimpleName}($actorRef)"
-
- final override def equals(that: Any): Boolean = that match {
- case other: AkkaRpcEndpointRef => actorRef == other.actorRef
- case _ => false
- }
-
- final override def hashCode(): Int = if (actorRef == null) 0 else actorRef.hashCode()
-}
-
-/**
- * A wrapper to `message` so that the receiver knows if the sender expects a reply.
- * @param message
- * @param needReply if the sender expects a reply message
- */
-private[akka] case class AkkaMessage(message: Any, needReply: Boolean)
-
-/**
- * A reply with the failure error from the receiver to the sender
- */
-private[akka] case class AkkaFailure(e: Throwable)
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
index eb25d6c7b721..e94babb84612 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
@@ -17,14 +17,15 @@
package org.apache.spark.rpc.netty
-import java.util.concurrent.{ThreadPoolExecutor, ConcurrentHashMap, LinkedBlockingQueue, TimeUnit}
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit}
import javax.annotation.concurrent.GuardedBy
import scala.collection.JavaConverters._
import scala.concurrent.Promise
import scala.util.control.NonFatal
-import org.apache.spark.{SparkException, Logging}
+import org.apache.spark.SparkException
+import org.apache.spark.internal.Logging
import org.apache.spark.network.client.RpcResponseCallback
import org.apache.spark.rpc._
import org.apache.spark.util.ThreadUtils
@@ -41,8 +42,10 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
val inbox = new Inbox(ref, endpoint)
}
- private val endpoints = new ConcurrentHashMap[String, EndpointData]
- private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]
+ private val endpoints: ConcurrentMap[String, EndpointData] =
+ new ConcurrentHashMap[String, EndpointData]
+ private val endpointRefs: ConcurrentMap[RpcEndpoint, RpcEndpointRef] =
+ new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]
// Track the receivers whose inboxes may contain messages.
private val receivers = new LinkedBlockingQueue[EndpointData]
@@ -106,71 +109,60 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
val iter = endpoints.keySet().iterator()
while (iter.hasNext) {
val name = iter.next
- postMessage(
- name,
- _ => message,
- () => { logWarning(s"Drop $message because $name has been stopped") })
- }
+ postMessage(name, message, (e) => { e match {
+ case e: RpcEnvStoppedException => logDebug (s"Message $message dropped. ${e.getMessage}")
+ case e: Throwable => logWarning(s"Message $message dropped. ${e.getMessage}")
+ }}
+ )}
}
/** Posts a message sent by a remote endpoint. */
def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = {
- def createMessage(sender: NettyRpcEndpointRef): InboxMessage = {
- val rpcCallContext =
- new RemoteNettyRpcCallContext(
- nettyEnv, sender, callback, message.senderAddress, message.needReply)
- ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext)
- }
-
- def onEndpointStopped(): Unit = {
- callback.onFailure(
- new SparkException(s"Could not find ${message.receiver.name} or it has been stopped"))
- }
-
- postMessage(message.receiver.name, createMessage, onEndpointStopped)
+ val rpcCallContext =
+ new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress)
+ val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
+ postMessage(message.receiver.name, rpcMessage, (e) => callback.onFailure(e))
}
/** Posts a message sent by a local endpoint. */
def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = {
- def createMessage(sender: NettyRpcEndpointRef): InboxMessage = {
- val rpcCallContext =
- new LocalNettyRpcCallContext(sender, message.senderAddress, message.needReply, p)
- ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext)
- }
-
- def onEndpointStopped(): Unit = {
- p.tryFailure(
- new SparkException(s"Could not find ${message.receiver.name} or it has been stopped"))
- }
+ val rpcCallContext =
+ new LocalNettyRpcCallContext(message.senderAddress, p)
+ val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
+ postMessage(message.receiver.name, rpcMessage, (e) => p.tryFailure(e))
+ }
- postMessage(message.receiver.name, createMessage, onEndpointStopped)
+ /** Posts a one-way message. */
+ def postOneWayMessage(message: RequestMessage): Unit = {
+ postMessage(message.receiver.name, OneWayMessage(message.senderAddress, message.content),
+ (e) => throw e)
}
/**
* Posts a message to a specific endpoint.
*
* @param endpointName name of the endpoint.
- * @param createMessageFn function to create the message.
+ * @param message the message to post
* @param callbackIfStopped callback function if the endpoint is stopped.
*/
private def postMessage(
endpointName: String,
- createMessageFn: NettyRpcEndpointRef => InboxMessage,
- callbackIfStopped: () => Unit): Unit = {
- val shouldCallOnStop = synchronized {
+ message: InboxMessage,
+ callbackIfStopped: (Exception) => Unit): Unit = {
+ val error = synchronized {
val data = endpoints.get(endpointName)
- if (stopped || data == null) {
- true
+ if (stopped) {
+ Some(new RpcEnvStoppedException())
+ } else if (data == null) {
+ Some(new SparkException(s"Could not find $endpointName."))
} else {
- data.inbox.post(createMessageFn(data.ref))
+ data.inbox.post(message)
receivers.offer(data)
- false
+ None
}
}
- if (shouldCallOnStop) {
- // We don't need to call `onStop` in the `synchronized` block
- callbackIfStopped()
- }
+ // We don't need to call `onStop` in the `synchronized` block
+ error.foreach(callbackIfStopped)
}
def stop(): Unit = {
@@ -201,7 +193,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
/** Thread pool used for dispatching messages. */
private val threadpool: ThreadPoolExecutor = {
val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads",
- Runtime.getRuntime.availableProcessors())
+ math.max(2, Runtime.getRuntime.availableProcessors()))
val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
for (i <- 0 until numThreads) {
pool.execute(new MessageLoop)
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
index c72b588db57f..d32eba64e13e 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
@@ -21,18 +21,20 @@ import javax.annotation.concurrent.GuardedBy
import scala.util.control.NonFatal
-import com.google.common.annotations.VisibleForTesting
-
-import org.apache.spark.{Logging, SparkException}
+import org.apache.spark.SparkException
+import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint}
private[netty] sealed trait InboxMessage
-private[netty] case class ContentMessage(
+private[netty] case class OneWayMessage(
+ senderAddress: RpcAddress,
+ content: Any) extends InboxMessage
+
+private[netty] case class RpcMessage(
senderAddress: RpcAddress,
content: Any,
- needReply: Boolean,
context: NettyRpcCallContext) extends InboxMessage
private[netty] case object OnStart extends InboxMessage
@@ -50,7 +52,7 @@ private[netty] case class RemoteProcessConnectionError(cause: Throwable, remoteA
extends InboxMessage
/**
- * A inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely.
+ * An inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely.
*/
private[netty] class Inbox(
val endpointRef: NettyRpcEndpointRef,
@@ -98,29 +100,24 @@ private[netty] class Inbox(
while (true) {
safelyCall(endpoint) {
message match {
- case ContentMessage(_sender, content, needReply, context) =>
- // The partial function to call
- val pf = if (needReply) endpoint.receiveAndReply(context) else endpoint.receive
+ case RpcMessage(_sender, content, context) =>
try {
- pf.applyOrElse[Any, Unit](content, { msg =>
+ endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg =>
throw new SparkException(s"Unsupported message $message from ${_sender}")
})
- if (!needReply) {
- context.finish()
- }
} catch {
case NonFatal(e) =>
- if (needReply) {
- // If the sender asks a reply, we should send the error back to the sender
- context.sendFailure(e)
- } else {
- context.finish()
- }
+ context.sendFailure(e)
// Throw the exception -- this exception will be caught by the safelyCall function.
// The endpoint's onError function will be called.
throw e
}
+ case OneWayMessage(_sender, content) =>
+ endpoint.receive.applyOrElse[Any, Unit](content, { msg =>
+ throw new SparkException(s"Unsupported message $message from ${_sender}")
+ })
+
case OnStart =>
endpoint.onStart()
if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
@@ -193,8 +190,10 @@ private[netty] class Inbox(
def isEmpty: Boolean = inbox.synchronized { messages.isEmpty }
- /** Called when we are dropping a message. Test cases override this to test message dropping. */
- @VisibleForTesting
+ /**
+ * Called when we are dropping a message. Test cases override this to test message dropping.
+ * Exposed for testing.
+ */
protected def onDrop(message: InboxMessage): Unit = {
logWarning(s"Drop $message because $endpointRef is stopped")
}
@@ -206,7 +205,12 @@ private[netty] class Inbox(
try action catch {
case NonFatal(e) =>
try endpoint.onError(e) catch {
- case NonFatal(ee) => logError(s"Ignoring error", ee)
+ case NonFatal(ee) =>
+ if (stopped) {
+ logDebug("Ignoring error", ee)
+ } else {
+ logError("Ignoring error", ee)
+ }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala
index 21d5bb4923d1..7dd7e610a28e 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala
@@ -19,53 +19,32 @@ package org.apache.spark.rpc.netty
import scala.concurrent.Promise
-import org.apache.spark.Logging
+import org.apache.spark.internal.Logging
import org.apache.spark.network.client.RpcResponseCallback
import org.apache.spark.rpc.{RpcAddress, RpcCallContext}
-private[netty] abstract class NettyRpcCallContext(
- endpointRef: NettyRpcEndpointRef,
- override val senderAddress: RpcAddress,
- needReply: Boolean)
+private[netty] abstract class NettyRpcCallContext(override val senderAddress: RpcAddress)
extends RpcCallContext with Logging {
protected def send(message: Any): Unit
override def reply(response: Any): Unit = {
- if (needReply) {
- send(AskResponse(endpointRef, response))
- } else {
- throw new IllegalStateException(
- s"Cannot send $response to the sender because the sender does not expect a reply")
- }
+ send(response)
}
override def sendFailure(e: Throwable): Unit = {
- if (needReply) {
- send(AskResponse(endpointRef, RpcFailure(e)))
- } else {
- logError(e.getMessage, e)
- throw new IllegalStateException(
- "Cannot send reply to the sender because the sender won't handle it")
- }
+ send(RpcFailure(e))
}
- def finish(): Unit = {
- if (!needReply) {
- send(Ack(endpointRef))
- }
- }
}
/**
* If the sender and the receiver are in the same process, the reply can be sent back via `Promise`.
*/
private[netty] class LocalNettyRpcCallContext(
- endpointRef: NettyRpcEndpointRef,
senderAddress: RpcAddress,
- needReply: Boolean,
p: Promise[Any])
- extends NettyRpcCallContext(endpointRef, senderAddress, needReply) {
+ extends NettyRpcCallContext(senderAddress) {
override protected def send(message: Any): Unit = {
p.success(message)
@@ -77,11 +56,9 @@ private[netty] class LocalNettyRpcCallContext(
*/
private[netty] class RemoteNettyRpcCallContext(
nettyEnv: NettyRpcEnv,
- endpointRef: NettyRpcEndpointRef,
callback: RpcResponseCallback,
- senderAddress: RpcAddress,
- needReply: Boolean)
- extends NettyRpcCallContext(endpointRef, senderAddress, needReply) {
+ senderAddress: RpcAddress)
+ extends NettyRpcCallContext(senderAddress) {
override protected def send(message: Any): Unit = {
val reply = nettyEnv.serialize(message)
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index 09093819bb22..7af63728652f 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -17,30 +17,28 @@
package org.apache.spark.rpc.netty
import java.io._
-import java.lang.{Boolean => JBoolean}
import java.net.{InetSocketAddress, URI}
import java.nio.ByteBuffer
+import java.nio.channels.{Pipe, ReadableByteChannel, WritableByteChannel}
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicBoolean
-import javax.annotation.Nullable;
-import javax.annotation.concurrent.GuardedBy
+import javax.annotation.Nullable
-import scala.collection.mutable
import scala.concurrent.{Future, Promise}
import scala.reflect.ClassTag
-import scala.util.{DynamicVariable, Failure, Success}
+import scala.util.{DynamicVariable, Failure, Success, Try}
import scala.util.control.NonFatal
-import com.google.common.base.Preconditions
-import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.{SecurityManager, SparkConf}
+import org.apache.spark.internal.Logging
import org.apache.spark.network.TransportContext
import org.apache.spark.network.client._
+import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap}
import org.apache.spark.network.netty.SparkTransportConf
-import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap}
import org.apache.spark.network.server._
import org.apache.spark.rpc._
-import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance}
-import org.apache.spark.util.{ThreadUtils, Utils}
+import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance, SerializationStream}
+import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, ThreadUtils, Utils}
private[netty] class NettyRpcEnv(
val conf: SparkConf,
@@ -48,26 +46,39 @@ private[netty] class NettyRpcEnv(
host: String,
securityManager: SecurityManager) extends RpcEnv(conf) with Logging {
- private val transportConf = SparkTransportConf.fromSparkConf(
- conf.clone.set("spark.shuffle.io.numConnectionsPerPeer", "1"),
+ private[netty] val transportConf = SparkTransportConf.fromSparkConf(
+ conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"),
+ "rpc",
conf.getInt("spark.rpc.io.threads", 0))
private val dispatcher: Dispatcher = new Dispatcher(this)
+ private val streamManager = new NettyStreamManager(this)
+
private val transportContext = new TransportContext(transportConf,
- new NettyRpcHandler(dispatcher, this))
+ new NettyRpcHandler(dispatcher, this, streamManager))
- private val clientFactory = {
- val bootstraps: java.util.List[TransportClientBootstrap] =
- if (securityManager.isAuthenticationEnabled()) {
- java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager,
- securityManager.isSaslEncryptionEnabled()))
- } else {
- java.util.Collections.emptyList[TransportClientBootstrap]
- }
- transportContext.createClientFactory(bootstraps)
+ private def createClientBootstraps(): java.util.List[TransportClientBootstrap] = {
+ if (securityManager.isAuthenticationEnabled()) {
+ java.util.Arrays.asList(new AuthClientBootstrap(transportConf,
+ securityManager.getSaslUser(), securityManager))
+ } else {
+ java.util.Collections.emptyList[TransportClientBootstrap]
+ }
}
+ private val clientFactory = transportContext.createClientFactory(createClientBootstraps())
+
+ /**
+ * A separate client factory for file downloads. This avoids using the same RPC handler as
+ * the main RPC context, so that events caused by these clients are kept isolated from the
+ * main RPC traffic.
+ *
+ * It also allows for different configuration of certain properties, such as the number of
+ * connections per peer.
+ */
+ @volatile private var fileDownloadFactory: TransportClientFactory = _
+
val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout")
// Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool
@@ -97,14 +108,14 @@ private[netty] class NettyRpcEnv(
}
}
- def startServer(port: Int): Unit = {
+ def startServer(bindAddress: String, port: Int): Unit = {
val bootstraps: java.util.List[TransportServerBootstrap] =
if (securityManager.isAuthenticationEnabled()) {
- java.util.Arrays.asList(new SaslServerBootstrap(transportConf, securityManager))
+ java.util.Arrays.asList(new AuthServerBootstrap(transportConf, securityManager))
} else {
java.util.Collections.emptyList()
}
- server = transportContext.createServer(port, bootstraps)
+ server = transportContext.createServer(bindAddress, port, bootstraps)
dispatcher.registerRpcEndpoint(
RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))
}
@@ -139,7 +150,7 @@ private[netty] class NettyRpcEnv(
private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = {
if (receiver.client != null) {
- receiver.client.sendRpc(message.content, message.createCallback(receiver.client));
+ message.sendWith(receiver.client)
} else {
require(receiver.address != null,
"Cannot send message to client endpoint with no listen address.")
@@ -171,25 +182,14 @@ private[netty] class NettyRpcEnv(
val remoteAddr = message.receiver.address
if (remoteAddr == address) {
// Message to a local RPC endpoint.
- val promise = Promise[Any]()
- dispatcher.postLocalMessage(message, promise)
- promise.future.onComplete {
- case Success(response) =>
- val ack = response.asInstanceOf[Ack]
- logTrace(s"Received ack from ${ack.sender}")
- case Failure(e) =>
- logWarning(s"Exception when sending $message", e)
- }(ThreadUtils.sameThread)
+ try {
+ dispatcher.postOneWayMessage(message)
+ } catch {
+ case e: RpcEnvStoppedException => logDebug(e.getMessage)
+ }
} else {
// Message to a remote RPC endpoint.
- postToOutbox(message.receiver, OutboxMessage(serialize(message),
- (e) => {
- logWarning(s"Exception when sending $message", e)
- },
- (client, response) => {
- val ack = deserialize[Ack](client, response)
- logDebug(s"Receive ack from ${ack.sender}")
- }))
+ postToOutbox(message.receiver, OneWayOutboxMessage(message.serialize(this)))
}
}
@@ -197,58 +197,77 @@ private[netty] class NettyRpcEnv(
clientFactory.createClient(address.host, address.port)
}
- private[netty] def ask(message: RequestMessage): Future[Any] = {
+ private[netty] def ask[T: ClassTag](message: RequestMessage, timeout: RpcTimeout): Future[T] = {
val promise = Promise[Any]()
val remoteAddr = message.receiver.address
- if (remoteAddr == address) {
- val p = Promise[Any]()
- dispatcher.postLocalMessage(message, p)
- p.future.onComplete {
- case Success(response) =>
- val reply = response.asInstanceOf[AskResponse]
- if (reply.reply.isInstanceOf[RpcFailure]) {
- if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
- logWarning(s"Ignore failure: ${reply.reply}")
- }
- } else if (!promise.trySuccess(reply.reply)) {
- logWarning(s"Ignore message: ${reply}")
- }
- case Failure(e) =>
- if (!promise.tryFailure(e)) {
- logWarning("Ignore Exception", e)
- }
+
+ def onFailure(e: Throwable): Unit = {
+ if (!promise.tryFailure(e)) {
+ e match {
+ case e : RpcEnvStoppedException => logDebug (s"Ignored failure: $e")
+ case _ => logWarning(s"Ignored failure: $e")
+ }
+ }
+ }
+
+ def onSuccess(reply: Any): Unit = reply match {
+ case RpcFailure(e) => onFailure(e)
+ case rpcReply =>
+ if (!promise.trySuccess(rpcReply)) {
+ logWarning(s"Ignored message: $reply")
+ }
+ }
+
+ try {
+ if (remoteAddr == address) {
+ val p = Promise[Any]()
+ p.future.onComplete {
+ case Success(response) => onSuccess(response)
+ case Failure(e) => onFailure(e)
+ }(ThreadUtils.sameThread)
+ dispatcher.postLocalMessage(message, p)
+ } else {
+ val rpcMessage = RpcOutboxMessage(message.serialize(this),
+ onFailure,
+ (client, response) => onSuccess(deserialize[Any](client, response)))
+ postToOutbox(message.receiver, rpcMessage)
+ promise.future.onFailure {
+ case _: TimeoutException => rpcMessage.onTimeout()
+ case _ =>
+ }(ThreadUtils.sameThread)
+ }
+
+ val timeoutCancelable = timeoutScheduler.schedule(new Runnable {
+ override def run(): Unit = {
+ onFailure(new TimeoutException(s"Cannot receive any reply from ${remoteAddr} " +
+ s"in ${timeout.duration}"))
+ }
+ }, timeout.duration.toNanos, TimeUnit.NANOSECONDS)
+ promise.future.onComplete { v =>
+ timeoutCancelable.cancel(true)
}(ThreadUtils.sameThread)
- } else {
- postToOutbox(message.receiver, OutboxMessage(serialize(message),
- (e) => {
- if (!promise.tryFailure(e)) {
- logWarning("Ignore Exception", e)
- }
- },
- (client, response) => {
- val reply = deserialize[AskResponse](client, response)
- if (reply.reply.isInstanceOf[RpcFailure]) {
- if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
- logWarning(s"Ignore failure: ${reply.reply}")
- }
- } else if (!promise.trySuccess(reply.reply)) {
- logWarning(s"Ignore message: ${reply}")
- }
- }))
+ } catch {
+ case NonFatal(e) =>
+ onFailure(e)
}
- promise.future
+ promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
+ }
+
+ private[netty] def serialize(content: Any): ByteBuffer = {
+ javaSerializerInstance.serialize(content)
}
- private[netty] def serialize(content: Any): Array[Byte] = {
- val buffer = javaSerializerInstance.serialize(content)
- java.util.Arrays.copyOfRange(
- buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit)
+ /**
+ * Returns [[SerializationStream]] that forwards the serialized bytes to `out`.
+ */
+ private[netty] def serializeStream(out: OutputStream): SerializationStream = {
+ javaSerializerInstance.serializeStream(out)
}
- private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: Array[Byte]): T = {
+ private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: ByteBuffer): T = {
NettyRpcEnv.currentClient.withValue(client) {
deserialize { () =>
- javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes))
+ javaSerializerInstance.deserialize[T](bytes)
}
}
}
@@ -257,9 +276,6 @@ private[netty] class NettyRpcEnv(
dispatcher.getRpcEndpointRef(endpoint)
}
- override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String =
- new RpcEndpointAddress(address, endpointName).toString
-
override def shutdown(): Unit = {
cleanup()
}
@@ -282,18 +298,21 @@ private[netty] class NettyRpcEnv(
if (timeoutScheduler != null) {
timeoutScheduler.shutdownNow()
}
+ if (dispatcher != null) {
+ dispatcher.stop()
+ }
if (server != null) {
server.close()
}
if (clientFactory != null) {
clientFactory.close()
}
- if (dispatcher != null) {
- dispatcher.stop()
- }
if (clientConnectionExecutor != null) {
clientConnectionExecutor.shutdownNow()
}
+ if (fileDownloadFactory != null) {
+ fileDownloadFactory.close()
+ }
}
override def deserialize[T](deserializationAction: () => T): T = {
@@ -302,10 +321,113 @@ private[netty] class NettyRpcEnv(
}
}
+ override def fileServer: RpcEnvFileServer = streamManager
+
+ override def openChannel(uri: String): ReadableByteChannel = {
+ val parsedUri = new URI(uri)
+ require(parsedUri.getHost() != null, "Host name must be defined.")
+ require(parsedUri.getPort() > 0, "Port must be defined.")
+ require(parsedUri.getPath() != null && parsedUri.getPath().nonEmpty, "Path must be defined.")
+
+ val pipe = Pipe.open()
+ val source = new FileDownloadChannel(pipe.source())
+ Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
+ val client = downloadClient(parsedUri.getHost(), parsedUri.getPort())
+ val callback = new FileDownloadCallback(pipe.sink(), source, client)
+ client.stream(parsedUri.getPath(), callback)
+ })(catchBlock = {
+ pipe.sink().close()
+ source.close()
+ })
+
+ source
+ }
+
+ private def downloadClient(host: String, port: Int): TransportClient = {
+ if (fileDownloadFactory == null) synchronized {
+ if (fileDownloadFactory == null) {
+ val module = "files"
+ val prefix = "spark.rpc.io."
+ val clone = conf.clone()
+
+ // Copy any RPC configuration that is not overridden in the spark.files namespace.
+ conf.getAll.foreach { case (key, value) =>
+ if (key.startsWith(prefix)) {
+ val opt = key.substring(prefix.length())
+ clone.setIfMissing(s"spark.$module.io.$opt", value)
+ }
+ }
+
+ val ioThreads = clone.getInt("spark.files.io.threads", 1)
+ val downloadConf = SparkTransportConf.fromSparkConf(clone, module, ioThreads)
+ val downloadContext = new TransportContext(downloadConf, new NoOpRpcHandler(), true)
+ fileDownloadFactory = downloadContext.createClientFactory(createClientBootstraps())
+ }
+ }
+ fileDownloadFactory.createClient(host, port)
+ }
+
+ private class FileDownloadChannel(source: Pipe.SourceChannel) extends ReadableByteChannel {
+
+ @volatile private var error: Throwable = _
+
+ def setError(e: Throwable): Unit = {
+ // This setError callback is invoked by internal RPC threads in order to propagate remote
+ // exceptions to application-level threads which are reading from this channel. When an
+ // RPC error occurs, the RPC system will call setError() and then will close the
+ // Pipe.SinkChannel corresponding to the other end of the `source` pipe. Closing of the pipe
+ // sink will cause `source.read()` operations to return EOF, unblocking the application-level
+ // reading thread. Thus there is no need to actually call `source.close()` here in the
+ // onError() callback and, in fact, calling it here would be dangerous because the close()
+ // would be asynchronous with respect to the read() call and could trigger race-conditions
+ // that lead to data corruption. See the PR for SPARK-22982 for more details on this topic.
+ error = e
+ }
+
+ override def read(dst: ByteBuffer): Int = {
+ Try(source.read(dst)) match {
+ // See the documentation above in setError(): if an RPC error has occurred then setError()
+ // will be called to propagate the RPC error and then `source`'s corresponding
+ // Pipe.SinkChannel will be closed, unblocking this read. In that case, we want to propagate
+ // the remote RPC exception (and not any exceptions triggered by the pipe close, such as
+ // ChannelClosedException), hence this `error != null` check:
+ case _ if error != null => throw error
+ case Success(bytesRead) => bytesRead
+ case Failure(readErr) => throw readErr
+ }
+ }
+
+ override def close(): Unit = source.close()
+
+ override def isOpen(): Boolean = source.isOpen()
+
+ }
+
+ private class FileDownloadCallback(
+ sink: WritableByteChannel,
+ source: FileDownloadChannel,
+ client: TransportClient) extends StreamCallback {
+
+ override def onData(streamId: String, buf: ByteBuffer): Unit = {
+ while (buf.remaining() > 0) {
+ sink.write(buf)
+ }
+ }
+
+ override def onComplete(streamId: String): Unit = {
+ sink.close()
+ }
+
+ override def onFailure(streamId: String, cause: Throwable): Unit = {
+ logDebug(s"Error downloading stream $streamId.", cause)
+ source.setError(cause)
+ sink.close()
+ }
+
+ }
}
private[netty] object NettyRpcEnv extends Logging {
-
/**
* When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]].
* Use `currentEnv` to wrap the deserialization codes. E.g.,
@@ -326,7 +448,7 @@ private[netty] object NettyRpcEnv extends Logging {
}
-private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
+private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
def create(config: RpcEnvConfig): RpcEnv = {
val sparkConf = config.conf
@@ -335,14 +457,15 @@ private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
val javaSerializerInstance =
new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
val nettyEnv =
- new NettyRpcEnv(sparkConf, javaSerializerInstance, config.host, config.securityManager)
+ new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress,
+ config.securityManager)
if (!config.clientMode) {
val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
- nettyEnv.startServer(actualPort)
- (nettyEnv, actualPort)
+ nettyEnv.startServer(config.bindAddress, actualPort)
+ (nettyEnv, nettyEnv.address.port)
}
try {
- Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1
+ Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1
} catch {
case NonFatal(e) =>
nettyEnv.shutdown()
@@ -372,20 +495,16 @@ private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
* @param conf Spark configuration.
* @param endpointAddress The address where the endpoint is listening.
* @param nettyEnv The RpcEnv associated with this ref.
- * @param local Whether the referenced endpoint lives in the same process.
*/
private[netty] class NettyRpcEndpointRef(
@transient private val conf: SparkConf,
- endpointAddress: RpcEndpointAddress,
- @transient @volatile private var nettyEnv: NettyRpcEnv)
- extends RpcEndpointRef(conf) with Serializable with Logging {
+ private val endpointAddress: RpcEndpointAddress,
+ @transient @volatile private var nettyEnv: NettyRpcEnv) extends RpcEndpointRef(conf) {
@transient @volatile var client: TransportClient = _
- private val _address = if (endpointAddress.rpcAddress != null) endpointAddress else null
- private val _name = endpointAddress.name
-
- override def address: RpcAddress = if (_address != null) _address.rpcAddress else null
+ override def address: RpcAddress =
+ if (endpointAddress.rpcAddress != null) endpointAddress.rpcAddress else null
private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject()
@@ -397,64 +516,103 @@ private[netty] class NettyRpcEndpointRef(
out.defaultWriteObject()
}
- override def name: String = _name
+ override def name: String = endpointAddress.name
override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
- val promise = Promise[Any]()
- val timeoutCancelable = nettyEnv.timeoutScheduler.schedule(new Runnable {
- override def run(): Unit = {
- promise.tryFailure(new TimeoutException("Cannot receive any reply in " + timeout.duration))
- }
- }, timeout.duration.toNanos, TimeUnit.NANOSECONDS)
- val f = nettyEnv.ask(RequestMessage(nettyEnv.address, this, message, true))
- f.onComplete { v =>
- timeoutCancelable.cancel(true)
- if (!promise.tryComplete(v)) {
- logWarning(s"Ignore message $v")
- }
- }(ThreadUtils.sameThread)
- promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
+ nettyEnv.ask(new RequestMessage(nettyEnv.address, this, message), timeout)
}
override def send(message: Any): Unit = {
require(message != null, "Message is null")
- nettyEnv.send(RequestMessage(nettyEnv.address, this, message, false))
+ nettyEnv.send(new RequestMessage(nettyEnv.address, this, message))
}
- override def toString: String = s"NettyRpcEndpointRef(${_address})"
-
- def toURI: URI = new URI(s"spark://${_address}")
+ override def toString: String = s"NettyRpcEndpointRef(${endpointAddress})"
final override def equals(that: Any): Boolean = that match {
- case other: NettyRpcEndpointRef => _address == other._address
+ case other: NettyRpcEndpointRef => endpointAddress == other.endpointAddress
case _ => false
}
- final override def hashCode(): Int = if (_address == null) 0 else _address.hashCode()
+ final override def hashCode(): Int =
+ if (endpointAddress == null) 0 else endpointAddress.hashCode()
}
/**
* The message that is sent from the sender to the receiver.
+ *
+ * @param senderAddress the sender address. It's `null` if this message is from a client
+ * `NettyRpcEnv`.
+ * @param receiver the receiver of this message.
+ * @param content the message content.
*/
-private[netty] case class RequestMessage(
- senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any, needReply: Boolean)
+private[netty] class RequestMessage(
+ val senderAddress: RpcAddress,
+ val receiver: NettyRpcEndpointRef,
+ val content: Any) {
+
+ /** Manually serialize [[RequestMessage]] to minimize the size. */
+ def serialize(nettyEnv: NettyRpcEnv): ByteBuffer = {
+ val bos = new ByteBufferOutputStream()
+ val out = new DataOutputStream(bos)
+ try {
+ writeRpcAddress(out, senderAddress)
+ writeRpcAddress(out, receiver.address)
+ out.writeUTF(receiver.name)
+ val s = nettyEnv.serializeStream(out)
+ try {
+ s.writeObject(content)
+ } finally {
+ s.close()
+ }
+ } finally {
+ out.close()
+ }
+ bos.toByteBuffer
+ }
-/**
- * The base trait for all messages that are sent back from the receiver to the sender.
- */
-private[netty] trait ResponseMessage
+ private def writeRpcAddress(out: DataOutputStream, rpcAddress: RpcAddress): Unit = {
+ if (rpcAddress == null) {
+ out.writeBoolean(false)
+ } else {
+ out.writeBoolean(true)
+ out.writeUTF(rpcAddress.host)
+ out.writeInt(rpcAddress.port)
+ }
+ }
-/**
- * The reply for `ask` from the receiver side.
- */
-private[netty] case class AskResponse(sender: NettyRpcEndpointRef, reply: Any)
- extends ResponseMessage
+ override def toString: String = s"RequestMessage($senderAddress, $receiver, $content)"
+}
-/**
- * A message to send back to the receiver side. It's necessary because [[TransportClient]] only
- * clean the resources when it receives a reply.
- */
-private[netty] case class Ack(sender: NettyRpcEndpointRef) extends ResponseMessage
+private[netty] object RequestMessage {
+
+ private def readRpcAddress(in: DataInputStream): RpcAddress = {
+ val hasRpcAddress = in.readBoolean()
+ if (hasRpcAddress) {
+ RpcAddress(in.readUTF(), in.readInt())
+ } else {
+ null
+ }
+ }
+
+ def apply(nettyEnv: NettyRpcEnv, client: TransportClient, bytes: ByteBuffer): RequestMessage = {
+ val bis = new ByteBufferInputStream(bytes)
+ val in = new DataInputStream(bis)
+ try {
+ val senderAddress = readRpcAddress(in)
+ val endpointAddress = RpcEndpointAddress(readRpcAddress(in), in.readUTF())
+ val ref = new NettyRpcEndpointRef(nettyEnv.conf, endpointAddress, nettyEnv)
+ ref.client = client
+ new RequestMessage(
+ senderAddress,
+ ref,
+ // The remaining bytes in `bytes` are the message content.
+ nettyEnv.deserialize(client, bytes))
+ } finally {
+ in.close()
+ }
+ }
+}
/**
* A response that indicates some failure happens in the receiver side.
@@ -474,40 +632,60 @@ private[netty] case class RpcFailure(e: Throwable)
* with different `RpcAddress` information).
*/
private[netty] class NettyRpcHandler(
- dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with Logging {
+ dispatcher: Dispatcher,
+ nettyEnv: NettyRpcEnv,
+ streamManager: StreamManager) extends RpcHandler with Logging {
- // TODO: Can we add connection callback (channel registered) to the underlying framework?
- // A variable to track whether we should dispatch the RemoteProcessConnected message.
- private val clients = new ConcurrentHashMap[TransportClient, JBoolean]()
+ // A variable to track the remote RpcEnv addresses of all clients
+ private val remoteAddresses = new ConcurrentHashMap[RpcAddress, RpcAddress]()
override def receive(
client: TransportClient,
- message: Array[Byte],
+ message: ByteBuffer,
callback: RpcResponseCallback): Unit = {
+ val messageToDispatch = internalReceive(client, message)
+ dispatcher.postRemoteMessage(messageToDispatch, callback)
+ }
+
+ override def receive(
+ client: TransportClient,
+ message: ByteBuffer): Unit = {
+ val messageToDispatch = internalReceive(client, message)
+ dispatcher.postOneWayMessage(messageToDispatch)
+ }
+
+ private def internalReceive(client: TransportClient, message: ByteBuffer): RequestMessage = {
val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
assert(addr != null)
- val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
- if (clients.putIfAbsent(client, JBoolean.TRUE) == null) {
- dispatcher.postToAll(RemoteProcessConnected(clientAddr))
- }
- val requestMessage = nettyEnv.deserialize[RequestMessage](client, message)
- val messageToDispatch = if (requestMessage.senderAddress == null) {
- // Create a new message with the socket address of the client as the sender.
- RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content,
- requestMessage.needReply)
- } else {
- requestMessage
+ val clientAddr = RpcAddress(addr.getHostString, addr.getPort)
+ val requestMessage = RequestMessage(nettyEnv, client, message)
+ if (requestMessage.senderAddress == null) {
+ // Create a new message with the socket address of the client as the sender.
+ new RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content)
+ } else {
+ // The remote RpcEnv listens to some port, we should also fire a RemoteProcessConnected for
+ // the listening address
+ val remoteEnvAddress = requestMessage.senderAddress
+ if (remoteAddresses.putIfAbsent(clientAddr, remoteEnvAddress) == null) {
+ dispatcher.postToAll(RemoteProcessConnected(remoteEnvAddress))
}
- dispatcher.postRemoteMessage(messageToDispatch, callback)
+ requestMessage
+ }
}
- override def getStreamManager: StreamManager = new OneForOneStreamManager
+ override def getStreamManager: StreamManager = streamManager
override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = {
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
if (addr != null) {
- val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
+ val clientAddr = RpcAddress(addr.getHostString, addr.getPort)
dispatcher.postToAll(RemoteProcessConnectionError(cause, clientAddr))
+ // If the remove RpcEnv listens to some address, we should also fire a
+ // RemoteProcessConnectionError for the remote RpcEnv listening address
+ val remoteEnvAddress = remoteAddresses.get(clientAddr)
+ if (remoteEnvAddress != null) {
+ dispatcher.postToAll(RemoteProcessConnectionError(cause, remoteEnvAddress))
+ }
} else {
// If the channel is closed before connecting, its remoteAddress will be null.
// See java.net.Socket.getRemoteSocketAddress
@@ -516,13 +694,25 @@ private[netty] class NettyRpcHandler(
}
}
- override def connectionTerminated(client: TransportClient): Unit = {
+ override def channelActive(client: TransportClient): Unit = {
+ val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
+ assert(addr != null)
+ val clientAddr = RpcAddress(addr.getHostString, addr.getPort)
+ dispatcher.postToAll(RemoteProcessConnected(clientAddr))
+ }
+
+ override def channelInactive(client: TransportClient): Unit = {
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
if (addr != null) {
- val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
- clients.remove(client)
+ val clientAddr = RpcAddress(addr.getHostString, addr.getPort)
nettyEnv.removeOutbox(clientAddr)
dispatcher.postToAll(RemoteProcessDisconnected(clientAddr))
+ val remoteEnvAddress = remoteAddresses.remove(clientAddr)
+ // If the remove RpcEnv listens to some address, we should also fire a
+ // RemoteProcessDisconnected for the remote RpcEnv listening address
+ if (remoteEnvAddress != null) {
+ dispatcher.postToAll(RemoteProcessDisconnected(remoteEnvAddress))
+ }
} else {
// If the channel is closed before connecting, its remoteAddress will be null. In this case,
// we can ignore it since we don't fire "Associated".
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala
new file mode 100644
index 000000000000..780fadd5bda8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.rpc.netty
+
+import java.io.File
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.server.StreamManager
+import org.apache.spark.rpc.RpcEnvFileServer
+import org.apache.spark.util.Utils
+
+/**
+ * StreamManager implementation for serving files from a NettyRpcEnv.
+ *
+ * Three kinds of resources can be registered in this manager, all backed by actual files:
+ *
+ * - "/files": a flat list of files; used as the backend for [[SparkContext.addFile]].
+ * - "/jars": a flat list of files; used as the backend for [[SparkContext.addJar]].
+ * - arbitrary directories; all files under the directory become available through the manager,
+ * respecting the directory's hierarchy.
+ *
+ * Only streaming (openStream) is supported.
+ */
+private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv)
+ extends StreamManager with RpcEnvFileServer {
+
+ private val files = new ConcurrentHashMap[String, File]()
+ private val jars = new ConcurrentHashMap[String, File]()
+ private val dirs = new ConcurrentHashMap[String, File]()
+
+ override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = {
+ throw new UnsupportedOperationException()
+ }
+
+ override def openStream(streamId: String): ManagedBuffer = {
+ val Array(ftype, fname) = streamId.stripPrefix("/").split("/", 2)
+ val file = ftype match {
+ case "files" => files.get(fname)
+ case "jars" => jars.get(fname)
+ case other =>
+ val dir = dirs.get(ftype)
+ require(dir != null, s"Invalid stream URI: $ftype not found.")
+ new File(dir, fname)
+ }
+
+ if (file != null && file.isFile()) {
+ new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length())
+ } else {
+ null
+ }
+ }
+
+ override def addFile(file: File): String = {
+ val existingPath = files.putIfAbsent(file.getName, file)
+ require(existingPath == null || existingPath == file,
+ s"File ${file.getName} was already registered with a different path " +
+ s"(old path = $existingPath, new path = $file")
+ s"${rpcEnv.address.toSparkURL}/files/${Utils.encodeFileNameToURIRawPath(file.getName())}"
+ }
+
+ override def addJar(file: File): String = {
+ val existingPath = jars.putIfAbsent(file.getName, file)
+ require(existingPath == null || existingPath == file,
+ s"File ${file.getName} was already registered with a different path " +
+ s"(old path = $existingPath, new path = $file")
+ s"${rpcEnv.address.toSparkURL}/jars/${Utils.encodeFileNameToURIRawPath(file.getName())}"
+ }
+
+ override def addDirectory(baseUri: String, path: File): String = {
+ val fixedBaseUri = validateDirectoryUri(baseUri)
+ require(dirs.putIfAbsent(fixedBaseUri.stripPrefix("/"), path) == null,
+ s"URI '$fixedBaseUri' already registered.")
+ s"${rpcEnv.address.toSparkURL}$fixedBaseUri"
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
index 2f6817f2eb93..b7e068aa6835 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
@@ -17,29 +17,71 @@
package org.apache.spark.rpc.netty
+import java.nio.ByteBuffer
import java.util.concurrent.Callable
import javax.annotation.concurrent.GuardedBy
import scala.util.control.NonFatal
import org.apache.spark.SparkException
+import org.apache.spark.internal.Logging
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
-import org.apache.spark.rpc.RpcAddress
+import org.apache.spark.rpc.{RpcAddress, RpcEnvStoppedException}
-private[netty] case class OutboxMessage(content: Array[Byte],
- _onFailure: (Throwable) => Unit,
- _onSuccess: (TransportClient, Array[Byte]) => Unit) {
+private[netty] sealed trait OutboxMessage {
- def createCallback(client: TransportClient): RpcResponseCallback = new RpcResponseCallback() {
- override def onFailure(e: Throwable): Unit = {
- _onFailure(e)
+ def sendWith(client: TransportClient): Unit
+
+ def onFailure(e: Throwable): Unit
+
+}
+
+private[netty] case class OneWayOutboxMessage(content: ByteBuffer) extends OutboxMessage
+ with Logging {
+
+ override def sendWith(client: TransportClient): Unit = {
+ client.send(content)
+ }
+
+ override def onFailure(e: Throwable): Unit = {
+ e match {
+ case e1: RpcEnvStoppedException => logDebug(e1.getMessage)
+ case e1: Throwable => logWarning(s"Failed to send one-way RPC.", e1)
}
+ }
+
+}
+
+private[netty] case class RpcOutboxMessage(
+ content: ByteBuffer,
+ _onFailure: (Throwable) => Unit,
+ _onSuccess: (TransportClient, ByteBuffer) => Unit)
+ extends OutboxMessage with RpcResponseCallback with Logging {
- override def onSuccess(response: Array[Byte]): Unit = {
- _onSuccess(client, response)
+ private var client: TransportClient = _
+ private var requestId: Long = _
+
+ override def sendWith(client: TransportClient): Unit = {
+ this.client = client
+ this.requestId = client.sendRpc(content, this)
+ }
+
+ def onTimeout(): Unit = {
+ if (client != null) {
+ client.removeRpcRequest(requestId)
+ } else {
+ logError("Ask timeout before connecting successfully")
}
}
+ override def onFailure(e: Throwable): Unit = {
+ _onFailure(e)
+ }
+
+ override def onSuccess(response: ByteBuffer): Unit = {
+ _onSuccess(client, response)
+ }
+
}
private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
@@ -82,7 +124,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
}
}
if (dropped) {
- message._onFailure(new SparkException("Message is dropped because Outbox is stopped"))
+ message.onFailure(new SparkException("Message is dropped because Outbox is stopped"))
} else {
drainOutbox()
}
@@ -122,7 +164,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
try {
val _client = synchronized { client }
if (_client != null) {
- _client.sendRpc(message.content, message.createCallback(_client))
+ message.sendWith(_client)
} else {
assert(stopped == true)
}
@@ -195,17 +237,14 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
// update messages and it's safe to just drain the queue.
var message = messages.poll()
while (message != null) {
- message._onFailure(e)
+ message.onFailure(e)
message = messages.poll()
}
assert(messages.isEmpty)
}
private def closeClient(): Unit = synchronized {
- // Not sure if `client.close` is idempotent. Just for safety.
- if (client != null) {
- client.close()
- }
+ // Just set client to null. Don't close it in order to reuse the connection.
client = null
}
@@ -229,7 +268,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
// update messages and it's safe to just drain the queue.
var message = messages.poll()
while (message != null) {
- message._onFailure(new SparkException("Message is dropped because Outbox is stopped"))
+ message.onFailure(new SparkException("Message is dropped because Outbox is stopped"))
message = messages.poll()
}
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala
index 99f20da2d66a..430dcc50ba71 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala
@@ -20,7 +20,7 @@ package org.apache.spark.rpc.netty
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv}
/**
- * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if an [[RpcEndpoint]] exists.
+ * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if an `RpcEndpoint` exists.
*
* This is used when setting up a remote endpoint reference.
*/
@@ -35,6 +35,6 @@ private[netty] class RpcEndpointVerifier(override val rpcEnv: RpcEnv, dispatcher
private[netty] object RpcEndpointVerifier {
val NAME = "endpoint-verifier"
- /** A message used to ask the remote [[RpcEndpointVerifier]] if an [[RpcEndpoint]] exists. */
+ /** A message used to ask the remote [[RpcEndpointVerifier]] if an `RpcEndpoint` exists. */
case class CheckExistence(name: String)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
index 146cfb9ba803..0a5fe5a1d3ee 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
@@ -19,47 +19,61 @@ package org.apache.spark.scheduler
import org.apache.spark.annotation.DeveloperApi
+
/**
* :: DeveloperApi ::
* Information about an [[org.apache.spark.Accumulable]] modified during a task or stage.
+ *
+ * @param id accumulator ID
+ * @param name accumulator name
+ * @param update partial value from a task, may be None if used on driver to describe a stage
+ * @param value total accumulated value so far, maybe None if used on executors to describe a task
+ * @param internal whether this accumulator was internal
+ * @param countFailedValues whether to count this accumulator's partial value if the task failed
+ * @param metadata internal metadata associated with this accumulator, if any
+ *
+ * @note Once this is JSON serialized the types of `update` and `value` will be lost and be
+ * cast to strings. This is because the user can define an accumulator of any type and it will
+ * be difficult to preserve the type in consumers of the event log. This does not apply to
+ * internal accumulators that represent task level metrics.
*/
@DeveloperApi
-class AccumulableInfo private[spark] (
- val id: Long,
- val name: String,
- val update: Option[String], // represents a partial update within a task
- val value: String,
- val internal: Boolean) {
-
- override def equals(other: Any): Boolean = other match {
- case acc: AccumulableInfo =>
- this.id == acc.id && this.name == acc.name &&
- this.update == acc.update && this.value == acc.value &&
- this.internal == acc.internal
- case _ => false
- }
+case class AccumulableInfo private[spark] (
+ id: Long,
+ name: Option[String],
+ update: Option[Any], // represents a partial update within a task
+ value: Option[Any],
+ private[spark] val internal: Boolean,
+ private[spark] val countFailedValues: Boolean,
+ // TODO: use this to identify internal task metrics instead of encoding it in the name
+ private[spark] val metadata: Option[String] = None)
- override def hashCode(): Int = {
- val state = Seq(id, name, update, value, internal)
- state.map(_.hashCode).reduceLeft(31 * _ + _)
- }
-}
+/**
+ * A collection of deprecated constructors. This will be removed soon.
+ */
object AccumulableInfo {
+
+ @deprecated("do not create AccumulableInfo", "2.0.0")
def apply(
id: Long,
name: String,
update: Option[String],
value: String,
internal: Boolean): AccumulableInfo = {
- new AccumulableInfo(id, name, update, value, internal)
+ new AccumulableInfo(
+ id, Option(name), update, Option(value), internal, countFailedValues = false)
}
+ @deprecated("do not create AccumulableInfo", "2.0.0")
def apply(id: Long, name: String, update: Option[String], value: String): AccumulableInfo = {
- new AccumulableInfo(id, name, update, value, internal = false)
+ new AccumulableInfo(
+ id, Option(name), update, Option(value), internal = false, countFailedValues = false)
}
+ @deprecated("do not create AccumulableInfo", "2.0.0")
def apply(id: Long, name: String, value: String): AccumulableInfo = {
- new AccumulableInfo(id, name, None, value, internal = false)
+ new AccumulableInfo(
+ id, Option(name), None, Option(value), internal = false, countFailedValues = false)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala
index a3d2db31301b..949e88f60627 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala
@@ -19,7 +19,6 @@ package org.apache.spark.scheduler
import java.util.Properties
-import org.apache.spark.TaskContext
import org.apache.spark.util.CallSite
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala
index 9f218c64cac2..28c45d800ed0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala
@@ -32,6 +32,8 @@ private[spark] class ApplicationEventListener extends SparkListener {
var endTime: Option[Long] = None
var viewAcls: Option[String] = None
var adminAcls: Option[String] = None
+ var viewAclsGroups: Option[String] = None
+ var adminAclsGroups: Option[String] = None
override def onApplicationStart(applicationStart: SparkListenerApplicationStart) {
appName = Some(applicationStart.appName)
@@ -51,6 +53,8 @@ private[spark] class ApplicationEventListener extends SparkListener {
val allProperties = environmentDetails("Spark Properties").toMap
viewAcls = allProperties.get("spark.ui.view.acls")
adminAcls = allProperties.get("spark.admin.acls")
+ viewAclsGroups = allProperties.get("spark.ui.view.acls.groups")
+ adminAclsGroups = allProperties.get("spark.admin.acls.groups")
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala
new file mode 100644
index 000000000000..e130e609e4f6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala
@@ -0,0 +1,419 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.concurrent.atomic.AtomicReference
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+
+import org.apache.spark.{ExecutorAllocationClient, SparkConf, SparkContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config
+import org.apache.spark.util.{Clock, SystemClock, Utils}
+
+/**
+ * BlacklistTracker is designed to track problematic executors and nodes. It supports blacklisting
+ * executors and nodes across an entire application (with a periodic expiry). TaskSetManagers add
+ * additional blacklisting of executors and nodes for individual tasks and stages which works in
+ * concert with the blacklisting here.
+ *
+ * The tracker needs to deal with a variety of workloads, eg.:
+ *
+ * * bad user code -- this may lead to many task failures, but that should not count against
+ * individual executors
+ * * many small stages -- this may prevent a bad executor for having many failures within one
+ * stage, but still many failures over the entire application
+ * * "flaky" executors -- they don't fail every task, but are still faulty enough to merit
+ * blacklisting
+ *
+ * See the design doc on SPARK-8425 for a more in-depth discussion.
+ *
+ * THREADING: As with most helpers of TaskSchedulerImpl, this is not thread-safe. Though it is
+ * called by multiple threads, callers must already have a lock on the TaskSchedulerImpl. The
+ * one exception is [[nodeBlacklist()]], which can be called without holding a lock.
+ */
+private[scheduler] class BlacklistTracker (
+ private val listenerBus: LiveListenerBus,
+ conf: SparkConf,
+ allocationClient: Option[ExecutorAllocationClient],
+ clock: Clock = new SystemClock()) extends Logging {
+
+ def this(sc: SparkContext, allocationClient: Option[ExecutorAllocationClient]) = {
+ this(sc.listenerBus, sc.conf, allocationClient)
+ }
+
+ BlacklistTracker.validateBlacklistConfs(conf)
+ private val MAX_FAILURES_PER_EXEC = conf.get(config.MAX_FAILURES_PER_EXEC)
+ private val MAX_FAILED_EXEC_PER_NODE = conf.get(config.MAX_FAILED_EXEC_PER_NODE)
+ val BLACKLIST_TIMEOUT_MILLIS = BlacklistTracker.getBlacklistTimeout(conf)
+
+ /**
+ * A map from executorId to information on task failures. Tracks the time of each task failure,
+ * so that we can avoid blacklisting executors due to failures that are very far apart. We do not
+ * actively remove from this as soon as tasks hit their timeouts, to avoid the time it would take
+ * to do so. But it will not grow too large, because as soon as an executor gets too many
+ * failures, we blacklist the executor and remove its entry here.
+ */
+ private val executorIdToFailureList = new HashMap[String, ExecutorFailureList]()
+ val executorIdToBlacklistStatus = new HashMap[String, BlacklistedExecutor]()
+ val nodeIdToBlacklistExpiryTime = new HashMap[String, Long]()
+ /**
+ * An immutable copy of the set of nodes that are currently blacklisted. Kept in an
+ * AtomicReference to make [[nodeBlacklist()]] thread-safe.
+ */
+ private val _nodeBlacklist = new AtomicReference[Set[String]](Set())
+ /**
+ * Time when the next blacklist will expire. Used as a
+ * shortcut to avoid iterating over all entries in the blacklist when none will have expired.
+ */
+ var nextExpiryTime: Long = Long.MaxValue
+ /**
+ * Mapping from nodes to all of the executors that have been blacklisted on that node. We do *not*
+ * remove from this when executors are removed from spark, so we can track when we get multiple
+ * successive blacklisted executors on one node. Nonetheless, it will not grow too large because
+ * there cannot be many blacklisted executors on one node, before we stop requesting more
+ * executors on that node, and we clean up the list of blacklisted executors once an executor has
+ * been blacklisted for BLACKLIST_TIMEOUT_MILLIS.
+ */
+ val nodeToBlacklistedExecs = new HashMap[String, HashSet[String]]()
+
+ /**
+ * Un-blacklists executors and nodes that have been blacklisted for at least
+ * BLACKLIST_TIMEOUT_MILLIS
+ */
+ def applyBlacklistTimeout(): Unit = {
+ val now = clock.getTimeMillis()
+ // quickly check if we've got anything to expire from blacklist -- if not, avoid doing any work
+ if (now > nextExpiryTime) {
+ // Apply the timeout to blacklisted nodes and executors
+ val execsToUnblacklist = executorIdToBlacklistStatus.filter(_._2.expiryTime < now).keys
+ if (execsToUnblacklist.nonEmpty) {
+ // Un-blacklist any executors that have been blacklisted longer than the blacklist timeout.
+ logInfo(s"Removing executors $execsToUnblacklist from blacklist because the blacklist " +
+ s"for those executors has timed out")
+ execsToUnblacklist.foreach { exec =>
+ val status = executorIdToBlacklistStatus.remove(exec).get
+ val failedExecsOnNode = nodeToBlacklistedExecs(status.node)
+ listenerBus.post(SparkListenerExecutorUnblacklisted(now, exec))
+ failedExecsOnNode.remove(exec)
+ if (failedExecsOnNode.isEmpty) {
+ nodeToBlacklistedExecs.remove(status.node)
+ }
+ }
+ }
+ val nodesToUnblacklist = nodeIdToBlacklistExpiryTime.filter(_._2 < now).keys
+ if (nodesToUnblacklist.nonEmpty) {
+ // Un-blacklist any nodes that have been blacklisted longer than the blacklist timeout.
+ logInfo(s"Removing nodes $nodesToUnblacklist from blacklist because the blacklist " +
+ s"has timed out")
+ nodesToUnblacklist.foreach { node =>
+ nodeIdToBlacklistExpiryTime.remove(node)
+ listenerBus.post(SparkListenerNodeUnblacklisted(now, node))
+ }
+ _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet)
+ }
+ updateNextExpiryTime()
+ }
+ }
+
+ private def updateNextExpiryTime(): Unit = {
+ val execMinExpiry = if (executorIdToBlacklistStatus.nonEmpty) {
+ executorIdToBlacklistStatus.map{_._2.expiryTime}.min
+ } else {
+ Long.MaxValue
+ }
+ val nodeMinExpiry = if (nodeIdToBlacklistExpiryTime.nonEmpty) {
+ nodeIdToBlacklistExpiryTime.values.min
+ } else {
+ Long.MaxValue
+ }
+ nextExpiryTime = math.min(execMinExpiry, nodeMinExpiry)
+ }
+
+
+ def updateBlacklistForSuccessfulTaskSet(
+ stageId: Int,
+ stageAttemptId: Int,
+ failuresByExec: HashMap[String, ExecutorFailuresInTaskSet]): Unit = {
+ // if any tasks failed, we count them towards the overall failure count for the executor at
+ // this point.
+ val now = clock.getTimeMillis()
+ failuresByExec.foreach { case (exec, failuresInTaskSet) =>
+ val appFailuresOnExecutor =
+ executorIdToFailureList.getOrElseUpdate(exec, new ExecutorFailureList)
+ appFailuresOnExecutor.addFailures(stageId, stageAttemptId, failuresInTaskSet)
+ appFailuresOnExecutor.dropFailuresWithTimeoutBefore(now)
+ val newTotal = appFailuresOnExecutor.numUniqueTaskFailures
+
+ val expiryTimeForNewBlacklists = now + BLACKLIST_TIMEOUT_MILLIS
+ // If this pushes the total number of failures over the threshold, blacklist the executor.
+ // If its already blacklisted, we avoid "re-blacklisting" (which can happen if there were
+ // other tasks already running in another taskset when it got blacklisted), because it makes
+ // some of the logic around expiry times a little more confusing. But it also wouldn't be a
+ // problem to re-blacklist, with a later expiry time.
+ if (newTotal >= MAX_FAILURES_PER_EXEC && !executorIdToBlacklistStatus.contains(exec)) {
+ logInfo(s"Blacklisting executor id: $exec because it has $newTotal" +
+ s" task failures in successful task sets")
+ val node = failuresInTaskSet.node
+ executorIdToBlacklistStatus.put(exec, BlacklistedExecutor(node, expiryTimeForNewBlacklists))
+ listenerBus.post(SparkListenerExecutorBlacklisted(now, exec, newTotal))
+ executorIdToFailureList.remove(exec)
+ updateNextExpiryTime()
+ if (conf.get(config.BLACKLIST_KILL_ENABLED)) {
+ allocationClient match {
+ case Some(allocationClient) =>
+ logInfo(s"Killing blacklisted executor id $exec " +
+ s"since spark.blacklist.killBlacklistedExecutors is set.")
+ allocationClient.killExecutors(Seq(exec), true, true)
+ case None =>
+ logWarning(s"Not attempting to kill blacklisted executor id $exec " +
+ s"since allocation client is not defined.")
+ }
+ }
+
+ // In addition to blacklisting the executor, we also update the data for failures on the
+ // node, and potentially put the entire node into a blacklist as well.
+ val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(node, HashSet[String]())
+ blacklistedExecsOnNode += exec
+ // If the node is already in the blacklist, we avoid adding it again with a later expiry
+ // time.
+ if (blacklistedExecsOnNode.size >= MAX_FAILED_EXEC_PER_NODE &&
+ !nodeIdToBlacklistExpiryTime.contains(node)) {
+ logInfo(s"Blacklisting node $node because it has ${blacklistedExecsOnNode.size} " +
+ s"executors blacklisted: ${blacklistedExecsOnNode}")
+ nodeIdToBlacklistExpiryTime.put(node, expiryTimeForNewBlacklists)
+ listenerBus.post(SparkListenerNodeBlacklisted(now, node, blacklistedExecsOnNode.size))
+ _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet)
+ if (conf.get(config.BLACKLIST_KILL_ENABLED)) {
+ allocationClient match {
+ case Some(allocationClient) =>
+ logInfo(s"Killing all executors on blacklisted host $node " +
+ s"since spark.blacklist.killBlacklistedExecutors is set.")
+ if (allocationClient.killExecutorsOnHost(node) == false) {
+ logError(s"Killing executors on node $node failed.")
+ }
+ case None =>
+ logWarning(s"Not attempting to kill executors on blacklisted host $node " +
+ s"since allocation client is not defined.")
+ }
+ }
+ }
+ }
+ }
+ }
+
+ def isExecutorBlacklisted(executorId: String): Boolean = {
+ executorIdToBlacklistStatus.contains(executorId)
+ }
+
+ /**
+ * Get the full set of nodes that are blacklisted. Unlike other methods in this class, this *IS*
+ * thread-safe -- no lock required on a taskScheduler.
+ */
+ def nodeBlacklist(): Set[String] = {
+ _nodeBlacklist.get()
+ }
+
+ def isNodeBlacklisted(node: String): Boolean = {
+ nodeIdToBlacklistExpiryTime.contains(node)
+ }
+
+ def handleRemovedExecutor(executorId: String): Unit = {
+ // We intentionally do not clean up executors that are already blacklisted in
+ // nodeToBlacklistedExecs, so that if another executor on the same node gets blacklisted, we can
+ // blacklist the entire node. We also can't clean up executorIdToBlacklistStatus, so we can
+ // eventually remove the executor after the timeout. Despite not clearing those structures
+ // here, we don't expect they will grow too big since you won't get too many executors on one
+ // node, and the timeout will clear it up periodically in any case.
+ executorIdToFailureList -= executorId
+ }
+
+
+ /**
+ * Tracks all failures for one executor (that have not passed the timeout).
+ *
+ * In general we actually expect this to be extremely small, since it won't contain more than the
+ * maximum number of task failures before an executor is failed (default 2).
+ */
+ private[scheduler] final class ExecutorFailureList extends Logging {
+
+ private case class TaskId(stage: Int, stageAttempt: Int, taskIndex: Int)
+
+ /**
+ * All failures on this executor in successful task sets.
+ */
+ private var failuresAndExpiryTimes = ArrayBuffer[(TaskId, Long)]()
+ /**
+ * As an optimization, we track the min expiry time over all entries in failuresAndExpiryTimes
+ * so its quick to tell if there are any failures with expiry before the current time.
+ */
+ private var minExpiryTime = Long.MaxValue
+
+ def addFailures(
+ stage: Int,
+ stageAttempt: Int,
+ failuresInTaskSet: ExecutorFailuresInTaskSet): Unit = {
+ failuresInTaskSet.taskToFailureCountAndFailureTime.foreach {
+ case (taskIdx, (_, failureTime)) =>
+ val expiryTime = failureTime + BLACKLIST_TIMEOUT_MILLIS
+ failuresAndExpiryTimes += ((TaskId(stage, stageAttempt, taskIdx), expiryTime))
+ if (expiryTime < minExpiryTime) {
+ minExpiryTime = expiryTime
+ }
+ }
+ }
+
+ /**
+ * The number of unique tasks that failed on this executor. Only counts failures within the
+ * timeout, and in successful tasksets.
+ */
+ def numUniqueTaskFailures: Int = failuresAndExpiryTimes.size
+
+ def isEmpty: Boolean = failuresAndExpiryTimes.isEmpty
+
+ /**
+ * Apply the timeout to individual tasks. This is to prevent one-off failures that are very
+ * spread out in time (and likely have nothing to do with problems on the executor) from
+ * triggering blacklisting. However, note that we do *not* remove executors and nodes from
+ * the blacklist as we expire individual task failures -- each have their own timeout. Eg.,
+ * suppose:
+ * * timeout = 10, maxFailuresPerExec = 2
+ * * Task 1 fails on exec 1 at time 0
+ * * Task 2 fails on exec 1 at time 5
+ * --> exec 1 is blacklisted from time 5 - 15.
+ * This is to simplify the implementation, as well as keep the behavior easier to understand
+ * for the end user.
+ */
+ def dropFailuresWithTimeoutBefore(dropBefore: Long): Unit = {
+ if (minExpiryTime < dropBefore) {
+ var newMinExpiry = Long.MaxValue
+ val newFailures = new ArrayBuffer[(TaskId, Long)]
+ failuresAndExpiryTimes.foreach { case (task, expiryTime) =>
+ if (expiryTime >= dropBefore) {
+ newFailures += ((task, expiryTime))
+ if (expiryTime < newMinExpiry) {
+ newMinExpiry = expiryTime
+ }
+ }
+ }
+ failuresAndExpiryTimes = newFailures
+ minExpiryTime = newMinExpiry
+ }
+ }
+
+ override def toString(): String = {
+ s"failures = $failuresAndExpiryTimes"
+ }
+ }
+
+}
+
+private[scheduler] object BlacklistTracker extends Logging {
+
+ private val DEFAULT_TIMEOUT = "1h"
+
+ /**
+ * Returns true if the blacklist is enabled, based on checking the configuration in the following
+ * order:
+ * 1. Is it specifically enabled or disabled?
+ * 2. Is it enabled via the legacy timeout conf?
+ * 3. Default is off
+ */
+ def isBlacklistEnabled(conf: SparkConf): Boolean = {
+ conf.get(config.BLACKLIST_ENABLED) match {
+ case Some(enabled) =>
+ enabled
+ case None =>
+ // if they've got a non-zero setting for the legacy conf, always enable the blacklist,
+ // otherwise, use the default.
+ val legacyKey = config.BLACKLIST_LEGACY_TIMEOUT_CONF.key
+ conf.get(config.BLACKLIST_LEGACY_TIMEOUT_CONF).exists { legacyTimeout =>
+ if (legacyTimeout == 0) {
+ logWarning(s"Turning off blacklisting due to legacy configuration: $legacyKey == 0")
+ false
+ } else {
+ logWarning(s"Turning on blacklisting due to legacy configuration: $legacyKey > 0")
+ true
+ }
+ }
+ }
+ }
+
+ def getBlacklistTimeout(conf: SparkConf): Long = {
+ conf.get(config.BLACKLIST_TIMEOUT_CONF).getOrElse {
+ conf.get(config.BLACKLIST_LEGACY_TIMEOUT_CONF).getOrElse {
+ Utils.timeStringAsMs(DEFAULT_TIMEOUT)
+ }
+ }
+ }
+
+ /**
+ * Verify that blacklist configurations are consistent; if not, throw an exception. Should only
+ * be called if blacklisting is enabled.
+ *
+ * The configuration for the blacklist is expected to adhere to a few invariants. Default
+ * values follow these rules of course, but users may unwittingly change one configuration
+ * without making the corresponding adjustment elsewhere. This ensures we fail-fast when
+ * there are such misconfigurations.
+ */
+ def validateBlacklistConfs(conf: SparkConf): Unit = {
+
+ def mustBePos(k: String, v: String): Unit = {
+ throw new IllegalArgumentException(s"$k was $v, but must be > 0.")
+ }
+
+ Seq(
+ config.MAX_TASK_ATTEMPTS_PER_EXECUTOR,
+ config.MAX_TASK_ATTEMPTS_PER_NODE,
+ config.MAX_FAILURES_PER_EXEC_STAGE,
+ config.MAX_FAILED_EXEC_PER_NODE_STAGE,
+ config.MAX_FAILURES_PER_EXEC,
+ config.MAX_FAILED_EXEC_PER_NODE
+ ).foreach { config =>
+ val v = conf.get(config)
+ if (v <= 0) {
+ mustBePos(config.key, v.toString)
+ }
+ }
+
+ val timeout = getBlacklistTimeout(conf)
+ if (timeout <= 0) {
+ // first, figure out where the timeout came from, to include the right conf in the message.
+ conf.get(config.BLACKLIST_TIMEOUT_CONF) match {
+ case Some(t) =>
+ mustBePos(config.BLACKLIST_TIMEOUT_CONF.key, timeout.toString)
+ case None =>
+ mustBePos(config.BLACKLIST_LEGACY_TIMEOUT_CONF.key, timeout.toString)
+ }
+ }
+
+ val maxTaskFailures = conf.get(config.MAX_TASK_FAILURES)
+ val maxNodeAttempts = conf.get(config.MAX_TASK_ATTEMPTS_PER_NODE)
+
+ if (maxNodeAttempts >= maxTaskFailures) {
+ throw new IllegalArgumentException(s"${config.MAX_TASK_ATTEMPTS_PER_NODE.key} " +
+ s"( = ${maxNodeAttempts}) was >= ${config.MAX_TASK_FAILURES.key} " +
+ s"( = ${maxTaskFailures} ). Though blacklisting is enabled, with this configuration, " +
+ s"Spark will not be robust to one bad node. Decrease " +
+ s"${config.MAX_TASK_ATTEMPTS_PER_NODE.key}, increase ${config.MAX_TASK_FAILURES.key}, " +
+ s"or disable blacklisting with ${config.BLACKLIST_ENABLED.key}")
+ }
+ }
+}
+
+private final case class BlacklistedExecutor(node: String, expiryTime: Long)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index a1f0fd05f661..f2fcb146b85e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -22,6 +22,7 @@ import java.util.Properties
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
+import scala.annotation.tailrec
import scala.collection.Map
import scala.collection.mutable.{HashMap, HashSet, Stack}
import scala.concurrent.duration._
@@ -34,12 +35,14 @@ import org.apache.commons.lang3.SerializationUtils
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.util.JavaUtils
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{DeterministicLevel, RDD, RDDCheckpointData}
import org.apache.spark.rpc.RpcTimeout
import org.apache.spark.storage._
-import org.apache.spark.util._
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
+import org.apache.spark.util._
/**
* The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of
@@ -130,7 +133,7 @@ class DAGScheduler(
def this(sc: SparkContext) = this(sc, sc.taskScheduler)
- private[scheduler] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this)
+ private[spark] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this)
private[scheduler] val nextJobId = new AtomicInteger(0)
private[scheduler] def numTotalJobs: Int = nextJobId.get()
@@ -138,7 +141,13 @@ class DAGScheduler(
private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]]
private[scheduler] val stageIdToStage = new HashMap[Int, Stage]
- private[scheduler] val shuffleToMapStage = new HashMap[Int, ShuffleMapStage]
+ /**
+ * Mapping from shuffle dependency ID to the ShuffleMapStage that will generate the data for
+ * that dependency. Only includes stages that are part of currently running job (when the job(s)
+ * that require the shuffle stage complete, the mapping will be removed, and the only record of
+ * the shuffle data will be in the MapOutputTracker).
+ */
+ private[scheduler] val shuffleIdToMapStage = new HashMap[Int, ShuffleMapStage]
private[scheduler] val jobIdToActiveJob = new HashMap[Int, ActiveJob]
// Stages we need to run whose parents aren't done
@@ -178,6 +187,13 @@ class DAGScheduler(
/** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */
private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false)
+ /**
+ * Number of consecutive stage attempts allowed before a stage is aborted.
+ */
+ private[scheduler] val maxConsecutiveStageAttempts =
+ sc.getConf.getInt("spark.stage.maxConsecutiveAttempts",
+ DAGScheduler.DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS)
+
private val messageScheduler =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message")
@@ -206,11 +222,10 @@ class DAGScheduler(
task: Task[_],
reason: TaskEndReason,
result: Any,
- accumUpdates: Map[Long, Any],
- taskInfo: TaskInfo,
- taskMetrics: TaskMetrics): Unit = {
+ accumUpdates: Seq[AccumulatorV2[_, _]],
+ taskInfo: TaskInfo): Unit = {
eventProcessLoop.post(
- CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics))
+ CompletionEvent(task, reason, result, accumUpdates, taskInfo))
}
/**
@@ -220,18 +235,19 @@ class DAGScheduler(
*/
def executorHeartbeatReceived(
execId: String,
- taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics)
+ // (taskId, stageId, stageAttemptId, accumUpdates)
+ accumUpdates: Array[(Long, Int, Int, Seq[AccumulableInfo])],
blockManagerId: BlockManagerId): Boolean = {
- listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics))
- blockManagerMaster.driverEndpoint.askWithRetry[Boolean](
+ listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, accumUpdates))
+ blockManagerMaster.driverEndpoint.askSync[Boolean](
BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat"))
}
/**
* Called by TaskScheduler implementation when an executor fails.
*/
- def executorLost(execId: String): Unit = {
- eventProcessLoop.post(ExecutorLost(execId))
+ def executorLost(execId: String, reason: ExecutorLossReason): Unit = {
+ eventProcessLoop.post(ExecutorLost(execId, reason))
}
/**
@@ -273,159 +289,141 @@ class DAGScheduler(
}
/**
- * Get or create a shuffle map stage for the given shuffle dependency's map side.
+ * Gets a shuffle map stage if one exists in shuffleIdToMapStage. Otherwise, if the
+ * shuffle map stage doesn't already exist, this method will create the shuffle map stage in
+ * addition to any missing ancestor shuffle map stages.
*/
- private def getShuffleMapStage(
+ private def getOrCreateShuffleMapStage(
shuffleDep: ShuffleDependency[_, _, _],
firstJobId: Int): ShuffleMapStage = {
- shuffleToMapStage.get(shuffleDep.shuffleId) match {
- case Some(stage) => stage
+ shuffleIdToMapStage.get(shuffleDep.shuffleId) match {
+ case Some(stage) =>
+ stage
+
case None =>
- // We are going to register ancestor shuffle dependencies
- getAncestorShuffleDependencies(shuffleDep.rdd).foreach { dep =>
- shuffleToMapStage(dep.shuffleId) = newOrUsedShuffleStage(dep, firstJobId)
+ // Create stages for all missing ancestor shuffle dependencies.
+ getMissingAncestorShuffleDependencies(shuffleDep.rdd).foreach { dep =>
+ // Even though getMissingAncestorShuffleDependencies only returns shuffle dependencies
+ // that were not already in shuffleIdToMapStage, it's possible that by the time we
+ // get to a particular dependency in the foreach loop, it's been added to
+ // shuffleIdToMapStage by the stage creation process for an earlier dependency. See
+ // SPARK-13902 for more information.
+ if (!shuffleIdToMapStage.contains(dep.shuffleId)) {
+ createShuffleMapStage(dep, firstJobId)
+ }
}
- // Then register current shuffleDep
- val stage = newOrUsedShuffleStage(shuffleDep, firstJobId)
- shuffleToMapStage(shuffleDep.shuffleId) = stage
- stage
+ // Finally, create a stage for the given shuffle dependency.
+ createShuffleMapStage(shuffleDep, firstJobId)
}
}
/**
- * Helper function to eliminate some code re-use when creating new stages.
+ * Creates a ShuffleMapStage that generates the given shuffle dependency's partitions. If a
+ * previously run stage generated the same shuffle data, this function will copy the output
+ * locations that are still available from the previous shuffle to avoid unnecessarily
+ * regenerating data.
*/
- private def getParentStagesAndId(rdd: RDD[_], firstJobId: Int): (List[Stage], Int) = {
- val parentStages = getParentStages(rdd, firstJobId)
+ def createShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): ShuffleMapStage = {
+ val rdd = shuffleDep.rdd
+ val numTasks = rdd.partitions.length
+ val parents = getOrCreateParentStages(rdd, jobId)
val id = nextStageId.getAndIncrement()
- (parentStages, id)
- }
-
- /**
- * Create a ShuffleMapStage as part of the (re)-creation of a shuffle map stage in
- * newOrUsedShuffleStage. The stage will be associated with the provided firstJobId.
- * Production of shuffle map stages should always use newOrUsedShuffleStage, not
- * newShuffleMapStage directly.
- */
- private def newShuffleMapStage(
- rdd: RDD[_],
- numTasks: Int,
- shuffleDep: ShuffleDependency[_, _, _],
- firstJobId: Int,
- callSite: CallSite): ShuffleMapStage = {
- val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, firstJobId)
- val stage: ShuffleMapStage = new ShuffleMapStage(id, rdd, numTasks, parentStages,
- firstJobId, callSite, shuffleDep)
+ val stage = new ShuffleMapStage(
+ id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep, mapOutputTracker)
stageIdToStage(id) = stage
- updateJobIdStageIdMaps(firstJobId, stage)
+ shuffleIdToMapStage(shuffleDep.shuffleId) = stage
+ updateJobIdStageIdMaps(jobId, stage)
+
+ if (!mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
+ // Kind of ugly: need to register RDDs with the cache and map output tracker here
+ // since we can't do it in the RDD constructor because # of partitions is unknown
+ logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")")
+ mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length)
+ }
stage
}
/**
* Create a ResultStage associated with the provided jobId.
*/
- private def newResultStage(
+ private def createResultStage(
rdd: RDD[_],
func: (TaskContext, Iterator[_]) => _,
partitions: Array[Int],
jobId: Int,
callSite: CallSite): ResultStage = {
- val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId)
- val stage = new ResultStage(id, rdd, func, partitions, parentStages, jobId, callSite)
+ val parents = getOrCreateParentStages(rdd, jobId)
+ val id = nextStageId.getAndIncrement()
+ val stage = new ResultStage(id, rdd, func, partitions, parents, jobId, callSite)
stageIdToStage(id) = stage
updateJobIdStageIdMaps(jobId, stage)
stage
}
- /**
- * Create a shuffle map Stage for the given RDD. The stage will also be associated with the
- * provided firstJobId. If a stage for the shuffleId existed previously so that the shuffleId is
- * present in the MapOutputTracker, then the number and location of available outputs are
- * recovered from the MapOutputTracker
- */
- private def newOrUsedShuffleStage(
- shuffleDep: ShuffleDependency[_, _, _],
- firstJobId: Int): ShuffleMapStage = {
- val rdd = shuffleDep.rdd
- val numTasks = rdd.partitions.length
- val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, firstJobId, rdd.creationSite)
- if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
- val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
- val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
- (0 until locs.length).foreach { i =>
- if (locs(i) ne null) {
- // locs(i) will be null if missing
- stage.addOutputLoc(i, locs(i))
- }
- }
- } else {
- // Kind of ugly: need to register RDDs with the cache and map output tracker here
- // since we can't do it in the RDD constructor because # of partitions is unknown
- logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")")
- mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length)
- }
- stage
- }
-
/**
* Get or create the list of parent stages for a given RDD. The new Stages will be created with
* the provided firstJobId.
*/
- private def getParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = {
- val parents = new HashSet[Stage]
- val visited = new HashSet[RDD[_]]
- // We are manually maintaining a stack here to prevent StackOverflowError
- // caused by recursively visiting
- val waitingForVisit = new Stack[RDD[_]]
- def visit(r: RDD[_]) {
- if (!visited(r)) {
- visited += r
- // Kind of ugly: need to register RDDs with the cache here since
- // we can't do it in its constructor because # of partitions is unknown
- for (dep <- r.dependencies) {
- dep match {
- case shufDep: ShuffleDependency[_, _, _] =>
- parents += getShuffleMapStage(shufDep, firstJobId)
- case _ =>
- waitingForVisit.push(dep.rdd)
- }
- }
- }
- }
- waitingForVisit.push(rdd)
- while (waitingForVisit.nonEmpty) {
- visit(waitingForVisit.pop())
- }
- parents.toList
+ private def getOrCreateParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = {
+ getShuffleDependencies(rdd).map { shuffleDep =>
+ getOrCreateShuffleMapStage(shuffleDep, firstJobId)
+ }.toList
}
/** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */
- private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = {
- val parents = new Stack[ShuffleDependency[_, _, _]]
+ private def getMissingAncestorShuffleDependencies(
+ rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = {
+ val ancestors = new Stack[ShuffleDependency[_, _, _]]
val visited = new HashSet[RDD[_]]
// We are manually maintaining a stack here to prevent StackOverflowError
// caused by recursively visiting
val waitingForVisit = new Stack[RDD[_]]
- def visit(r: RDD[_]) {
- if (!visited(r)) {
- visited += r
- for (dep <- r.dependencies) {
- dep match {
- case shufDep: ShuffleDependency[_, _, _] =>
- if (!shuffleToMapStage.contains(shufDep.shuffleId)) {
- parents.push(shufDep)
- }
- case _ =>
- }
- waitingForVisit.push(dep.rdd)
+ waitingForVisit.push(rdd)
+ while (waitingForVisit.nonEmpty) {
+ val toVisit = waitingForVisit.pop()
+ if (!visited(toVisit)) {
+ visited += toVisit
+ getShuffleDependencies(toVisit).foreach { shuffleDep =>
+ if (!shuffleIdToMapStage.contains(shuffleDep.shuffleId)) {
+ ancestors.push(shuffleDep)
+ waitingForVisit.push(shuffleDep.rdd)
+ } // Otherwise, the dependency and its ancestors have already been registered.
}
}
}
+ ancestors
+ }
+ /**
+ * Returns shuffle dependencies that are immediate parents of the given RDD.
+ *
+ * This function will not return more distant ancestors. For example, if C has a shuffle
+ * dependency on B which has a shuffle dependency on A:
+ *
+ * A <-- B <-- C
+ *
+ * calling this function with rdd C will only return the B <-- C dependency.
+ *
+ * This function is scheduler-visible for the purpose of unit testing.
+ */
+ private[scheduler] def getShuffleDependencies(
+ rdd: RDD[_]): HashSet[ShuffleDependency[_, _, _]] = {
+ val parents = new HashSet[ShuffleDependency[_, _, _]]
+ val visited = new HashSet[RDD[_]]
+ val waitingForVisit = new Stack[RDD[_]]
waitingForVisit.push(rdd)
while (waitingForVisit.nonEmpty) {
- visit(waitingForVisit.pop())
+ val toVisit = waitingForVisit.pop()
+ if (!visited(toVisit)) {
+ visited += toVisit
+ toVisit.dependencies.foreach {
+ case shuffleDep: ShuffleDependency[_, _, _] =>
+ parents += shuffleDep
+ case dependency =>
+ waitingForVisit.push(dependency.rdd)
+ }
+ }
}
parents
}
@@ -444,7 +442,7 @@ class DAGScheduler(
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_, _, _] =>
- val mapStage = getShuffleMapStage(shufDep, stage.firstJobId)
+ val mapStage = getOrCreateShuffleMapStage(shufDep, stage.firstJobId)
if (!mapStage.isAvailable) {
missing += mapStage
}
@@ -467,13 +465,13 @@ class DAGScheduler(
* all of that stage's ancestors.
*/
private def updateJobIdStageIdMaps(jobId: Int, stage: Stage): Unit = {
+ @tailrec
def updateJobIdStageIdMapsList(stages: List[Stage]) {
if (stages.nonEmpty) {
val s = stages.head
s.jobIds += jobId
jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id
- val parents: List[Stage] = getParentStages(s.rdd, jobId)
- val parentsWithoutThisJobId = parents.filter { ! _.jobIds.contains(jobId) }
+ val parentsWithoutThisJobId = s.parents.filter { ! _.jobIds.contains(jobId) }
updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail)
}
}
@@ -506,8 +504,8 @@ class DAGScheduler(
logDebug("Removing running stage %d".format(stageId))
runningStages -= stage
}
- for ((k, v) <- shuffleToMapStage.find(_._2 == stage)) {
- shuffleToMapStage.remove(k)
+ for ((k, v) <- shuffleIdToMapStage.find(_._2 == stage)) {
+ shuffleIdToMapStage.remove(k)
}
if (waitingStages.contains(stage)) {
logDebug("Removing stage %d from waiting set.".format(stageId))
@@ -541,8 +539,7 @@ class DAGScheduler(
}
/**
- * Submit an action job to the scheduler and get a JobWaiter object back. The JobWaiter object
- * can be used to block until the the job finishes executing or can be used to cancel the job.
+ * Submit an action job to the scheduler.
*
* @param rdd target RDD to run tasks on
* @param func a function to run on each partition of the RDD
@@ -551,6 +548,11 @@ class DAGScheduler(
* @param callSite where in the user program this job was called
* @param resultHandler callback to pass each result to
* @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name
+ *
+ * @return a JobWaiter object that can be used to block until the job finishes executing
+ * or can be used to cancel the job.
+ *
+ * @throws IllegalArgumentException when partitions ids are illegal
*/
def submitJob[T, U](
rdd: RDD[T],
@@ -584,7 +586,7 @@ class DAGScheduler(
/**
* Run an action job on the given RDD and pass all the results to the resultHandler function as
- * they arrive. Throws an exception if the job fials, or returns normally if successful.
+ * they arrive.
*
* @param rdd target RDD to run tasks on
* @param func a function to run on each partition of the RDD
@@ -593,6 +595,8 @@ class DAGScheduler(
* @param callSite where in the user program this job was called
* @param resultHandler callback to pass each result to
* @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name
+ *
+ * @note Throws `Exception` when the job fails
*/
def runJob[T, U](
rdd: RDD[T],
@@ -603,11 +607,12 @@ class DAGScheduler(
properties: Properties): Unit = {
val start = System.nanoTime
val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties)
- waiter.awaitResult() match {
- case JobSucceeded =>
+ ThreadUtils.awaitReady(waiter.completionFuture, Duration.Inf)
+ waiter.completionFuture.value.get match {
+ case scala.util.Success(_) =>
logInfo("Job %d finished: %s, took %f s".format
(waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
- case JobFailed(exception: Exception) =>
+ case scala.util.Failure(exception) =>
logInfo("Job %d failed: %s, took %f s".format
(waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
// SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler.
@@ -623,7 +628,7 @@ class DAGScheduler(
*
* @param rdd target RDD to run tasks on
* @param func a function to run on each partition of the RDD
- * @param evaluator [[ApproximateEvaluator]] to receive the partial results
+ * @param evaluator `ApproximateEvaluator` to receive the partial results
* @param callSite where in the user program this job was called
* @param timeout maximum time to wait for the job, in milliseconds
* @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name
@@ -646,7 +651,7 @@ class DAGScheduler(
/**
* Submit a shuffle map stage to run independently and get a JobWaiter object back. The waiter
- * can be used to block until the the job finishes executing or can be used to cancel the job.
+ * can be used to block until the job finishes executing or can be used to cancel the job.
* This method is used for adaptive query planning, to run map stages and look at statistics
* about their outputs before submitting downstream stages.
*
@@ -682,9 +687,9 @@ class DAGScheduler(
/**
* Cancel a job that is running or waiting in the queue.
*/
- def cancelJob(jobId: Int): Unit = {
+ def cancelJob(jobId: Int, reason: Option[String]): Unit = {
logInfo("Asked to cancel job " + jobId)
- eventProcessLoop.post(JobCancelled(jobId))
+ eventProcessLoop.post(JobCancelled(jobId, reason))
}
/**
@@ -705,17 +710,25 @@ class DAGScheduler(
private[scheduler] def doCancelAllJobs() {
// Cancel all running jobs.
runningStages.map(_.firstJobId).foreach(handleJobCancellation(_,
- reason = "as part of cancellation of all jobs"))
+ Option("as part of cancellation of all jobs")))
activeJobs.clear() // These should already be empty by this point,
jobIdToActiveJob.clear() // but just in case we lost track of some jobs...
- submitWaitingStages()
}
/**
* Cancel all jobs associated with a running or scheduled stage.
*/
- def cancelStage(stageId: Int) {
- eventProcessLoop.post(StageCancelled(stageId))
+ def cancelStage(stageId: Int, reason: Option[String]) {
+ eventProcessLoop.post(StageCancelled(stageId, reason))
+ }
+
+ /**
+ * Kill a given task. It will be retried.
+ *
+ * @return Whether the task was successfully killed.
+ */
+ def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean = {
+ taskScheduler.killTaskAttempt(taskId, interruptThread, reason)
}
/**
@@ -734,23 +747,21 @@ class DAGScheduler(
submitStage(stage)
}
}
- submitWaitingStages()
}
/**
- * Check for waiting or failed stages which are now eligible for resubmission.
- * Ordinarily run on every iteration of the event loop.
+ * Check for waiting stages which are now eligible for resubmission.
+ * Submits stages that depend on the given parent stage. Called when the parent stage completes
+ * successfully.
*/
- private def submitWaitingStages() {
- // TODO: We might want to run this less often, when we are sure that something has become
- // runnable that wasn't before.
- logTrace("Checking for newly runnable parent stages")
+ private def submitWaitingChildStages(parent: Stage) {
+ logTrace(s"Checking if any dependencies of $parent are now runnable")
logTrace("running: " + runningStages)
logTrace("waiting: " + waitingStages)
logTrace("failed: " + failedStages)
- val waitingStagesCopy = waitingStages.toArray
- waitingStages.clear()
- for (stage <- waitingStagesCopy.sortBy(_.firstJobId)) {
+ val childStages = waitingStages.filter(_.parents.contains(parent)).toArray
+ waitingStages --= childStages
+ for (stage <- childStages.sortBy(_.firstJobId)) {
submitStage(stage)
}
}
@@ -774,8 +785,8 @@ class DAGScheduler(
}
}
val jobIds = activeInGroup.map(_.jobId)
- jobIds.foreach(handleJobCancellation(_, "part of cancelled job group %s".format(groupId)))
- submitWaitingStages()
+ jobIds.foreach(handleJobCancellation(_,
+ Option("part of cancelled job group %s".format(groupId))))
}
private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo) {
@@ -783,7 +794,6 @@ class DAGScheduler(
// In that case, we wouldn't have the stage anymore in stageIdToStage.
val stageAttemptId = stageIdToStage.get(task.stageId).map(_.latestInfo.attemptId).getOrElse(-1)
listenerBus.post(SparkListenerTaskStart(task.stageId, stageAttemptId, taskInfo))
- submitWaitingStages()
}
private[scheduler] def handleTaskSetFailed(
@@ -791,12 +801,12 @@ class DAGScheduler(
reason: String,
exception: Option[Throwable]): Unit = {
stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason, exception) }
- submitWaitingStages()
}
private[scheduler] def cleanUpAfterSchedulerStop() {
for (job <- activeJobs) {
- val error = new SparkException("Job cancelled because SparkContext was shut down")
+ val error =
+ new SparkException(s"Job ${job.jobId} cancelled because SparkContext was shut down")
job.listener.jobFailed(error)
// Tell the listeners that all of the running stages have ended. Don't bother
// cancelling the stages because if the DAG scheduler is stopped, the entire application
@@ -813,7 +823,6 @@ class DAGScheduler(
private[scheduler] def handleGetTaskResult(taskInfo: TaskInfo) {
listenerBus.post(SparkListenerTaskGettingResult(taskInfo))
- submitWaitingStages()
}
private[scheduler] def handleJobSubmitted(jobId: Int,
@@ -827,7 +836,7 @@ class DAGScheduler(
try {
// New stage creation may throw an exception if, for example, jobs are run on a
// HadoopRDD whose underlying HDFS files have been deleted.
- finalStage = newResultStage(finalRDD, func, partitions, jobId, callSite)
+ finalStage = createResultStage(finalRDD, func, partitions, jobId, callSite)
} catch {
case e: Exception =>
logWarning("Creating new stage failed due to exception - job: " + jobId, e)
@@ -852,8 +861,6 @@ class DAGScheduler(
listenerBus.post(
SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties))
submitStage(finalStage)
-
- submitWaitingStages()
}
private[scheduler] def handleMapStageSubmitted(jobId: Int,
@@ -867,7 +874,7 @@ class DAGScheduler(
try {
// New stage creation may throw an exception if, for example, jobs are run on a
// HadoopRDD whose underlying HDFS files have been deleted.
- finalStage = getShuffleMapStage(dependency, jobId)
+ finalStage = getOrCreateShuffleMapStage(dependency, jobId)
} catch {
case e: Exception =>
logWarning("Creating new stage failed due to exception - job: " + jobId, e)
@@ -897,8 +904,6 @@ class DAGScheduler(
if (finalStage.isAvailable) {
markMapStageJobAsFinished(job, mapOutputTracker.getStatistics(dependency))
}
-
- submitWaitingStages()
}
/** Submits stage, but first recursively submits any missing parents. */
@@ -927,20 +932,13 @@ class DAGScheduler(
/** Called when stage's parents are available and we can now do its task. */
private def submitMissingTasks(stage: Stage, jobId: Int) {
logDebug("submitMissingTasks(" + stage + ")")
- // Get our pending tasks and remember them in our pendingTasks entry
- stage.pendingPartitions.clear()
// First figure out the indexes of partition ids to compute.
val partitionsToCompute: Seq[Int] = stage.findMissingPartitions()
- // Create internal accumulators if the stage has no accumulators initialized.
- // Reset internal accumulators only if this stage is not partially submitted
- // Otherwise, we may override existing accumulator values from some tasks
- if (stage.internalAccumulators.isEmpty || stage.numPartitions == partitionsToCompute.size) {
- stage.resetInternalAccumulators()
- }
-
- val properties = jobIdToActiveJob.get(stage.firstJobId).map(_.properties).orNull
+ // Use the scheduling pool, job group, description, etc. from an ActiveJob associated
+ // with this Stage
+ val properties = jobIdToActiveJob(jobId).properties
runningStages += stage
// SparkListenerStageSubmitted should be posted before testing whether tasks are
@@ -959,7 +957,6 @@ class DAGScheduler(
case s: ShuffleMapStage =>
partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap
case s: ResultStage =>
- val job = s.activeJob.get
partitionsToCompute.map { id =>
val p = s.partitions(id)
(id, getPreferredLocs(stage.rdd, p))
@@ -969,7 +966,7 @@ class DAGScheduler(
case NonFatal(e) =>
stage.makeNewStageAttempt(partitionsToCompute.size)
listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))
- abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e))
+ abortStage(stage, s"Task creation failed: $e\n${Utils.exceptionString(e)}", Some(e))
runningStages -= stage
return
}
@@ -984,14 +981,24 @@ class DAGScheduler(
// might modify state of objects referenced in their closures. This is necessary in Hadoop
// where the JobConf/Configuration object is not thread-safe.
var taskBinary: Broadcast[Array[Byte]] = null
+ var partitions: Array[Partition] = null
try {
// For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep).
// For ResultTask, serialize and broadcast (rdd, func).
- val taskBinaryBytes: Array[Byte] = stage match {
- case stage: ShuffleMapStage =>
- closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).array()
- case stage: ResultStage =>
- closureSerializer.serialize((stage.rdd, stage.func): AnyRef).array()
+ var taskBinaryBytes: Array[Byte] = null
+ // taskBinaryBytes and partitions are both effected by the checkpoint status. We need
+ // this synchronization in case another concurrent job is checkpointing this RDD, so we get a
+ // consistent view of both variables.
+ RDDCheckpointData.synchronized {
+ taskBinaryBytes = stage match {
+ case stage: ShuffleMapStage =>
+ JavaUtils.bufferToArray(
+ closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef))
+ case stage: ResultStage =>
+ JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef))
+ }
+
+ partitions = stage.rdd.partitions
}
taskBinary = sc.broadcast(taskBinaryBytes)
@@ -1004,44 +1011,47 @@ class DAGScheduler(
// Abort execution
return
case NonFatal(e) =>
- abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}", Some(e))
+ abortStage(stage, s"Task serialization failed: $e\n${Utils.exceptionString(e)}", Some(e))
runningStages -= stage
return
}
val tasks: Seq[Task[_]] = try {
+ val serializedTaskMetrics = closureSerializer.serialize(stage.latestInfo.taskMetrics).array()
stage match {
case stage: ShuffleMapStage =>
+ stage.pendingPartitions.clear()
partitionsToCompute.map { id =>
val locs = taskIdToLocations(id)
- val part = stage.rdd.partitions(id)
+ val part = partitions(id)
+ stage.pendingPartitions += id
new ShuffleMapTask(stage.id, stage.latestInfo.attemptId,
- taskBinary, part, locs, stage.internalAccumulators)
+ taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId),
+ Option(sc.applicationId), sc.applicationAttemptId)
}
case stage: ResultStage =>
- val job = stage.activeJob.get
partitionsToCompute.map { id =>
val p: Int = stage.partitions(id)
- val part = stage.rdd.partitions(p)
+ val part = partitions(p)
val locs = taskIdToLocations(id)
new ResultTask(stage.id, stage.latestInfo.attemptId,
- taskBinary, part, locs, id, stage.internalAccumulators)
+ taskBinary, part, locs, id, properties, serializedTaskMetrics,
+ Option(jobId), Option(sc.applicationId), sc.applicationAttemptId)
}
}
} catch {
case NonFatal(e) =>
- abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e))
+ abortStage(stage, s"Task creation failed: $e\n${Utils.exceptionString(e)}", Some(e))
runningStages -= stage
return
}
if (tasks.size > 0) {
- logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
- stage.pendingPartitions ++= tasks.map(_.partitionId)
- logDebug("New pending partitions: " + stage.pendingPartitions)
+ logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " +
+ s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})")
taskScheduler.submitTasks(new TaskSet(
- tasks.toArray, stage.id, stage.latestInfo.attemptId, stage.firstJobId, properties))
+ tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties))
stage.latestInfo.submissionTime = Some(clock.getTimeMillis())
} else {
// Because we posted SparkListenerStageSubmitted earlier, we should mark
@@ -1058,79 +1068,123 @@ class DAGScheduler(
s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})"
}
logDebug(debugString)
+
+ submitWaitingChildStages(stage)
}
}
- /** Merge updates from a task to our local accumulator values */
+ /**
+ * Merge local values from a task into the corresponding accumulators previously registered
+ * here on the driver.
+ *
+ * Although accumulators themselves are not thread-safe, this method is called only from one
+ * thread, the one that runs the scheduling loop. This means we only handle one task
+ * completion event at a time so we don't need to worry about locking the accumulators.
+ * This still doesn't stop the caller from updating the accumulator outside the scheduler,
+ * but that's not our problem since there's nothing we can do about that.
+ */
private def updateAccumulators(event: CompletionEvent): Unit = {
val task = event.task
val stage = stageIdToStage(task.stageId)
- if (event.accumUpdates != null) {
- try {
- Accumulators.add(event.accumUpdates)
- event.accumUpdates.foreach { case (id, partialValue) =>
- // In this instance, although the reference in Accumulators.originals is a WeakRef,
- // it's guaranteed to exist since the event.accumUpdates Map exists
-
- val acc = Accumulators.originals(id).get match {
- case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]]
- case None => throw new NullPointerException("Non-existent reference to Accumulator")
- }
-
- // To avoid UI cruft, ignore cases where value wasn't updated
- if (acc.name.isDefined && partialValue != acc.zero) {
- val name = acc.name.get
- val value = s"${acc.value}"
- stage.latestInfo.accumulables(id) =
- new AccumulableInfo(id, name, None, value, acc.isInternal)
- event.taskInfo.accumulables +=
- new AccumulableInfo(id, name, Some(s"$partialValue"), value, acc.isInternal)
- }
+ event.accumUpdates.foreach { updates =>
+ val id = updates.id
+ try {
+ // Find the corresponding accumulator on the driver and update it
+ val acc: AccumulatorV2[Any, Any] = AccumulatorContext.get(id) match {
+ case Some(accum) => accum.asInstanceOf[AccumulatorV2[Any, Any]]
+ case None =>
+ throw new SparkException(s"attempted to access non-existent accumulator $id")
+ }
+ acc.merge(updates.asInstanceOf[AccumulatorV2[Any, Any]])
+ // To avoid UI cruft, ignore cases where value wasn't updated
+ if (acc.name.isDefined && !updates.isZero) {
+ stage.latestInfo.accumulables(id) = acc.toInfo(None, Some(acc.value))
+ event.taskInfo.setAccumulables(
+ acc.toInfo(Some(updates.value), Some(acc.value)) +: event.taskInfo.accumulables)
}
} catch {
- // If we see an exception during accumulator update, just log the
- // error and move on.
- case e: Exception =>
- logError(s"Failed to update accumulators for $task", e)
+ case NonFatal(e) =>
+ // Log the class name to make it easy to find the bad implementation
+ val accumClassName = AccumulatorContext.get(id) match {
+ case Some(accum) => accum.getClass.getName
+ case None => "Unknown class"
+ }
+ logError(
+ s"Failed to update accumulator $id ($accumClassName) for task ${task.partitionId}",
+ e)
}
}
}
+ private def postTaskEnd(event: CompletionEvent): Unit = {
+ val taskMetrics: TaskMetrics =
+ if (event.accumUpdates.nonEmpty) {
+ try {
+ TaskMetrics.fromAccumulators(event.accumUpdates)
+ } catch {
+ case NonFatal(e) =>
+ val taskId = event.taskInfo.taskId
+ logError(s"Error when attempting to reconstruct metrics for task $taskId", e)
+ null
+ }
+ } else {
+ null
+ }
+
+ listenerBus.post(SparkListenerTaskEnd(event.task.stageId, event.task.stageAttemptId,
+ Utils.getFormattedClassName(event.task), event.reason, event.taskInfo, taskMetrics))
+ }
+
/**
* Responds to a task finishing. This is called inside the event loop so it assumes that it can
* modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
*/
private[scheduler] def handleTaskCompletion(event: CompletionEvent) {
val task = event.task
+ val taskId = event.taskInfo.id
val stageId = task.stageId
val taskType = Utils.getFormattedClassName(task)
outputCommitCoordinator.taskCompleted(
stageId,
+ task.stageAttemptId,
task.partitionId,
event.taskInfo.attemptNumber, // this is a task attempt number
event.reason)
- // The success case is dealt with separately below, since we need to compute accumulator
- // updates before posting.
- if (event.reason != Success) {
- val attemptId = task.stageAttemptId
- listenerBus.post(SparkListenerTaskEnd(stageId, attemptId, taskType, event.reason,
- event.taskInfo, event.taskMetrics))
- }
-
if (!stageIdToStage.contains(task.stageId)) {
+ // The stage may have already finished when we get this event -- eg. maybe it was a
+ // speculative task. It is important that we send the TaskEnd event in any case, so listeners
+ // are properly notified and can chose to handle it. For instance, some listeners are
+ // doing their own accounting and if they don't get the task end event they think
+ // tasks are still running when they really aren't.
+ postTaskEnd(event)
+
// Skip all the actions if the stage has been cancelled.
return
}
val stage = stageIdToStage(task.stageId)
+
+ // Make sure the task's accumulators are updated before any other processing happens, so that
+ // we can post a task end event before any jobs or stages are updated. The accumulators are
+ // only updated in certain cases.
+ event.reason match {
+ case Success =>
+ stage match {
+ case rs: ResultStage if rs.activeJob.isEmpty =>
+ // Ignore update if task's job has finished.
+ case _ =>
+ updateAccumulators(event)
+ }
+ case _: ExceptionFailure => updateAccumulators(event)
+ case _ =>
+ }
+ postTaskEnd(event)
+
event.reason match {
case Success =>
- listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType,
- event.reason, event.taskInfo, event.taskMetrics))
- stage.pendingPartitions -= task.partitionId
task match {
case rt: ResultTask[_, _] =>
// Cast to ResultStage here because it's part of the ResultTask
@@ -1139,7 +1193,6 @@ class DAGScheduler(
resultStage.activeJob match {
case Some(job) =>
if (!job.finished(rt.outputId)) {
- updateAccumulators(event)
job.finished(rt.outputId) = true
job.numFinished += 1
// If the whole job has finished, remove it
@@ -1166,14 +1219,33 @@ class DAGScheduler(
case smt: ShuffleMapTask =>
val shuffleStage = stage.asInstanceOf[ShuffleMapStage]
- updateAccumulators(event)
val status = event.result.asInstanceOf[MapStatus]
val execId = status.location.executorId
logDebug("ShuffleMapTask finished on " + execId)
+ if (stageIdToStage(task.stageId).latestInfo.attemptId == task.stageAttemptId) {
+ // This task was for the currently running attempt of the stage. Since the task
+ // completed successfully from the perspective of the TaskSetManager, mark it as
+ // no longer pending (the TaskSetManager may consider the task complete even
+ // when the output needs to be ignored because the task's epoch is too small below.
+ // In this case, when pending partitions is empty, there will still be missing
+ // output locations, which will cause the DAGScheduler to resubmit the stage below.)
+ shuffleStage.pendingPartitions -= task.partitionId
+ }
if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")
} else {
- shuffleStage.addOutputLoc(smt.partitionId, status)
+ // The epoch of the task is acceptable (i.e., the task was launched after the most
+ // recent failure we're aware of for the executor), so mark the task's output as
+ // available.
+ mapOutputTracker.registerMapOutput(
+ shuffleStage.shuffleDep.shuffleId, smt.partitionId, status)
+ // Remove the task's partition from pending partitions. This may have already been
+ // done above, but will not have been done yet in cases where the task attempt was
+ // from an earlier attempt of the stage (i.e., not the attempt that's currently
+ // running). This allows the DAGScheduler to mark the stage as complete when one
+ // copy of each task has finished successfully, even if the currently active stage
+ // still has tasks running.
+ shuffleStage.pendingPartitions -= task.partitionId
}
if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) {
@@ -1183,21 +1255,19 @@ class DAGScheduler(
logInfo("waiting: " + waitingStages)
logInfo("failed: " + failedStages)
- // We supply true to increment the epoch number here in case this is a
- // recomputation of the map outputs. In that case, some nodes may have cached
- // locations with holes (from when we detected the error) and will need the
- // epoch incremented to refetch them.
- // TODO: Only increment the epoch number if this is not the first time
- // we registered these map outputs.
- mapOutputTracker.registerMapOutputs(
- shuffleStage.shuffleDep.shuffleId,
- shuffleStage.outputLocInMapOutputTrackerFormat(),
- changeEpoch = true)
+ // This call to increment the epoch may not be strictly necessary, but it is retained
+ // for now in order to minimize the changes in behavior from an earlier version of the
+ // code. This existing behavior of always incrementing the epoch following any
+ // successful shuffle map stage completion may have benefits by causing unneeded
+ // cached map outputs to be cleaned up earlier on executors. In the future we can
+ // consider removing this call, but this will require some extra investigation.
+ // See https://github.com/apache/spark/pull/17955/files#r117385673 for more details.
+ mapOutputTracker.incrementEpoch()
clearCacheLocs()
if (!shuffleStage.isAvailable) {
- // Some tasks had failed; let's resubmit this shuffleStage
+ // Some tasks had failed; let's resubmit this shuffleStage.
// TODO: Lower-level scheduler should also deal with this
logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name +
") because some of its tasks had failed: " +
@@ -1211,66 +1281,151 @@ class DAGScheduler(
markMapStageJobAsFinished(job, stats)
}
}
+ submitWaitingChildStages(shuffleStage)
}
-
- // Note: newly runnable stages will be submitted below when we submit waiting stages
}
}
case Resubmitted =>
logInfo("Resubmitted " + task + ", so marking it as still running")
- stage.pendingPartitions += task.partitionId
+ stage match {
+ case sms: ShuffleMapStage =>
+ sms.pendingPartitions += task.partitionId
+
+ case _ =>
+ assert(false, "TaskSetManagers should only send Resubmitted task statuses for " +
+ "tasks in ShuffleMapStages.")
+ }
case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) =>
val failedStage = stageIdToStage(task.stageId)
- val mapStage = shuffleToMapStage(shuffleId)
+ val mapStage = shuffleIdToMapStage(shuffleId)
if (failedStage.latestInfo.attemptId != task.stageAttemptId) {
logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
s" ${task.stageAttemptId} and there is a more recent attempt for that stage " +
s"(attempt ID ${failedStage.latestInfo.attemptId}) running")
} else {
+ failedStage.fetchFailedAttemptIds.add(task.stageAttemptId)
+ val shouldAbortStage =
+ failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts ||
+ disallowStageRetryForTest
+
// It is likely that we receive multiple FetchFailed for a single stage (because we have
// multiple tasks running concurrently on different executors). In that case, it is
// possible the fetch failure has already been handled by the scheduler.
if (runningStages.contains(failedStage)) {
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
s"due to a fetch failure from $mapStage (${mapStage.name})")
- markStageAsFinished(failedStage, Some(failureMessage))
+ markStageAsFinished(failedStage, errorMessage = Some(failureMessage),
+ willRetry = !shouldAbortStage)
} else {
logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " +
s"longer running")
}
- if (disallowStageRetryForTest) {
- abortStage(failedStage, "Fetch failure will not retry stage due to testing config",
- None)
- } else if (failedStage.failedOnFetchAndShouldAbort(task.stageAttemptId)) {
- abortStage(failedStage, s"$failedStage (${failedStage.name}) " +
- s"has failed the maximum allowable number of " +
- s"times: ${Stage.MAX_CONSECUTIVE_FETCH_FAILURES}. " +
- s"Most recent failure reason: ${failureMessage}", None)
- } else if (failedStages.isEmpty) {
- // Don't schedule an event to resubmit failed stages if failed isn't empty, because
- // in that case the event will already have been scheduled.
- // TODO: Cancel running tasks in the stage
- logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
- s"$failedStage (${failedStage.name}) due to fetch failure")
- messageScheduler.schedule(new Runnable {
- override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
- }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
+ if (shouldAbortStage) {
+ val abortMessage = if (disallowStageRetryForTest) {
+ "Fetch failure will not retry stage due to testing config"
+ } else {
+ s"""$failedStage (${failedStage.name})
+ |has failed the maximum allowable number of
+ |times: $maxConsecutiveStageAttempts.
+ |Most recent failure reason: $failureMessage""".stripMargin.replaceAll("\n", " ")
+ }
+ abortStage(failedStage, abortMessage, None)
+ } else { // update failedStages and make sure a ResubmitFailedStages event is enqueued
+ // TODO: Cancel running tasks in the failed stage -- cf. SPARK-17064
+ val noResubmitEnqueued = !failedStages.contains(failedStage)
+ failedStages += failedStage
+ failedStages += mapStage
+ if (noResubmitEnqueued) {
+ // If the map stage is INDETERMINATE, which means the map tasks may return
+ // different result when re-try, we need to re-try all the tasks of the failed
+ // stage and its succeeding stages, because the input data will be changed after the
+ // map tasks are re-tried.
+ // Note that, if map stage is UNORDERED, we are fine. The shuffle partitioner is
+ // guaranteed to be determinate, so the input data of the reducers will not change
+ // even if the map tasks are re-tried.
+ if (mapStage.rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE) {
+ // It's a little tricky to find all the succeeding stages of `failedStage`, because
+ // each stage only know its parents not children. Here we traverse the stages from
+ // the leaf nodes (the result stages of active jobs), and rollback all the stages
+ // in the stage chains that connect to the `failedStage`. To speed up the stage
+ // traversing, we collect the stages to rollback first. If a stage needs to
+ // rollback, all its succeeding stages need to rollback to.
+ val stagesToRollback = scala.collection.mutable.HashSet(failedStage)
+
+ def collectStagesToRollback(stageChain: List[Stage]): Unit = {
+ if (stagesToRollback.contains(stageChain.head)) {
+ stageChain.drop(1).foreach(s => stagesToRollback += s)
+ } else {
+ stageChain.head.parents.foreach { s =>
+ collectStagesToRollback(s :: stageChain)
+ }
+ }
+ }
+
+ def generateErrorMessage(stage: Stage): String = {
+ "A shuffle map stage with indeterminate output was failed and retried. " +
+ s"However, Spark cannot rollback the $stage to re-process the input data, " +
+ "and has to fail this job. Please eliminate the indeterminacy by " +
+ "checkpointing the RDD before repartition and try again."
+ }
+
+ activeJobs.foreach(job => collectStagesToRollback(job.finalStage :: Nil))
+
+ stagesToRollback.foreach {
+ case mapStage: ShuffleMapStage =>
+ val numMissingPartitions = mapStage.findMissingPartitions().length
+ if (numMissingPartitions < mapStage.numTasks) {
+ // TODO: support to rollback shuffle files.
+ // Currently the shuffle writing is "first write wins", so we can't re-run a
+ // shuffle map stage and overwrite existing shuffle files. We have to finish
+ // SPARK-8029 first.
+ abortStage(mapStage, generateErrorMessage(mapStage), None)
+ }
+
+ case resultStage: ResultStage if resultStage.activeJob.isDefined =>
+ val numMissingPartitions = resultStage.findMissingPartitions().length
+ if (numMissingPartitions < resultStage.numTasks) {
+ // TODO: support to rollback result tasks.
+ abortStage(resultStage, generateErrorMessage(resultStage), None)
+ }
+
+ case _ =>
+ }
+ }
+
+ // We expect one executor failure to trigger many FetchFailures in rapid succession,
+ // but all of those task failures can typically be handled by a single resubmission of
+ // the failed stage. We avoid flooding the scheduler's event queue with resubmit
+ // messages by checking whether a resubmit is already in the event queue for the
+ // failed stage. If there is already a resubmit enqueued for a different failed
+ // stage, that event would also be sufficient to handle the current failed stage, but
+ // producing a resubmit for each failed stage makes debugging and logging a little
+ // simpler while not producing an overwhelming number of scheduler events.
+ logInfo(
+ s"Resubmitting $mapStage (${mapStage.name}) and " +
+ s"$failedStage (${failedStage.name}) due to fetch failure"
+ )
+ messageScheduler.schedule(
+ new Runnable {
+ override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
+ },
+ DAGScheduler.RESUBMIT_TIMEOUT,
+ TimeUnit.MILLISECONDS
+ )
+ }
}
- failedStages += failedStage
- failedStages += mapStage
// Mark the map whose fetch failed as broken in the map stage
if (mapId != -1) {
- mapStage.removeOutputLoc(mapId, bmAddress)
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
}
// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
- handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
+ handleExecutorLost(bmAddress.executorId, filesLost = true, Some(task.epoch))
}
}
@@ -1278,16 +1433,15 @@ class DAGScheduler(
// Do nothing here, left up to the TaskScheduler to decide how to handle denied commits
case exceptionFailure: ExceptionFailure =>
- // Do nothing here, left up to the TaskScheduler to decide how to handle user failures
+ // Nothing left to do, already handled above for accumulator updates.
case TaskResultLost =>
// Do nothing here; the TaskScheduler handles these failures and resubmits the task.
- case other =>
+ case _: ExecutorLostFailure | _: TaskKilled | UnknownReason =>
// Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler
// will abort the job.
}
- submitWaitingStages()
}
/**
@@ -1295,15 +1449,16 @@ class DAGScheduler(
* modify the scheduler's internal state. Use executorLost() to post a loss event from outside.
*
* We will also assume that we've lost all shuffle blocks associated with the executor if the
- * executor serves its own blocks (i.e., we're not using external shuffle) OR a FetchFailed
- * occurred, in which case we presume all shuffle data related to this executor to be lost.
+ * executor serves its own blocks (i.e., we're not using external shuffle), the entire slave
+ * is lost (likely including the shuffle service), or a FetchFailed occurred, in which case we
+ * presume all shuffle data related to this executor to be lost.
*
* Optionally the epoch during which the failure was caught can be passed to avoid allowing
* stray fetch failures from possibly retriggering the detection of a node as lost.
*/
private[scheduler] def handleExecutorLost(
execId: String,
- fetchFailed: Boolean,
+ filesLost: Boolean,
maybeEpoch: Option[Long] = None) {
val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch)
if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) {
@@ -1311,25 +1466,15 @@ class DAGScheduler(
logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch))
blockManagerMaster.removeExecutor(execId)
- if (!env.blockManager.externalShuffleServiceEnabled || fetchFailed) {
- // TODO: This will be really slow if we keep accumulating shuffle map stages
- for ((shuffleId, stage) <- shuffleToMapStage) {
- stage.removeOutputsOnExecutor(execId)
- mapOutputTracker.registerMapOutputs(
- shuffleId,
- stage.outputLocInMapOutputTrackerFormat(),
- changeEpoch = true)
- }
- if (shuffleToMapStage.isEmpty) {
- mapOutputTracker.incrementEpoch()
- }
+ if (filesLost || !env.blockManager.externalShuffleServiceEnabled) {
+ logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch))
+ mapOutputTracker.removeOutputsOnExecutor(execId)
clearCacheLocs()
}
} else {
logDebug("Additional executor lost message for " + execId +
"(epoch " + currentEpoch + ")")
}
- submitWaitingStages()
}
private[scheduler] def handleExecutorAdded(execId: String, host: String) {
@@ -1338,36 +1483,42 @@ class DAGScheduler(
logInfo("Host added was in lost list earlier: " + host)
failedEpoch -= execId
}
- submitWaitingStages()
}
- private[scheduler] def handleStageCancellation(stageId: Int) {
+ private[scheduler] def handleStageCancellation(stageId: Int, reason: Option[String]) {
stageIdToStage.get(stageId) match {
case Some(stage) =>
val jobsThatUseStage: Array[Int] = stage.jobIds.toArray
jobsThatUseStage.foreach { jobId =>
- handleJobCancellation(jobId, s"because Stage $stageId was cancelled")
+ val reasonStr = reason match {
+ case Some(originalReason) =>
+ s"because $originalReason"
+ case None =>
+ s"because Stage $stageId was cancelled"
+ }
+ handleJobCancellation(jobId, Option(reasonStr))
}
case None =>
logInfo("No active jobs to kill for Stage " + stageId)
}
- submitWaitingStages()
}
- private[scheduler] def handleJobCancellation(jobId: Int, reason: String = "") {
+ private[scheduler] def handleJobCancellation(jobId: Int, reason: Option[String]) {
if (!jobIdToStageIds.contains(jobId)) {
logDebug("Trying to cancel unregistered job " + jobId)
} else {
failJobAndIndependentStages(
- jobIdToActiveJob(jobId), "Job %d cancelled %s".format(jobId, reason))
+ jobIdToActiveJob(jobId), "Job %d cancelled %s".format(jobId, reason.getOrElse("")))
}
- submitWaitingStages()
}
/**
* Marks a stage as finished and removes it from the list of running stages.
*/
- private def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = {
+ private def markStageAsFinished(
+ stage: Stage,
+ errorMessage: Option[String] = None,
+ willRetry: Boolean = false): Unit = {
val serviceTime = stage.latestInfo.submissionTime match {
case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0)
case _ => "Unknown"
@@ -1383,10 +1534,12 @@ class DAGScheduler(
stage.clearFailures()
} else {
stage.latestInfo.stageFailed(errorMessage.get)
- logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime))
+ logInfo(s"$stage (${stage.name}) failed in $serviceTime s due to ${errorMessage.get}")
}
- outputCommitCoordinator.stageEnd(stage.id)
+ if (!willRetry) {
+ outputCommitCoordinator.stageEnd(stage.id)
+ }
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
runningStages -= stage
}
@@ -1458,8 +1611,10 @@ class DAGScheduler(
}
if (ableToCancelStages) {
- job.listener.jobFailed(error)
+ // SPARK-15783 important to cleanup state first, just for tests where we have some asserts
+ // against the state. Otherwise we have a *little* bit of flakiness in the tests.
cleanupStateForJobAndIndependentStages(job)
+ job.listener.jobFailed(error)
listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error)))
}
}
@@ -1479,7 +1634,7 @@ class DAGScheduler(
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_, _, _] =>
- val mapStage = getShuffleMapStage(shufDep, stage.firstJobId)
+ val mapStage = getOrCreateShuffleMapStage(shufDep, stage.firstJobId)
if (!mapStage.isAvailable) {
waitingForVisit.push(mapStage.rdd)
} // Otherwise there's no need to follow the dependency back
@@ -1568,14 +1723,11 @@ class DAGScheduler(
}
def stop() {
- logInfo("Stopping DAGScheduler")
messageScheduler.shutdownNow()
eventProcessLoop.stop()
taskScheduler.stop()
}
- // Start the event thread and register the metrics source at the end of the constructor
- env.metricsSystem.registerSource(metricsSource)
eventProcessLoop.start()
}
@@ -1603,11 +1755,11 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
case MapStageSubmitted(jobId, dependency, callSite, listener, properties) =>
dagScheduler.handleMapStageSubmitted(jobId, dependency, callSite, listener, properties)
- case StageCancelled(stageId) =>
- dagScheduler.handleStageCancellation(stageId)
+ case StageCancelled(stageId, reason) =>
+ dagScheduler.handleStageCancellation(stageId, reason)
- case JobCancelled(jobId) =>
- dagScheduler.handleJobCancellation(jobId)
+ case JobCancelled(jobId, reason) =>
+ dagScheduler.handleJobCancellation(jobId, reason)
case JobGroupCancelled(groupId) =>
dagScheduler.handleJobGroupCancelled(groupId)
@@ -1618,8 +1770,12 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
case ExecutorAdded(execId, host) =>
dagScheduler.handleExecutorAdded(execId, host)
- case ExecutorLost(execId) =>
- dagScheduler.handleExecutorLost(execId, fetchFailed = false)
+ case ExecutorLost(execId, reason) =>
+ val filesLost = reason match {
+ case SlaveLost(_, true) => true
+ case _ => false
+ }
+ dagScheduler.handleExecutorLost(execId, filesLost)
case BeginEvent(task, taskInfo) =>
dagScheduler.handleBeginEvent(task, taskInfo)
@@ -1627,7 +1783,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
case GettingResultEvent(taskInfo) =>
dagScheduler.handleGetTaskResult(taskInfo)
- case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) =>
+ case completion: CompletionEvent =>
dagScheduler.handleTaskCompletion(completion)
case TaskSetFailed(taskSet, reason, exception) =>
@@ -1644,7 +1800,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
} catch {
case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t)
}
- dagScheduler.sc.stop()
+ dagScheduler.sc.stopInNewThread()
}
override def onStop(): Unit = {
@@ -1658,4 +1814,7 @@ private[spark] object DAGScheduler {
// this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
// as more failure events come in
val RESUBMIT_TIMEOUT = 200
+
+ // Number of consecutive stage attempts allowed before a stage is aborted
+ val DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS = 4
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index dda3b6cc7f96..cda0585f154a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -19,13 +19,11 @@ package org.apache.spark.scheduler
import java.util.Properties
-import scala.collection.Map
import scala.language.existentials
import org.apache.spark._
-import org.apache.spark.executor.TaskMetrics
import org.apache.spark.rdd.RDD
-import org.apache.spark.util.CallSite
+import org.apache.spark.util.{AccumulatorV2, CallSite}
/**
* Types of events that can be handled by the DAGScheduler. The DAGScheduler uses an event queue
@@ -55,9 +53,15 @@ private[scheduler] case class MapStageSubmitted(
properties: Properties = null)
extends DAGSchedulerEvent
-private[scheduler] case class StageCancelled(stageId: Int) extends DAGSchedulerEvent
+private[scheduler] case class StageCancelled(
+ stageId: Int,
+ reason: Option[String])
+ extends DAGSchedulerEvent
-private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent
+private[scheduler] case class JobCancelled(
+ jobId: Int,
+ reason: Option[String])
+ extends DAGSchedulerEvent
private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent
@@ -73,14 +77,14 @@ private[scheduler] case class CompletionEvent(
task: Task[_],
reason: TaskEndReason,
result: Any,
- accumUpdates: Map[Long, Any],
- taskInfo: TaskInfo,
- taskMetrics: TaskMetrics)
+ accumUpdates: Seq[AccumulatorV2[_, _]],
+ taskInfo: TaskInfo)
extends DAGSchedulerEvent
private[scheduler] case class ExecutorAdded(execId: String, host: String) extends DAGSchedulerEvent
-private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
+private[scheduler] case class ExecutorLost(execId: String, reason: ExecutorLossReason)
+ extends DAGSchedulerEvent
private[scheduler]
case class TaskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable])
diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
index 000a021a528c..a7dbf87915b2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
@@ -19,19 +19,21 @@ package org.apache.spark.scheduler
import java.io._
import java.net.URI
+import java.nio.charset.StandardCharsets
+import java.util.Locale
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import com.google.common.base.Charsets
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, FSDataOutputStream, Path}
import org.apache.hadoop.fs.permission.FsPermission
import org.json4s.JsonAST.JValue
import org.json4s.jackson.JsonMethods._
-import org.apache.spark.{Logging, SparkConf, SPARK_VERSION}
+import org.apache.spark.{SPARK_VERSION, SparkConf}
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.internal.Logging
import org.apache.spark.io.CompressionCodec
import org.apache.spark.util.{JsonProtocol, Utils}
@@ -77,14 +79,6 @@ private[spark] class EventLoggingListener(
// Only defined if the file system scheme is not local
private var hadoopDataStream: Option[FSDataOutputStream] = None
- // The Hadoop APIs have changed over time, so we use reflection to figure out
- // the correct method to use to flush a hadoop data stream. See SPARK-1518
- // for details.
- private val hadoopFlushMethod = {
- val cls = classOf[FSDataOutputStream]
- scala.util.Try(cls.getMethod("hflush")).getOrElse(cls.getMethod("sync"))
- }
-
private var writer: Option[PrintWriter] = None
// For testing. Keep track of all JSON serialized events that have been logged.
@@ -97,8 +91,8 @@ private[spark] class EventLoggingListener(
* Creates the log file in the configured log directory.
*/
def start() {
- if (!fileSystem.getFileStatus(new Path(logBaseDir)).isDir) {
- throw new IllegalArgumentException(s"Log directory $logBaseDir does not exist.")
+ if (!fileSystem.getFileStatus(new Path(logBaseDir)).isDirectory) {
+ throw new IllegalArgumentException(s"Log directory $logBaseDir is not a directory.")
}
val workingPath = logPath + IN_PROGRESS
@@ -107,11 +101,8 @@ private[spark] class EventLoggingListener(
val defaultFs = FileSystem.getDefaultUri(hadoopConf).getScheme
val isDefaultLocal = defaultFs == null || defaultFs == "file"
- if (shouldOverwrite && fileSystem.exists(path)) {
+ if (shouldOverwrite && fileSystem.delete(path, true)) {
logWarning(s"Event log $path already exists. Overwriting...")
- if (!fileSystem.delete(path, true)) {
- logWarning(s"Error deleting $path")
- }
}
/* The Hadoop LocalFileSystem (r1.0.4) has known issues with syncing (HADOOP-7844).
@@ -147,7 +138,7 @@ private[spark] class EventLoggingListener(
// scalastyle:on println
if (flushLogger) {
writer.foreach(_.flush())
- hadoopDataStream.foreach(hadoopFlushMethod.invoke(_))
+ hadoopDataStream.foreach(_.hflush())
}
if (testing) {
loggedEvents += eventJson
@@ -163,7 +154,9 @@ private[spark] class EventLoggingListener(
override def onTaskEnd(event: SparkListenerTaskEnd): Unit = logEvent(event)
- override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = logEvent(event)
+ override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = {
+ logEvent(redactEvent(event))
+ }
// Events that trigger a flush
override def onStageCompleted(event: SparkListenerStageCompleted): Unit = {
@@ -201,12 +194,34 @@ private[spark] class EventLoggingListener(
logEvent(event, flushLogger = true)
}
+ override def onExecutorBlacklisted(event: SparkListenerExecutorBlacklisted): Unit = {
+ logEvent(event, flushLogger = true)
+ }
+
+ override def onExecutorUnblacklisted(event: SparkListenerExecutorUnblacklisted): Unit = {
+ logEvent(event, flushLogger = true)
+ }
+
+ override def onNodeBlacklisted(event: SparkListenerNodeBlacklisted): Unit = {
+ logEvent(event, flushLogger = true)
+ }
+
+ override def onNodeUnblacklisted(event: SparkListenerNodeUnblacklisted): Unit = {
+ logEvent(event, flushLogger = true)
+ }
+
// No-op because logging every update would be overkill
override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = {}
// No-op because logging every update would be overkill
override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { }
+ override def onOtherEvent(event: SparkListenerEvent): Unit = {
+ if (event.logEvent) {
+ logEvent(event, flushLogger = true)
+ }
+ }
+
/**
* Stop logging events. The event log file will be renamed so that it loses the
* ".inprogress" suffix.
@@ -226,6 +241,28 @@ private[spark] class EventLoggingListener(
}
}
fileSystem.rename(new Path(logPath + IN_PROGRESS), target)
+ // touch file to ensure modtime is current across those filesystems where rename()
+ // does not set it, -and which support setTimes(); it's a no-op on most object stores
+ try {
+ fileSystem.setTimes(target, System.currentTimeMillis(), -1)
+ } catch {
+ case e: Exception => logDebug(s"failed to set time of $target", e)
+ }
+ }
+
+ private[spark] def redactEvent(
+ event: SparkListenerEnvironmentUpdate): SparkListenerEnvironmentUpdate = {
+ // environmentDetails maps a string descriptor to a set of properties
+ // Similar to:
+ // "JVM Information" -> jvmInformation,
+ // "Spark Properties" -> sparkProperties,
+ // ...
+ // where jvmInformation, sparkProperties, etc. are sequence of tuples.
+ // We go through the various of properties and redact sensitive information from them.
+ val redactedProps = event.environmentDetails.map{ case (name, props) =>
+ name -> Utils.redact(sparkConf, props)
+ }
+ SparkListenerEnvironmentUpdate(redactedProps)
}
}
@@ -234,8 +271,6 @@ private[spark] object EventLoggingListener extends Logging {
// Suffix applied to the names of files still being written by applications.
val IN_PROGRESS = ".inprogress"
val DEFAULT_LOG_DIR = "/tmp/spark-events"
- val SPARK_VERSION_KEY = "SPARK_VERSION"
- val COMPRESSION_CODEC_KEY = "COMPRESSION_CODEC"
private val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort)
@@ -251,7 +286,7 @@ private[spark] object EventLoggingListener extends Logging {
def initEventLog(logStream: OutputStream): Unit = {
val metadata = SparkListenerLogStart(SPARK_VERSION)
val metadataJson = compact(JsonProtocol.logStartToJson(metadata)) + "\n"
- logStream.write(metadataJson.getBytes(Charsets.UTF_8))
+ logStream.write(metadataJson.getBytes(StandardCharsets.UTF_8))
}
/**
@@ -288,7 +323,7 @@ private[spark] object EventLoggingListener extends Logging {
}
private def sanitize(str: String): String = {
- str.replaceAll("[ :/]", "-").replaceAll("[.${}'\"]", "_").toLowerCase
+ str.replaceAll("[ :/]", "-").replaceAll("[.${}'\"]", "_").toLowerCase(Locale.ROOT)
}
/**
@@ -297,12 +332,6 @@ private[spark] object EventLoggingListener extends Logging {
* @return input stream that holds one JSON record per line.
*/
def openEventLog(log: Path, fs: FileSystem): InputStream = {
- // It's not clear whether FileSystem.open() throws FileNotFoundException or just plain
- // IOException when a file does not exist, so try our best to throw a proper exception.
- if (!fs.exists(log)) {
- throw new FileNotFoundException(s"File $log does not exist.")
- }
-
val in = new BufferedInputStream(fs.open(log))
// Compression codec is encoded as an extension, e.g. app_123.lzf
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorFailuresInTaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorFailuresInTaskSet.scala
new file mode 100644
index 000000000000..70553d8be28b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorFailuresInTaskSet.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler
+
+import scala.collection.mutable.HashMap
+
+/**
+ * Small helper for tracking failed tasks for blacklisting purposes. Info on all failures on one
+ * executor, within one task set.
+ */
+private[scheduler] class ExecutorFailuresInTaskSet(val node: String) {
+ /**
+ * Mapping from index of the tasks in the taskset, to the number of times it has failed on this
+ * executor and the most recent failure time.
+ */
+ val taskToFailureCountAndFailureTime = HashMap[Int, (Int, Long)]()
+
+ def updateWithFailure(taskIndex: Int, failureTime: Long): Unit = {
+ val (prevFailureCount, prevFailureTime) =
+ taskToFailureCountAndFailureTime.getOrElse(taskIndex, (0, -1L))
+ // these times always come from the driver, so we don't need to worry about skew, but might
+ // as well still be defensive in case there is non-monotonicity in the clock
+ val newFailureTime = math.max(prevFailureTime, failureTime)
+ taskToFailureCountAndFailureTime(taskIndex) = (prevFailureCount + 1, newFailureTime)
+ }
+
+ def numUniqueTasksWithFailures: Int = taskToFailureCountAndFailureTime.size
+
+ /**
+ * Return the number of times this executor has failed on the given task index.
+ */
+ def getNumTaskFailures(index: Int): Int = {
+ taskToFailureCountAndFailureTime.getOrElse(index, (0, 0))._1
+ }
+
+ override def toString(): String = {
+ s"numUniqueTasksWithFailures = $numUniqueTasksWithFailures; " +
+ s"tasksToFailureCount = $taskToFailureCountAndFailureTime"
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala
index 47a5cbff4930..46a35b6a2eaf 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala
@@ -20,7 +20,7 @@ package org.apache.spark.scheduler
import org.apache.spark.executor.ExecutorExitCode
/**
- * Represents an explanation for a executor or whole slave failing or exiting.
+ * Represents an explanation for an executor or whole slave failing or exiting.
*/
private[spark]
class ExecutorLossReason(val message: String) extends Serializable {
@@ -40,6 +40,8 @@ private[spark] object ExecutorExited {
}
}
+private[spark] object ExecutorKilled extends ExecutorLossReason("Executor killed by driver.")
+
/**
* A loss reason that means we don't yet know why the executor exited.
*
@@ -49,6 +51,10 @@ private[spark] object ExecutorExited {
*/
private [spark] object LossReasonPending extends ExecutorLossReason("Pending loss reason.")
+/**
+ * @param _message human readable loss reason
+ * @param workerLost whether the worker is confirmed lost too (i.e. including shuffle service)
+ */
private[spark]
-case class SlaveLost(_message: String = "Slave lost")
+case class SlaveLost(_message: String = "Slave lost", workerLost: Boolean = false)
extends ExecutorLossReason(_message)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala b/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala
new file mode 100644
index 000000000000..47f3527a32c0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import org.apache.spark.SparkContext
+
+/**
+ * A cluster manager interface to plugin external scheduler.
+ */
+private[spark] trait ExternalClusterManager {
+
+ /**
+ * Check if this cluster manager instance can create scheduler components
+ * for a certain master URL.
+ * @param masterURL the master URL
+ * @return True if the cluster manager can create scheduler backend/
+ */
+ def canCreate(masterURL: String): Boolean
+
+ /**
+ * Create a task scheduler instance for the given SparkContext
+ * @param sc SparkContext
+ * @param masterURL the master URL
+ * @return TaskScheduler that will be responsible for task handling
+ */
+ def createTaskScheduler(sc: SparkContext, masterURL: String): TaskScheduler
+
+ /**
+ * Create a scheduler backend for the given SparkContext and scheduler. This is
+ * called after task scheduler is created using `ExternalClusterManager.createTaskScheduler()`.
+ * @param sc SparkContext
+ * @param masterURL the master URL
+ * @param scheduler TaskScheduler that will be used with the scheduler backend.
+ * @return SchedulerBackend that works with a TaskScheduler
+ */
+ def createSchedulerBackend(sc: SparkContext,
+ masterURL: String,
+ scheduler: TaskScheduler): SchedulerBackend
+
+ /**
+ * Initialize task scheduler and backend scheduler. This is called after the
+ * scheduler components are created
+ * @param scheduler TaskScheduler that will be responsible for task handling
+ * @param backend SchedulerBackend that works with a TaskScheduler
+ */
+ def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
index 0e438ab4366d..66ab9a52b778 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
@@ -26,9 +26,9 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
import org.apache.hadoop.mapreduce.Job
import org.apache.hadoop.util.ReflectionUtils
-import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.internal.Logging
/**
* :: DeveloperApi ::
@@ -57,11 +57,10 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
// Since we are not doing canonicalization of path, this can be wrong : like relative vs
// absolute path .. which is fine, this is best case effort to remove duplicates - right ?
override def equals(other: Any): Boolean = other match {
- case that: InputFormatInfo => {
+ case that: InputFormatInfo =>
// not checking config - that should be fine, right ?
this.inputFormatClazz == that.inputFormatClazz &&
this.path == that.path
- }
case _ => false
}
@@ -86,10 +85,9 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
}
}
catch {
- case e: ClassNotFoundException => {
+ case e: ClassNotFoundException =>
throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz +
" cannot be found ?", e)
- }
}
}
@@ -103,7 +101,7 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] =
ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], conf).asInstanceOf[
org.apache.hadoop.mapreduce.InputFormat[_, _]]
- val job = new Job(conf)
+ val job = Job.getInstance(conf)
val retval = new ArrayBuffer[SplitInfo]()
val list = instance.getSplits(job)
@@ -155,9 +153,9 @@ object InputFormatInfo {
a) For each host, count number of splits hosted on that host.
b) Decrement the currently allocated containers on that host.
- c) Compute rack info for each host and update rack -> count map based on (b).
+ c) Compute rack info for each host and update rack to count map based on (b).
d) Allocate nodes based on (c)
- e) On the allocation result, ensure that we dont allocate "too many" jobs on a single node
+ e) On the allocation result, ensure that we don't allocate "too many" jobs on a single node
(even if data locality on that is very high) : this is to prevent fragility of job if a
single (or small set of) hosts go down.
@@ -173,7 +171,7 @@ object InputFormatInfo {
for (inputSplit <- formats) {
val splits = inputSplit.findPreferredLocations()
- for (split <- splits){
+ for (split <- splits) {
val location = split.hostLocation
val set = nodeToSplit.getOrElseUpdate(location, new HashSet[SplitInfo])
set += split
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala b/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala
index 50c2b9acd609..e0f7c8f02132 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala
@@ -23,6 +23,6 @@ package org.apache.spark.scheduler
* job fails (and no further taskSucceeded events will happen).
*/
private[spark] trait JobListener {
- def taskSucceeded(index: Int, result: Any)
- def jobFailed(exception: Exception)
+ def taskSucceeded(index: Int, result: Any): Unit
+ def jobFailed(exception: Exception): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
deleted file mode 100644
index f96eb8ca0ae0..000000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ /dev/null
@@ -1,277 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler
-
-import java.io.{File, FileNotFoundException, IOException, PrintWriter}
-import java.text.SimpleDateFormat
-import java.util.{Date, Properties}
-
-import scala.collection.mutable.HashMap
-
-import org.apache.spark._
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.executor.TaskMetrics
-
-/**
- * :: DeveloperApi ::
- * A logger class to record runtime information for jobs in Spark. This class outputs one log file
- * for each Spark job, containing tasks start/stop and shuffle information. JobLogger is a subclass
- * of SparkListener, use addSparkListener to add JobLogger to a SparkContext after the SparkContext
- * is created. Note that each JobLogger only works for one SparkContext
- *
- * NOTE: The functionality of this class is heavily stripped down to accommodate for a general
- * refactor of the SparkListener interface. In its place, the EventLoggingListener is introduced
- * to log application information as SparkListenerEvents. To enable this functionality, set
- * spark.eventLog.enabled to true.
- */
-@DeveloperApi
-@deprecated("Log application information by setting spark.eventLog.enabled.", "1.0.0")
-class JobLogger(val user: String, val logDirName: String) extends SparkListener with Logging {
-
- def this() = this(System.getProperty("user.name", ""),
- String.valueOf(System.currentTimeMillis()))
-
- private val logDir =
- if (System.getenv("SPARK_LOG_DIR") != null) {
- System.getenv("SPARK_LOG_DIR")
- } else {
- "/tmp/spark-%s".format(user)
- }
-
- private val jobIdToPrintWriter = new HashMap[Int, PrintWriter]
- private val stageIdToJobId = new HashMap[Int, Int]
- private val jobIdToStageIds = new HashMap[Int, Seq[Int]]
- private val dateFormat = new ThreadLocal[SimpleDateFormat]() {
- override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
- }
-
- createLogDir()
-
- /** Create a folder for log files, the folder's name is the creation time of jobLogger */
- protected def createLogDir() {
- val dir = new File(logDir + "/" + logDirName + "/")
- if (dir.exists()) {
- return
- }
- if (!dir.mkdirs()) {
- // JobLogger should throw a exception rather than continue to construct this object.
- throw new IOException("create log directory error:" + logDir + "/" + logDirName + "/")
- }
- }
-
- /**
- * Create a log file for one job
- * @param jobId ID of the job
- * @throws FileNotFoundException Fail to create log file
- */
- protected def createLogWriter(jobId: Int) {
- try {
- val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobId)
- jobIdToPrintWriter += (jobId -> fileWriter)
- } catch {
- case e: FileNotFoundException => e.printStackTrace()
- }
- }
-
- /**
- * Close log file, and clean the stage relationship in stageIdToJobId
- * @param jobId ID of the job
- */
- protected def closeLogWriter(jobId: Int) {
- jobIdToPrintWriter.get(jobId).foreach { fileWriter =>
- fileWriter.close()
- jobIdToStageIds.get(jobId).foreach(_.foreach { stageId =>
- stageIdToJobId -= stageId
- })
- jobIdToPrintWriter -= jobId
- jobIdToStageIds -= jobId
- }
- }
-
- /**
- * Build up the maps that represent stage-job relationships
- * @param jobId ID of the job
- * @param stageIds IDs of the associated stages
- */
- protected def buildJobStageDependencies(jobId: Int, stageIds: Seq[Int]) = {
- jobIdToStageIds(jobId) = stageIds
- stageIds.foreach { stageId => stageIdToJobId(stageId) = jobId }
- }
-
- /**
- * Write info into log file
- * @param jobId ID of the job
- * @param info Info to be recorded
- * @param withTime Controls whether to record time stamp before the info, default is true
- */
- protected def jobLogInfo(jobId: Int, info: String, withTime: Boolean = true) {
- var writeInfo = info
- if (withTime) {
- val date = new Date(System.currentTimeMillis())
- writeInfo = dateFormat.get.format(date) + ": " + info
- }
- // scalastyle:off println
- jobIdToPrintWriter.get(jobId).foreach(_.println(writeInfo))
- // scalastyle:on println
- }
-
- /**
- * Write info into log file
- * @param stageId ID of the stage
- * @param info Info to be recorded
- * @param withTime Controls whether to record time stamp before the info, default is true
- */
- protected def stageLogInfo(stageId: Int, info: String, withTime: Boolean = true) {
- stageIdToJobId.get(stageId).foreach(jobId => jobLogInfo(jobId, info, withTime))
- }
-
- /**
- * Record task metrics into job log files, including execution info and shuffle metrics
- * @param stageId Stage ID of the task
- * @param status Status info of the task
- * @param taskInfo Task description info
- * @param taskMetrics Task running metrics
- */
- protected def recordTaskMetrics(stageId: Int, status: String,
- taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
- val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageId +
- " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
- " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
- val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
- val gcTime = " GC_TIME=" + taskMetrics.jvmGCTime
- val inputMetrics = taskMetrics.inputMetrics match {
- case Some(metrics) =>
- " READ_METHOD=" + metrics.readMethod.toString +
- " INPUT_BYTES=" + metrics.bytesRead
- case None => ""
- }
- val outputMetrics = taskMetrics.outputMetrics match {
- case Some(metrics) =>
- " OUTPUT_BYTES=" + metrics.bytesWritten
- case None => ""
- }
- val shuffleReadMetrics = taskMetrics.shuffleReadMetrics match {
- case Some(metrics) =>
- " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
- " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
- " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
- " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
- " REMOTE_BYTES_READ=" + metrics.remoteBytesRead +
- " LOCAL_BYTES_READ=" + metrics.localBytesRead
- case None => ""
- }
- val writeMetrics = taskMetrics.shuffleWriteMetrics match {
- case Some(metrics) =>
- " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten +
- " SHUFFLE_WRITE_TIME=" + metrics.shuffleWriteTime
- case None => ""
- }
- stageLogInfo(stageId, status + info + executorRunTime + gcTime + inputMetrics + outputMetrics +
- shuffleReadMetrics + writeMetrics)
- }
-
- /**
- * When stage is submitted, record stage submit info
- * @param stageSubmitted Stage submitted event
- */
- override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
- val stageInfo = stageSubmitted.stageInfo
- stageLogInfo(stageInfo.stageId, "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
- stageInfo.stageId, stageInfo.numTasks))
- }
-
- /**
- * When stage is completed, record stage completion status
- * @param stageCompleted Stage completed event
- */
- override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
- val stageId = stageCompleted.stageInfo.stageId
- if (stageCompleted.stageInfo.failureReason.isEmpty) {
- stageLogInfo(stageId, s"STAGE_ID=$stageId STATUS=COMPLETED")
- } else {
- stageLogInfo(stageId, s"STAGE_ID=$stageId STATUS=FAILED")
- }
- }
-
- /**
- * When task ends, record task completion status and metrics
- * @param taskEnd Task end event
- */
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
- val taskInfo = taskEnd.taskInfo
- var taskStatus = "TASK_TYPE=%s".format(taskEnd.taskType)
- val taskMetrics = if (taskEnd.taskMetrics != null) taskEnd.taskMetrics else TaskMetrics.empty
- taskEnd.reason match {
- case Success => taskStatus += " STATUS=SUCCESS"
- recordTaskMetrics(taskEnd.stageId, taskStatus, taskInfo, taskMetrics)
- case Resubmitted =>
- taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
- " STAGE_ID=" + taskEnd.stageId
- stageLogInfo(taskEnd.stageId, taskStatus)
- case FetchFailed(bmAddress, shuffleId, mapId, reduceId, message) =>
- taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
- taskEnd.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
- mapId + " REDUCE_ID=" + reduceId
- stageLogInfo(taskEnd.stageId, taskStatus)
- case _ =>
- }
- }
-
- /**
- * When job ends, recording job completion status and close log file
- * @param jobEnd Job end event
- */
- override def onJobEnd(jobEnd: SparkListenerJobEnd) {
- val jobId = jobEnd.jobId
- var info = "JOB_ID=" + jobId
- jobEnd.jobResult match {
- case JobSucceeded => info += " STATUS=SUCCESS"
- case JobFailed(exception) =>
- info += " STATUS=FAILED REASON="
- exception.getMessage.split("\\s+").foreach(info += _ + "_")
- case _ =>
- }
- jobLogInfo(jobId, info.substring(0, info.length - 1).toUpperCase)
- closeLogWriter(jobId)
- }
-
- /**
- * Record job properties into job log file
- * @param jobId ID of the job
- * @param properties Properties of the job
- */
- protected def recordJobProperties(jobId: Int, properties: Properties) {
- if (properties != null) {
- val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "")
- jobLogInfo(jobId, description, withTime = false)
- }
- }
-
- /**
- * When job starts, record job property and stage graph
- * @param jobStart Job start event
- */
- override def onJobStart(jobStart: SparkListenerJobStart) {
- val jobId = jobStart.jobId
- val properties = jobStart.properties
- createLogWriter(jobId)
- recordJobProperties(jobId, properties)
- buildJobStageDependencies(jobId, jobStart.stageIds)
- jobLogInfo(jobId, "JOB_ID=" + jobId + " STATUS=STARTED")
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala b/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala
index 4cd6cbe189aa..4a304a078d65 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala
@@ -29,5 +29,4 @@ sealed trait JobResult
@DeveloperApi
case object JobSucceeded extends JobResult
-@DeveloperApi
private[spark] case class JobFailed(exception: Exception) extends JobResult
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
index 382b09422a4a..65d7184231e2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
@@ -17,6 +17,12 @@
package org.apache.spark.scheduler
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.concurrent.{Future, Promise}
+
+import org.apache.spark.internal.Logging
+
/**
* An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their
* results to the given handler function.
@@ -26,19 +32,17 @@ private[spark] class JobWaiter[T](
val jobId: Int,
totalTasks: Int,
resultHandler: (Int, T) => Unit)
- extends JobListener {
-
- private var finishedTasks = 0
-
- // Is the job as a whole finished (succeeded or failed)?
- @volatile
- private var _jobFinished = totalTasks == 0
-
- def jobFinished: Boolean = _jobFinished
+ extends JobListener with Logging {
+ private val finishedTasks = new AtomicInteger(0)
// If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero
// partition RDDs), we set the jobResult directly to JobSucceeded.
- private var jobResult: JobResult = if (jobFinished) JobSucceeded else null
+ private val jobPromise: Promise[Unit] =
+ if (totalTasks == 0) Promise.successful(()) else Promise()
+
+ def jobFinished: Boolean = jobPromise.isCompleted
+
+ def completionFuture: Future[Unit] = jobPromise.future
/**
* Sends a signal to the DAGScheduler to cancel the job. The cancellation itself is handled
@@ -46,32 +50,23 @@ private[spark] class JobWaiter[T](
* will fail this job with a SparkException.
*/
def cancel() {
- dagScheduler.cancelJob(jobId)
+ dagScheduler.cancelJob(jobId, None)
}
- override def taskSucceeded(index: Int, result: Any): Unit = synchronized {
- if (_jobFinished) {
- throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
+ override def taskSucceeded(index: Int, result: Any): Unit = {
+ // resultHandler call must be synchronized in case resultHandler itself is not thread safe.
+ synchronized {
+ resultHandler(index, result.asInstanceOf[T])
}
- resultHandler(index, result.asInstanceOf[T])
- finishedTasks += 1
- if (finishedTasks == totalTasks) {
- _jobFinished = true
- jobResult = JobSucceeded
- this.notifyAll()
+ if (finishedTasks.incrementAndGet() == totalTasks) {
+ jobPromise.success(())
}
}
- override def jobFailed(exception: Exception): Unit = synchronized {
- _jobFinished = true
- jobResult = JobFailed(exception)
- this.notifyAll()
- }
-
- def awaitResult(): JobResult = synchronized {
- while (!_jobFinished) {
- this.wait()
+ override def jobFailed(exception: Exception): Unit = {
+ if (!jobPromise.tryFailure(exception)) {
+ logWarning("Ignore failure", exception)
}
- return jobResult
}
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala
index be23056e7d42..73e9141d344b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala
@@ -17,24 +17,198 @@
package org.apache.spark.scheduler
-import java.util.concurrent.atomic.AtomicBoolean
+import java.util.concurrent._
+import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong}
-import org.apache.spark.util.AsynchronousListenerBus
+import scala.util.DynamicVariable
+
+import org.apache.spark.{SparkContext, SparkException}
+import org.apache.spark.internal.config._
+import org.apache.spark.util.Utils
/**
* Asynchronously passes SparkListenerEvents to registered SparkListeners.
*
- * Until start() is called, all posted events are only buffered. Only after this listener bus
+ * Until `start()` is called, all posted events are only buffered. Only after this listener bus
* has started will events be actually propagated to all attached listeners. This listener bus
- * is stopped when it receives a SparkListenerShutdown event, which is posted using stop().
+ * is stopped when `stop()` is called, and it will drop further events after stopping.
*/
-private[spark] class LiveListenerBus
- extends AsynchronousListenerBus[SparkListener, SparkListenerEvent]("SparkListenerBus")
- with SparkListenerBus {
+private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends SparkListenerBus {
+
+ self =>
+
+ import LiveListenerBus._
+
+ // Cap the capacity of the event queue so we get an explicit error (rather than
+ // an OOM exception) if it's perpetually being added to more quickly than it's being drained.
+ private lazy val EVENT_QUEUE_CAPACITY = validateAndGetQueueSize()
+ private lazy val eventQueue = new LinkedBlockingQueue[SparkListenerEvent](EVENT_QUEUE_CAPACITY)
+
+ private def validateAndGetQueueSize(): Int = {
+ val queueSize = sparkContext.conf.get(LISTENER_BUS_EVENT_QUEUE_SIZE)
+ if (queueSize <= 0) {
+ throw new SparkException("spark.scheduler.listenerbus.eventqueue.size must be > 0!")
+ }
+ queueSize
+ }
+
+ // Indicate if `start()` is called
+ private val started = new AtomicBoolean(false)
+ // Indicate if `stop()` is called
+ private val stopped = new AtomicBoolean(false)
+
+ /** A counter for dropped events. It will be reset every time we log it. */
+ private val droppedEventsCounter = new AtomicLong(0L)
+
+ /** When `droppedEventsCounter` was logged last time in milliseconds. */
+ @volatile private var lastReportTimestamp = 0L
+
+ // Indicate if we are processing some event
+ // Guarded by `self`
+ private var processingEvent = false
private val logDroppedEvent = new AtomicBoolean(false)
- override def onDropEvent(event: SparkListenerEvent): Unit = {
+ // A counter that represents the number of events produced and consumed in the queue
+ private val eventLock = new Semaphore(0)
+
+ private val listenerThread = new Thread(name) {
+ setDaemon(true)
+ override def run(): Unit = Utils.tryOrStopSparkContext(sparkContext) {
+ LiveListenerBus.withinListenerThread.withValue(true) {
+ while (true) {
+ eventLock.acquire()
+ self.synchronized {
+ processingEvent = true
+ }
+ try {
+ val event = eventQueue.poll
+ if (event == null) {
+ // Get out of the while loop and shutdown the daemon thread
+ if (!stopped.get) {
+ throw new IllegalStateException("Polling `null` from eventQueue means" +
+ " the listener bus has been stopped. So `stopped` must be true")
+ }
+ return
+ }
+ postToAll(event)
+ } finally {
+ self.synchronized {
+ processingEvent = false
+ }
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Start sending events to attached listeners.
+ *
+ * This first sends out all buffered events posted before this listener bus has started, then
+ * listens for any additional events asynchronously while the listener bus is still running.
+ * This should only be called once.
+ *
+ */
+ def start(): Unit = {
+ if (started.compareAndSet(false, true)) {
+ listenerThread.start()
+ } else {
+ throw new IllegalStateException(s"$name already started!")
+ }
+ }
+
+ def post(event: SparkListenerEvent): Unit = {
+ if (stopped.get) {
+ // Drop further events to make `listenerThread` exit ASAP
+ logDebug(s"$name has already stopped! Dropping event $event")
+ return
+ }
+ val eventAdded = eventQueue.offer(event)
+ if (eventAdded) {
+ eventLock.release()
+ } else {
+ onDropEvent(event)
+ droppedEventsCounter.incrementAndGet()
+ }
+
+ val droppedEvents = droppedEventsCounter.get
+ if (droppedEvents > 0) {
+ // Don't log too frequently
+ if (System.currentTimeMillis() - lastReportTimestamp >= 60 * 1000) {
+ // There may be multiple threads trying to decrease droppedEventsCounter.
+ // Use "compareAndSet" to make sure only one thread can win.
+ // And if another thread is increasing droppedEventsCounter, "compareAndSet" will fail and
+ // then that thread will update it.
+ if (droppedEventsCounter.compareAndSet(droppedEvents, 0)) {
+ val prevLastReportTimestamp = lastReportTimestamp
+ lastReportTimestamp = System.currentTimeMillis()
+ logWarning(s"Dropped $droppedEvents SparkListenerEvents since " +
+ new java.util.Date(prevLastReportTimestamp))
+ }
+ }
+ }
+ }
+
+ /**
+ * For testing only. Wait until there are no more events in the queue, or until the specified
+ * time has elapsed. Throw `TimeoutException` if the specified time elapsed before the queue
+ * emptied.
+ * Exposed for testing.
+ */
+ @throws(classOf[TimeoutException])
+ def waitUntilEmpty(timeoutMillis: Long): Unit = {
+ val finishTime = System.currentTimeMillis + timeoutMillis
+ while (!queueIsEmpty) {
+ if (System.currentTimeMillis > finishTime) {
+ throw new TimeoutException(
+ s"The event queue is not empty after $timeoutMillis milliseconds")
+ }
+ /* Sleep rather than using wait/notify, because this is used only for testing and
+ * wait/notify add overhead in the general case. */
+ Thread.sleep(10)
+ }
+ }
+
+ /**
+ * For testing only. Return whether the listener daemon thread is still alive.
+ * Exposed for testing.
+ */
+ def listenerThreadIsAlive: Boolean = listenerThread.isAlive
+
+ /**
+ * Return whether the event queue is empty.
+ *
+ * The use of synchronized here guarantees that all events that once belonged to this queue
+ * have already been processed by all attached listeners, if this returns true.
+ */
+ private def queueIsEmpty: Boolean = synchronized { eventQueue.isEmpty && !processingEvent }
+
+ /**
+ * Stop the listener bus. It will wait until the queued events have been processed, but drop the
+ * new events after stopping.
+ */
+ def stop(): Unit = {
+ if (!started.get()) {
+ throw new IllegalStateException(s"Attempted to stop $name that has not yet started!")
+ }
+ if (stopped.compareAndSet(false, true)) {
+ // Call eventLock.release() so that listenerThread will poll `null` from `eventQueue` and know
+ // `stop` is called.
+ eventLock.release()
+ listenerThread.join()
+ } else {
+ // Keep quiet
+ }
+ }
+
+ /**
+ * If the event queue exceeds its capacity, the new events will be dropped. The subclasses will be
+ * notified with the dropped events.
+ *
+ * Note: `onDropEvent` can be called in any thread.
+ */
+ def onDropEvent(event: SparkListenerEvent): Unit = {
if (logDroppedEvent.compareAndSet(false, true)) {
// Only log the following message once to avoid duplicated annoying logs.
logError("Dropping SparkListenerEvent because no remaining room in event queue. " +
@@ -42,5 +216,13 @@ private[spark] class LiveListenerBus
"the rate at which tasks are being started by the scheduler.")
}
}
+}
+private[spark] object LiveListenerBus {
+ // Allows for Context to check whether stop() call is made within listener thread
+ val withinListenerThread: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false)
+
+ /** The thread name of Spark listener bus */
+ val name = "SparkListenerBus"
}
+
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
index 180c8d1827e1..2ec2f2031aa4 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -19,8 +19,14 @@ package org.apache.spark.scheduler
import java.io.{Externalizable, ObjectInput, ObjectOutput}
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.config
import org.apache.spark.storage.BlockManagerId
-import org.apache.spark.util.collection.BitSet
import org.apache.spark.util.Utils
/**
@@ -120,35 +126,41 @@ private[spark] class CompressedMapStatus(
}
/**
- * A [[MapStatus]] implementation that only stores the average size of non-empty blocks,
- * plus a bitmap for tracking which blocks are empty. During serialization, this bitmap
- * is compressed.
+ * A [[MapStatus]] implementation that stores the accurate size of huge blocks, which are larger
+ * than spark.shuffle.accurateBlockThreshold. It stores the average size of other non-empty blocks,
+ * plus a bitmap for tracking which blocks are empty.
*
* @param loc location where the task is being executed
* @param numNonEmptyBlocks the number of non-empty blocks
* @param emptyBlocks a bitmap tracking which blocks are empty
- * @param avgSize average size of the non-empty blocks
+ * @param avgSize average size of the non-empty and non-huge blocks
+ * @param hugeBlockSizes sizes of huge blocks by their reduceId.
*/
private[spark] class HighlyCompressedMapStatus private (
private[this] var loc: BlockManagerId,
private[this] var numNonEmptyBlocks: Int,
- private[this] var emptyBlocks: BitSet,
- private[this] var avgSize: Long)
+ private[this] var emptyBlocks: RoaringBitmap,
+ private[this] var avgSize: Long,
+ private var hugeBlockSizes: Map[Int, Byte])
extends MapStatus with Externalizable {
// loc could be null when the default constructor is called during deserialization
- require(loc == null || avgSize > 0 || numNonEmptyBlocks == 0,
+ require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0,
"Average size can only be zero for map stages that produced no output")
- protected def this() = this(null, -1, null, -1) // For deserialization only
+ protected def this() = this(null, -1, null, -1, null) // For deserialization only
override def location: BlockManagerId = loc
override def getSizeForBlock(reduceId: Int): Long = {
- if (emptyBlocks.get(reduceId)) {
+ assert(hugeBlockSizes != null)
+ if (emptyBlocks.contains(reduceId)) {
0
} else {
- avgSize
+ hugeBlockSizes.get(reduceId) match {
+ case Some(size) => MapStatus.decompressSize(size)
+ case None => avgSize
+ }
}
}
@@ -156,13 +168,26 @@ private[spark] class HighlyCompressedMapStatus private (
loc.writeExternal(out)
emptyBlocks.writeExternal(out)
out.writeLong(avgSize)
+ out.writeInt(hugeBlockSizes.size)
+ hugeBlockSizes.foreach { kv =>
+ out.writeInt(kv._1)
+ out.writeByte(kv._2)
+ }
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
loc = BlockManagerId(in)
- emptyBlocks = new BitSet
+ emptyBlocks = new RoaringBitmap()
emptyBlocks.readExternal(in)
avgSize = in.readLong()
+ val count = in.readInt()
+ val hugeBlockSizesArray = mutable.ArrayBuffer[Tuple2[Int, Byte]]()
+ (0 until count).foreach { _ =>
+ val block = in.readInt()
+ val size = in.readByte()
+ hugeBlockSizesArray += Tuple2(block, size)
+ }
+ hugeBlockSizes = hugeBlockSizesArray.toMap
}
}
@@ -172,27 +197,42 @@ private[spark] object HighlyCompressedMapStatus {
// block as being non-empty (or vice-versa) when using the average block size.
var i = 0
var numNonEmptyBlocks: Int = 0
- var totalSize: Long = 0
+ var numSmallBlocks: Int = 0
+ var totalSmallBlockSize: Long = 0
// From a compression standpoint, it shouldn't matter whether we track empty or non-empty
// blocks. From a performance standpoint, we benefit from tracking empty blocks because
// we expect that there will be far fewer of them, so we will perform fewer bitmap insertions.
+ val emptyBlocks = new RoaringBitmap()
val totalNumBlocks = uncompressedSizes.length
- val emptyBlocks = new BitSet(totalNumBlocks)
+ val threshold = Option(SparkEnv.get)
+ .map(_.conf.get(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD))
+ .getOrElse(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.defaultValue.get)
+ val hugeBlockSizesArray = ArrayBuffer[Tuple2[Int, Byte]]()
while (i < totalNumBlocks) {
- var size = uncompressedSizes(i)
+ val size = uncompressedSizes(i)
if (size > 0) {
numNonEmptyBlocks += 1
- totalSize += size
+ // Huge blocks are not included in the calculation for average size, thus size for smaller
+ // blocks is more accurate.
+ if (size < threshold) {
+ totalSmallBlockSize += size
+ numSmallBlocks += 1
+ } else {
+ hugeBlockSizesArray += Tuple2(i, MapStatus.compressSize(uncompressedSizes(i)))
+ }
} else {
- emptyBlocks.set(i)
+ emptyBlocks.add(i)
}
i += 1
}
- val avgSize = if (numNonEmptyBlocks > 0) {
- totalSize / numNonEmptyBlocks
+ val avgSize = if (numSmallBlocks > 0) {
+ totalSmallBlockSize / numSmallBlocks
} else {
0
}
- new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize)
+ emptyBlocks.trim()
+ emptyBlocks.runOptimize()
+ new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize,
+ hugeBlockSizesArray.toMap)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
index 4d146678174f..b382d623806e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
@@ -20,12 +20,18 @@ package org.apache.spark.scheduler
import scala.collection.mutable
import org.apache.spark._
-import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, RpcEndpoint}
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
+import org.apache.spark.util.{RpcUtils, ThreadUtils}
private sealed trait OutputCommitCoordinationMessage extends Serializable
private case object StopCoordinator extends OutputCommitCoordinationMessage
-private case class AskPermissionToCommitOutput(stage: Int, partition: Int, attemptNumber: Int)
+private case class AskPermissionToCommitOutput(
+ stage: Int,
+ stageAttempt: Int,
+ partition: Int,
+ attemptNumber: Int)
/**
* Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins"
@@ -43,28 +49,34 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
// Initialized by SparkEnv
var coordinatorRef: Option[RpcEndpointRef] = None
- private type StageId = Int
- private type PartitionId = Int
- private type TaskAttemptNumber = Int
+ // Class used to identify a committer. The task ID for a committer is implicitly defined by
+ // the partition being processed, but the coordinator needs to keep track of both the stage
+ // attempt and the task attempt, because in some situations the same task may be running
+ // concurrently in two different attempts of the same stage.
+ private case class TaskIdentifier(stageAttempt: Int, taskAttempt: Int)
- private val NO_AUTHORIZED_COMMITTER: TaskAttemptNumber = -1
+ private case class StageState(numPartitions: Int) {
+ val authorizedCommitters = Array.fill[TaskIdentifier](numPartitions)(null)
+ val failures = mutable.Map[Int, mutable.Set[TaskIdentifier]]()
+ }
/**
- * Map from active stages's id => partition id => task attempt with exclusive lock on committing
- * output for that partition.
+ * Map from active stages's id => authorized task attempts for each partition id, which hold an
+ * exclusive lock on committing task output for that partition, as well as any known failed
+ * attempts in the stage.
*
* Entries are added to the top-level map when stages start and are removed they finish
* (either successfully or unsuccessfully).
*
* Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance.
*/
- private val authorizedCommittersByStage = mutable.Map[StageId, Array[TaskAttemptNumber]]()
+ private val stageStates = mutable.Map[Int, StageState]()
/**
* Returns whether the OutputCommitCoordinator's internal data structures are all empty.
*/
def isEmpty: Boolean = {
- authorizedCommittersByStage.isEmpty
+ stageStates.isEmpty
}
/**
@@ -81,13 +93,15 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
* @return true if this task is authorized to commit, false otherwise
*/
def canCommit(
- stage: StageId,
- partition: PartitionId,
- attemptNumber: TaskAttemptNumber): Boolean = {
- val msg = AskPermissionToCommitOutput(stage, partition, attemptNumber)
+ stage: Int,
+ stageAttempt: Int,
+ partition: Int,
+ attemptNumber: Int): Boolean = {
+ val msg = AskPermissionToCommitOutput(stage, stageAttempt, partition, attemptNumber)
coordinatorRef match {
case Some(endpointRef) =>
- endpointRef.askWithRetry[Boolean](msg)
+ ThreadUtils.awaitResult(endpointRef.ask[Boolean](msg),
+ RpcUtils.askRpcTimeout(conf).duration)
case None =>
logError(
"canCommit called after coordinator was stopped (is SparkEnv shutdown in progress)?")
@@ -96,48 +110,54 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
}
/**
- * Called by the DAGScheduler when a stage starts.
+ * Called by the DAGScheduler when a stage starts. Initializes the stage's state if it hasn't
+ * yet been initialized.
*
* @param stage the stage id.
* @param maxPartitionId the maximum partition id that could appear in this stage's tasks (i.e.
* the maximum possible value of `context.partitionId`).
*/
- private[scheduler] def stageStart(
- stage: StageId,
- maxPartitionId: Int): Unit = {
- val arr = new Array[TaskAttemptNumber](maxPartitionId + 1)
- java.util.Arrays.fill(arr, NO_AUTHORIZED_COMMITTER)
- synchronized {
- authorizedCommittersByStage(stage) = arr
+ private[scheduler] def stageStart(stage: Int, maxPartitionId: Int): Unit = synchronized {
+ stageStates.get(stage) match {
+ case Some(state) =>
+ require(state.authorizedCommitters.length == maxPartitionId + 1)
+ logInfo(s"Reusing state from previous attempt of stage $stage.")
+
+ case _ =>
+ stageStates(stage) = new StageState(maxPartitionId + 1)
}
}
// Called by DAGScheduler
- private[scheduler] def stageEnd(stage: StageId): Unit = synchronized {
- authorizedCommittersByStage.remove(stage)
+ private[scheduler] def stageEnd(stage: Int): Unit = synchronized {
+ stageStates.remove(stage)
}
// Called by DAGScheduler
private[scheduler] def taskCompleted(
- stage: StageId,
- partition: PartitionId,
- attemptNumber: TaskAttemptNumber,
+ stage: Int,
+ stageAttempt: Int,
+ partition: Int,
+ attemptNumber: Int,
reason: TaskEndReason): Unit = synchronized {
- val authorizedCommitters = authorizedCommittersByStage.getOrElse(stage, {
+ val stageState = stageStates.getOrElse(stage, {
logDebug(s"Ignoring task completion for completed stage")
return
})
reason match {
case Success =>
// The task output has been committed successfully
- case denied: TaskCommitDenied =>
- logInfo(s"Task was denied committing, stage: $stage, partition: $partition, " +
- s"attempt: $attemptNumber")
- case otherReason =>
- if (authorizedCommitters(partition) == attemptNumber) {
+ case _: TaskCommitDenied =>
+ logInfo(s"Task was denied committing, stage: $stage.$stageAttempt, " +
+ s"partition: $partition, attempt: $attemptNumber")
+ case _ =>
+ // Mark the attempt as failed to blacklist from future commit protocol
+ val taskId = TaskIdentifier(stageAttempt, attemptNumber)
+ stageState.failures.getOrElseUpdate(partition, mutable.Set()) += taskId
+ if (stageState.authorizedCommitters(partition) == taskId) {
logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " +
s"partition=$partition) failed; clearing lock")
- authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER
+ stageState.authorizedCommitters(partition) = null
}
}
}
@@ -146,34 +166,48 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
if (isDriver) {
coordinatorRef.foreach(_ send StopCoordinator)
coordinatorRef = None
- authorizedCommittersByStage.clear()
+ stageStates.clear()
}
}
// Marked private[scheduler] instead of private so this can be mocked in tests
private[scheduler] def handleAskPermissionToCommit(
- stage: StageId,
- partition: PartitionId,
- attemptNumber: TaskAttemptNumber): Boolean = synchronized {
- authorizedCommittersByStage.get(stage) match {
- case Some(authorizedCommitters) =>
- authorizedCommitters(partition) match {
- case NO_AUTHORIZED_COMMITTER =>
- logDebug(s"Authorizing attemptNumber=$attemptNumber to commit for stage=$stage, " +
- s"partition=$partition")
- authorizedCommitters(partition) = attemptNumber
- true
- case existingCommitter =>
- logDebug(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage, " +
- s"partition=$partition; existingCommitter = $existingCommitter")
- false
+ stage: Int,
+ stageAttempt: Int,
+ partition: Int,
+ attemptNumber: Int): Boolean = synchronized {
+ stageStates.get(stage) match {
+ case Some(state) if attemptFailed(state, stageAttempt, partition, attemptNumber) =>
+ logInfo(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " +
+ s"task attempt $attemptNumber already marked as failed.")
+ false
+ case Some(state) =>
+ val existing = state.authorizedCommitters(partition)
+ if (existing == null) {
+ logDebug(s"Commit allowed for stage=$stage.$stageAttempt, partition=$partition, " +
+ s"task attempt $attemptNumber")
+ state.authorizedCommitters(partition) = TaskIdentifier(stageAttempt, attemptNumber)
+ true
+ } else {
+ logDebug(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " +
+ s"already committed by $existing")
+ false
}
case None =>
- logDebug(s"Stage $stage has completed, so not allowing attempt number $attemptNumber of" +
- s"partition $partition to commit")
+ logDebug(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " +
+ "stage already marked as completed.")
false
}
}
+
+ private def attemptFailed(
+ stageState: StageState,
+ stageAttempt: Int,
+ partition: Int,
+ attempt: Int): Boolean = synchronized {
+ val failInfo = TaskIdentifier(stageAttempt, attempt)
+ stageState.failures.get(partition).exists(_.contains(failInfo))
+ }
}
private[spark] object OutputCommitCoordinator {
@@ -183,6 +217,8 @@ private[spark] object OutputCommitCoordinator {
override val rpcEnv: RpcEnv, outputCommitCoordinator: OutputCommitCoordinator)
extends RpcEndpoint with Logging {
+ logDebug("init") // force eager creation of logger
+
override def receive: PartialFunction[Any, Unit] = {
case StopCoordinator =>
logInfo("OutputCommitCoordinator stopped!")
@@ -190,9 +226,10 @@ private[spark] object OutputCommitCoordinator {
}
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
- case AskPermissionToCommitOutput(stage, partition, attemptNumber) =>
+ case AskPermissionToCommitOutput(stage, stageAttempt, partition, attemptNumber) =>
context.reply(
- outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, attemptNumber))
+ outputCommitCoordinator.handleAskPermissionToCommit(stage, stageAttempt, partition,
+ attemptNumber))
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
index 551e39a81b69..1181371ab425 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
@@ -22,39 +22,40 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue}
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.Logging
+import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
/**
- * An Schedulable entity that represent collection of Pools or TaskSetManagers
+ * A Schedulable entity that represents collection of Pools or TaskSetManagers
*/
-
private[spark] class Pool(
val poolName: String,
val schedulingMode: SchedulingMode,
initMinShare: Int,
initWeight: Int)
- extends Schedulable
- with Logging {
+ extends Schedulable with Logging {
val schedulableQueue = new ConcurrentLinkedQueue[Schedulable]
val schedulableNameToSchedulable = new ConcurrentHashMap[String, Schedulable]
- var weight = initWeight
- var minShare = initMinShare
+ val weight = initWeight
+ val minShare = initMinShare
var runningTasks = 0
- var priority = 0
+ val priority = 0
// A pool's stage id is used to break the tie in scheduling.
var stageId = -1
- var name = poolName
+ val name = poolName
var parent: Pool = null
- var taskSetSchedulingAlgorithm: SchedulingAlgorithm = {
+ private val taskSetSchedulingAlgorithm: SchedulingAlgorithm = {
schedulingMode match {
case SchedulingMode.FAIR =>
new FairSchedulingAlgorithm()
case SchedulingMode.FIFO =>
new FIFOSchedulingAlgorithm()
+ case _ =>
+ val msg = s"Unsupported scheduling mode: $schedulingMode. Use FAIR or FIFO instead."
+ throw new IllegalArgumentException(msg)
}
}
@@ -87,10 +88,10 @@ private[spark] class Pool(
schedulableQueue.asScala.foreach(_.executorLost(executorId, host, reason))
}
- override def checkSpeculatableTasks(): Boolean = {
+ override def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean = {
var shouldRevive = false
for (schedulable <- schedulableQueue.asScala) {
- shouldRevive |= schedulable.checkSpeculatableTasks()
+ shouldRevive |= schedulable.checkSpeculatableTasks(minTimeToSpeculation)
}
shouldRevive
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala
index c6d957b65f3f..26a6a3effc9a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala
@@ -17,14 +17,16 @@
package org.apache.spark.scheduler
-import java.io.{InputStream, IOException}
+import java.io.{EOFException, InputStream, IOException}
import scala.io.Source
import com.fasterxml.jackson.core.JsonParseException
+import com.fasterxml.jackson.databind.exc.UnrecognizedPropertyException
import org.json4s.jackson.JsonMethods._
-import org.apache.spark.Logging
+import org.apache.spark.internal.Logging
+import org.apache.spark.scheduler.ReplayListenerBus._
import org.apache.spark.util.JsonProtocol
/**
@@ -43,32 +45,69 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging {
* @param sourceName Filename (or other source identifier) from whence @logData is being read
* @param maybeTruncated Indicate whether log file might be truncated (some abnormal situations
* encountered, log file might not finished writing) or not
+ * @param eventsFilter Filter function to select JSON event strings in the log data stream that
+ * should be parsed and replayed. When not specified, all event strings in the log data
+ * are parsed and replayed.
*/
def replay(
logData: InputStream,
sourceName: String,
- maybeTruncated: Boolean = false): Unit = {
+ maybeTruncated: Boolean = false,
+ eventsFilter: ReplayEventsFilter = SELECT_ALL_FILTER): Unit = {
+ val lines = Source.fromInputStream(logData).getLines()
+ replay(lines, sourceName, maybeTruncated, eventsFilter)
+ }
+
+ /**
+ * Overloaded variant of [[replay()]] which accepts an iterator of lines instead of an
+ * [[InputStream]]. Exposed for use by custom ApplicationHistoryProvider implementations.
+ */
+ def replay(
+ lines: Iterator[String],
+ sourceName: String,
+ maybeTruncated: Boolean,
+ eventsFilter: ReplayEventsFilter): Unit = {
var currentLine: String = null
- var lineNumber: Int = 1
+ var lineNumber: Int = 0
+
try {
- val lines = Source.fromInputStream(logData).getLines()
- while (lines.hasNext) {
- currentLine = lines.next()
+ val lineEntries = lines
+ .zipWithIndex
+ .filter { case (line, _) => eventsFilter(line) }
+
+ while (lineEntries.hasNext) {
try {
+ val entry = lineEntries.next()
+
+ currentLine = entry._1
+ lineNumber = entry._2 + 1
+
postToAll(JsonProtocol.sparkEventFromJson(parse(currentLine)))
} catch {
+ case e: ClassNotFoundException if KNOWN_REMOVED_CLASSES.contains(e.getMessage) =>
+ // Ignore events generated by Structured Streaming in Spark 2.0.0 and 2.0.1.
+ // It's safe since no place uses them.
+ logWarning(s"Dropped incompatible Structured Streaming log: $currentLine")
+ case e: UnrecognizedPropertyException if e.getMessage != null && e.getMessage.startsWith(
+ "Unrecognized field \"queryStatus\" " +
+ "(class org.apache.spark.sql.streaming.StreamingQueryListener$") =>
+ // Ignore events generated by Structured Streaming in Spark 2.0.2
+ // It's safe since no place uses them.
+ logWarning(s"Dropped incompatible Structured Streaming log: $currentLine")
case jpe: JsonParseException =>
// We can only ignore exception from last line of the file that might be truncated
- if (!maybeTruncated || lines.hasNext) {
+ // the last entry may not be the very last line in the event log, but we treat it
+ // as such in a best effort to replay the given input
+ if (!maybeTruncated || lineEntries.hasNext) {
throw jpe
} else {
logWarning(s"Got JsonParseException from log file $sourceName" +
s" at line $lineNumber, the file might not have finished writing cleanly.")
}
}
- lineNumber += 1
}
} catch {
+ case _: EOFException if maybeTruncated =>
case ioe: IOException =>
throw ioe
case e: Exception =>
@@ -78,3 +117,21 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging {
}
}
+
+
+private[spark] object ReplayListenerBus {
+
+ type ReplayEventsFilter = (String) => Boolean
+
+ // utility filter that selects all event logs during replay
+ val SELECT_ALL_FILTER: ReplayEventsFilter = { (eventString: String) => true }
+
+ /**
+ * Classes that were removed. Structured Streaming doesn't use them any more. However, parsing
+ * old json may fail and we can just ignore these failures.
+ */
+ val KNOWN_REMOVED_CLASSES = Set(
+ "org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgress",
+ "org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminated"
+ )
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index fb693721a9cb..e36c759a4255 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -17,9 +17,10 @@
package org.apache.spark.scheduler
-import java.nio.ByteBuffer
-
import java.io._
+import java.lang.management.ManagementFactory
+import java.nio.ByteBuffer
+import java.util.Properties
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
@@ -31,6 +32,7 @@ import org.apache.spark.rdd.RDD
* See [[Task]] for more information.
*
* @param stageId id of the stage this task belongs to
+ * @param stageAttemptId attempt id of the stage this task belongs to
* @param taskBinary broadcasted version of the serialized RDD and the function to apply on each
* partition of the given RDD. Once deserialized, the type should be
* (RDD[T], (TaskContext, Iterator[T]) => U).
@@ -38,7 +40,15 @@ import org.apache.spark.rdd.RDD
* @param locs preferred task execution locations for locality scheduling
* @param outputId index of the task in this job (a job can launch tasks on only a subset of the
* input RDD's partitions).
- */
+ * @param localProperties copy of thread-local properties set by the user on the driver side.
+ * @param serializedTaskMetrics a `TaskMetrics` that is created and serialized on the driver side
+ * and sent to executor side.
+ *
+ * The parameters below are optional:
+ * @param jobId id of the job this task belongs to
+ * @param appId id of the app this task belongs to
+ * @param appAttemptId attempt id of the app this task belongs to
+ */
private[spark] class ResultTask[T, U](
stageId: Int,
stageAttemptId: Int,
@@ -46,8 +56,13 @@ private[spark] class ResultTask[T, U](
partition: Partition,
locs: Seq[TaskLocation],
val outputId: Int,
- internalAccumulators: Seq[Accumulator[Long]])
- extends Task[U](stageId, stageAttemptId, partition.index, internalAccumulators)
+ localProperties: Properties,
+ serializedTaskMetrics: Array[Byte],
+ jobId: Option[Int] = None,
+ appId: Option[String] = None,
+ appAttemptId: Option[String] = None)
+ extends Task[U](stageId, stageAttemptId, partition.index, localProperties, serializedTaskMetrics,
+ jobId, appId, appAttemptId)
with Serializable {
@transient private[this] val preferredLocs: Seq[TaskLocation] = {
@@ -56,13 +71,19 @@ private[spark] class ResultTask[T, U](
override def runTask(context: TaskContext): U = {
// Deserialize the RDD and the func using the broadcast variables.
+ val threadMXBean = ManagementFactory.getThreadMXBean
val deserializeStartTime = System.currentTimeMillis()
+ val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
+ threadMXBean.getCurrentThreadCpuTime
+ } else 0L
val ser = SparkEnv.get.closureSerializer.newInstance()
val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
_executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
+ _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
+ threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
+ } else 0L
- metrics = Some(context.taskMetrics)
func(context, rdd.iterator(partition, context))
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala
index ab00bc8f0bf4..b6f88ed0a93a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala
@@ -43,6 +43,6 @@ private[spark] trait Schedulable {
def removeSchedulable(schedulable: Schedulable): Unit
def getSchedulableByName(name: String): Schedulable
def executorLost(executorId: String, host: String, reason: ExecutorLossReason): Unit
- def checkSpeculatableTasks(): Boolean
+ def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean
def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager]
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
index 6c5827f75e63..5f3c280ec31e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
@@ -18,11 +18,14 @@
package org.apache.spark.scheduler
import java.io.{FileInputStream, InputStream}
-import java.util.{NoSuchElementException, Properties}
+import java.util.{Locale, NoSuchElementException, Properties}
-import scala.xml.XML
+import scala.util.control.NonFatal
+import scala.xml.{Node, XML}
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.Logging
+import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.util.Utils
/**
@@ -33,9 +36,9 @@ import org.apache.spark.util.Utils
private[spark] trait SchedulableBuilder {
def rootPool: Pool
- def buildPools()
+ def buildPools(): Unit
- def addTaskSetManager(manager: Schedulable, properties: Properties)
+ def addTaskSetManager(manager: Schedulable, properties: Properties): Unit
}
private[spark] class FIFOSchedulableBuilder(val rootPool: Pool)
@@ -53,7 +56,8 @@ private[spark] class FIFOSchedulableBuilder(val rootPool: Pool)
private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf)
extends SchedulableBuilder with Logging {
- val schedulerAllocFile = conf.getOption("spark.scheduler.allocation.file")
+ val SCHEDULER_ALLOCATION_FILE_PROPERTY = "spark.scheduler.allocation.file"
+ val schedulerAllocFile = conf.getOption(SCHEDULER_ALLOCATION_FILE_PROPERTY)
val DEFAULT_SCHEDULER_FILE = "fairscheduler.xml"
val FAIR_SCHEDULER_PROPERTIES = "spark.scheduler.pool"
val DEFAULT_POOL_NAME = "default"
@@ -67,19 +71,35 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf)
val DEFAULT_WEIGHT = 1
override def buildPools() {
- var is: Option[InputStream] = None
+ var fileData: Option[(InputStream, String)] = None
try {
- is = Option {
- schedulerAllocFile.map { f =>
- new FileInputStream(f)
- }.getOrElse {
- Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_SCHEDULER_FILE)
+ fileData = schedulerAllocFile.map { f =>
+ val fis = new FileInputStream(f)
+ logInfo(s"Creating Fair Scheduler pools from $f")
+ Some((fis, f))
+ }.getOrElse {
+ val is = Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_SCHEDULER_FILE)
+ if (is != null) {
+ logInfo(s"Creating Fair Scheduler pools from default file: $DEFAULT_SCHEDULER_FILE")
+ Some((is, DEFAULT_SCHEDULER_FILE))
+ } else {
+ logWarning("Fair Scheduler configuration file not found so jobs will be scheduled in " +
+ s"FIFO order. To use fair scheduling, configure pools in $DEFAULT_SCHEDULER_FILE or " +
+ s"set $SCHEDULER_ALLOCATION_FILE_PROPERTY to a file that contains the configuration.")
+ None
}
}
- is.foreach { i => buildFairSchedulerPool(i) }
+ fileData.foreach { case (is, fileName) => buildFairSchedulerPool(is, fileName) }
+ } catch {
+ case NonFatal(t) =>
+ val defaultMessage = "Error while building the fair scheduler pools"
+ val message = fileData.map { case (is, fileName) => s"$defaultMessage from $fileName" }
+ .getOrElse(defaultMessage)
+ logError(message, t)
+ throw t
} finally {
- is.foreach(_.close())
+ fileData.foreach { case (is, fileName) => is.close() }
}
// finally create "default" pool
@@ -91,62 +111,93 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf)
val pool = new Pool(DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE,
DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)
rootPool.addSchedulable(pool)
- logInfo("Created default pool %s, schedulingMode: %s, minShare: %d, weight: %d".format(
+ logInfo("Created default pool: %s, schedulingMode: %s, minShare: %d, weight: %d".format(
DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT))
}
}
- private def buildFairSchedulerPool(is: InputStream) {
+ private def buildFairSchedulerPool(is: InputStream, fileName: String) {
val xml = XML.load(is)
for (poolNode <- (xml \\ POOLS_PROPERTY)) {
val poolName = (poolNode \ POOL_NAME_PROPERTY).text
- var schedulingMode = DEFAULT_SCHEDULING_MODE
- var minShare = DEFAULT_MINIMUM_SHARE
- var weight = DEFAULT_WEIGHT
-
- val xmlSchedulingMode = (poolNode \ SCHEDULING_MODE_PROPERTY).text
- if (xmlSchedulingMode != "") {
- try {
- schedulingMode = SchedulingMode.withName(xmlSchedulingMode)
- } catch {
- case e: NoSuchElementException =>
- logWarning("Error xml schedulingMode, using default schedulingMode")
- }
- }
- val xmlMinShare = (poolNode \ MINIMUM_SHARES_PROPERTY).text
- if (xmlMinShare != "") {
- minShare = xmlMinShare.toInt
- }
+ val schedulingMode = getSchedulingModeValue(poolNode, poolName,
+ DEFAULT_SCHEDULING_MODE, fileName)
+ val minShare = getIntValue(poolNode, poolName, MINIMUM_SHARES_PROPERTY,
+ DEFAULT_MINIMUM_SHARE, fileName)
+ val weight = getIntValue(poolNode, poolName, WEIGHT_PROPERTY,
+ DEFAULT_WEIGHT, fileName)
- val xmlWeight = (poolNode \ WEIGHT_PROPERTY).text
- if (xmlWeight != "") {
- weight = xmlWeight.toInt
- }
+ rootPool.addSchedulable(new Pool(poolName, schedulingMode, minShare, weight))
- val pool = new Pool(poolName, schedulingMode, minShare, weight)
- rootPool.addSchedulable(pool)
- logInfo("Created pool %s, schedulingMode: %s, minShare: %d, weight: %d".format(
+ logInfo("Created pool: %s, schedulingMode: %s, minShare: %d, weight: %d".format(
poolName, schedulingMode, minShare, weight))
}
}
+ private def getSchedulingModeValue(
+ poolNode: Node,
+ poolName: String,
+ defaultValue: SchedulingMode,
+ fileName: String): SchedulingMode = {
+
+ val xmlSchedulingMode =
+ (poolNode \ SCHEDULING_MODE_PROPERTY).text.trim.toUpperCase(Locale.ROOT)
+ val warningMessage = s"Unsupported schedulingMode: $xmlSchedulingMode found in " +
+ s"Fair Scheduler configuration file: $fileName, using " +
+ s"the default schedulingMode: $defaultValue for pool: $poolName"
+ try {
+ if (SchedulingMode.withName(xmlSchedulingMode) != SchedulingMode.NONE) {
+ SchedulingMode.withName(xmlSchedulingMode)
+ } else {
+ logWarning(warningMessage)
+ defaultValue
+ }
+ } catch {
+ case e: NoSuchElementException =>
+ logWarning(warningMessage)
+ defaultValue
+ }
+ }
+
+ private def getIntValue(
+ poolNode: Node,
+ poolName: String,
+ propertyName: String,
+ defaultValue: Int,
+ fileName: String): Int = {
+
+ val data = (poolNode \ propertyName).text.trim
+ try {
+ data.toInt
+ } catch {
+ case e: NumberFormatException =>
+ logWarning(s"Error while loading fair scheduler configuration from $fileName: " +
+ s"$propertyName is blank or invalid: $data, using the default $propertyName: " +
+ s"$defaultValue for pool: $poolName")
+ defaultValue
+ }
+ }
+
override def addTaskSetManager(manager: Schedulable, properties: Properties) {
- var poolName = DEFAULT_POOL_NAME
- var parentPool = rootPool.getSchedulableByName(poolName)
- if (properties != null) {
- poolName = properties.getProperty(FAIR_SCHEDULER_PROPERTIES, DEFAULT_POOL_NAME)
- parentPool = rootPool.getSchedulableByName(poolName)
- if (parentPool == null) {
- // we will create a new pool that user has configured in app
- // instead of being defined in xml file
- parentPool = new Pool(poolName, DEFAULT_SCHEDULING_MODE,
- DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)
- rootPool.addSchedulable(parentPool)
- logInfo("Created pool %s, schedulingMode: %s, minShare: %d, weight: %d".format(
- poolName, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT))
+ val poolName = if (properties != null) {
+ properties.getProperty(FAIR_SCHEDULER_PROPERTIES, DEFAULT_POOL_NAME)
+ } else {
+ DEFAULT_POOL_NAME
}
+ var parentPool = rootPool.getSchedulableByName(poolName)
+ if (parentPool == null) {
+ // we will create a new pool that user has configured in app
+ // instead of being defined in xml file
+ parentPool = new Pool(poolName, DEFAULT_SCHEDULING_MODE,
+ DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)
+ rootPool.addSchedulable(parentPool)
+ logWarning(s"A job was submitted with scheduler pool $poolName, which has not been " +
+ "configured. This can happen when the file that pools are read from isn't set, or " +
+ s"when that file doesn't contain $poolName. Created $poolName with default " +
+ s"configuration (schedulingMode: $DEFAULT_SCHEDULING_MODE, " +
+ s"minShare: $DEFAULT_MINIMUM_SHARE, weight: $DEFAULT_WEIGHT)")
}
parentPool.addSchedulable(manager)
logInfo("Added task set " + manager.name + " tasks to pool " + poolName)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
index 8801a761afae..22db3350abfa 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
@@ -30,8 +30,21 @@ private[spark] trait SchedulerBackend {
def reviveOffers(): Unit
def defaultParallelism(): Int
- def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit =
+ /**
+ * Requests that an executor kills a running task.
+ *
+ * @param taskId Id of the task.
+ * @param executorId Id of the executor the task is running on.
+ * @param interruptThread Whether the executor should interrupt the task thread.
+ * @param reason The reason for the task kill.
+ */
+ def killTask(
+ taskId: Long,
+ executorId: String,
+ interruptThread: Boolean,
+ reason: String): Unit =
throw new UnsupportedOperationException
+
def isReady(): Boolean = true
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala
index 864941d468af..18ebbbe78a5b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala
@@ -36,11 +36,7 @@ private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm {
val stageId2 = s2.stageId
res = math.signum(stageId1 - stageId2)
}
- if (res < 0) {
- true
- } else {
- false
- }
+ res < 0
}
}
@@ -52,12 +48,12 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm {
val runningTasks2 = s2.runningTasks
val s1Needy = runningTasks1 < minShare1
val s2Needy = runningTasks2 < minShare2
- val minShareRatio1 = runningTasks1.toDouble / math.max(minShare1, 1.0).toDouble
- val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0).toDouble
+ val minShareRatio1 = runningTasks1.toDouble / math.max(minShare1, 1.0)
+ val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0)
val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble
val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble
- var compare: Int = 0
+ var compare = 0
if (s1Needy && !s2Needy) {
return true
} else if (!s1Needy && s2Needy) {
@@ -67,7 +63,6 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm {
} else {
compare = taskToWeightRatio1.compareTo(taskToWeightRatio2)
}
-
if (compare < 0) {
true
} else if (compare > 0) {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
index 51416e5ce97f..05f650fbf5df 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
@@ -17,9 +17,10 @@
package org.apache.spark.scheduler
-import org.apache.spark.ShuffleDependency
+import scala.collection.mutable.HashSet
+
+import org.apache.spark.{MapOutputTrackerMaster, ShuffleDependency, SparkEnv}
import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.CallSite
/**
@@ -40,19 +41,22 @@ private[spark] class ShuffleMapStage(
parents: List[Stage],
firstJobId: Int,
callSite: CallSite,
- val shuffleDep: ShuffleDependency[_, _, _])
+ val shuffleDep: ShuffleDependency[_, _, _],
+ mapOutputTrackerMaster: MapOutputTrackerMaster)
extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) {
private[this] var _mapStageJobs: List[ActiveJob] = Nil
- private[this] var _numAvailableOutputs: Int = 0
-
/**
- * List of [[MapStatus]] for each partition. The index of the array is the map partition id,
- * and each value in the array is the list of possible [[MapStatus]] for a partition
- * (a single task might run multiple times).
+ * Partitions that either haven't yet been computed, or that were computed on an executor
+ * that has since been lost, so should be re-computed. This variable is used by the
+ * DAGScheduler to determine when a stage has completed. Task successes in both the active
+ * attempt for the stage or in earlier attempts for this stage can cause paritition ids to get
+ * removed from pendingPartitions. As a result, this variable may be inconsistent with the pending
+ * tasks in the TaskSetManager for the active attempt for the stage (the partitions stored here
+ * will always be a subset of the partitions that the TaskSetManager thinks are pending).
*/
- private[this] val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
+ val pendingPartitions = new HashSet[Int]
override def toString: String = "ShuffleMapStage " + id
@@ -75,69 +79,18 @@ private[spark] class ShuffleMapStage(
/**
* Number of partitions that have shuffle outputs.
* When this reaches [[numPartitions]], this map stage is ready.
- * This should be kept consistent as `outputLocs.filter(!_.isEmpty).size`.
*/
- def numAvailableOutputs: Int = _numAvailableOutputs
+ def numAvailableOutputs: Int = mapOutputTrackerMaster.getNumAvailableOutputs(shuffleDep.shuffleId)
/**
* Returns true if the map stage is ready, i.e. all partitions have shuffle outputs.
- * This should be the same as `outputLocs.contains(Nil)`.
*/
- def isAvailable: Boolean = _numAvailableOutputs == numPartitions
+ def isAvailable: Boolean = numAvailableOutputs == numPartitions
/** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */
override def findMissingPartitions(): Seq[Int] = {
- val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty)
- assert(missing.size == numPartitions - _numAvailableOutputs,
- s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}")
- missing
- }
-
- def addOutputLoc(partition: Int, status: MapStatus): Unit = {
- val prevList = outputLocs(partition)
- outputLocs(partition) = status :: prevList
- if (prevList == Nil) {
- _numAvailableOutputs += 1
- }
- }
-
- def removeOutputLoc(partition: Int, bmAddress: BlockManagerId): Unit = {
- val prevList = outputLocs(partition)
- val newList = prevList.filterNot(_.location == bmAddress)
- outputLocs(partition) = newList
- if (prevList != Nil && newList == Nil) {
- _numAvailableOutputs -= 1
- }
- }
-
- /**
- * Returns an array of [[MapStatus]] (index by partition id). For each partition, the returned
- * value contains only one (i.e. the first) [[MapStatus]]. If there is no entry for the partition,
- * that position is filled with null.
- */
- def outputLocInMapOutputTrackerFormat(): Array[MapStatus] = {
- outputLocs.map(_.headOption.orNull)
- }
-
- /**
- * Removes all shuffle outputs associated with this executor. Note that this will also remove
- * outputs which are served by an external shuffle server (if one exists), as they are still
- * registered with this execId.
- */
- def removeOutputsOnExecutor(execId: String): Unit = {
- var becameUnavailable = false
- for (partition <- 0 until numPartitions) {
- val prevList = outputLocs(partition)
- val newList = prevList.filterNot(_.location.executorId == execId)
- outputLocs(partition) = newList
- if (prevList != Nil && newList == Nil) {
- becameUnavailable = true
- _numAvailableOutputs -= 1
- }
- }
- if (becameUnavailable) {
- logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format(
- this, execId, _numAvailableOutputs, numPartitions, isAvailable))
- }
+ mapOutputTrackerMaster
+ .findMissingPartitions(shuffleDep.shuffleId)
+ .getOrElse(0 until numPartitions)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index f478f9982afe..7a25c47e2cab 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -17,26 +17,38 @@
package org.apache.spark.scheduler
+import java.lang.management.ManagementFactory
import java.nio.ByteBuffer
+import java.util.Properties
import scala.language.existentials
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.shuffle.ShuffleWriter
/**
-* A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner
-* specified in the ShuffleDependency).
-*
-* See [[org.apache.spark.scheduler.Task]] for more information.
-*
+ * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner
+ * specified in the ShuffleDependency).
+ *
+ * See [[org.apache.spark.scheduler.Task]] for more information.
+ *
* @param stageId id of the stage this task belongs to
+ * @param stageAttemptId attempt id of the stage this task belongs to
* @param taskBinary broadcast version of the RDD and the ShuffleDependency. Once deserialized,
* the type should be (RDD[_], ShuffleDependency[_, _, _]).
* @param partition partition of the RDD this task is associated with
* @param locs preferred task execution locations for locality scheduling
+ * @param localProperties copy of thread-local properties set by the user on the driver side.
+ * @param serializedTaskMetrics a `TaskMetrics` that is created and serialized on the driver side
+ * and sent to executor side.
+ *
+ * The parameters below are optional:
+ * @param jobId id of the job this task belongs to
+ * @param appId id of the app this task belongs to
+ * @param appAttemptId attempt id of the app this task belongs to
*/
private[spark] class ShuffleMapTask(
stageId: Int,
@@ -44,13 +56,18 @@ private[spark] class ShuffleMapTask(
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient private var locs: Seq[TaskLocation],
- internalAccumulators: Seq[Accumulator[Long]])
- extends Task[MapStatus](stageId, stageAttemptId, partition.index, internalAccumulators)
+ localProperties: Properties,
+ serializedTaskMetrics: Array[Byte],
+ jobId: Option[Int] = None,
+ appId: Option[String] = None,
+ appAttemptId: Option[String] = None)
+ extends Task[MapStatus](stageId, stageAttemptId, partition.index, localProperties,
+ serializedTaskMetrics, jobId, appId, appAttemptId)
with Logging {
/** A constructor used only in test suites. This does not require passing in an RDD. */
def this(partitionId: Int) {
- this(0, 0, null, new Partition { override def index: Int = 0 }, null, null)
+ this(0, 0, null, new Partition { override def index: Int = 0 }, null, new Properties, null)
}
@transient private val preferredLocs: Seq[TaskLocation] = {
@@ -59,13 +76,19 @@ private[spark] class ShuffleMapTask(
override def runTask(context: TaskContext): MapStatus = {
// Deserialize the RDD using the broadcast variable.
+ val threadMXBean = ManagementFactory.getThreadMXBean
val deserializeStartTime = System.currentTimeMillis()
+ val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
+ threadMXBean.getCurrentThreadCpuTime
+ } else 0L
val ser = SparkEnv.get.closureSerializer.newInstance()
val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
_executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
+ _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
+ threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
+ } else 0L
- metrics = Some(context.taskMetrics)
var writer: ShuffleWriter[Any, Any] = null
try {
val manager = SparkEnv.get.shuffleManager
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 896f1743332f..bc2e53071668 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -18,19 +18,25 @@
package org.apache.spark.scheduler
import java.util.Properties
+import javax.annotation.Nullable
import scala.collection.Map
-import scala.collection.mutable
-import org.apache.spark.{Logging, TaskEndReason}
+import com.fasterxml.jackson.annotation.JsonTypeInfo
+
+import org.apache.spark.{SparkConf, TaskEndReason}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo}
-import org.apache.spark.util.{Distribution, Utils}
+import org.apache.spark.ui.SparkUI
@DeveloperApi
-sealed trait SparkListenerEvent
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event")
+trait SparkListenerEvent {
+ /* Whether output this event to the event log */
+ protected[spark] def logEvent: Boolean = true
+}
@DeveloperApi
case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Properties = null)
@@ -53,7 +59,8 @@ case class SparkListenerTaskEnd(
taskType: String,
reason: TaskEndReason,
taskInfo: TaskInfo,
- taskMetrics: TaskMetrics)
+ // may be null if the task has failed
+ @Nullable taskMetrics: TaskMetrics)
extends SparkListenerEvent
@DeveloperApi
@@ -80,8 +87,13 @@ case class SparkListenerEnvironmentUpdate(environmentDetails: Map[String, Seq[(S
extends SparkListenerEvent
@DeveloperApi
-case class SparkListenerBlockManagerAdded(time: Long, blockManagerId: BlockManagerId, maxMem: Long)
- extends SparkListenerEvent
+case class SparkListenerBlockManagerAdded(
+ time: Long,
+ blockManagerId: BlockManagerId,
+ maxMem: Long,
+ maxOnHeapMem: Option[Long] = None,
+ maxOffHeapMem: Option[Long] = None) extends SparkListenerEvent {
+}
@DeveloperApi
case class SparkListenerBlockManagerRemoved(time: Long, blockManagerId: BlockManagerId)
@@ -98,18 +110,40 @@ case class SparkListenerExecutorAdded(time: Long, executorId: String, executorIn
case class SparkListenerExecutorRemoved(time: Long, executorId: String, reason: String)
extends SparkListenerEvent
+@DeveloperApi
+case class SparkListenerExecutorBlacklisted(
+ time: Long,
+ executorId: String,
+ taskFailures: Int)
+ extends SparkListenerEvent
+
+@DeveloperApi
+case class SparkListenerExecutorUnblacklisted(time: Long, executorId: String)
+ extends SparkListenerEvent
+
+@DeveloperApi
+case class SparkListenerNodeBlacklisted(
+ time: Long,
+ hostId: String,
+ executorFailures: Int)
+ extends SparkListenerEvent
+
+@DeveloperApi
+case class SparkListenerNodeUnblacklisted(time: Long, hostId: String)
+ extends SparkListenerEvent
+
@DeveloperApi
case class SparkListenerBlockUpdated(blockUpdatedInfo: BlockUpdatedInfo) extends SparkListenerEvent
/**
* Periodic updates from executors.
* @param execId executor id
- * @param taskMetrics sequence of (task id, stage id, stage attempt, metrics)
+ * @param accumUpdates sequence of (taskId, stageId, stageAttemptId, accumUpdates)
*/
@DeveloperApi
case class SparkListenerExecutorMetricsUpdate(
execId: String,
- taskMetrics: Seq[(Long, Int, Int, TaskMetrics)])
+ accumUpdates: Seq[(Long, Int, Int, Seq[AccumulableInfo])])
extends SparkListenerEvent
@DeveloperApi
@@ -131,258 +165,194 @@ case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent
private[spark] case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent
/**
- * :: DeveloperApi ::
- * Interface for listening to events from the Spark scheduler. Note that this is an internal
- * interface which might change in different Spark releases. Java clients should extend
- * {@link JavaSparkListener}
+ * Interface for creating history listeners defined in other modules like SQL, which are used to
+ * rebuild the history UI.
*/
-@DeveloperApi
-trait SparkListener {
+private[spark] trait SparkHistoryListenerFactory {
+ /**
+ * Create listeners used to rebuild the history UI.
+ */
+ def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener]
+}
+
+
+/**
+ * Interface for listening to events from the Spark scheduler. Most applications should probably
+ * extend SparkListener or SparkFirehoseListener directly, rather than implementing this class.
+ *
+ * Note that this is an internal interface which might change in different Spark releases.
+ */
+private[spark] trait SparkListenerInterface {
+
/**
* Called when a stage completes successfully or fails, with information on the completed stage.
*/
- def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { }
+ def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit
/**
* Called when a stage is submitted
*/
- def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { }
+ def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit
/**
* Called when a task starts
*/
- def onTaskStart(taskStart: SparkListenerTaskStart) { }
+ def onTaskStart(taskStart: SparkListenerTaskStart): Unit
/**
* Called when a task begins remotely fetching its result (will not be called for tasks that do
* not need to fetch the result remotely).
*/
- def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { }
+ def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult): Unit
/**
* Called when a task ends
*/
- def onTaskEnd(taskEnd: SparkListenerTaskEnd) { }
+ def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit
/**
* Called when a job starts
*/
- def onJobStart(jobStart: SparkListenerJobStart) { }
+ def onJobStart(jobStart: SparkListenerJobStart): Unit
/**
* Called when a job ends
*/
- def onJobEnd(jobEnd: SparkListenerJobEnd) { }
+ def onJobEnd(jobEnd: SparkListenerJobEnd): Unit
/**
* Called when environment properties have been updated
*/
- def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { }
+ def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate): Unit
/**
* Called when a new block manager has joined
*/
- def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded) { }
+ def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit
/**
* Called when an existing block manager has been removed
*/
- def onBlockManagerRemoved(blockManagerRemoved: SparkListenerBlockManagerRemoved) { }
+ def onBlockManagerRemoved(blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit
/**
* Called when an RDD is manually unpersisted by the application
*/
- def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD) { }
+ def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit
/**
* Called when the application starts
*/
- def onApplicationStart(applicationStart: SparkListenerApplicationStart) { }
+ def onApplicationStart(applicationStart: SparkListenerApplicationStart): Unit
/**
* Called when the application ends
*/
- def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd) { }
+ def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit
/**
* Called when the driver receives task metrics from an executor in a heartbeat.
*/
- def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { }
+ def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit
/**
* Called when the driver registers a new executor.
*/
- def onExecutorAdded(executorAdded: SparkListenerExecutorAdded) { }
+ def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit
/**
* Called when the driver removes an executor.
*/
- def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved) { }
+ def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit
+
+ /**
+ * Called when the driver blacklists an executor for a Spark application.
+ */
+ def onExecutorBlacklisted(executorBlacklisted: SparkListenerExecutorBlacklisted): Unit
+
+ /**
+ * Called when the driver re-enables a previously blacklisted executor.
+ */
+ def onExecutorUnblacklisted(executorUnblacklisted: SparkListenerExecutorUnblacklisted): Unit
+
+ /**
+ * Called when the driver blacklists a node for a Spark application.
+ */
+ def onNodeBlacklisted(nodeBlacklisted: SparkListenerNodeBlacklisted): Unit
+
+ /**
+ * Called when the driver re-enables a previously blacklisted node.
+ */
+ def onNodeUnblacklisted(nodeUnblacklisted: SparkListenerNodeUnblacklisted): Unit
/**
* Called when the driver receives a block update info.
*/
- def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated) { }
+ def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit
+
+ /**
+ * Called when other events like SQL-specific events are posted.
+ */
+ def onOtherEvent(event: SparkListenerEvent): Unit
}
+
/**
* :: DeveloperApi ::
- * Simple SparkListener that logs a few summary statistics when each stage completes
+ * A default implementation for `SparkListenerInterface` that has no-op implementations for
+ * all callbacks.
+ *
+ * Note that this is an internal interface which might change in different Spark releases.
*/
@DeveloperApi
-class StatsReportListener extends SparkListener with Logging {
-
- import org.apache.spark.scheduler.StatsReportListener._
-
- private val taskInfoMetrics = mutable.Buffer[(TaskInfo, TaskMetrics)]()
-
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
- val info = taskEnd.taskInfo
- val metrics = taskEnd.taskMetrics
- if (info != null && metrics != null) {
- taskInfoMetrics += ((info, metrics))
- }
- }
-
- override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
- implicit val sc = stageCompleted
- this.logInfo("Finished stage: " + stageCompleted.stageInfo)
- showMillisDistribution("task runtime:", (info, _) => Some(info.duration), taskInfoMetrics)
-
- // Shuffle write
- showBytesDistribution("shuffle bytes written:",
- (_, metric) => metric.shuffleWriteMetrics.map(_.shuffleBytesWritten), taskInfoMetrics)
-
- // Fetch & I/O
- showMillisDistribution("fetch wait time:",
- (_, metric) => metric.shuffleReadMetrics.map(_.fetchWaitTime), taskInfoMetrics)
- showBytesDistribution("remote bytes read:",
- (_, metric) => metric.shuffleReadMetrics.map(_.remoteBytesRead), taskInfoMetrics)
- showBytesDistribution("task result size:",
- (_, metric) => Some(metric.resultSize), taskInfoMetrics)
-
- // Runtime breakdown
- val runtimePcts = taskInfoMetrics.map { case (info, metrics) =>
- RuntimePercentage(info.duration, metrics)
- }
- showDistribution("executor (non-fetch) time pct: ",
- Distribution(runtimePcts.map(_.executorPct * 100)), "%2.0f %%")
- showDistribution("fetch wait time pct: ",
- Distribution(runtimePcts.flatMap(_.fetchPct.map(_ * 100))), "%2.0f %%")
- showDistribution("other time pct: ", Distribution(runtimePcts.map(_.other * 100)), "%2.0f %%")
- taskInfoMetrics.clear()
- }
+abstract class SparkListener extends SparkListenerInterface {
+ override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { }
-}
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { }
-private[spark] object StatsReportListener extends Logging {
-
- // For profiling, the extremes are more interesting
- val percentiles = Array[Int](0, 5, 10, 25, 50, 75, 90, 95, 100)
- val probabilities = percentiles.map(_ / 100.0)
- val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%"
-
- def extractDoubleDistribution(
- taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)],
- getMetric: (TaskInfo, TaskMetrics) => Option[Double]): Option[Distribution] = {
- Distribution(taskInfoMetrics.flatMap { case (info, metric) => getMetric(info, metric) })
- }
-
- // Is there some way to setup the types that I can get rid of this completely?
- def extractLongDistribution(
- taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)],
- getMetric: (TaskInfo, TaskMetrics) => Option[Long]): Option[Distribution] = {
- extractDoubleDistribution(
- taskInfoMetrics,
- (info, metric) => { getMetric(info, metric).map(_.toDouble) })
- }
-
- def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) {
- val stats = d.statCounter
- val quantiles = d.getQuantiles(probabilities).map(formatNumber)
- logInfo(heading + stats)
- logInfo(percentilesHeader)
- logInfo("\t" + quantiles.mkString("\t"))
- }
-
- def showDistribution(
- heading: String,
- dOpt: Option[Distribution],
- formatNumber: Double => String) {
- dOpt.foreach { d => showDistribution(heading, d, formatNumber)}
- }
-
- def showDistribution(heading: String, dOpt: Option[Distribution], format: String) {
- def f(d: Double): String = format.format(d)
- showDistribution(heading, dOpt, f _)
- }
-
- def showDistribution(
- heading: String,
- format: String,
- getMetric: (TaskInfo, TaskMetrics) => Option[Double],
- taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) {
- showDistribution(heading, extractDoubleDistribution(taskInfoMetrics, getMetric), format)
- }
-
- def showBytesDistribution(
- heading: String,
- getMetric: (TaskInfo, TaskMetrics) => Option[Long],
- taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) {
- showBytesDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric))
- }
-
- def showBytesDistribution(heading: String, dOpt: Option[Distribution]) {
- dOpt.foreach { dist => showBytesDistribution(heading, dist) }
- }
-
- def showBytesDistribution(heading: String, dist: Distribution) {
- showDistribution(heading, dist, (d => Utils.bytesToString(d.toLong)): Double => String)
- }
-
- def showMillisDistribution(heading: String, dOpt: Option[Distribution]) {
- showDistribution(heading, dOpt,
- (d => StatsReportListener.millisToString(d.toLong)): Double => String)
- }
-
- def showMillisDistribution(
- heading: String,
- getMetric: (TaskInfo, TaskMetrics) => Option[Long],
- taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) {
- showMillisDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric))
- }
-
- val seconds = 1000L
- val minutes = seconds * 60
- val hours = minutes * 60
+ override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { }
- /**
- * Reformat a time interval in milliseconds to a prettier format for output
- */
- def millisToString(ms: Long): String = {
- val (size, units) =
- if (ms > hours) {
- (ms.toDouble / hours, "hours")
- } else if (ms > minutes) {
- (ms.toDouble / minutes, "min")
- } else if (ms > seconds) {
- (ms.toDouble / seconds, "s")
- } else {
- (ms.toDouble, "ms")
- }
- "%.1f %s".format(size, units)
- }
-}
+ override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult): Unit = { }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { }
+
+ override def onJobStart(jobStart: SparkListenerJobStart): Unit = { }
+
+ override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { }
+
+ override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate): Unit = { }
+
+ override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = { }
+
+ override def onBlockManagerRemoved(
+ blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = { }
+
+ override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = { }
+
+ override def onApplicationStart(applicationStart: SparkListenerApplicationStart): Unit = { }
+
+ override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { }
+
+ override def onExecutorMetricsUpdate(
+ executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = { }
+
+ override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = { }
+
+ override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { }
+
+ override def onExecutorBlacklisted(
+ executorBlacklisted: SparkListenerExecutorBlacklisted): Unit = { }
+
+ override def onExecutorUnblacklisted(
+ executorUnblacklisted: SparkListenerExecutorUnblacklisted): Unit = { }
+
+ override def onNodeBlacklisted(
+ nodeBlacklisted: SparkListenerNodeBlacklisted): Unit = { }
+
+ override def onNodeUnblacklisted(
+ nodeUnblacklisted: SparkListenerNodeUnblacklisted): Unit = { }
+
+ override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { }
-private case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double)
-
-private object RuntimePercentage {
- def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = {
- val denom = totalTime.toDouble
- val fetchTime = metrics.shuffleReadMetrics.map(_.fetchWaitTime)
- val fetch = fetchTime.map(_ / denom)
- val exec = (metrics.executorRunTime - fetchTime.getOrElse(0L)) / denom
- val other = 1.0 - (exec + fetch.getOrElse(0d))
- RuntimePercentage(exec, fetch, other)
- }
+ override def onOtherEvent(event: SparkListenerEvent): Unit = { }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
index 04afde33f5aa..3ff363321e8c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
@@ -22,9 +22,12 @@ import org.apache.spark.util.ListenerBus
/**
* A [[SparkListenerEvent]] bus that relays [[SparkListenerEvent]]s to its listeners
*/
-private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkListenerEvent] {
+private[spark] trait SparkListenerBus
+ extends ListenerBus[SparkListenerInterface, SparkListenerEvent] {
- override def onPostEvent(listener: SparkListener, event: SparkListenerEvent): Unit = {
+ protected override def doPostEvent(
+ listener: SparkListenerInterface,
+ event: SparkListenerEvent): Unit = {
event match {
case stageSubmitted: SparkListenerStageSubmitted =>
listener.onStageSubmitted(stageSubmitted)
@@ -58,9 +61,18 @@ private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkLi
listener.onExecutorAdded(executorAdded)
case executorRemoved: SparkListenerExecutorRemoved =>
listener.onExecutorRemoved(executorRemoved)
+ case executorBlacklisted: SparkListenerExecutorBlacklisted =>
+ listener.onExecutorBlacklisted(executorBlacklisted)
+ case executorUnblacklisted: SparkListenerExecutorUnblacklisted =>
+ listener.onExecutorUnblacklisted(executorUnblacklisted)
+ case nodeBlacklisted: SparkListenerNodeBlacklisted =>
+ listener.onNodeBlacklisted(nodeBlacklisted)
+ case nodeUnblacklisted: SparkListenerNodeUnblacklisted =>
+ listener.onNodeUnblacklisted(nodeUnblacklisted)
case blockUpdated: SparkListenerBlockUpdated =>
listener.onBlockUpdated(blockUpdated)
case logStart: SparkListenerLogStart => // ignore event log metadata
+ case _ => listener.onOtherEvent(event)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala
index 1ce83485f024..bc1431835e25 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala
@@ -45,18 +45,17 @@ class SplitInfo(
hashCode
}
- // This is practically useless since most of the Split impl's dont seem to implement equals :-(
+ // This is practically useless since most of the Split impl's don't seem to implement equals :-(
// So unless there is identity equality between underlyingSplits, it will always fail even if it
// is pointing to same block.
override def equals(other: Any): Boolean = other match {
- case that: SplitInfo => {
+ case that: SplitInfo =>
this.hostLocation == that.hostLocation &&
this.inputFormatClazz == that.inputFormatClazz &&
this.path == that.path &&
this.length == that.length &&
// other split specific checks (like start for FileSplit)
this.underlyingSplit == that.underlyingSplit
- }
case _ => false
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 7ea24a217bd3..290fd073caf2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -19,7 +19,8 @@ package org.apache.spark.scheduler
import scala.collection.mutable.HashSet
-import org.apache.spark._
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.util.CallSite
@@ -66,32 +67,14 @@ private[scheduler] abstract class Stage(
/** Set of jobs that this stage belongs to. */
val jobIds = new HashSet[Int]
- val pendingPartitions = new HashSet[Int]
-
/** The ID to use for the next new attempt for this stage. */
private var nextAttemptId: Int = 0
val name: String = callSite.shortForm
val details: String = callSite.longForm
- private var _internalAccumulators: Seq[Accumulator[Long]] = Seq.empty
-
- /** Internal accumulators shared across all tasks in this stage. */
- def internalAccumulators: Seq[Accumulator[Long]] = _internalAccumulators
-
- /**
- * Re-initialize the internal accumulators associated with this stage.
- *
- * This is called every time the stage is submitted, *except* when a subset of tasks
- * belonging to this stage has already finished. Otherwise, reinitializing the internal
- * accumulators here again will override partial values from the finished tasks.
- */
- def resetInternalAccumulators(): Unit = {
- _internalAccumulators = InternalAccumulator.create(rdd.sparkContext)
- }
-
/**
- * Pointer to the [StageInfo] object for the most recent attempt. This needs to be initialized
+ * Pointer to the [[StageInfo]] object for the most recent attempt. This needs to be initialized
* here, before any attempts have actually been created, because the DAGScheduler uses this
* StageInfo to tell SparkListeners when a job starts (which happens before any stage attempts
* have been created).
@@ -104,29 +87,20 @@ private[scheduler] abstract class Stage(
* We keep track of each attempt ID that has failed to avoid recording duplicate failures if
* multiple tasks from the same stage attempt fail (SPARK-5945).
*/
- private val fetchFailedAttemptIds = new HashSet[Int]
+ val fetchFailedAttemptIds = new HashSet[Int]
private[scheduler] def clearFailures() : Unit = {
fetchFailedAttemptIds.clear()
}
- /**
- * Check whether we should abort the failedStage due to multiple consecutive fetch failures.
- *
- * This method updates the running set of failed stage attempts and returns
- * true if the number of failures exceeds the allowable number of failures.
- */
- private[scheduler] def failedOnFetchAndShouldAbort(stageAttemptId: Int): Boolean = {
- fetchFailedAttemptIds.add(stageAttemptId)
- fetchFailedAttemptIds.size >= Stage.MAX_CONSECUTIVE_FETCH_FAILURES
- }
-
/** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */
def makeNewStageAttempt(
numPartitionsToCompute: Int,
taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty): Unit = {
+ val metrics = new TaskMetrics
+ metrics.register(rdd.sparkContext)
_latestInfo = StageInfo.fromStage(
- this, nextAttemptId, Some(numPartitionsToCompute), taskLocalityPreferences)
+ this, nextAttemptId, Some(numPartitionsToCompute), metrics, taskLocalityPreferences)
nextAttemptId += 1
}
@@ -143,8 +117,3 @@ private[scheduler] abstract class Stage(
/** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */
def findMissingPartitions(): Seq[Int]
}
-
-private[scheduler] object Stage {
- // The number of consecutive failures allowed before a stage is aborted
- val MAX_CONSECUTIVE_FETCH_FAILURES = 4
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index 24796c14300b..c513ed36d168 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -20,6 +20,7 @@ package org.apache.spark.scheduler
import scala.collection.mutable.HashMap
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.executor.TaskMetrics
import org.apache.spark.storage.RDDInfo
/**
@@ -35,6 +36,7 @@ class StageInfo(
val rddInfos: Seq[RDDInfo],
val parentIds: Seq[Int],
val details: String,
+ val taskMetrics: TaskMetrics = null,
private[spark] val taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty) {
/** When this stage was submitted from the DAGScheduler to a TaskScheduler. */
var submissionTime: Option[Long] = None
@@ -42,7 +44,11 @@ class StageInfo(
var completionTime: Option[Long] = None
/** If the stage failed, the reason why. */
var failureReason: Option[String] = None
- /** Terminal values of accumulables updated during this stage. */
+
+ /**
+ * Terminal values of accumulables updated during this stage, including all the user-defined
+ * accumulators.
+ */
val accumulables = HashMap[Long, AccumulableInfo]()
def stageFailed(reason: String) {
@@ -75,6 +81,7 @@ private[spark] object StageInfo {
stage: Stage,
attemptId: Int,
numTasks: Option[Int] = None,
+ taskMetrics: TaskMetrics = null,
taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty
): StageInfo = {
val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd)
@@ -87,6 +94,7 @@ private[spark] object StageInfo {
rddInfos,
stage.parents.map(_.id),
stage.details,
+ taskMetrics,
taskLocalityPreferences)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala
new file mode 100644
index 000000000000..3c8cab7504c1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import scala.collection.mutable
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.{Distribution, Utils}
+
+
+/**
+ * :: DeveloperApi ::
+ * Simple SparkListener that logs a few summary statistics when each stage completes.
+ */
+@DeveloperApi
+class StatsReportListener extends SparkListener with Logging {
+
+ import org.apache.spark.scheduler.StatsReportListener._
+
+ private val taskInfoMetrics = mutable.Buffer[(TaskInfo, TaskMetrics)]()
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ val info = taskEnd.taskInfo
+ val metrics = taskEnd.taskMetrics
+ if (info != null && metrics != null) {
+ taskInfoMetrics += ((info, metrics))
+ }
+ }
+
+ override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
+ implicit val sc = stageCompleted
+ this.logInfo(s"Finished stage: ${getStatusDetail(stageCompleted.stageInfo)}")
+ showMillisDistribution("task runtime:", (info, _) => info.duration, taskInfoMetrics)
+
+ // Shuffle write
+ showBytesDistribution("shuffle bytes written:",
+ (_, metric) => metric.shuffleWriteMetrics.bytesWritten, taskInfoMetrics)
+
+ // Fetch & I/O
+ showMillisDistribution("fetch wait time:",
+ (_, metric) => metric.shuffleReadMetrics.fetchWaitTime, taskInfoMetrics)
+ showBytesDistribution("remote bytes read:",
+ (_, metric) => metric.shuffleReadMetrics.remoteBytesRead, taskInfoMetrics)
+ showBytesDistribution("task result size:",
+ (_, metric) => metric.resultSize, taskInfoMetrics)
+
+ // Runtime breakdown
+ val runtimePcts = taskInfoMetrics.map { case (info, metrics) =>
+ RuntimePercentage(info.duration, metrics)
+ }
+ showDistribution("executor (non-fetch) time pct: ",
+ Distribution(runtimePcts.map(_.executorPct * 100)), "%2.0f %%")
+ showDistribution("fetch wait time pct: ",
+ Distribution(runtimePcts.flatMap(_.fetchPct.map(_ * 100))), "%2.0f %%")
+ showDistribution("other time pct: ", Distribution(runtimePcts.map(_.other * 100)), "%2.0f %%")
+ taskInfoMetrics.clear()
+ }
+
+ private def getStatusDetail(info: StageInfo): String = {
+ val failureReason = info.failureReason.map("(" + _ + ")").getOrElse("")
+ val timeTaken = info.submissionTime.map(
+ x => info.completionTime.getOrElse(System.currentTimeMillis()) - x
+ ).getOrElse("-")
+
+ s"Stage(${info.stageId}, ${info.attemptId}); Name: '${info.name}'; " +
+ s"Status: ${info.getStatusString}$failureReason; numTasks: ${info.numTasks}; " +
+ s"Took: $timeTaken msec"
+ }
+
+}
+
+private[spark] object StatsReportListener extends Logging {
+
+ // For profiling, the extremes are more interesting
+ val percentiles = Array[Int](0, 5, 10, 25, 50, 75, 90, 95, 100)
+ val probabilities = percentiles.map(_ / 100.0)
+ val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%"
+
+ def extractDoubleDistribution(
+ taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)],
+ getMetric: (TaskInfo, TaskMetrics) => Double): Option[Distribution] = {
+ Distribution(taskInfoMetrics.map { case (info, metric) => getMetric(info, metric) })
+ }
+
+ // Is there some way to setup the types that I can get rid of this completely?
+ def extractLongDistribution(
+ taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)],
+ getMetric: (TaskInfo, TaskMetrics) => Long): Option[Distribution] = {
+ extractDoubleDistribution(
+ taskInfoMetrics,
+ (info, metric) => { getMetric(info, metric).toDouble })
+ }
+
+ def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) {
+ val stats = d.statCounter
+ val quantiles = d.getQuantiles(probabilities).map(formatNumber)
+ logInfo(heading + stats)
+ logInfo(percentilesHeader)
+ logInfo("\t" + quantiles.mkString("\t"))
+ }
+
+ def showDistribution(
+ heading: String,
+ dOpt: Option[Distribution],
+ formatNumber: Double => String) {
+ dOpt.foreach { d => showDistribution(heading, d, formatNumber)}
+ }
+
+ def showDistribution(heading: String, dOpt: Option[Distribution], format: String) {
+ def f(d: Double): String = format.format(d)
+ showDistribution(heading, dOpt, f _)
+ }
+
+ def showDistribution(
+ heading: String,
+ format: String,
+ getMetric: (TaskInfo, TaskMetrics) => Double,
+ taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) {
+ showDistribution(heading, extractDoubleDistribution(taskInfoMetrics, getMetric), format)
+ }
+
+ def showBytesDistribution(
+ heading: String,
+ getMetric: (TaskInfo, TaskMetrics) => Long,
+ taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) {
+ showBytesDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric))
+ }
+
+ def showBytesDistribution(heading: String, dOpt: Option[Distribution]) {
+ dOpt.foreach { dist => showBytesDistribution(heading, dist) }
+ }
+
+ def showBytesDistribution(heading: String, dist: Distribution) {
+ showDistribution(heading, dist, (d => Utils.bytesToString(d.toLong)): Double => String)
+ }
+
+ def showMillisDistribution(heading: String, dOpt: Option[Distribution]) {
+ showDistribution(heading, dOpt,
+ (d => StatsReportListener.millisToString(d.toLong)): Double => String)
+ }
+
+ def showMillisDistribution(
+ heading: String,
+ getMetric: (TaskInfo, TaskMetrics) => Long,
+ taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) {
+ showMillisDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric))
+ }
+
+ val seconds = 1000L
+ val minutes = seconds * 60
+ val hours = minutes * 60
+
+ /**
+ * Reformat a time interval in milliseconds to a prettier format for output
+ */
+ def millisToString(ms: Long): String = {
+ val (size, units) =
+ if (ms > hours) {
+ (ms.toDouble / hours, "hours")
+ } else if (ms > minutes) {
+ (ms.toDouble / minutes, "min")
+ } else if (ms > seconds) {
+ (ms.toDouble / seconds, "s")
+ } else {
+ (ms.toDouble, "ms")
+ }
+ "%.1f %s".format(size, units)
+ }
+}
+
+private case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double)
+
+private object RuntimePercentage {
+ def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = {
+ val denom = totalTime.toDouble
+ val fetchTime = Some(metrics.shuffleReadMetrics.fetchWaitTime)
+ val fetch = fetchTime.map(_ / denom)
+ val exec = (metrics.executorRunTime - fetchTime.getOrElse(0L)) / denom
+ val other = 1.0 - (exec + fetch.getOrElse(0d))
+ RuntimePercentage(exec, fetch, other)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 4fb32ba8cb18..f536fc2a5f0a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -17,24 +17,21 @@
package org.apache.spark.scheduler
-import java.io.{ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
+import java.util.Properties
-import scala.collection.mutable.HashMap
-
-import org.apache.spark.metrics.MetricsSystem
-import org.apache.spark.{Accumulator, SparkEnv, TaskContextImpl, TaskContext}
+import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.memory.TaskMemoryManager
-import org.apache.spark.serializer.SerializerInstance
-import org.apache.spark.util.ByteBufferInputStream
-import org.apache.spark.util.Utils
-
+import org.apache.spark.internal.config.APP_CALLER_CONTEXT
+import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
+import org.apache.spark.metrics.MetricsSystem
+import org.apache.spark.util._
/**
* A unit of execution. We have two kinds of Task's in Spark:
- * - [[org.apache.spark.scheduler.ShuffleMapTask]]
- * - [[org.apache.spark.scheduler.ResultTask]]
+ *
+ * - [[org.apache.spark.scheduler.ShuffleMapTask]]
+ * - [[org.apache.spark.scheduler.ResultTask]]
*
* A Spark job consists of one or more stages. The very last stage in a job consists of multiple
* ResultTasks, while earlier stages consist of ShuffleMapTasks. A ResultTask executes the task
@@ -42,59 +39,110 @@ import org.apache.spark.util.Utils
* and divides the task output to multiple buckets (based on the task's partitioner).
*
* @param stageId id of the stage this task belongs to
+ * @param stageAttemptId attempt id of the stage this task belongs to
* @param partitionId index of the number in the RDD
+ * @param localProperties copy of thread-local properties set by the user on the driver side.
+ * @param serializedTaskMetrics a `TaskMetrics` that is created and serialized on the driver side
+ * and sent to executor side.
+ *
+ * The parameters below are optional:
+ * @param jobId id of the job this task belongs to
+ * @param appId id of the app this task belongs to
+ * @param appAttemptId attempt id of the app this task belongs to
*/
private[spark] abstract class Task[T](
val stageId: Int,
val stageAttemptId: Int,
val partitionId: Int,
- internalAccumulators: Seq[Accumulator[Long]]) extends Serializable {
+ @transient var localProperties: Properties = new Properties,
+ // The default value is only used in tests.
+ serializedTaskMetrics: Array[Byte] =
+ SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array(),
+ val jobId: Option[Int] = None,
+ val appId: Option[String] = None,
+ val appAttemptId: Option[String] = None) extends Serializable {
- /**
- * The key of the Map is the accumulator id and the value of the Map is the latest accumulator
- * local value.
- */
- type AccumulatorUpdates = Map[Long, Any]
+ @transient lazy val metrics: TaskMetrics =
+ SparkEnv.get.closureSerializer.newInstance().deserialize(ByteBuffer.wrap(serializedTaskMetrics))
/**
- * Called by [[Executor]] to run this task.
+ * Called by [[org.apache.spark.executor.Executor]] to run this task.
*
* @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext.
* @param attemptNumber how many times this task has been attempted (0 for the first attempt)
* @return the result of the task along with updates of Accumulators.
*/
final def run(
- taskAttemptId: Long,
- attemptNumber: Int,
- metricsSystem: MetricsSystem)
- : (T, AccumulatorUpdates) = {
+ taskAttemptId: Long,
+ attemptNumber: Int,
+ metricsSystem: MetricsSystem): T = {
+ SparkEnv.get.blockManager.registerTask(taskAttemptId)
context = new TaskContextImpl(
stageId,
+ stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal
partitionId,
taskAttemptId,
attemptNumber,
taskMemoryManager,
+ localProperties,
metricsSystem,
- internalAccumulators,
- runningLocally = false)
+ metrics)
TaskContext.setTaskContext(context)
- context.taskMetrics.setHostname(Utils.localHostName())
- context.taskMetrics.setAccumulatorsUpdater(context.collectInternalAccumulators)
taskThread = Thread.currentThread()
- if (_killed) {
- kill(interruptThread = false)
+
+ if (_reasonIfKilled != null) {
+ kill(interruptThread = false, _reasonIfKilled)
}
+
+ new CallerContext(
+ "TASK",
+ SparkEnv.get.conf.get(APP_CALLER_CONTEXT),
+ appId,
+ appAttemptId,
+ jobId,
+ Option(stageId),
+ Option(stageAttemptId),
+ Option(taskAttemptId),
+ Option(attemptNumber)).setCurrentContext()
+
try {
- (runTask(context), context.collectAccumulators())
+ runTask(context)
+ } catch {
+ case e: Throwable =>
+ // Catch all errors; run task failure callbacks, and rethrow the exception.
+ try {
+ context.markTaskFailed(e)
+ } catch {
+ case t: Throwable =>
+ e.addSuppressed(t)
+ }
+ context.markTaskCompleted(Some(e))
+ throw e
} finally {
- context.markTaskCompleted()
try {
- Utils.tryLogNonFatalError {
- // Release memory used by this thread for unrolling blocks
- SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask()
- }
+ // Call the task completion callbacks. If "markTaskCompleted" is called twice, the second
+ // one is no-op.
+ context.markTaskCompleted(None)
} finally {
- TaskContext.unset()
+ try {
+ Utils.tryLogNonFatalError {
+ // Release memory used by this thread for unrolling blocks
+ SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
+ SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(
+ MemoryMode.OFF_HEAP)
+ // Notify any tasks waiting for execution memory to be freed to wake up and try to
+ // acquire memory again. This makes impossible the scenario where a task sleeps forever
+ // because there are no other tasks left to notify it. Since this is safe to do but may
+ // not be strictly necessary, we should revisit whether we can remove this in the
+ // future.
+ val memoryManager = SparkEnv.get.memoryManager
+ memoryManager.synchronized { memoryManager.notifyAll() }
+ }
+ } finally {
+ // Though we unset the ThreadLocal here, the context member variable itself is still
+ // queried directly in the TaskRunner to check for FetchFailedExceptions.
+ TaskContext.unset()
+ }
}
}
}
@@ -109,32 +157,48 @@ private[spark] abstract class Task[T](
def preferredLocations: Seq[TaskLocation] = Nil
- // Map output tracker epoch. Will be set by TaskScheduler.
+ // Map output tracker epoch. Will be set by TaskSetManager.
var epoch: Long = -1
- var metrics: Option[TaskMetrics] = None
-
// Task context, to be initialized in run().
- @transient protected var context: TaskContextImpl = _
+ @transient var context: TaskContextImpl = _
// The actual Thread on which the task is running, if any. Initialized in run().
@volatile @transient private var taskThread: Thread = _
- // A flag to indicate whether the task is killed. This is used in case context is not yet
- // initialized when kill() is invoked.
- @volatile @transient private var _killed = false
+ // If non-null, this task has been killed and the reason is as specified. This is used in case
+ // context is not yet initialized when kill() is invoked.
+ @volatile @transient private var _reasonIfKilled: String = null
protected var _executorDeserializeTime: Long = 0
+ protected var _executorDeserializeCpuTime: Long = 0
/**
- * Whether the task has been killed.
+ * If defined, this task has been killed and this option contains the reason.
*/
- def killed: Boolean = _killed
+ def reasonIfKilled: Option[String] = Option(_reasonIfKilled)
/**
* Returns the amount of time spent deserializing the RDD and function to be run.
*/
def executorDeserializeTime: Long = _executorDeserializeTime
+ def executorDeserializeCpuTime: Long = _executorDeserializeCpuTime
+
+ /**
+ * Collect the latest values of accumulators used in this task. If the task failed,
+ * filter out the accumulators whose values should not be included on failures.
+ */
+ def collectAccumulatorUpdates(taskFailed: Boolean = false): Seq[AccumulatorV2[_, _]] = {
+ if (context != null) {
+ // Note: internal accumulators representing task metrics always count failed values
+ context.taskMetrics.nonZeroInternalAccums() ++
+ // zero value external accumulators may still be useful, e.g. SQLMetrics, we should not
+ // filter them out.
+ context.taskMetrics.externalAccums.filter(a => !taskFailed || a.countFailedValues)
+ } else {
+ Seq.empty
+ }
+ }
/**
* Kills a task by setting the interrupted flag to true. This relies on the upper level Spark
@@ -142,88 +206,14 @@ private[spark] abstract class Task[T](
* be called multiple times.
* If interruptThread is true, we will also call Thread.interrupt() on the Task's executor thread.
*/
- def kill(interruptThread: Boolean) {
- _killed = true
+ def kill(interruptThread: Boolean, reason: String) {
+ require(reason != null)
+ _reasonIfKilled = reason
if (context != null) {
- context.markInterrupted()
+ context.markInterrupted(reason)
}
if (interruptThread && taskThread != null) {
taskThread.interrupt()
}
}
}
-
-/**
- * Handles transmission of tasks and their dependencies, because this can be slightly tricky. We
- * need to send the list of JARs and files added to the SparkContext with each task to ensure that
- * worker nodes find out about it, but we can't make it part of the Task because the user's code in
- * the task might depend on one of the JARs. Thus we serialize each task as multiple objects, by
- * first writing out its dependencies.
- */
-private[spark] object Task {
- /**
- * Serialize a task and the current app dependencies (files and JARs added to the SparkContext)
- */
- def serializeWithDependencies(
- task: Task[_],
- currentFiles: HashMap[String, Long],
- currentJars: HashMap[String, Long],
- serializer: SerializerInstance)
- : ByteBuffer = {
-
- val out = new ByteArrayOutputStream(4096)
- val dataOut = new DataOutputStream(out)
-
- // Write currentFiles
- dataOut.writeInt(currentFiles.size)
- for ((name, timestamp) <- currentFiles) {
- dataOut.writeUTF(name)
- dataOut.writeLong(timestamp)
- }
-
- // Write currentJars
- dataOut.writeInt(currentJars.size)
- for ((name, timestamp) <- currentJars) {
- dataOut.writeUTF(name)
- dataOut.writeLong(timestamp)
- }
-
- // Write the task itself and finish
- dataOut.flush()
- val taskBytes = serializer.serialize(task).array()
- out.write(taskBytes)
- ByteBuffer.wrap(out.toByteArray)
- }
-
- /**
- * Deserialize the list of dependencies in a task serialized with serializeWithDependencies,
- * and return the task itself as a serialized ByteBuffer. The caller can then update its
- * ClassLoaders and deserialize the task.
- *
- * @return (taskFiles, taskJars, taskBytes)
- */
- def deserializeWithDependencies(serializedTask: ByteBuffer)
- : (HashMap[String, Long], HashMap[String, Long], ByteBuffer) = {
-
- val in = new ByteBufferInputStream(serializedTask)
- val dataIn = new DataInputStream(in)
-
- // Read task's files
- val taskFiles = new HashMap[String, Long]()
- val numFiles = dataIn.readInt()
- for (i <- 0 until numFiles) {
- taskFiles(dataIn.readUTF()) = dataIn.readLong()
- }
-
- // Read task's JARs
- val taskJars = new HashMap[String, Long]()
- val numJars = dataIn.readInt()
- for (i <- 0 until numJars) {
- taskJars(dataIn.readUTF()) = dataIn.readLong()
- }
-
- // Create a sub-buffer for the rest of the data, which is the serialized Task object
- val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task
- (taskFiles, taskJars, subBuffer)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
index 1c7c81c488c3..c98b87148e40 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
@@ -17,13 +17,32 @@
package org.apache.spark.scheduler
+import java.io.{DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
+import java.nio.charset.StandardCharsets
+import java.util.Properties
-import org.apache.spark.util.SerializableBuffer
+import scala.collection.JavaConverters._
+import scala.collection.mutable.{HashMap, Map}
+
+import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils}
/**
* Description of a task that gets passed onto executors to be executed, usually created by
- * [[TaskSetManager.resourceOffer]].
+ * `TaskSetManager.resourceOffer`.
+ *
+ * TaskDescriptions and the associated Task need to be serialized carefully for two reasons:
+ *
+ * (1) When a TaskDescription is received by an Executor, the Executor needs to first get the
+ * list of JARs and files and add these to the classpath, and set the properties, before
+ * deserializing the Task object (serializedTask). This is why the Properties are included
+ * in the TaskDescription, even though they're also in the serialized task.
+ * (2) Because a TaskDescription is serialized and sent to an executor for each task, efficient
+ * serialization (both in terms of serialization time and serialized buffer size) is
+ * important. For this reason, we serialize TaskDescriptions ourselves with the
+ * TaskDescription.encode and TaskDescription.decode methods. This results in a smaller
+ * serialized size because it avoids serializing unnecessary fields in the Map objects
+ * (which can introduce significant overhead when the maps are small).
*/
private[spark] class TaskDescription(
val taskId: Long,
@@ -31,13 +50,95 @@ private[spark] class TaskDescription(
val executorId: String,
val name: String,
val index: Int, // Index within this task's TaskSet
- _serializedTask: ByteBuffer)
- extends Serializable {
+ val addedFiles: Map[String, Long],
+ val addedJars: Map[String, Long],
+ val properties: Properties,
+ val serializedTask: ByteBuffer) {
+
+ override def toString: String = "TaskDescription(TID=%d, index=%d)".format(taskId, index)
+}
- // Because ByteBuffers are not serializable, wrap the task in a SerializableBuffer
- private val buffer = new SerializableBuffer(_serializedTask)
+private[spark] object TaskDescription {
+ private def serializeStringLongMap(map: Map[String, Long], dataOut: DataOutputStream): Unit = {
+ dataOut.writeInt(map.size)
+ for ((key, value) <- map) {
+ dataOut.writeUTF(key)
+ dataOut.writeLong(value)
+ }
+ }
- def serializedTask: ByteBuffer = buffer.value
+ def encode(taskDescription: TaskDescription): ByteBuffer = {
+ val bytesOut = new ByteBufferOutputStream(4096)
+ val dataOut = new DataOutputStream(bytesOut)
- override def toString: String = "TaskDescription(TID=%d, index=%d)".format(taskId, index)
+ dataOut.writeLong(taskDescription.taskId)
+ dataOut.writeInt(taskDescription.attemptNumber)
+ dataOut.writeUTF(taskDescription.executorId)
+ dataOut.writeUTF(taskDescription.name)
+ dataOut.writeInt(taskDescription.index)
+
+ // Write files.
+ serializeStringLongMap(taskDescription.addedFiles, dataOut)
+
+ // Write jars.
+ serializeStringLongMap(taskDescription.addedJars, dataOut)
+
+ // Write properties.
+ dataOut.writeInt(taskDescription.properties.size())
+ taskDescription.properties.asScala.foreach { case (key, value) =>
+ dataOut.writeUTF(key)
+ // SPARK-19796 -- writeUTF doesn't work for long strings, which can happen for property values
+ val bytes = value.getBytes(StandardCharsets.UTF_8)
+ dataOut.writeInt(bytes.length)
+ dataOut.write(bytes)
+ }
+
+ // Write the task. The task is already serialized, so write it directly to the byte buffer.
+ Utils.writeByteBuffer(taskDescription.serializedTask, bytesOut)
+
+ dataOut.close()
+ bytesOut.close()
+ bytesOut.toByteBuffer
+ }
+
+ private def deserializeStringLongMap(dataIn: DataInputStream): HashMap[String, Long] = {
+ val map = new HashMap[String, Long]()
+ val mapSize = dataIn.readInt()
+ for (i <- 0 until mapSize) {
+ map(dataIn.readUTF()) = dataIn.readLong()
+ }
+ map
+ }
+
+ def decode(byteBuffer: ByteBuffer): TaskDescription = {
+ val dataIn = new DataInputStream(new ByteBufferInputStream(byteBuffer))
+ val taskId = dataIn.readLong()
+ val attemptNumber = dataIn.readInt()
+ val executorId = dataIn.readUTF()
+ val name = dataIn.readUTF()
+ val index = dataIn.readInt()
+
+ // Read files.
+ val taskFiles = deserializeStringLongMap(dataIn)
+
+ // Read jars.
+ val taskJars = deserializeStringLongMap(dataIn)
+
+ // Read properties.
+ val properties = new Properties()
+ val numProperties = dataIn.readInt()
+ for (i <- 0 until numProperties) {
+ val key = dataIn.readUTF()
+ val valueLength = dataIn.readInt()
+ val valueBytes = new Array[Byte](valueLength)
+ dataIn.readFully(valueBytes)
+ properties.setProperty(key, new String(valueBytes, StandardCharsets.UTF_8))
+ }
+
+ // Create a sub-buffer for the serialized task into its own buffer (to be deserialized later).
+ val serializedTask = byteBuffer.slice()
+
+ new TaskDescription(taskId, attemptNumber, executorId, name, index, taskFiles, taskJars,
+ properties, serializedTask)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
index f113c2b1b843..9843eab4f134 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
@@ -17,8 +17,8 @@
package org.apache.spark.scheduler
-import scala.collection.mutable.ListBuffer
-
+import org.apache.spark.TaskState
+import org.apache.spark.TaskState.TaskState
import org.apache.spark.annotation.DeveloperApi
/**
@@ -28,6 +28,10 @@ import org.apache.spark.annotation.DeveloperApi
@DeveloperApi
class TaskInfo(
val taskId: Long,
+ /**
+ * The index of this task within its task set. Not necessarily the same as the ID of the RDD
+ * partition that the task is computing.
+ */
val index: Int,
val attemptNumber: Int,
val launchTime: Long,
@@ -48,7 +52,13 @@ class TaskInfo(
* accumulable to be updated multiple times in a single task or for two accumulables with the
* same name but different IDs to exist in a task.
*/
- val accumulables = ListBuffer[AccumulableInfo]()
+ def accumulables: Seq[AccumulableInfo] = _accumulables
+
+ private[this] var _accumulables: Seq[AccumulableInfo] = Nil
+
+ private[spark] def setAccumulables(newAccumulables: Seq[AccumulableInfo]): Unit = {
+ _accumulables = newAccumulables
+ }
/**
* The time when the task has completed successfully (including the time to remotely fetch
@@ -58,24 +68,28 @@ class TaskInfo(
var failed = false
- private[spark] def markGettingResult(time: Long = System.currentTimeMillis) {
- gettingResultTime = time
- }
+ var killed = false
- private[spark] def markSuccessful(time: Long = System.currentTimeMillis) {
- finishTime = time
+ private[spark] def markGettingResult(time: Long) {
+ gettingResultTime = time
}
- private[spark] def markFailed(time: Long = System.currentTimeMillis) {
+ private[spark] def markFinished(state: TaskState, time: Long) {
+ // finishTime should be set larger than 0, otherwise "finished" below will return false.
+ assert(time > 0)
finishTime = time
- failed = true
+ if (state == TaskState.FAILED) {
+ failed = true
+ } else if (state == TaskState.KILLED) {
+ killed = true
+ }
}
def gettingResult: Boolean = gettingResultTime != 0
def finished: Boolean = finishTime != 0
- def successful: Boolean = finished && !failed
+ def successful: Boolean = finished && !failed && !killed
def running: Boolean = !finished
@@ -88,6 +102,8 @@ class TaskInfo(
}
} else if (failed) {
"FAILED"
+ } else if (killed) {
+ "KILLED"
} else if (successful) {
"SUCCESS"
} else {
@@ -95,9 +111,6 @@ class TaskInfo(
}
}
- @deprecated("Use attemptNumber", "1.6.0")
- def attempt: Int = attemptNumber
-
def id: String = s"$index.$attemptNumber"
def duration: Long = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala
index 1eb6c1614fc0..06b52935c696 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala
@@ -64,18 +64,18 @@ private[spark] object TaskLocation {
/**
* Create a TaskLocation from a string returned by getPreferredLocations.
- * These strings have the form [hostname] or hdfs_cache_[hostname], depending on whether the
- * location is cached.
+ * These strings have the form executor_[hostname]_[executorid], [hostname], or
+ * hdfs_cache_[hostname], depending on whether the location is cached.
*/
def apply(str: String): TaskLocation = {
val hstr = str.stripPrefix(inMemoryLocationTag)
if (hstr.equals(str)) {
if (str.startsWith(executorLocationTag)) {
- val splits = str.split("_")
- if (splits.length != 3) {
- throw new IllegalArgumentException("Illegal executor location format: " + str)
- }
- new ExecutorCacheTaskLocation(splits(1), splits(2))
+ val hostAndExecutorId = str.stripPrefix(executorLocationTag)
+ val splits = hostAndExecutorId.split("_", 2)
+ require(splits.length == 2, "Illegal executor location format: " + str)
+ val Array(host, executorId) = splits
+ new ExecutorCacheTaskLocation(host, executorId)
} else {
new HostTaskLocation(str)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index b82c7f3fa54f..366b92c5f2ad 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -20,13 +20,12 @@ package org.apache.spark.scheduler
import java.io._
import java.nio.ByteBuffer
-import scala.collection.Map
-import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
import org.apache.spark.SparkEnv
-import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.storage.BlockId
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{AccumulatorV2, Utils}
// Task result. Also contains updates to accumulator variables.
private[spark] sealed trait TaskResult[T]
@@ -36,31 +35,24 @@ private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int)
extends TaskResult[T] with Serializable
/** A TaskResult that contains the task's return value and accumulator updates. */
-private[spark]
-class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long, Any],
- var metrics: TaskMetrics)
+private[spark] class DirectTaskResult[T](
+ var valueBytes: ByteBuffer,
+ var accumUpdates: Seq[AccumulatorV2[_, _]])
extends TaskResult[T] with Externalizable {
private var valueObjectDeserialized = false
private var valueObject: T = _
- def this() = this(null.asInstanceOf[ByteBuffer], null, null)
+ def this() = this(null.asInstanceOf[ByteBuffer], null)
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
-
- out.writeInt(valueBytes.remaining);
+ out.writeInt(valueBytes.remaining)
Utils.writeByteBuffer(valueBytes, out)
-
out.writeInt(accumUpdates.size)
- for ((key, value) <- accumUpdates) {
- out.writeLong(key)
- out.writeObject(value)
- }
- out.writeObject(metrics)
+ accumUpdates.foreach(out.writeObject)
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
-
val blen = in.readInt()
val byteVal = new Array[Byte](blen)
in.readFully(byteVal)
@@ -68,15 +60,14 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long
val numUpdates = in.readInt
if (numUpdates == 0) {
- accumUpdates = null
+ accumUpdates = Seq()
} else {
- val _accumUpdates = mutable.Map[Long, Any]()
+ val _accumUpdates = new ArrayBuffer[AccumulatorV2[_, _]]
for (i <- 0 until numUpdates) {
- _accumUpdates(in.readLong()) = in.readObject()
+ _accumUpdates += in.readObject.asInstanceOf[AccumulatorV2[_, _]]
}
accumUpdates = _accumUpdates
}
- metrics = in.readObject().asInstanceOf[TaskMetrics]
valueObjectDeserialized = false
}
@@ -87,14 +78,14 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long
*
* After the first time, `value()` is trivial and just returns the deserialized `valueObject`.
*/
- def value(): T = {
+ def value(resultSer: SerializerInstance = null): T = {
if (valueObjectDeserialized) {
valueObject
} else {
// This should not run when holding a lock because it may cost dozens of seconds for a large
- // value.
- val resultSer = SparkEnv.get.serializer.newInstance()
- valueObject = resultSer.deserialize(valueBytes)
+ // value
+ val ser = if (resultSer == null) SparkEnv.get.serializer.newInstance() else resultSer
+ valueObject = ser.deserialize(valueBytes)
valueObjectDeserialized = true
valueObject
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 46a6f6537e2e..a284f7956cd3 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -18,15 +18,16 @@
package org.apache.spark.scheduler
import java.nio.ByteBuffer
-import java.util.concurrent.RejectedExecutionException
+import java.util.concurrent.{ExecutorService, RejectedExecutionException}
import scala.language.existentials
import scala.util.control.NonFatal
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
+import org.apache.spark.internal.Logging
import org.apache.spark.serializer.SerializerInstance
-import org.apache.spark.util.{ThreadUtils, Utils}
+import org.apache.spark.util.{LongAccumulator, ThreadUtils, Utils}
/**
* Runs a thread pool that deserializes and remotely fetches (if necessary) task results.
@@ -35,17 +36,28 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
extends Logging {
private val THREADS = sparkEnv.conf.getInt("spark.resultGetter.threads", 4)
- private val getTaskResultExecutor = ThreadUtils.newDaemonFixedThreadPool(
- THREADS, "task-result-getter")
+ // Exposed for testing.
+ protected val getTaskResultExecutor: ExecutorService =
+ ThreadUtils.newDaemonFixedThreadPool(THREADS, "task-result-getter")
+
+ // Exposed for testing.
protected val serializer = new ThreadLocal[SerializerInstance] {
override def initialValue(): SerializerInstance = {
sparkEnv.closureSerializer.newInstance()
}
}
+ protected val taskResultSerializer = new ThreadLocal[SerializerInstance] {
+ override def initialValue(): SerializerInstance = {
+ sparkEnv.serializer.newInstance()
+ }
+ }
+
def enqueueSuccessfulTask(
- taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) {
+ taskSetManager: TaskSetManager,
+ tid: Long,
+ serializedData: ByteBuffer): Unit = {
getTaskResultExecutor.execute(new Runnable {
override def run(): Unit = Utils.logUncaughtExceptions {
try {
@@ -57,7 +69,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
// deserialize "value" without holding any lock so that it won't block other threads.
// We should call it here, so that when it's called again in
// "TaskSetManager.handleSuccessfulTask", it does not need to deserialize the value.
- directResult.value()
+ directResult.value(taskResultSerializer.get())
(directResult, serializedData.limit())
case IndirectTaskResult(blockId, size) =>
if (!taskSetManager.canFetchMoreResults(size)) {
@@ -77,12 +89,27 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
return
}
val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]](
- serializedTaskResult.get)
+ serializedTaskResult.get.toByteBuffer)
+ // force deserialization of referenced value
+ deserializedResult.value(taskResultSerializer.get())
sparkEnv.blockManager.master.removeBlock(blockId)
(deserializedResult, size)
}
- result.metrics.setResultSize(size)
+ // Set the task result size in the accumulator updates received from the executors.
+ // We need to do this here on the driver because if we did this on the executors then
+ // we would have to serialize the result again after updating the size.
+ result.accumUpdates = result.accumUpdates.map { a =>
+ if (a.name == Some(InternalAccumulator.RESULT_SIZE)) {
+ val acc = a.asInstanceOf[LongAccumulator]
+ assert(acc.sum == 0L, "task result size should not have been set on the executors")
+ acc.setValue(size.toLong)
+ acc
+ } else {
+ a
+ }
+ }
+
scheduler.handleSuccessfulTask(taskSetManager, tid, result)
} catch {
case cnf: ClassNotFoundException =>
@@ -99,25 +126,29 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState,
serializedData: ByteBuffer) {
- var reason : TaskEndReason = UnknownReason
+ var reason : TaskFailedReason = UnknownReason
try {
getTaskResultExecutor.execute(new Runnable {
override def run(): Unit = Utils.logUncaughtExceptions {
+ val loader = Utils.getContextOrSparkClassLoader
try {
if (serializedData != null && serializedData.limit() > 0) {
- reason = serializer.get().deserialize[TaskEndReason](
- serializedData, Utils.getSparkClassLoader)
+ reason = serializer.get().deserialize[TaskFailedReason](
+ serializedData, loader)
}
} catch {
case cnd: ClassNotFoundException =>
// Log an error but keep going here -- the task failed, so not catastrophic
// if we can't deserialize the reason.
- val loader = Utils.getContextOrSparkClassLoader
logError(
"Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
- case ex: Exception => {}
+ case ex: Exception => // No-op
+ } finally {
+ // If there's an error while deserializing the TaskEndReason, this Runnable
+ // will die. Still tell the scheduler about the task failure, to avoid a hang
+ // where the scheduler thinks the task is still running.
+ scheduler.handleFailedTask(taskSetManager, tid, taskState, reason)
}
- scheduler.handleFailedTask(taskSetManager, tid, taskState, reason)
}
})
} catch {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index cb9a3008107d..3de7d1f7de22 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -18,8 +18,8 @@
package org.apache.spark.scheduler
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
-import org.apache.spark.executor.TaskMetrics
import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.AccumulatorV2
/**
* Low-level task scheduler interface, currently implemented exclusively by
@@ -52,7 +52,14 @@ private[spark] trait TaskScheduler {
def submitTasks(taskSet: TaskSet): Unit
// Cancel a stage.
- def cancelTasks(stageId: Int, interruptThread: Boolean)
+ def cancelTasks(stageId: Int, interruptThread: Boolean): Unit
+
+ /**
+ * Kills a task attempt.
+ *
+ * @return Whether the task was successfully killed.
+ */
+ def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean
// Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
def setDAGScheduler(dagScheduler: DAGScheduler): Unit
@@ -65,8 +72,10 @@ private[spark] trait TaskScheduler {
* alive. Return true if the driver knows about the given block manager. Otherwise, return false,
* indicating that the block manager should re-register.
*/
- def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)],
- blockManagerId: BlockManagerId): Boolean
+ def executorHeartbeatReceived(
+ execId: String,
+ accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])],
+ blockManagerId: BlockManagerId): Boolean
/**
* Get an application ID associated with the job.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 43d7d80b7aae..bc0d4700bb76 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -18,52 +18,73 @@
package org.apache.spark.scheduler
import java.nio.ByteBuffer
-import java.util.{TimerTask, Timer}
+import java.util.{Locale, Timer, TimerTask}
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicLong
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-import scala.language.postfixOps
+import scala.collection.Set
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.util.Random
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.scheduler.TaskLocality.TaskLocality
-import org.apache.spark.util.{ThreadUtils, Utils}
-import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.scheduler.local.LocalSchedulerBackend
import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils}
/**
* Schedules tasks for multiple types of clusters by acting through a SchedulerBackend.
- * It can also work with a local setup by using a LocalBackend and setting isLocal to true.
- * It handles common logic, like determining a scheduling order across jobs, waking up to launch
- * speculative tasks, etc.
+ * It can also work with a local setup by using a `LocalSchedulerBackend` and setting
+ * isLocal to true. It handles common logic, like determining a scheduling order across jobs, waking
+ * up to launch speculative tasks, etc.
*
* Clients should first call initialize() and start(), then submit task sets through the
* runTasks method.
*
- * THREADING: SchedulerBackends and task-submitting clients can call this class from multiple
+ * THREADING: [[SchedulerBackend]]s and task-submitting clients can call this class from multiple
* threads, so it needs locks in public API methods to maintain its state. In addition, some
- * SchedulerBackends synchronize on themselves when they want to send events here, and then
+ * [[SchedulerBackend]]s synchronize on themselves when they want to send events here, and then
* acquire a lock on us, so we need to make sure that we don't try to lock the backend while
* we are holding a lock on ourselves.
*/
-private[spark] class TaskSchedulerImpl(
+private[spark] class TaskSchedulerImpl private[scheduler](
val sc: SparkContext,
val maxTaskFailures: Int,
+ private[scheduler] val blacklistTrackerOpt: Option[BlacklistTracker],
isLocal: Boolean = false)
- extends TaskScheduler with Logging
-{
- def this(sc: SparkContext) = this(sc, sc.conf.getInt("spark.task.maxFailures", 4))
+ extends TaskScheduler with Logging {
+
+ import TaskSchedulerImpl._
+
+ def this(sc: SparkContext) = {
+ this(
+ sc,
+ sc.conf.get(config.MAX_TASK_FAILURES),
+ TaskSchedulerImpl.maybeCreateBlacklistTracker(sc))
+ }
+
+ def this(sc: SparkContext, maxTaskFailures: Int, isLocal: Boolean) = {
+ this(
+ sc,
+ maxTaskFailures,
+ TaskSchedulerImpl.maybeCreateBlacklistTracker(sc),
+ isLocal = isLocal)
+ }
val conf = sc.conf
// How often to check for speculative tasks
val SPECULATION_INTERVAL_MS = conf.getTimeAsMs("spark.speculation.interval", "100ms")
+ // Duplicate copies of a task will only be launched if the original copy has been running for
+ // at least this amount of time. This is to avoid the overhead of launching speculative copies
+ // of tasks that are very short.
+ val MIN_TIME_TO_SPECULATION = 100
+
private val speculationScheduler =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("task-scheduler-speculation")
@@ -77,6 +98,7 @@ private[spark] class TaskSchedulerImpl(
// on this class.
private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]]
+ // Protected by `this`
private[scheduler] val taskIdToTaskSetManager = new HashMap[Long, TaskSetManager]
val taskIdToExecutorId = new HashMap[Long, String]
@@ -87,12 +109,16 @@ private[spark] class TaskSchedulerImpl(
// Incrementing task IDs
val nextTaskId = new AtomicLong(0)
- // Which executor IDs we have executors on
- val activeExecutorIds = new HashSet[String]
+ // IDs of the tasks running on each executor
+ private val executorIdToRunningTaskIds = new HashMap[String, HashSet[Long]]
+
+ def runningTasksByExecutors: Map[String, Int] = synchronized {
+ executorIdToRunningTaskIds.toMap.mapValues(_.size)
+ }
// The set of executors we have on each host; this is used to compute hostsAlive, which
// in turn is used to decide when we can attain data locality on a given host
- protected val executorsByHost = new HashMap[String, HashSet[String]]
+ protected val hostToExecutors = new HashMap[String, HashSet[String]]
protected val hostsByRack = new HashMap[String, HashSet[String]]
@@ -103,18 +129,20 @@ private[spark] class TaskSchedulerImpl(
var backend: SchedulerBackend = null
- val mapOutputTracker = SparkEnv.get.mapOutputTracker
+ val mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
- var schedulableBuilder: SchedulableBuilder = null
- var rootPool: Pool = null
+ private var schedulableBuilder: SchedulableBuilder = null
// default scheduler is FIFO
- private val schedulingModeConf = conf.get("spark.scheduler.mode", "FIFO")
- val schedulingMode: SchedulingMode = try {
- SchedulingMode.withName(schedulingModeConf.toUpperCase)
- } catch {
- case e: java.util.NoSuchElementException =>
- throw new SparkException(s"Unrecognized spark.scheduler.mode: $schedulingModeConf")
- }
+ private val schedulingModeConf = conf.get(SCHEDULER_MODE_PROPERTY, SchedulingMode.FIFO.toString)
+ val schedulingMode: SchedulingMode =
+ try {
+ SchedulingMode.withName(schedulingModeConf.toUpperCase(Locale.ROOT))
+ } catch {
+ case e: java.util.NoSuchElementException =>
+ throw new SparkException(s"Unrecognized $SCHEDULER_MODE_PROPERTY: $schedulingModeConf")
+ }
+
+ val rootPool: Pool = new Pool("", schedulingMode, 0, 0)
// This is a var so that we can reset it for testing purposes.
private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this)
@@ -125,14 +153,15 @@ private[spark] class TaskSchedulerImpl(
def initialize(backend: SchedulerBackend) {
this.backend = backend
- // temporarily set rootPool name to empty
- rootPool = new Pool("", schedulingMode, 0, 0)
schedulableBuilder = {
schedulingMode match {
case SchedulingMode.FIFO =>
new FIFOSchedulableBuilder(rootPool)
case SchedulingMode.FAIR =>
new FairSchedulableBuilder(rootPool, conf)
+ case _ =>
+ throw new IllegalArgumentException(s"Unsupported $SCHEDULER_MODE_PROPERTY: " +
+ s"$schedulingMode")
}
}
schedulableBuilder.buildPools()
@@ -145,7 +174,7 @@ private[spark] class TaskSchedulerImpl(
if (!isLocal && conf.getBoolean("spark.speculation", false)) {
logInfo("Starting speculative execution thread")
- speculationScheduler.scheduleAtFixedRate(new Runnable {
+ speculationScheduler.scheduleWithFixedDelay(new Runnable {
override def run(): Unit = Utils.tryOrStopSparkContext(sc) {
checkSpeculatableTasks()
}
@@ -197,7 +226,7 @@ private[spark] class TaskSchedulerImpl(
private[scheduler] def createTaskSetManager(
taskSet: TaskSet,
maxTaskFailures: Int): TaskSetManager = {
- new TaskSetManager(this, taskSet, maxTaskFailures)
+ new TaskSetManager(this, taskSet, maxTaskFailures, blacklistTrackerOpt)
}
override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
@@ -212,7 +241,7 @@ private[spark] class TaskSchedulerImpl(
// simply abort the stage.
tsm.runningTasksSet.foreach { tid =>
val execId = taskIdToExecutorId(tid)
- backend.killTask(tid, execId, interruptThread)
+ backend.killTask(tid, execId, interruptThread, reason = "stage cancelled")
}
tsm.abort("Stage %s cancelled".format(stageId))
logInfo("Stage %d was cancelled".format(stageId))
@@ -220,6 +249,18 @@ private[spark] class TaskSchedulerImpl(
}
}
+ override def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean = {
+ logInfo(s"Killing task $taskId: $reason")
+ val execId = taskIdToExecutorId.get(taskId)
+ if (execId.isDefined) {
+ backend.killTask(taskId, execId.get, interruptThread, reason)
+ true
+ } else {
+ logWarning(s"Could not kill task $taskId because no task with that ID was found.")
+ false
+ }
+ }
+
/**
* Called to indicate that all task attempts (including speculated tasks) associated with the
* given TaskSetManager have completed, so state associated with the TaskSetManager should be
@@ -233,8 +274,8 @@ private[spark] class TaskSchedulerImpl(
}
}
manager.parent.removeSchedulable(manager)
- logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s"
- .format(manager.taskSet.id, manager.parent.name))
+ logInfo(s"Removed TaskSet ${manager.taskSet.id}, whose tasks have all completed, from pool" +
+ s" ${manager.parent.name}")
}
private def resourceOfferSingleTaskSet(
@@ -242,8 +283,10 @@ private[spark] class TaskSchedulerImpl(
maxLocality: TaskLocality,
shuffledOffers: Seq[WorkerOffer],
availableCpus: Array[Int],
- tasks: Seq[ArrayBuffer[TaskDescription]]) : Boolean = {
+ tasks: IndexedSeq[ArrayBuffer[TaskDescription]]) : Boolean = {
var launchedTask = false
+ // nodes and executors that are blacklisted for the entire application have already been
+ // filtered out by this point
for (i <- 0 until shuffledOffers.size) {
val execId = shuffledOffers(i).executorId
val host = shuffledOffers(i).host
@@ -254,7 +297,7 @@ private[spark] class TaskSchedulerImpl(
val tid = task.taskId
taskIdToTaskSetManager(tid) = taskSet
taskIdToExecutorId(tid) = execId
- executorsByHost(host) += execId
+ executorIdToRunningTaskIds(execId).add(tid)
availableCpus(i) -= CPUS_PER_TASK
assert(availableCpus(i) >= 0)
launchedTask = true
@@ -276,16 +319,19 @@ private[spark] class TaskSchedulerImpl(
* sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so
* that tasks are balanced across the cluster.
*/
- def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized {
+ def resourceOffers(offers: IndexedSeq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized {
// Mark each slave as alive and remember its hostname
// Also track if new executor is added
var newExecAvail = false
for (o <- offers) {
- executorIdToHost(o.executorId) = o.host
- activeExecutorIds += o.executorId
- if (!executorsByHost.contains(o.host)) {
- executorsByHost(o.host) = new HashSet[String]()
+ if (!hostToExecutors.contains(o.host)) {
+ hostToExecutors(o.host) = new HashSet[String]()
+ }
+ if (!executorIdToRunningTaskIds.contains(o.executorId)) {
+ hostToExecutors(o.host) += o.executorId
executorAdded(o.executorId, o.host)
+ executorIdToHost(o.executorId) = o.host
+ executorIdToRunningTaskIds(o.executorId) = HashSet[Long]()
newExecAvail = true
}
for (rack <- getRackForHost(o.host)) {
@@ -293,8 +339,19 @@ private[spark] class TaskSchedulerImpl(
}
}
- // Randomly shuffle offers to avoid always placing tasks on the same set of workers.
- val shuffledOffers = Random.shuffle(offers)
+ // Before making any offers, remove any nodes from the blacklist whose blacklist has expired. Do
+ // this here to avoid a separate thread and added synchronization overhead, and also because
+ // updating the blacklist is only relevant when task offers are being made.
+ blacklistTrackerOpt.foreach(_.applyBlacklistTimeout())
+
+ val filteredOffers = blacklistTrackerOpt.map { blacklistTracker =>
+ offers.filter { offer =>
+ !blacklistTracker.isNodeBlacklisted(offer.host) &&
+ !blacklistTracker.isExecutorBlacklisted(offer.executorId)
+ }
+ }.getOrElse(offers)
+
+ val shuffledOffers = shuffleOffers(filteredOffers)
// Build a list of tasks to assign to each worker.
val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores))
val availableCpus = shuffledOffers.map(o => o.cores).toArray
@@ -310,12 +367,19 @@ private[spark] class TaskSchedulerImpl(
// Take each TaskSet in our scheduling order, and then offer it each node in increasing order
// of locality levels so that it gets a chance to launch local tasks on all of them.
// NOTE: the preferredLocality order: PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY
- var launchedTask = false
- for (taskSet <- sortedTaskSets; maxLocality <- taskSet.myLocalityLevels) {
- do {
- launchedTask = resourceOfferSingleTaskSet(
- taskSet, maxLocality, shuffledOffers, availableCpus, tasks)
- } while (launchedTask)
+ for (taskSet <- sortedTaskSets) {
+ var launchedAnyTask = false
+ var launchedTaskAtCurrentMaxLocality = false
+ for (currentMaxLocality <- taskSet.myLocalityLevels) {
+ do {
+ launchedTaskAtCurrentMaxLocality = resourceOfferSingleTaskSet(
+ taskSet, currentMaxLocality, shuffledOffers, availableCpus, tasks)
+ launchedAnyTask |= launchedTaskAtCurrentMaxLocality
+ } while (launchedTaskAtCurrentMaxLocality)
+ }
+ if (!launchedAnyTask) {
+ taskSet.abortIfCompletelyBlacklisted(hostToExecutors)
+ }
}
if (tasks.size > 0) {
@@ -324,36 +388,47 @@ private[spark] class TaskSchedulerImpl(
return tasks
}
+ /**
+ * Shuffle offers around to avoid always placing tasks on the same workers. Exposed to allow
+ * overriding in tests, so it can be deterministic.
+ */
+ protected def shuffleOffers(offers: IndexedSeq[WorkerOffer]): IndexedSeq[WorkerOffer] = {
+ Random.shuffle(offers)
+ }
+
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
var failedExecutor: Option[String] = None
+ var reason: Option[ExecutorLossReason] = None
synchronized {
try {
- if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) {
- // We lost this entire executor, so remember that it's gone
- val execId = taskIdToExecutorId(tid)
- if (activeExecutorIds.contains(execId)) {
- removeExecutor(execId,
- SlaveLost(s"Task $tid was lost, so marking the executor as lost as well."))
- failedExecutor = Some(execId)
- }
- }
taskIdToTaskSetManager.get(tid) match {
case Some(taskSet) =>
- if (TaskState.isFinished(state)) {
- taskIdToTaskSetManager.remove(tid)
- taskIdToExecutorId.remove(tid)
+ if (state == TaskState.LOST) {
+ // TaskState.LOST is only used by the deprecated Mesos fine-grained scheduling mode,
+ // where each executor corresponds to a single task, so mark the executor as failed.
+ val execId = taskIdToExecutorId.getOrElse(tid, throw new IllegalStateException(
+ "taskIdToTaskSetManager.contains(tid) <=> taskIdToExecutorId.contains(tid)"))
+ if (executorIdToRunningTaskIds.contains(execId)) {
+ reason = Some(
+ SlaveLost(s"Task $tid was lost, so marking the executor as lost as well."))
+ removeExecutor(execId, reason.get)
+ failedExecutor = Some(execId)
+ }
}
- if (state == TaskState.FINISHED) {
- taskSet.removeRunningTask(tid)
- taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
- } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
+ if (TaskState.isFinished(state)) {
+ cleanupTaskState(tid)
taskSet.removeRunningTask(tid)
- taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
+ if (state == TaskState.FINISHED) {
+ taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
+ } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
+ taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
+ }
}
case None =>
logError(
("Ignoring update with state %s for TID %s because its task set is gone (this is " +
- "likely the result of receiving duplicate task finished status updates)")
+ "likely the result of receiving duplicate task finished status updates) or its " +
+ "executor has been marked as failed.")
.format(state, tid))
}
} catch {
@@ -362,7 +437,8 @@ private[spark] class TaskSchedulerImpl(
}
// Update the DAGScheduler without holding a lock on this, since that can deadlock
if (failedExecutor.isDefined) {
- dagScheduler.executorLost(failedExecutor.get)
+ assert(reason.isDefined)
+ dagScheduler.executorLost(failedExecutor.get, reason.get)
backend.reviveOffers()
}
}
@@ -374,17 +450,18 @@ private[spark] class TaskSchedulerImpl(
*/
override def executorHeartbeatReceived(
execId: String,
- taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics
+ accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])],
blockManagerId: BlockManagerId): Boolean = {
-
- val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized {
- taskMetrics.flatMap { case (id, metrics) =>
+ // (taskId, stageId, stageAttemptId, accumUpdates)
+ val accumUpdatesWithTaskIds: Array[(Long, Int, Int, Seq[AccumulableInfo])] = synchronized {
+ accumUpdates.flatMap { case (id, updates) =>
+ val accInfos = updates.map(acc => acc.toInfo(Some(acc.value), None))
taskIdToTaskSetManager.get(id).map { taskSetMgr =>
- (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
+ (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, accInfos)
}
}
}
- dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId)
+ dagScheduler.executorHeartbeatReceived(execId, accumUpdatesWithTaskIds, blockManagerId)
}
def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long): Unit = synchronized {
@@ -402,9 +479,9 @@ private[spark] class TaskSchedulerImpl(
taskSetManager: TaskSetManager,
tid: Long,
taskState: TaskState,
- reason: TaskEndReason): Unit = synchronized {
+ reason: TaskFailedReason): Unit = synchronized {
taskSetManager.handleFailedTask(tid, taskState, reason)
- if (!taskSetManager.isZombie && taskState != TaskState.KILLED) {
+ if (!taskSetManager.isZombie && !taskSetManager.someAttemptSucceeded(tid)) {
// Need to revive offers again now that the task set manager state has been updated to
// reflect failed tasks that need to be re-run.
backend.reviveOffers()
@@ -451,7 +528,7 @@ private[spark] class TaskSchedulerImpl(
def checkSpeculatableTasks() {
var shouldRevive = false
synchronized {
- shouldRevive = rootPool.checkSpeculatableTasks()
+ shouldRevive = rootPool.checkSpeculatableTasks(MIN_TIME_TO_SPECULATION)
}
if (shouldRevive) {
backend.reviveOffers()
@@ -462,47 +539,79 @@ private[spark] class TaskSchedulerImpl(
var failedExecutor: Option[String] = None
synchronized {
- if (activeExecutorIds.contains(executorId)) {
+ if (executorIdToRunningTaskIds.contains(executorId)) {
val hostPort = executorIdToHost(executorId)
- logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason))
+ logExecutorLoss(executorId, hostPort, reason)
removeExecutor(executorId, reason)
failedExecutor = Some(executorId)
} else {
- executorIdToHost.get(executorId) match {
- case Some(_) =>
- // If the host mapping still exists, it means we don't know the loss reason for the
- // executor. So call removeExecutor() to update tasks running on that executor when
- // the real loss reason is finally known.
- removeExecutor(executorId, reason)
-
- case None =>
- // We may get multiple executorLost() calls with different loss reasons. For example,
- // one may be triggered by a dropped connection from the slave while another may be a
- // report of executor termination from Mesos. We produce log messages for both so we
- // eventually report the termination reason.
- logError("Lost an executor " + executorId + " (already removed): " + reason)
- }
+ executorIdToHost.get(executorId) match {
+ case Some(hostPort) =>
+ // If the host mapping still exists, it means we don't know the loss reason for the
+ // executor. So call removeExecutor() to update tasks running on that executor when
+ // the real loss reason is finally known.
+ logExecutorLoss(executorId, hostPort, reason)
+ removeExecutor(executorId, reason)
+
+ case None =>
+ // We may get multiple executorLost() calls with different loss reasons. For example,
+ // one may be triggered by a dropped connection from the slave while another may be a
+ // report of executor termination from Mesos. We produce log messages for both so we
+ // eventually report the termination reason.
+ logError(s"Lost an executor $executorId (already removed): $reason")
+ }
}
}
// Call dagScheduler.executorLost without holding the lock on this to prevent deadlock
if (failedExecutor.isDefined) {
- dagScheduler.executorLost(failedExecutor.get)
+ dagScheduler.executorLost(failedExecutor.get, reason)
backend.reviveOffers()
}
}
+ private def logExecutorLoss(
+ executorId: String,
+ hostPort: String,
+ reason: ExecutorLossReason): Unit = reason match {
+ case LossReasonPending =>
+ logDebug(s"Executor $executorId on $hostPort lost, but reason not yet known.")
+ case ExecutorKilled =>
+ logInfo(s"Executor $executorId on $hostPort killed by driver.")
+ case _ =>
+ logError(s"Lost executor $executorId on $hostPort: $reason")
+ }
+
+ /**
+ * Cleans up the TaskScheduler's state for tracking the given task.
+ */
+ private def cleanupTaskState(tid: Long): Unit = {
+ taskIdToTaskSetManager.remove(tid)
+ taskIdToExecutorId.remove(tid).foreach { executorId =>
+ executorIdToRunningTaskIds.get(executorId).foreach { _.remove(tid) }
+ }
+ }
+
/**
* Remove an executor from all our data structures and mark it as lost. If the executor's loss
* reason is not yet known, do not yet remove its association with its host nor update the status
* of any running tasks, since the loss reason defines whether we'll fail those tasks.
*/
private def removeExecutor(executorId: String, reason: ExecutorLossReason) {
- activeExecutorIds -= executorId
+ // The tasks on the lost executor may not send any more status updates (because the executor
+ // has been lost), so they should be cleaned up here.
+ executorIdToRunningTaskIds.remove(executorId).foreach { taskIds =>
+ logDebug("Cleaning up TaskScheduler state for tasks " +
+ s"${taskIds.mkString("[", ",", "]")} on failed executor $executorId")
+ // We do not notify the TaskSetManager of the task failures because that will
+ // happen below in the rootPool.executorLost() call.
+ taskIds.foreach(cleanupTaskState)
+ }
+
val host = executorIdToHost(executorId)
- val execs = executorsByHost.getOrElse(host, new HashSet)
+ val execs = hostToExecutors.getOrElse(host, new HashSet)
execs -= executorId
if (execs.isEmpty) {
- executorsByHost -= host
+ hostToExecutors -= host
for (rack <- getRackForHost(host); hosts <- hostsByRack.get(rack)) {
hosts -= host
if (hosts.isEmpty) {
@@ -515,6 +624,7 @@ private[spark] class TaskSchedulerImpl(
executorIdToHost -= executorId
rootPool.executorLost(executorId, host, reason)
}
+ blacklistTrackerOpt.foreach(_.handleRemovedExecutor(executorId))
}
def executorAdded(execId: String, host: String) {
@@ -522,11 +632,11 @@ private[spark] class TaskSchedulerImpl(
}
def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized {
- executorsByHost.get(host).map(_.toSet)
+ hostToExecutors.get(host).map(_.toSet)
}
def hasExecutorsAliveOnHost(host: String): Boolean = synchronized {
- executorsByHost.contains(host)
+ hostToExecutors.contains(host)
}
def hasHostAliveOnRack(rack: String): Boolean = synchronized {
@@ -534,7 +644,19 @@ private[spark] class TaskSchedulerImpl(
}
def isExecutorAlive(execId: String): Boolean = synchronized {
- activeExecutorIds.contains(execId)
+ executorIdToRunningTaskIds.contains(execId)
+ }
+
+ def isExecutorBusy(execId: String): Boolean = synchronized {
+ executorIdToRunningTaskIds.get(execId).exists(_.nonEmpty)
+ }
+
+ /**
+ * Get a snapshot of the currently blacklisted nodes for the entire application. This is
+ * thread-safe -- it can be called without a lock on the TaskScheduler.
+ */
+ def nodeBlacklist(): scala.collection.immutable.Set[String] = {
+ blacklistTrackerOpt.map(_.nodeBlacklist()).getOrElse(scala.collection.immutable.Set())
}
// By default, rack is unknown
@@ -545,6 +667,11 @@ private[spark] class TaskSchedulerImpl(
return
}
while (!backend.isReady) {
+ // Might take a while for backend to be ready if it is waiting on resources.
+ if (sc.stopped.get) {
+ // For example: the master removes the application for some reason
+ throw new IllegalStateException("Spark context stopped while waiting for backend")
+ }
synchronized {
this.wait(100)
}
@@ -566,20 +693,40 @@ private[spark] class TaskSchedulerImpl(
}
}
+ /**
+ * Marks the task has completed in all TaskSetManagers for the given stage.
+ *
+ * After stage failure and retry, there may be multiple TaskSetManagers for the stage.
+ * If an earlier attempt of a stage completes a task, we should ensure that the later attempts
+ * do not also submit those same tasks. That also means that a task completion from an earlier
+ * attempt can lead to the entire stage getting marked as successful.
+ */
+ private[scheduler] def markPartitionCompletedInAllTaskSets(
+ stageId: Int,
+ partitionId: Int,
+ taskInfo: TaskInfo) = {
+ taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm =>
+ tsm.markPartitionCompleted(partitionId, taskInfo)
+ }
+ }
+
}
private[spark] object TaskSchedulerImpl {
+
+ val SCHEDULER_MODE_PROPERTY = "spark.scheduler.mode"
+
/**
* Used to balance containers across hosts.
*
* Accepts a map of hosts to resource offers for that host, and returns a prioritized list of
- * resource offers representing the order in which the offers should be used. The resource
+ * resource offers representing the order in which the offers should be used. The resource
* offers are ordered such that we'll allocate one container on each host before allocating a
* second container on any host, and so on, in order to reduce the damage if a host fails.
*
- * For example, given , , , returns
- * [o1, o5, o4, 02, o6, o3]
+ * For example, given {@literal }, {@literal } and
+ * {@literal }, returns {@literal [o1, o5, o4, o2, o6, o3]}.
*/
def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = {
val _keyList = new ArrayBuffer[K](map.size)
@@ -597,10 +744,10 @@ private[spark] object TaskSchedulerImpl {
while (found) {
found = false
for (key <- keyList) {
- val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null)
+ val containerList: ArrayBuffer[T] = map.getOrElse(key, null)
assert(containerList != null)
// Get the index'th entry for this host - if present
- if (index < containerList.size){
+ if (index < containerList.size) {
retval += containerList.apply(index)
found = true
}
@@ -611,4 +758,16 @@ private[spark] object TaskSchedulerImpl {
retval.toList
}
+ private def maybeCreateBlacklistTracker(sc: SparkContext): Option[BlacklistTracker] = {
+ if (BlacklistTracker.isBlacklistEnabled(sc.conf)) {
+ val executorAllocClient: Option[ExecutorAllocationClient] = sc.schedulerBackend match {
+ case b: ExecutorAllocationClient => Some(b)
+ case _ => None
+ }
+ Some(new BlacklistTracker(sc, executorAllocClient))
+ } else {
+ None
+ }
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
index be8526ba9b94..517c8991aed7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
@@ -29,7 +29,7 @@ private[spark] class TaskSet(
val stageAttemptId: Int,
val priority: Int,
val properties: Properties) {
- val id: String = stageId + "." + stageAttemptId
+ val id: String = stageId + "." + stageAttemptId
override def toString: String = "TaskSet " + id
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala
new file mode 100644
index 000000000000..e815b7e0cf6c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler
+
+import scala.collection.mutable.{HashMap, HashSet}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config
+import org.apache.spark.util.Clock
+
+/**
+ * Handles blacklisting executors and nodes within a taskset. This includes blacklisting specific
+ * (task, executor) / (task, nodes) pairs, and also completely blacklisting executors and nodes
+ * for the entire taskset.
+ *
+ * It also must store sufficient information in task failures for application level blacklisting,
+ * which is handled by [[BlacklistTracker]]. Note that BlacklistTracker does not know anything
+ * about task failures until a taskset completes successfully.
+ *
+ * THREADING: This class is a helper to [[TaskSetManager]]; as with the methods in
+ * [[TaskSetManager]] this class is designed only to be called from code with a lock on the
+ * TaskScheduler (e.g. its event handlers). It should not be called from other threads.
+ */
+private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, val clock: Clock)
+ extends Logging {
+
+ private val MAX_TASK_ATTEMPTS_PER_EXECUTOR = conf.get(config.MAX_TASK_ATTEMPTS_PER_EXECUTOR)
+ private val MAX_TASK_ATTEMPTS_PER_NODE = conf.get(config.MAX_TASK_ATTEMPTS_PER_NODE)
+ private val MAX_FAILURES_PER_EXEC_STAGE = conf.get(config.MAX_FAILURES_PER_EXEC_STAGE)
+ private val MAX_FAILED_EXEC_PER_NODE_STAGE = conf.get(config.MAX_FAILED_EXEC_PER_NODE_STAGE)
+
+ /**
+ * A map from each executor to the task failures on that executor. This is used for blacklisting
+ * within this taskset, and it is also relayed onto [[BlacklistTracker]] for app-level
+ * blacklisting if this taskset completes successfully.
+ */
+ val execToFailures = new HashMap[String, ExecutorFailuresInTaskSet]()
+
+ /**
+ * Map from node to all executors on it with failures. Needed because we want to know about
+ * executors on a node even after they have died. (We don't want to bother tracking the
+ * node -> execs mapping in the usual case when there aren't any failures).
+ */
+ private val nodeToExecsWithFailures = new HashMap[String, HashSet[String]]()
+ private val nodeToBlacklistedTaskIndexes = new HashMap[String, HashSet[Int]]()
+ private val blacklistedExecs = new HashSet[String]()
+ private val blacklistedNodes = new HashSet[String]()
+
+ /**
+ * Return true if this executor is blacklisted for the given task. This does *not*
+ * need to return true if the executor is blacklisted for the entire stage, or blacklisted
+ * for the entire application. That is to keep this method as fast as possible in the inner-loop
+ * of the scheduler, where those filters will have already been applied.
+ */
+ def isExecutorBlacklistedForTask(executorId: String, index: Int): Boolean = {
+ execToFailures.get(executorId).exists { execFailures =>
+ execFailures.getNumTaskFailures(index) >= MAX_TASK_ATTEMPTS_PER_EXECUTOR
+ }
+ }
+
+ def isNodeBlacklistedForTask(node: String, index: Int): Boolean = {
+ nodeToBlacklistedTaskIndexes.get(node).exists(_.contains(index))
+ }
+
+ /**
+ * Return true if this executor is blacklisted for the given stage. Completely ignores whether
+ * the executor is blacklisted for the entire application (or anything to do with the node the
+ * executor is on). That is to keep this method as fast as possible in the inner-loop of the
+ * scheduler, where those filters will already have been applied.
+ */
+ def isExecutorBlacklistedForTaskSet(executorId: String): Boolean = {
+ blacklistedExecs.contains(executorId)
+ }
+
+ def isNodeBlacklistedForTaskSet(node: String): Boolean = {
+ blacklistedNodes.contains(node)
+ }
+
+ private[scheduler] def updateBlacklistForFailedTask(
+ host: String,
+ exec: String,
+ index: Int): Unit = {
+ val execFailures = execToFailures.getOrElseUpdate(exec, new ExecutorFailuresInTaskSet(host))
+ execFailures.updateWithFailure(index, clock.getTimeMillis())
+
+ // check if this task has also failed on other executors on the same host -- if its gone
+ // over the limit, blacklist this task from the entire host.
+ val execsWithFailuresOnNode = nodeToExecsWithFailures.getOrElseUpdate(host, new HashSet())
+ execsWithFailuresOnNode += exec
+ val failuresOnHost = execsWithFailuresOnNode.toIterator.flatMap { exec =>
+ execToFailures.get(exec).map { failures =>
+ // We count task attempts here, not the number of unique executors with failures. This is
+ // because jobs are aborted based on the number task attempts; if we counted unique
+ // executors, it would be hard to config to ensure that you try another
+ // node before hitting the max number of task failures.
+ failures.getNumTaskFailures(index)
+ }
+ }.sum
+ if (failuresOnHost >= MAX_TASK_ATTEMPTS_PER_NODE) {
+ nodeToBlacklistedTaskIndexes.getOrElseUpdate(host, new HashSet()) += index
+ }
+
+ // Check if enough tasks have failed on the executor to blacklist it for the entire stage.
+ if (execFailures.numUniqueTasksWithFailures >= MAX_FAILURES_PER_EXEC_STAGE) {
+ if (blacklistedExecs.add(exec)) {
+ logInfo(s"Blacklisting executor ${exec} for stage $stageId")
+ // This executor has been pushed into the blacklist for this stage. Let's check if it
+ // pushes the whole node into the blacklist.
+ val blacklistedExecutorsOnNode =
+ execsWithFailuresOnNode.filter(blacklistedExecs.contains(_))
+ if (blacklistedExecutorsOnNode.size >= MAX_FAILED_EXEC_PER_NODE_STAGE) {
+ if (blacklistedNodes.add(host)) {
+ logInfo(s"Blacklisting ${host} for stage $stageId")
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 114468c48c44..705b8961956b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -19,20 +19,18 @@ package org.apache.spark.scheduler
import java.io.NotSerializableException
import java.nio.ByteBuffer
-import java.util.Arrays
import java.util.concurrent.ConcurrentLinkedQueue
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-import scala.math.{min, max}
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.math.max
import scala.util.control.NonFatal
import org.apache.spark._
-import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.SchedulingMode._
import org.apache.spark.TaskState.TaskState
-import org.apache.spark.util.{Clock, SystemClock, Utils}
+import org.apache.spark.util.{AccumulatorV2, Clock, SystemClock, Utils}
+import org.apache.spark.util.collection.MedianHeap
/**
* Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of
@@ -53,19 +51,14 @@ private[spark] class TaskSetManager(
sched: TaskSchedulerImpl,
val taskSet: TaskSet,
val maxTaskFailures: Int,
- clock: Clock = new SystemClock())
- extends Schedulable with Logging {
+ blacklistTracker: Option[BlacklistTracker] = None,
+ clock: Clock = new SystemClock()) extends Schedulable with Logging {
- val conf = sched.sc.conf
+ private val conf = sched.sc.conf
- /*
- * Sometimes if an executor is dead or in an otherwise invalid state, the driver
- * does not realize right away leading to repeated task failures. If enabled,
- * this temporarily prevents a task from re-launching on an executor where
- * it just failed.
- */
- private val EXECUTOR_TASK_BLACKLIST_TIMEOUT =
- conf.getLong("spark.scheduler.executorTaskBlacklistTime", 0L)
+ // SPARK-21563 make a copy of the jars/files so they are consistent across the TaskSet
+ private val addedJars = HashMap[String, Long](sched.sc.addedJars.toSeq: _*)
+ private val addedFiles = HashMap[String, Long](sched.sc.addedFiles.toSeq: _*)
// Quantile of tasks at which to start speculation
val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75)
@@ -74,49 +67,70 @@ private[spark] class TaskSetManager(
// Limit of bytes for total size of results (default is 1GB)
val maxResultSize = Utils.getMaxResultSize(conf)
+ val speculationEnabled = conf.getBoolean("spark.speculation", false)
+
// Serializer for closures and tasks.
val env = SparkEnv.get
val ser = env.closureSerializer.newInstance()
val tasks = taskSet.tasks
+ private[scheduler] val partitionToIndex = tasks.zipWithIndex
+ .map { case (t, idx) => t.partitionId -> idx }.toMap
val numTasks = tasks.length
val copiesRunning = new Array[Int](numTasks)
+
+ // For each task, tracks whether a copy of the task has succeeded. A task will also be
+ // marked as "succeeded" if it failed with a fetch failure, in which case it should not
+ // be re-run because the missing map data needs to be regenerated first.
val successful = new Array[Boolean](numTasks)
private val numFailures = new Array[Int](numTasks)
- // key is taskId, value is a Map of executor id to when it failed
- private val failedExecutors = new HashMap[Int, HashMap[String, Long]]()
val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
- var tasksSuccessful = 0
+ private[scheduler] var tasksSuccessful = 0
- var weight = 1
- var minShare = 0
+ val weight = 1
+ val minShare = 0
var priority = taskSet.priority
var stageId = taskSet.stageId
- var name = "TaskSet_" + taskSet.stageId.toString
+ val name = "TaskSet_" + taskSet.id
var parent: Pool = null
- var totalResultSize = 0L
- var calculatedTasks = 0
+ private var totalResultSize = 0L
+ private var calculatedTasks = 0
+
+ private[scheduler] val taskSetBlacklistHelperOpt: Option[TaskSetBlacklist] = {
+ blacklistTracker.map { _ =>
+ new TaskSetBlacklist(conf, stageId, clock)
+ }
+ }
- val runningTasksSet = new HashSet[Long]
+ private[scheduler] val runningTasksSet = new HashSet[Long]
override def runningTasks: Int = runningTasksSet.size
+ def someAttemptSucceeded(tid: Long): Boolean = {
+ successful(taskInfos(tid).index)
+ }
+
// True once no more tasks should be launched for this task set manager. TaskSetManagers enter
// the zombie state once at least one attempt of each task has completed successfully, or if the
// task set is aborted (for example, because it was killed). TaskSetManagers remain in the zombie
// state until all tasks have finished running; we keep TaskSetManagers that are in the zombie
// state in order to continue to track and account for the running tasks.
// TODO: We should kill any running task attempts when the task set manager becomes a zombie.
- var isZombie = false
+ private[scheduler] var isZombie = false
// Set of pending tasks for each executor. These collections are actually
// treated as stacks, in which new tasks are added to the end of the
// ArrayBuffer and removed from the end. This makes it faster to detect
// tasks that repeatedly fail because whenever a task failed, it is put
- // back at the head of the stack. They are also only cleaned up lazily;
- // when a task is launched, it remains in all the pending lists except
- // the one that it was launched from, but gets removed from them later.
+ // back at the head of the stack. These collections may contain duplicates
+ // for two reasons:
+ // (1): Tasks are only removed lazily; when a task is launched, it remains
+ // in all the pending lists except the one that it was launched from.
+ // (2): Tasks may be re-added to these lists multiple times as a result
+ // of failures.
+ // Duplicates are handled in dequeueTaskFromList, which ensures that a
+ // task hasn't already started running before launching it.
private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]]
// Set of pending tasks for each host. Similar to pendingTasksForExecutor,
@@ -127,17 +141,22 @@ private[spark] class TaskSetManager(
private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]]
// Set containing pending tasks with no locality preferences.
- var pendingTasksWithNoPrefs = new ArrayBuffer[Int]
+ private[scheduler] var pendingTasksWithNoPrefs = new ArrayBuffer[Int]
// Set containing all pending tasks (also used as a stack, as above).
- val allPendingTasks = new ArrayBuffer[Int]
+ private val allPendingTasks = new ArrayBuffer[Int]
// Tasks that can be speculated. Since these will be a small fraction of total
// tasks, we'll just hold them in a HashSet.
- val speculatableTasks = new HashSet[Int]
+ private[scheduler] val speculatableTasks = new HashSet[Int]
// Task index, start and finish time for each task attempt (indexed by task ID)
- val taskInfos = new HashMap[Long, TaskInfo]
+ private[scheduler] val taskInfos = new HashMap[Long, TaskInfo]
+
+ // Use a MedianHeap to record durations of successful tasks so we know when to launch
+ // speculative tasks. This is only used when speculation is enabled, to avoid the overhead
+ // of inserting into the heap when the heap won't be used.
+ val successfulTaskDurations = new MedianHeap()
// How frequently to reprint duplicate exceptions in full, in milliseconds
val EXCEPTION_PRINT_INTERVAL =
@@ -146,7 +165,7 @@ private[spark] class TaskSetManager(
// Map of recent exceptions (identified by string representation and top stack frame) to
// duplicate count (how many times the same exception has appeared) and time the full exception
// was printed. This should ideally be an LRU map that can drop old exceptions automatically.
- val recentExceptions = HashMap[String, (Int, Long)]()
+ private val recentExceptions = HashMap[String, (Int, Long)]()
// Figure out the current map output tracker epoch and set it on all tasks
val epoch = sched.mapOutputTracker.getEpoch
@@ -161,59 +180,57 @@ private[spark] class TaskSetManager(
addPendingTask(i)
}
- // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling
- var myLocalityLevels = computeValidLocalityLevels()
- var localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level
+ /**
+ * Track the set of locality levels which are valid given the tasks locality preferences and
+ * the set of currently available executors. This is updated as executors are added and removed.
+ * This allows a performance optimization, of skipping levels that aren't relevant (eg., skip
+ * PROCESS_LOCAL if no tasks could be run PROCESS_LOCAL for the current set of executors).
+ */
+ private[scheduler] var myLocalityLevels = computeValidLocalityLevels()
+
+ // Time to wait at each level
+ private[scheduler] var localityWaits = myLocalityLevels.map(getLocalityWait)
// Delay scheduling variables: we keep track of our current locality level and the time we
// last launched a task at that level, and move up a level when localityWaits[curLevel] expires.
// We then move down if we manage to launch a "more local" task.
- var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels
- var lastLaunchTime = clock.getTimeMillis() // Time we last launched a task at this level
+ private var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels
+ private var lastLaunchTime = clock.getTimeMillis() // Time we last launched a task at this level
override def schedulableQueue: ConcurrentLinkedQueue[Schedulable] = null
override def schedulingMode: SchedulingMode = SchedulingMode.NONE
- var emittedTaskSizeWarning = false
+ private[scheduler] var emittedTaskSizeWarning = false
/** Add a task to all the pending-task lists that it should be on. */
- private def addPendingTask(index: Int) {
- // Utility method that adds `index` to a list only if it's not already there
- def addTo(list: ArrayBuffer[Int]) {
- if (!list.contains(index)) {
- list += index
- }
- }
-
+ private[spark] def addPendingTask(index: Int) {
for (loc <- tasks(index).preferredLocations) {
loc match {
case e: ExecutorCacheTaskLocation =>
- addTo(pendingTasksForExecutor.getOrElseUpdate(e.executorId, new ArrayBuffer))
- case e: HDFSCacheTaskLocation => {
+ pendingTasksForExecutor.getOrElseUpdate(e.executorId, new ArrayBuffer) += index
+ case e: HDFSCacheTaskLocation =>
val exe = sched.getExecutorsAliveOnHost(loc.host)
exe match {
- case Some(set) => {
+ case Some(set) =>
for (e <- set) {
- addTo(pendingTasksForExecutor.getOrElseUpdate(e, new ArrayBuffer))
+ pendingTasksForExecutor.getOrElseUpdate(e, new ArrayBuffer) += index
}
logInfo(s"Pending task $index has a cached location at ${e.host} " +
", where there are executors " + set.mkString(","))
- }
case None => logDebug(s"Pending task $index has a cached location at ${e.host} " +
", but there are no executors alive there.")
}
- }
- case _ => Unit
+ case _ =>
}
- addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer))
+ pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer) += index
for (rack <- sched.getRackForHost(loc.host)) {
- addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer))
+ pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer) += index
}
}
if (tasks(index).preferredLocations == Nil) {
- addTo(pendingTasksWithNoPrefs)
+ pendingTasksWithNoPrefs += index
}
allPendingTasks += index // No point scanning this whole list to find the old task there
@@ -249,12 +266,15 @@ private[spark] class TaskSetManager(
* This method also cleans up any tasks in the list that have already
* been launched, since we want that to happen lazily.
*/
- private def dequeueTaskFromList(execId: String, list: ArrayBuffer[Int]): Option[Int] = {
+ private def dequeueTaskFromList(
+ execId: String,
+ host: String,
+ list: ArrayBuffer[Int]): Option[Int] = {
var indexOffset = list.size
while (indexOffset > 0) {
indexOffset -= 1
val index = list(indexOffset)
- if (!executorIsBlacklisted(execId, index)) {
+ if (!isTaskBlacklistedOnExecOrNode(index, execId, host)) {
// This should almost always be list.trimEnd(1) to remove tail
list.remove(indexOffset)
if (copiesRunning(index) == 0 && !successful(index)) {
@@ -270,19 +290,11 @@ private[spark] class TaskSetManager(
taskAttempts(taskIndex).exists(_.host == host)
}
- /**
- * Is this re-execution of a failed task on an executor it already failed in before
- * EXECUTOR_TASK_BLACKLIST_TIMEOUT has elapsed ?
- */
- private def executorIsBlacklisted(execId: String, taskId: Int): Boolean = {
- if (failedExecutors.contains(taskId)) {
- val failed = failedExecutors.get(taskId).get
-
- return failed.contains(execId) &&
- clock.getTimeMillis() - failed.get(execId).get < EXECUTOR_TASK_BLACKLIST_TIMEOUT
+ private def isTaskBlacklistedOnExecOrNode(index: Int, execId: String, host: String): Boolean = {
+ taskSetBlacklistHelperOpt.exists { blacklist =>
+ blacklist.isNodeBlacklistedForTask(host, index) ||
+ blacklist.isExecutorBlacklistedForTask(execId, index)
}
-
- false
}
/**
@@ -296,8 +308,10 @@ private[spark] class TaskSetManager(
{
speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set
- def canRunOnHost(index: Int): Boolean =
- !hasAttemptOnHost(index, host) && !executorIsBlacklisted(execId, index)
+ def canRunOnHost(index: Int): Boolean = {
+ !hasAttemptOnHost(index, host) &&
+ !isTaskBlacklistedOnExecOrNode(index, execId, host)
+ }
if (!speculatableTasks.isEmpty) {
// Check for process-local tasks; note that tasks can be process-local
@@ -340,7 +354,7 @@ private[spark] class TaskSetManager(
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
for (rack <- sched.getRackForHost(host)) {
for (index <- speculatableTasks if canRunOnHost(index)) {
- val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost)
+ val racks = tasks(index).preferredLocations.map(_.host).flatMap(sched.getRackForHost)
if (racks.contains(rack)) {
speculatableTasks -= index
return Some((index, TaskLocality.RACK_LOCAL))
@@ -370,19 +384,19 @@ private[spark] class TaskSetManager(
private def dequeueTask(execId: String, host: String, maxLocality: TaskLocality.Value)
: Option[(Int, TaskLocality.Value, Boolean)] =
{
- for (index <- dequeueTaskFromList(execId, getPendingTasksForExecutor(execId))) {
+ for (index <- dequeueTaskFromList(execId, host, getPendingTasksForExecutor(execId))) {
return Some((index, TaskLocality.PROCESS_LOCAL, false))
}
if (TaskLocality.isAllowed(maxLocality, TaskLocality.NODE_LOCAL)) {
- for (index <- dequeueTaskFromList(execId, getPendingTasksForHost(host))) {
+ for (index <- dequeueTaskFromList(execId, host, getPendingTasksForHost(host))) {
return Some((index, TaskLocality.NODE_LOCAL, false))
}
}
if (TaskLocality.isAllowed(maxLocality, TaskLocality.NO_PREF)) {
// Look for noPref tasks after NODE_LOCAL for minimize cross-rack traffic
- for (index <- dequeueTaskFromList(execId, pendingTasksWithNoPrefs)) {
+ for (index <- dequeueTaskFromList(execId, host, pendingTasksWithNoPrefs)) {
return Some((index, TaskLocality.PROCESS_LOCAL, false))
}
}
@@ -390,14 +404,14 @@ private[spark] class TaskSetManager(
if (TaskLocality.isAllowed(maxLocality, TaskLocality.RACK_LOCAL)) {
for {
rack <- sched.getRackForHost(host)
- index <- dequeueTaskFromList(execId, getPendingTasksForRack(rack))
+ index <- dequeueTaskFromList(execId, host, getPendingTasksForRack(rack))
} {
return Some((index, TaskLocality.RACK_LOCAL, false))
}
}
if (TaskLocality.isAllowed(maxLocality, TaskLocality.ANY)) {
- for (index <- dequeueTaskFromList(execId, allPendingTasks)) {
+ for (index <- dequeueTaskFromList(execId, host, allPendingTasks)) {
return Some((index, TaskLocality.ANY, false))
}
}
@@ -425,7 +439,11 @@ private[spark] class TaskSetManager(
maxLocality: TaskLocality.TaskLocality)
: Option[TaskDescription] =
{
- if (!isZombie) {
+ val offerBlacklisted = taskSetBlacklistHelperOpt.exists { blacklist =>
+ blacklist.isNodeBlacklistedForTaskSet(host) ||
+ blacklist.isExecutorBlacklistedForTaskSet(execId)
+ }
+ if (!isZombie && !offerBlacklisted) {
val curTime = clock.getTimeMillis()
var allowedLocality = maxLocality
@@ -438,66 +456,77 @@ private[spark] class TaskSetManager(
}
}
- dequeueTask(execId, host, allowedLocality) match {
- case Some((index, taskLocality, speculative)) => {
- // Found a task; do some bookkeeping and return a task description
- val task = tasks(index)
- val taskId = sched.newTaskId()
- // Do various bookkeeping
- copiesRunning(index) += 1
- val attemptNum = taskAttempts(index).size
- val info = new TaskInfo(taskId, index, attemptNum, curTime,
- execId, host, taskLocality, speculative)
- taskInfos(taskId) = info
- taskAttempts(index) = info :: taskAttempts(index)
- // Update our locality level for delay scheduling
- // NO_PREF will not affect the variables related to delay scheduling
- if (maxLocality != TaskLocality.NO_PREF) {
- currentLocalityIndex = getLocalityIndex(taskLocality)
- lastLaunchTime = curTime
- }
- // Serialize and return the task
- val startTime = clock.getTimeMillis()
- val serializedTask: ByteBuffer = try {
- Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser)
- } catch {
- // If the task cannot be serialized, then there's no point to re-attempt the task,
- // as it will always fail. So just abort the whole task-set.
- case NonFatal(e) =>
- val msg = s"Failed to serialize task $taskId, not attempting to retry it."
- logError(msg, e)
- abort(s"$msg Exception during serialization: $e")
- throw new TaskNotSerializableException(e)
- }
- if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 &&
- !emittedTaskSizeWarning) {
- emittedTaskSizeWarning = true
- logWarning(s"Stage ${task.stageId} contains a task of very large size " +
- s"(${serializedTask.limit / 1024} KB). The maximum recommended task size is " +
- s"${TaskSetManager.TASK_SIZE_TO_WARN_KB} KB.")
- }
- addRunningTask(taskId)
-
- // We used to log the time it takes to serialize the task, but task size is already
- // a good proxy to task serialization time.
- // val timeTaken = clock.getTime() - startTime
- val taskName = s"task ${info.id} in stage ${taskSet.id}"
- logInfo(s"Starting $taskName (TID $taskId, $host, partition ${task.partitionId}," +
- s"$taskLocality, ${serializedTask.limit} bytes)")
-
- sched.dagScheduler.taskStarted(task, info)
- return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId,
- taskName, index, serializedTask))
+ dequeueTask(execId, host, allowedLocality).map { case ((index, taskLocality, speculative)) =>
+ // Found a task; do some bookkeeping and return a task description
+ val task = tasks(index)
+ val taskId = sched.newTaskId()
+ // Do various bookkeeping
+ copiesRunning(index) += 1
+ val attemptNum = taskAttempts(index).size
+ val info = new TaskInfo(taskId, index, attemptNum, curTime,
+ execId, host, taskLocality, speculative)
+ taskInfos(taskId) = info
+ taskAttempts(index) = info :: taskAttempts(index)
+ // Update our locality level for delay scheduling
+ // NO_PREF will not affect the variables related to delay scheduling
+ if (maxLocality != TaskLocality.NO_PREF) {
+ currentLocalityIndex = getLocalityIndex(taskLocality)
+ lastLaunchTime = curTime
}
- case _ =>
+ // Serialize and return the task
+ val serializedTask: ByteBuffer = try {
+ ser.serialize(task)
+ } catch {
+ // If the task cannot be serialized, then there's no point to re-attempt the task,
+ // as it will always fail. So just abort the whole task-set.
+ case NonFatal(e) =>
+ val msg = s"Failed to serialize task $taskId, not attempting to retry it."
+ logError(msg, e)
+ abort(s"$msg Exception during serialization: $e")
+ throw new TaskNotSerializableException(e)
+ }
+ if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 &&
+ !emittedTaskSizeWarning) {
+ emittedTaskSizeWarning = true
+ logWarning(s"Stage ${task.stageId} contains a task of very large size " +
+ s"(${serializedTask.limit / 1024} KB). The maximum recommended task size is " +
+ s"${TaskSetManager.TASK_SIZE_TO_WARN_KB} KB.")
+ }
+ addRunningTask(taskId)
+
+ // We used to log the time it takes to serialize the task, but task size is already
+ // a good proxy to task serialization time.
+ // val timeTaken = clock.getTime() - startTime
+ val taskName = s"task ${info.id} in stage ${taskSet.id}"
+ logInfo(s"Starting $taskName (TID $taskId, $host, executor ${info.executorId}, " +
+ s"partition ${task.partitionId}, $taskLocality, ${serializedTask.limit} bytes)")
+
+ sched.dagScheduler.taskStarted(task, info)
+ new TaskDescription(
+ taskId,
+ attemptNum,
+ execId,
+ taskName,
+ index,
+ addedFiles,
+ addedJars,
+ task.localProperties,
+ serializedTask)
}
+ } else {
+ None
}
- None
}
private def maybeFinishTaskSet() {
if (isZombie && runningTasks == 0) {
sched.taskSetFinished(this)
+ if (tasksSuccessful == numTasks) {
+ blacklistTracker.foreach(_.updateBlacklistForSuccessfulTaskSet(
+ taskSet.stageId,
+ taskSet.stageAttemptId,
+ taskSetBlacklistHelperOpt.get.execToFailures))
+ }
}
}
@@ -557,9 +586,9 @@ private[spark] class TaskSetManager(
// Jump to the next locality level, and reset lastLaunchTime so that the next locality
// wait timer doesn't immediately expire
lastLaunchTime += localityWaits(currentLocalityIndex)
- currentLocalityIndex += 1
- logDebug(s"Moving to ${myLocalityLevels(currentLocalityIndex)} after waiting for " +
+ logDebug(s"Moving to ${myLocalityLevels(currentLocalityIndex + 1)} after waiting for " +
s"${localityWaits(currentLocalityIndex)}ms")
+ currentLocalityIndex += 1
} else {
return myLocalityLevels(currentLocalityIndex)
}
@@ -580,12 +609,84 @@ private[spark] class TaskSetManager(
index
}
+ /**
+ * Check whether the given task set has been blacklisted to the point that it can't run anywhere.
+ *
+ * It is possible that this taskset has become impossible to schedule *anywhere* due to the
+ * blacklist. The most common scenario would be if there are fewer executors than
+ * spark.task.maxFailures. We need to detect this so we can fail the task set, otherwise the job
+ * will hang.
+ *
+ * There's a tradeoff here: we could make sure all tasks in the task set are schedulable, but that
+ * would add extra time to each iteration of the scheduling loop. Here, we take the approach of
+ * making sure at least one of the unscheduled tasks is schedulable. This means we may not detect
+ * the hang as quickly as we could have, but we'll always detect the hang eventually, and the
+ * method is faster in the typical case. In the worst case, this method can take
+ * O(maxTaskFailures + numTasks) time, but it will be faster when there haven't been any task
+ * failures (this is because the method picks one unscheduled task, and then iterates through each
+ * executor until it finds one that the task isn't blacklisted on).
+ */
+ private[scheduler] def abortIfCompletelyBlacklisted(
+ hostToExecutors: HashMap[String, HashSet[String]]): Unit = {
+ taskSetBlacklistHelperOpt.foreach { taskSetBlacklist =>
+ val appBlacklist = blacklistTracker.get
+ // Only look for unschedulable tasks when at least one executor has registered. Otherwise,
+ // task sets will be (unnecessarily) aborted in cases when no executors have registered yet.
+ if (hostToExecutors.nonEmpty) {
+ // find any task that needs to be scheduled
+ val pendingTask: Option[Int] = {
+ // usually this will just take the last pending task, but because of the lazy removal
+ // from each list, we may need to go deeper in the list. We poll from the end because
+ // failed tasks are put back at the end of allPendingTasks, so we're more likely to find
+ // an unschedulable task this way.
+ val indexOffset = allPendingTasks.lastIndexWhere { indexInTaskSet =>
+ copiesRunning(indexInTaskSet) == 0 && !successful(indexInTaskSet)
+ }
+ if (indexOffset == -1) {
+ None
+ } else {
+ Some(allPendingTasks(indexOffset))
+ }
+ }
+
+ pendingTask.foreach { indexInTaskSet =>
+ // try to find some executor this task can run on. Its possible that some *other*
+ // task isn't schedulable anywhere, but we will discover that in some later call,
+ // when that unschedulable task is the last task remaining.
+ val blacklistedEverywhere = hostToExecutors.forall { case (host, execsOnHost) =>
+ // Check if the task can run on the node
+ val nodeBlacklisted =
+ appBlacklist.isNodeBlacklisted(host) ||
+ taskSetBlacklist.isNodeBlacklistedForTaskSet(host) ||
+ taskSetBlacklist.isNodeBlacklistedForTask(host, indexInTaskSet)
+ if (nodeBlacklisted) {
+ true
+ } else {
+ // Check if the task can run on any of the executors
+ execsOnHost.forall { exec =>
+ appBlacklist.isExecutorBlacklisted(exec) ||
+ taskSetBlacklist.isExecutorBlacklistedForTaskSet(exec) ||
+ taskSetBlacklist.isExecutorBlacklistedForTask(exec, indexInTaskSet)
+ }
+ }
+ }
+ if (blacklistedEverywhere) {
+ val partition = tasks(indexInTaskSet).partitionId
+ abort(s"Aborting $taskSet because task $indexInTaskSet (partition $partition) " +
+ s"cannot run anywhere due to node and executor blacklist. Blacklisting behavior " +
+ s"can be configured via spark.blacklist.*.")
+ }
+ }
+ }
+ }
+ }
+
/**
* Marks the task as getting result and notifies the DAG Scheduler
*/
def handleTaskGettingResult(tid: Long): Unit = {
val info = taskInfos(tid)
- info.markGettingResult()
+ info.markGettingResult(clock.getTimeMillis())
sched.dagScheduler.taskGettingResult(info)
}
@@ -608,25 +709,34 @@ private[spark] class TaskSetManager(
}
/**
- * Marks the task as successful and notifies the DAGScheduler that a task has ended.
+ * Marks a task as successful and notifies the DAGScheduler that the task has ended.
*/
def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]): Unit = {
val info = taskInfos(tid)
val index = info.index
- info.markSuccessful()
+ info.markFinished(TaskState.FINISHED, clock.getTimeMillis())
+ if (speculationEnabled) {
+ successfulTaskDurations.insert(info.duration)
+ }
removeRunningTask(tid)
- // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the
- // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not
- // "deserialize" the value when holding a lock to avoid blocking other threads. So we call
- // "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here.
- // Note: "result.value()" only deserializes the value when it's called at the first time, so
- // here "result.value()" just returns the value and won't block other threads.
- sched.dagScheduler.taskEnded(
- tasks(index), Success, result.value(), result.accumUpdates, info, result.metrics)
+
+ // Kill any other attempts for the same task (since those are unnecessary now that one
+ // attempt completed successfully).
+ for (attemptInfo <- taskAttempts(index) if attemptInfo.running) {
+ logInfo(s"Killing attempt ${attemptInfo.attemptNumber} for task ${attemptInfo.id} " +
+ s"in stage ${taskSet.id} (TID ${attemptInfo.taskId}) on ${attemptInfo.host} " +
+ s"as the attempt ${info.attemptNumber} succeeded on ${info.host}")
+ sched.backend.killTask(
+ attemptInfo.taskId,
+ attemptInfo.executorId,
+ interruptThread = true,
+ reason = "another attempt succeeded")
+ }
if (!successful(index)) {
tasksSuccessful += 1
- logInfo("Finished task %s in stage %s (TID %d) in %d ms on %s (%d/%d)".format(
- info.id, taskSet.id, info.taskId, info.duration, info.host, tasksSuccessful, numTasks))
+ logInfo(s"Finished task ${info.id} in stage ${taskSet.id} (TID ${info.taskId}) in" +
+ s" ${info.duration} ms on ${info.host} (executor ${info.executorId})" +
+ s" ($tasksSuccessful/$numTasks)")
// Mark successful and stop if all the tasks have succeeded.
successful(index) = true
if (tasksSuccessful == numTasks) {
@@ -636,27 +746,51 @@ private[spark] class TaskSetManager(
logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id +
" because task " + index + " has already completed successfully")
}
- failedExecutors.remove(index)
+ // There may be multiple tasksets for this stage -- we let all of them know that the partition
+ // was completed. This may result in some of the tasksets getting completed.
+ sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId, info)
+ // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the
+ // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not
+ // "deserialize" the value when holding a lock to avoid blocking other threads. So we call
+ // "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here.
+ // Note: "result.value()" only deserializes the value when it's called at the first time, so
+ // here "result.value()" just returns the value and won't block other threads.
+ sched.dagScheduler.taskEnded(tasks(index), Success, result.value(), result.accumUpdates, info)
maybeFinishTaskSet()
}
+ private[scheduler] def markPartitionCompleted(partitionId: Int, taskInfo: TaskInfo): Unit = {
+ partitionToIndex.get(partitionId).foreach { index =>
+ if (!successful(index)) {
+ if (speculationEnabled && !isZombie) {
+ successfulTaskDurations.insert(taskInfo.duration)
+ }
+ tasksSuccessful += 1
+ successful(index) = true
+ if (tasksSuccessful == numTasks) {
+ isZombie = true
+ }
+ maybeFinishTaskSet()
+ }
+ }
+ }
+
/**
* Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
* DAG Scheduler.
*/
- def handleFailedTask(tid: Long, state: TaskState, reason: TaskEndReason) {
+ def handleFailedTask(tid: Long, state: TaskState, reason: TaskFailedReason) {
val info = taskInfos(tid)
- if (info.failed) {
+ if (info.failed || info.killed) {
return
}
removeRunningTask(tid)
- info.markFailed()
+ info.markFinished(state, clock.getTimeMillis())
val index = info.index
copiesRunning(index) -= 1
- var taskMetrics : TaskMetrics = null
-
- val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " +
- reason.asInstanceOf[TaskFailedReason].toErrorString
+ var accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty
+ val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}," +
+ s" executor ${info.executorId}): ${reason.toErrorString}"
val failureException: Option[Throwable] = reason match {
case fetchFailed: FetchFailed =>
logWarning(failureReason)
@@ -664,12 +798,12 @@ private[spark] class TaskSetManager(
successful(index) = true
tasksSuccessful += 1
}
- // Not adding to failed executors for FetchFailed.
isZombie = true
None
case ef: ExceptionFailure =>
- taskMetrics = ef.metrics.orNull
+ // ExceptionFailure's might have accumulator updates
+ accumUpdates = ef.accums
if (ef.className == classOf[NotSerializableException].getName) {
// If the task result wasn't serializable, there's no point in trying to re-execute it.
logError("Task %s in stage %s (TID %d) had a not serializable result: %s; not retrying"
@@ -699,13 +833,13 @@ private[spark] class TaskSetManager(
logWarning(failureReason)
} else {
logInfo(
- s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid) on executor ${info.host}: " +
- s"${ef.className} (${ef.description}) [duplicate $dupCount]")
+ s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid) on ${info.host}, executor" +
+ s" ${info.executorId}: ${ef.className} (${ef.description}) [duplicate $dupCount]")
}
ef.exception
case e: ExecutorLostFailure if !e.exitCausedByApp =>
- logInfo(s"Task $tid failed because while it was being computed, its executor" +
+ logInfo(s"Task $tid failed because while it was being computed, its executor " +
"exited for a reason unrelated to the task. Not counting this failure towards the " +
"maximum number of failures for the task.")
None
@@ -713,19 +847,13 @@ private[spark] class TaskSetManager(
case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others
logWarning(failureReason)
None
-
- case e: TaskEndReason =>
- logError("Unknown TaskEndReason: " + e)
- None
}
- // always add to failed executors
- failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()).
- put(info.executorId, clock.getTimeMillis())
- sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics)
- addPendingTask(index)
- if (!isZombie && state != TaskState.KILLED
- && reason.isInstanceOf[TaskFailedReason]
- && reason.asInstanceOf[TaskFailedReason].countTowardsTaskFailures) {
+
+ sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, info)
+
+ if (!isZombie && reason.countTowardsTaskFailures) {
+ taskSetBlacklistHelperOpt.foreach(_.updateBlacklistForFailedTask(
+ info.host, info.executorId, index))
assert (null != failureReason)
numFailures(index) += 1
if (numFailures(index) >= maxTaskFailures) {
@@ -736,6 +864,16 @@ private[spark] class TaskSetManager(
return
}
}
+
+ if (successful(index)) {
+ logInfo(s"Task ${info.id} in stage ${taskSet.id} (TID $tid) failed, but the task will not" +
+ s" be re-executed (either because the task failed with a shuffle data fetch failure," +
+ s" so the previous stage needs to be re-run, or because a different copy of the task" +
+ s" has already succeeded).")
+ } else {
+ addPendingTask(index)
+ }
+
maybeFinishTaskSet()
}
@@ -783,7 +921,8 @@ private[spark] class TaskSetManager(
// and we are not using an external shuffle server which could serve the shuffle outputs.
// The reason is the next stage wouldn't be able to fetch the data from this dead executor
// so we would need to rerun these tasks on other executors.
- if (tasks(0).isInstanceOf[ShuffleMapTask] && !env.blockManager.externalShuffleServiceEnabled) {
+ if (tasks(0).isInstanceOf[ShuffleMapTask] && !env.blockManager.externalShuffleServiceEnabled
+ && !isZombie) {
for ((tid, info) <- taskInfos if info.executorId == execId) {
val index = taskInfos(tid).index
if (successful(index)) {
@@ -793,13 +932,15 @@ private[spark] class TaskSetManager(
addPendingTask(index)
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
// stage finishes when a total of tasks.size tasks finish.
- sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null)
+ sched.dagScheduler.taskEnded(
+ tasks(index), Resubmitted, null, Seq.empty, info)
}
}
}
for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
val exitCausedByApp: Boolean = reason match {
case exited: ExecutorExited => exited.exitCausedByApp
+ case ExecutorKilled => false
case _ => true
}
handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(info.executorId, exitCausedByApp,
@@ -813,10 +954,8 @@ private[spark] class TaskSetManager(
* Check for tasks to be speculated and return true if there are any. This is called periodically
* by the TaskScheduler.
*
- * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
- * we don't scan the whole task set. It might also help to make this sorted by launch time.
*/
- override def checkSpeculatableTasks(): Boolean = {
+ override def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean = {
// Can't speculate if we only have one task, and no need to speculate if the task set is a
// zombie.
if (isZombie || numTasks == 1) {
@@ -825,16 +964,16 @@ private[spark] class TaskSetManager(
var foundTasks = false
val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
+
if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
val time = clock.getTimeMillis()
- val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
- Arrays.sort(durations)
- val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1))
- val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
+ var medianDuration = successfulTaskDurations.median
+ val threshold = max(SPECULATION_MULTIPLIER * medianDuration, minTimeToSpeculation)
// TODO: Threshold should also look at standard deviation of task durations and have a lower
// bound based on that.
logDebug("Task length threshold for speculation: " + threshold)
- for ((tid, info) <- taskInfos) {
+ for (tid <- runningTasksSet) {
+ val info = taskInfos(tid)
val index = info.index
if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
!speculatableTasks.contains(index)) {
@@ -873,18 +1012,18 @@ private[spark] class TaskSetManager(
private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = {
import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY}
val levels = new ArrayBuffer[TaskLocality.TaskLocality]
- if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0 &&
+ if (!pendingTasksForExecutor.isEmpty &&
pendingTasksForExecutor.keySet.exists(sched.isExecutorAlive(_))) {
levels += PROCESS_LOCAL
}
- if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0 &&
+ if (!pendingTasksForHost.isEmpty &&
pendingTasksForHost.keySet.exists(sched.hasExecutorsAliveOnHost(_))) {
levels += NODE_LOCAL
}
if (!pendingTasksWithNoPrefs.isEmpty) {
levels += NO_PREF
}
- if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0 &&
+ if (!pendingTasksForRack.isEmpty &&
pendingTasksForRack.keySet.exists(sched.hasHostAliveOnRack(_))) {
levels += RACK_LOCAL
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index f3d0d8547677..6b49bd699a13 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -22,24 +22,33 @@ import java.nio.ByteBuffer
import org.apache.spark.TaskState.TaskState
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.scheduler.ExecutorLossReason
-import org.apache.spark.util.{SerializableBuffer, Utils}
+import org.apache.spark.util.SerializableBuffer
private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable
private[spark] object CoarseGrainedClusterMessages {
- case object RetrieveSparkProps extends CoarseGrainedClusterMessage
+ case object RetrieveSparkAppConfig extends CoarseGrainedClusterMessage
+
+ case class SparkAppConfig(
+ sparkProperties: Seq[(String, String)],
+ ioEncryptionKey: Option[Array[Byte]])
+ extends CoarseGrainedClusterMessage
+
+ case object RetrieveLastAllocatedExecutorId extends CoarseGrainedClusterMessage
// Driver to executors
case class LaunchTask(data: SerializableBuffer) extends CoarseGrainedClusterMessage
- case class KillTask(taskId: Long, executor: String, interruptThread: Boolean)
+ case class KillTask(taskId: Long, executor: String, interruptThread: Boolean, reason: String)
+ extends CoarseGrainedClusterMessage
+
+ case class KillExecutorsOnHost(host: String)
extends CoarseGrainedClusterMessage
sealed trait RegisterExecutorResponse
- case class RegisteredExecutor(hostname: String) extends CoarseGrainedClusterMessage
- with RegisterExecutorResponse
+ case object RegisteredExecutor extends CoarseGrainedClusterMessage with RegisterExecutorResponse
case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage
with RegisterExecutorResponse
@@ -48,7 +57,7 @@ private[spark] object CoarseGrainedClusterMessages {
case class RegisterExecutor(
executorId: String,
executorRef: RpcEndpointRef,
- hostPort: String,
+ hostname: String,
cores: Int,
logUrls: Map[String, String])
extends CoarseGrainedClusterMessage
@@ -93,7 +102,8 @@ private[spark] object CoarseGrainedClusterMessages {
case class RequestExecutors(
requestedTotal: Int,
localityAwareTasks: Int,
- hostToLocalTaskCount: Map[String, Int])
+ hostToLocalTaskCount: Map[String, Int],
+ nodeBlacklist: Set[String])
extends CoarseGrainedClusterMessage
// Check if an executor was force-killed but for a reason unrelated to the running tasks.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index f71d98feac05..ab6e42690959 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -19,18 +19,22 @@ package org.apache.spark.scheduler.cluster
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
+import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.concurrent.Future
+import scala.concurrent.duration.Duration
+import org.apache.spark.{ExecutorAllocationClient, SparkEnv, SparkException, TaskState}
+import org.apache.spark.internal.Logging
import org.apache.spark.rpc._
-import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState}
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.ENDPOINT_NAME
-import org.apache.spark.util.{ThreadUtils, SerializableBuffer, AkkaUtils, Utils}
+import org.apache.spark.util.{RpcUtils, SerializableBuffer, ThreadUtils, Utils}
/**
- * A scheduler backend that waits for coarse grained executors to connect to it through Akka.
+ * A scheduler backend that waits for coarse-grained executors to connect.
* This backend holds onto each executor for the duration of the Spark job rather than relinquishing
* executors whenever a task is done and asking the scheduler to launch a new executor for
* each new task. Executors may be launched in a variety of ways, such as Mesos tasks for the
@@ -42,49 +46,61 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
extends ExecutorAllocationClient with SchedulerBackend with Logging
{
// Use an atomic variable to track total number of cores in the cluster for simplicity and speed
- var totalCoreCount = new AtomicInteger(0)
+ protected val totalCoreCount = new AtomicInteger(0)
// Total number of executors that are currently registered
- var totalRegisteredExecutors = new AtomicInteger(0)
- val conf = scheduler.sc.conf
- private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
+ protected val totalRegisteredExecutors = new AtomicInteger(0)
+ protected val conf = scheduler.sc.conf
+ private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf)
+ private val defaultAskTimeout = RpcUtils.askRpcTimeout(conf)
// Submit tasks only after (registered resources / total expected resources)
// is equal to at least this value, that is double between 0 and 1.
- var minRegisteredRatio =
+ private val _minRegisteredRatio =
math.min(1, conf.getDouble("spark.scheduler.minRegisteredResourcesRatio", 0))
// Submit tasks after maxRegisteredWaitingTime milliseconds
// if minRegisteredRatio has not yet been reached
- val maxRegisteredWaitingTimeMs =
+ private val maxRegisteredWaitingTimeMs =
conf.getTimeAsMs("spark.scheduler.maxRegisteredResourcesWaitingTime", "30s")
- val createTime = System.currentTimeMillis()
+ private val createTime = System.currentTimeMillis()
+ // Accessing `executorDataMap` in `DriverEndpoint.receive/receiveAndReply` doesn't need any
+ // protection. But accessing `executorDataMap` out of `DriverEndpoint.receive/receiveAndReply`
+ // must be protected by `CoarseGrainedSchedulerBackend.this`. Besides, `executorDataMap` should
+ // only be modified in `DriverEndpoint.receive/receiveAndReply` with protection by
+ // `CoarseGrainedSchedulerBackend.this`.
private val executorDataMap = new HashMap[String, ExecutorData]
+ // Number of executors requested by the cluster manager, [[ExecutorAllocationManager]]
+ @GuardedBy("CoarseGrainedSchedulerBackend.this")
+ private var requestedTotalExecutors = 0
+
// Number of executors requested from the cluster manager that have not registered yet
+ @GuardedBy("CoarseGrainedSchedulerBackend.this")
private var numPendingExecutors = 0
private val listenerBus = scheduler.sc.listenerBus
- // Executors we have requested the cluster manager to kill that have not died yet
- private val executorsPendingToRemove = new HashSet[String]
+ // Executors we have requested the cluster manager to kill that have not died yet; maps
+ // the executor ID to whether it was explicitly killed by the driver (and thus shouldn't
+ // be considered an app-related failure).
+ @GuardedBy("CoarseGrainedSchedulerBackend.this")
+ private val executorsPendingToRemove = new HashMap[String, Boolean]
// A map to store hostname with its possible task number running on it
+ @GuardedBy("CoarseGrainedSchedulerBackend.this")
protected var hostToLocalTaskCount: Map[String, Int] = Map.empty
// The number of pending tasks which is locality required
+ @GuardedBy("CoarseGrainedSchedulerBackend.this")
protected var localityAwareTasks = 0
- // Executors that have been lost, but for which we don't yet know the real exit reason.
- protected val executorsPendingLossReason = new HashSet[String]
+ // The num of current max ExecutorId used to re-register appMaster
+ @volatile protected var currentExecutorIdCounter = 0
class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)])
extends ThreadSafeRpcEndpoint with Logging {
- // If this DriverEndpoint is changed to support multiple threads,
- // then this may need to be changed so that we don't share the serializer
- // instance across threads
- private val ser = SparkEnv.get.closureSerializer.newInstance()
-
- override protected def log = CoarseGrainedSchedulerBackend.this.log
+ // Executors that have been lost, but for which we don't yet know the real exit reason.
+ protected val executorsPendingLossReason = new HashSet[String]
protected val addressToExecutorId = new HashMap[RpcAddress, String]
@@ -120,21 +136,36 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
case ReviveOffers =>
makeOffers()
- case KillTask(taskId, executorId, interruptThread) =>
+ case KillTask(taskId, executorId, interruptThread, reason) =>
executorDataMap.get(executorId) match {
case Some(executorInfo) =>
- executorInfo.executorEndpoint.send(KillTask(taskId, executorId, interruptThread))
+ executorInfo.executorEndpoint.send(
+ KillTask(taskId, executorId, interruptThread, reason))
case None =>
// Ignoring the task kill since the executor is not registered.
logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.")
}
+
+ case KillExecutorsOnHost(host) =>
+ scheduler.getExecutorsAliveOnHost(host).foreach { exec =>
+ killExecutors(exec.toSeq, replace = true, force = true)
+ }
}
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
- case RegisterExecutor(executorId, executorRef, hostPort, cores, logUrls) =>
+ case RegisterExecutor(executorId, executorRef, hostname, cores, logUrls) =>
if (executorDataMap.contains(executorId)) {
- context.reply(RegisterExecutorFailed("Duplicate executor ID: " + executorId))
+ executorRef.send(RegisterExecutorFailed("Duplicate executor ID: " + executorId))
+ context.reply(true)
+ } else if (scheduler.nodeBlacklist != null &&
+ scheduler.nodeBlacklist.contains(hostname)) {
+ // If the cluster manager gives us an executor on a blacklisted node (because it
+ // already started allocating those resources before we informed it of our blacklist,
+ // or if it ignored our blacklist), then we reject that executor immediately.
+ logInfo(s"Rejecting $executorId as it has been blacklisted.")
+ executorRef.send(RegisterExecutorFailed(s"Executor is blacklisted: $executorId"))
+ context.reply(true)
} else {
// If the executor's rpc env is not listening for incoming connections, `hostPort`
// will be null, and the client connection should be used to contact the executor.
@@ -147,19 +178,23 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
addressToExecutorId(executorAddress) = executorId
totalCoreCount.addAndGet(cores)
totalRegisteredExecutors.addAndGet(1)
- val data = new ExecutorData(executorRef, executorRef.address, executorAddress.host,
+ val data = new ExecutorData(executorRef, executorAddress, hostname,
cores, cores, logUrls)
// This must be synchronized because variables mutated
// in this block are read when requesting executors
CoarseGrainedSchedulerBackend.this.synchronized {
executorDataMap.put(executorId, data)
+ if (currentExecutorIdCounter < executorId.toInt) {
+ currentExecutorIdCounter = executorId.toInt
+ }
if (numPendingExecutors > 0) {
numPendingExecutors -= 1
logDebug(s"Decremented number of pending executors ($numPendingExecutors left)")
}
}
+ executorRef.send(RegisteredExecutor)
// Note: some tests expect the reply to come after we put the executor in the map
- context.reply(RegisteredExecutor(executorAddress.host))
+ context.reply(true)
listenerBus.post(
SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data))
makeOffers()
@@ -177,21 +212,33 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
context.reply(true)
case RemoveExecutor(executorId, reason) =>
+ // We will remove the executor's state and cannot restore it. However, the connection
+ // between the driver and the executor may be still alive so that the executor won't exit
+ // automatically, so try to tell the executor to stop itself. See SPARK-13519.
+ executorDataMap.get(executorId).foreach(_.executorEndpoint.send(StopExecutor))
removeExecutor(executorId, reason)
context.reply(true)
- case RetrieveSparkProps =>
- context.reply(sparkProperties)
+ case RetrieveSparkAppConfig =>
+ val reply = SparkAppConfig(sparkProperties,
+ SparkEnv.get.securityManager.getIOEncryptionKey())
+ context.reply(reply)
}
// Make fake resource offers on all executors
private def makeOffers() {
- // Filter out executors under killing
- val activeExecutors = executorDataMap.filterKeys(executorIsAlive)
- val workOffers = activeExecutors.map { case (id, executorData) =>
- new WorkerOffer(id, executorData.executorHost, executorData.freeCores)
- }.toSeq
- launchTasks(scheduler.resourceOffers(workOffers))
+ // Make sure no executor is killed while some task is launching on it
+ val taskDescs = CoarseGrainedSchedulerBackend.this.synchronized {
+ // Filter out executors under killing
+ val activeExecutors = executorDataMap.filterKeys(executorIsAlive)
+ val workOffers = activeExecutors.map { case (id, executorData) =>
+ new WorkerOffer(id, executorData.executorHost, executorData.freeCores)
+ }.toIndexedSeq
+ scheduler.resourceOffers(workOffers)
+ }
+ if (!taskDescs.isEmpty) {
+ launchTasks(taskDescs)
+ }
}
override def onDisconnected(remoteAddress: RpcAddress): Unit = {
@@ -204,12 +251,20 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
// Make fake resource offers on just one executor
private def makeOffers(executorId: String) {
- // Filter out executors under killing
- if (executorIsAlive(executorId)) {
- val executorData = executorDataMap(executorId)
- val workOffers = Seq(
- new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores))
- launchTasks(scheduler.resourceOffers(workOffers))
+ // Make sure no executor is killed while some task is launching on it
+ val taskDescs = CoarseGrainedSchedulerBackend.this.synchronized {
+ // Filter out executors under killing
+ if (executorIsAlive(executorId)) {
+ val executorData = executorDataMap(executorId)
+ val workOffers = IndexedSeq(
+ new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores))
+ scheduler.resourceOffers(workOffers)
+ } else {
+ Seq.empty
+ }
+ }
+ if (!taskDescs.isEmpty) {
+ launchTasks(taskDescs)
}
}
@@ -221,15 +276,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
// Launch tasks returned by a set of resource offers
private def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
for (task <- tasks.flatten) {
- val serializedTask = ser.serialize(task)
- if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
+ val serializedTask = TaskDescription.encode(task)
+ if (serializedTask.limit >= maxRpcMessageSize) {
scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr =>
try {
var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
- "spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +
- "spark.akka.frameSize or using broadcast variables for large values."
- msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize,
- AkkaUtils.reservedSizeBytes)
+ "spark.rpc.message.maxSize (%d bytes). Consider increasing " +
+ "spark.rpc.message.maxSize or using broadcast variables for large values."
+ msg = msg.format(task.taskId, task.index, serializedTask.limit, maxRpcMessageSize)
taskSetMgr.abort(msg)
} catch {
case e: Exception => logError("Exception in error callback", e)
@@ -239,29 +293,41 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
else {
val executorData = executorDataMap(task.executorId)
executorData.freeCores -= scheduler.CPUS_PER_TASK
+
+ logDebug(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " +
+ s"${executorData.executorHost}.")
+
executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask)))
}
}
}
// Remove a disconnected slave from the cluster
- def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = {
+ private def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = {
+ logDebug(s"Asked to remove executor $executorId with reason $reason")
executorDataMap.get(executorId) match {
case Some(executorInfo) =>
// This must be synchronized because variables mutated
// in this block are read when requesting executors
- CoarseGrainedSchedulerBackend.this.synchronized {
+ val killed = CoarseGrainedSchedulerBackend.this.synchronized {
addressToExecutorId -= executorInfo.executorAddress
executorDataMap -= executorId
- executorsPendingToRemove -= executorId
executorsPendingLossReason -= executorId
+ executorsPendingToRemove.remove(executorId).getOrElse(false)
}
totalCoreCount.addAndGet(-executorInfo.totalCores)
totalRegisteredExecutors.addAndGet(-1)
- scheduler.executorLost(executorId, reason)
+ scheduler.executorLost(executorId, if (killed) ExecutorKilled else reason)
listenerBus.post(
SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason.toString))
- case None => logInfo(s"Asked to remove non-existent executor $executorId")
+ case None =>
+ // SPARK-15262: If an executor is still alive even after the scheduler has removed
+ // its metadata, we may receive a heartbeat from that executor and tell its block
+ // manager to reregister itself. If that happens, the block manager master will know
+ // about the executor, but the scheduler will not. Therefore, we should remove the
+ // executor from the block manager when we hit this case.
+ scheduler.sc.env.blockManager.master.removeExecutorAsync(executorId)
+ logInfo(s"Asked to remove non-existent executor $executorId")
}
}
@@ -269,7 +335,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
* Stop making resource offers for the given executor. The executor is marked as lost with
* the loss reason still pending.
*
- * @return Whether executor was alive.
+ * @return Whether executor should be disabled
*/
protected def disableExecutor(executorId: String): Boolean = {
val shouldDisable = CoarseGrainedSchedulerBackend.this.synchronized {
@@ -277,7 +343,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
executorsPendingLossReason += executorId
true
} else {
- false
+ // Returns true for explicitly killed executors, we also need to get pending loss reasons;
+ // For others return false.
+ executorsPendingToRemove.contains(executorId)
}
}
@@ -295,7 +363,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
var driverEndpoint: RpcEndpointRef = null
- val taskIdsOnSlave = new HashMap[String, HashSet[String]]
+
+ protected def minRegisteredRatio: Double = _minRegisteredRatio
override def start() {
val properties = new ArrayBuffer[(String, String)]
@@ -306,7 +375,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
// TODO (prashant) send conf instead of properties
- driverEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, createDriverEndpoint(properties))
+ driverEndpoint = createDriverEndpointRef(properties)
+ }
+
+ protected def createDriverEndpointRef(
+ properties: ArrayBuffer[(String, String)]): RpcEndpointRef = {
+ rpcEnv.setupEndpoint(ENDPOINT_NAME, createDriverEndpoint(properties))
}
protected def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = {
@@ -317,7 +391,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
try {
if (driverEndpoint != null) {
logInfo("Shutting down all executors")
- driverEndpoint.askWithRetry[Boolean](StopExecutors)
+ driverEndpoint.askSync[Boolean](StopExecutors)
}
} catch {
case e: Exception =>
@@ -329,7 +403,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
stopExecutors()
try {
if (driverEndpoint != null) {
- driverEndpoint.askWithRetry[Boolean](StopDriver)
+ driverEndpoint.askSync[Boolean](StopDriver)
}
} catch {
case e: Exception =>
@@ -337,26 +411,47 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
}
+ /**
+ * Reset the state of CoarseGrainedSchedulerBackend to the initial state. Currently it will only
+ * be called in the yarn-client mode when AM re-registers after a failure.
+ * */
+ protected def reset(): Unit = {
+ val executors = synchronized {
+ requestedTotalExecutors = 0
+ numPendingExecutors = 0
+ executorsPendingToRemove.clear()
+ Set() ++ executorDataMap.keys
+ }
+
+ // Remove all the lingering executors that should be removed but not yet. The reason might be
+ // because (1) disconnected event is not yet received; (2) executors die silently.
+ executors.foreach { eid =>
+ removeExecutor(eid, SlaveLost("Stale executor after cluster manager re-registered."))
+ }
+ }
+
override def reviveOffers() {
driverEndpoint.send(ReviveOffers)
}
- override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) {
- driverEndpoint.send(KillTask(taskId, executorId, interruptThread))
+ override def killTask(
+ taskId: Long, executorId: String, interruptThread: Boolean, reason: String) {
+ driverEndpoint.send(KillTask(taskId, executorId, interruptThread, reason))
}
override def defaultParallelism(): Int = {
conf.getInt("spark.default.parallelism", math.max(totalCoreCount.get(), 2))
}
- // Called by subclasses when notified of a lost worker
- def removeExecutor(executorId: String, reason: ExecutorLossReason) {
- try {
- driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason))
- } catch {
- case e: Exception =>
- throw new SparkException("Error notifying standalone scheduler's driver endpoint", e)
- }
+ /**
+ * Called by subclasses when notified of a lost worker. It just fires the message and returns
+ * at once.
+ */
+ protected def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = {
+ // Only log the failure since we don't care about the result.
+ driverEndpoint.ask[Boolean](RemoveExecutor(executorId, reason)).onFailure { case t =>
+ logError(t.getMessage, t)
+ }(ThreadUtils.sameThread)
}
def sufficientResourcesRegistered(): Boolean = true
@@ -378,25 +473,43 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
/**
* Return the number of executors currently registered with this backend.
*/
- def numExistingExecutors: Int = executorDataMap.size
+ private def numExistingExecutors: Int = executorDataMap.size
+
+ override def getExecutorIds(): Seq[String] = {
+ executorDataMap.keySet.toSeq
+ }
/**
* Request an additional number of executors from the cluster manager.
* @return whether the request is acknowledged.
*/
- final override def requestExecutors(numAdditionalExecutors: Int): Boolean = synchronized {
+ final override def requestExecutors(numAdditionalExecutors: Int): Boolean = {
if (numAdditionalExecutors < 0) {
throw new IllegalArgumentException(
"Attempted to request a negative number of additional executor(s) " +
s"$numAdditionalExecutors from the cluster manager. Please specify a positive number!")
}
logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager")
- logDebug(s"Number of pending executors is now $numPendingExecutors")
- numPendingExecutors += numAdditionalExecutors
- // Account for executors pending to be added or removed
- val newTotal = numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size
- doRequestTotalExecutors(newTotal)
+ val response = synchronized {
+ requestedTotalExecutors += numAdditionalExecutors
+ numPendingExecutors += numAdditionalExecutors
+ logDebug(s"Number of pending executors is now $numPendingExecutors")
+ if (requestedTotalExecutors !=
+ (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) {
+ logDebug(
+ s"""requestExecutors($numAdditionalExecutors): Executor request doesn't match:
+ |requestedTotalExecutors = $requestedTotalExecutors
+ |numExistingExecutors = $numExistingExecutors
+ |numPendingExecutors = $numPendingExecutors
+ |executorsPendingToRemove = ${executorsPendingToRemove.size}""".stripMargin)
+ }
+
+ // Account for executors pending to be added or removed
+ doRequestTotalExecutors(requestedTotalExecutors)
+ }
+
+ defaultAskTimeout.awaitResult(response)
}
/**
@@ -417,19 +530,25 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
numExecutors: Int,
localityAwareTasks: Int,
hostToLocalTaskCount: Map[String, Int]
- ): Boolean = synchronized {
+ ): Boolean = {
if (numExecutors < 0) {
throw new IllegalArgumentException(
"Attempted to request a negative number of executor(s) " +
s"$numExecutors from the cluster manager. Please specify a positive number!")
}
- this.localityAwareTasks = localityAwareTasks
- this.hostToLocalTaskCount = hostToLocalTaskCount
+ val response = synchronized {
+ this.requestedTotalExecutors = numExecutors
+ this.localityAwareTasks = localityAwareTasks
+ this.hostToLocalTaskCount = hostToLocalTaskCount
- numPendingExecutors =
- math.max(numExecutors - numExistingExecutors + executorsPendingToRemove.size, 0)
- doRequestTotalExecutors(numExecutors)
+ numPendingExecutors =
+ math.max(numExecutors - numExistingExecutors + executorsPendingToRemove.size, 0)
+
+ doRequestTotalExecutors(numExecutors)
+ }
+
+ defaultAskTimeout.awaitResult(response)
}
/**
@@ -442,55 +561,104 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
* insufficient resources to satisfy the first request. We make the assumption here that the
* cluster manager will eventually fulfill all requests when resources free up.
*
- * @return whether the request is acknowledged.
+ * @return a future whose evaluation indicates whether the request is acknowledged.
*/
- protected def doRequestTotalExecutors(requestedTotal: Int): Boolean = false
-
- /**
- * Request that the cluster manager kill the specified executors.
- * @return whether the kill request is acknowledged.
- */
- final override def killExecutors(executorIds: Seq[String]): Boolean = synchronized {
- killExecutors(executorIds, replace = false)
- }
+ protected def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] =
+ Future.successful(false)
/**
* Request that the cluster manager kill the specified executors.
*
+ * When asking the executor to be replaced, the executor loss is considered a failure, and
+ * killed tasks that are running on the executor will count towards the failure limits. If no
+ * replacement is being requested, then the tasks will not count towards the limit.
+ *
* @param executorIds identifiers of executors to kill
- * @param replace whether to replace the killed executors with new ones
- * @return whether the kill request is acknowledged.
+ * @param replace whether to replace the killed executors with new ones, default false
+ * @param force whether to force kill busy executors, default false
+ * @return the ids of the executors acknowledged by the cluster manager to be removed.
*/
- final def killExecutors(executorIds: Seq[String], replace: Boolean): Boolean = synchronized {
+ final override def killExecutors(
+ executorIds: Seq[String],
+ replace: Boolean,
+ force: Boolean): Seq[String] = {
logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}")
- val (knownExecutors, unknownExecutors) = executorIds.partition(executorDataMap.contains)
- unknownExecutors.foreach { id =>
- logWarning(s"Executor to kill $id does not exist!")
- }
- // If an executor is already pending to be removed, do not kill it again (SPARK-9795)
- val executorsToKill = knownExecutors.filter { id => !executorsPendingToRemove.contains(id) }
- executorsPendingToRemove ++= executorsToKill
-
- // If we do not wish to replace the executors we kill, sync the target number of executors
- // with the cluster manager to avoid allocating new ones. When computing the new target,
- // take into account executors that are pending to be added or removed.
- if (!replace) {
- doRequestTotalExecutors(
- numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)
- } else {
- numPendingExecutors += knownExecutors.size
+ val response = synchronized {
+ val (knownExecutors, unknownExecutors) = executorIds.partition(executorDataMap.contains)
+ unknownExecutors.foreach { id =>
+ logWarning(s"Executor to kill $id does not exist!")
+ }
+
+ // If an executor is already pending to be removed, do not kill it again (SPARK-9795)
+ // If this executor is busy, do not kill it unless we are told to force kill it (SPARK-9552)
+ val executorsToKill = knownExecutors
+ .filter { id => !executorsPendingToRemove.contains(id) }
+ .filter { id => force || !scheduler.isExecutorBusy(id) }
+ executorsToKill.foreach { id => executorsPendingToRemove(id) = !replace }
+
+ logInfo(s"Actual list of executor(s) to be killed is ${executorsToKill.mkString(", ")}")
+
+ // If we do not wish to replace the executors we kill, sync the target number of executors
+ // with the cluster manager to avoid allocating new ones. When computing the new target,
+ // take into account executors that are pending to be added or removed.
+ val adjustTotalExecutors =
+ if (!replace) {
+ requestedTotalExecutors = math.max(requestedTotalExecutors - executorsToKill.size, 0)
+ if (requestedTotalExecutors !=
+ (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) {
+ logDebug(
+ s"""killExecutors($executorIds, $replace, $force): Executor counts do not match:
+ |requestedTotalExecutors = $requestedTotalExecutors
+ |numExistingExecutors = $numExistingExecutors
+ |numPendingExecutors = $numPendingExecutors
+ |executorsPendingToRemove = ${executorsPendingToRemove.size}""".stripMargin)
+ }
+ doRequestTotalExecutors(requestedTotalExecutors)
+ } else {
+ numPendingExecutors += knownExecutors.size
+ Future.successful(true)
+ }
+
+ val killExecutors: Boolean => Future[Boolean] =
+ if (!executorsToKill.isEmpty) {
+ _ => doKillExecutors(executorsToKill)
+ } else {
+ _ => Future.successful(false)
+ }
+
+ val killResponse = adjustTotalExecutors.flatMap(killExecutors)(ThreadUtils.sameThread)
+
+ killResponse.flatMap(killSuccessful =>
+ Future.successful (if (killSuccessful) executorsToKill else Seq.empty[String])
+ )(ThreadUtils.sameThread)
}
- doKillExecutors(executorsToKill)
+ defaultAskTimeout.awaitResult(response)
}
/**
* Kill the given list of executors through the cluster manager.
* @return whether the kill request is acknowledged.
*/
- protected def doKillExecutors(executorIds: Seq[String]): Boolean = false
+ protected def doKillExecutors(executorIds: Seq[String]): Future[Boolean] =
+ Future.successful(false)
+ /**
+ * Request that the cluster manager kill all executors on a given host.
+ * @return whether the kill request is acknowledged.
+ */
+ final override def killExecutorsOnHost(host: String): Boolean = {
+ logInfo(s"Requesting to kill any and all executors on host ${host}")
+ // A potential race exists if a new executor attempts to register on a host
+ // that is on the blacklist and is no no longer valid. To avoid this race,
+ // all executor registration and killing happens in the event loop. This way, either
+ // an executor will fail to register, or will be killed when all executors on a host
+ // are killed.
+ // Kill all the executors on this host in an event loop to ensure serialization.
+ driverEndpoint.send(KillExecutorsOnHost(host))
+ true
+ }
}
private[spark] object CoarseGrainedSchedulerBackend {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala
index 626a2b7d69ab..b25a4bfb501f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala
@@ -17,7 +17,7 @@
package org.apache.spark.scheduler.cluster
-import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress}
+import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef}
/**
* Grouping of data for an executor used by CoarseGrainedSchedulerBackend.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
deleted file mode 100644
index 641638a77d5f..000000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
+++ /dev/null
@@ -1,74 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.cluster
-
-import org.apache.hadoop.fs.{Path, FileSystem}
-
-import org.apache.spark.rpc.RpcAddress
-import org.apache.spark.{Logging, SparkContext, SparkEnv}
-import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.scheduler.TaskSchedulerImpl
-
-private[spark] class SimrSchedulerBackend(
- scheduler: TaskSchedulerImpl,
- sc: SparkContext,
- driverFilePath: String)
- extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv)
- with Logging {
-
- val tmpPath = new Path(driverFilePath + "_tmp")
- val filePath = new Path(driverFilePath)
-
- val maxCores = conf.getInt("spark.simr.executor.cores", 1)
-
- override def start() {
- super.start()
-
- val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName,
- RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt),
- CoarseGrainedSchedulerBackend.ENDPOINT_NAME)
-
- val conf = SparkHadoopUtil.get.newConfiguration(sc.conf)
- val fs = FileSystem.get(conf)
- val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("")
-
- logInfo("Writing to HDFS file: " + driverFilePath)
- logInfo("Writing Akka address: " + driverUrl)
- logInfo("Writing Spark UI Address: " + appUIAddress)
-
- // Create temporary file to prevent race condition where executors get empty driverUrl file
- val temp = fs.create(tmpPath, true)
- temp.writeUTF(driverUrl)
- temp.writeInt(maxCores)
- temp.writeUTF(appUIAddress)
- temp.close()
-
- // "Atomic" rename
- fs.rename(tmpPath, filePath)
- }
-
- override def stop() {
- val conf = SparkHadoopUtil.get.newConfiguration(sc.conf)
- val fs = FileSystem.get(conf)
- if (!fs.delete(new Path(driverFilePath), false)) {
- logWarning(s"error deleting ${driverFilePath}")
- }
- super.stop()
- }
-
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
deleted file mode 100644
index 05d9bc92f228..000000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ /dev/null
@@ -1,208 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.cluster
-
-import java.util.concurrent.Semaphore
-
-import org.apache.spark.rpc.RpcAddress
-import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv}
-import org.apache.spark.deploy.{ApplicationDescription, Command}
-import org.apache.spark.deploy.client.{AppClient, AppClientListener}
-import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle}
-import org.apache.spark.scheduler._
-import org.apache.spark.util.Utils
-
-private[spark] class SparkDeploySchedulerBackend(
- scheduler: TaskSchedulerImpl,
- sc: SparkContext,
- masters: Array[String])
- extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv)
- with AppClientListener
- with Logging {
-
- private var client: AppClient = null
- private var stopping = false
- private val launcherBackend = new LauncherBackend() {
- override protected def onStopRequest(): Unit = stop(SparkAppHandle.State.KILLED)
- }
-
- @volatile var shutdownCallback: SparkDeploySchedulerBackend => Unit = _
- @volatile private var appId: String = _
-
- private val registrationBarrier = new Semaphore(0)
-
- private val maxCores = conf.getOption("spark.cores.max").map(_.toInt)
- private val totalExpectedCores = maxCores.getOrElse(0)
-
- override def start() {
- super.start()
- launcherBackend.connect()
-
- // The endpoint for executors to talk to us
- val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName,
- RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt),
- CoarseGrainedSchedulerBackend.ENDPOINT_NAME)
- val args = Seq(
- "--driver-url", driverUrl,
- "--executor-id", "{{EXECUTOR_ID}}",
- "--hostname", "{{HOSTNAME}}",
- "--cores", "{{CORES}}",
- "--app-id", "{{APP_ID}}",
- "--worker-url", "{{WORKER_URL}}")
- val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions")
- .map(Utils.splitCommandString).getOrElse(Seq.empty)
- val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath")
- .map(_.split(java.io.File.pathSeparator).toSeq).getOrElse(Nil)
- val libraryPathEntries = sc.conf.getOption("spark.executor.extraLibraryPath")
- .map(_.split(java.io.File.pathSeparator).toSeq).getOrElse(Nil)
-
- // When testing, expose the parent class path to the child. This is processed by
- // compute-classpath.{cmd,sh} and makes all needed jars available to child processes
- // when the assembly is built with the "*-provided" profiles enabled.
- val testingClassPath =
- if (sys.props.contains("spark.testing")) {
- sys.props("java.class.path").split(java.io.File.pathSeparator).toSeq
- } else {
- Nil
- }
-
- // Start executors with a few necessary configs for registering with the scheduler
- val sparkJavaOpts = Utils.sparkJavaOpts(conf, SparkConf.isExecutorStartupConf)
- val javaOpts = sparkJavaOpts ++ extraJavaOpts
- val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend",
- args, sc.executorEnvs, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts)
- val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("")
- val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt)
- val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory,
- command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor)
- client = new AppClient(sc.env.rpcEnv, masters, appDesc, this, conf)
- client.start()
- launcherBackend.setState(SparkAppHandle.State.SUBMITTED)
- waitForRegistration()
- launcherBackend.setState(SparkAppHandle.State.RUNNING)
- }
-
- override def stop(): Unit = synchronized {
- stop(SparkAppHandle.State.FINISHED)
- }
-
- override def connected(appId: String) {
- logInfo("Connected to Spark cluster with app ID " + appId)
- this.appId = appId
- notifyContext()
- launcherBackend.setAppId(appId)
- }
-
- override def disconnected() {
- notifyContext()
- if (!stopping) {
- logWarning("Disconnected from Spark cluster! Waiting for reconnection...")
- }
- }
-
- override def dead(reason: String) {
- notifyContext()
- if (!stopping) {
- launcherBackend.setState(SparkAppHandle.State.KILLED)
- logError("Application has been killed. Reason: " + reason)
- try {
- scheduler.error(reason)
- } finally {
- // Ensure the application terminates, as we can no longer run jobs.
- sc.stop()
- }
- }
- }
-
- override def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int,
- memory: Int) {
- logInfo("Granted executor ID %s on hostPort %s with %d cores, %s RAM".format(
- fullId, hostPort, cores, Utils.megabytesToString(memory)))
- }
-
- override def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]) {
- val reason: ExecutorLossReason = exitStatus match {
- case Some(code) => ExecutorExited(code, exitCausedByApp = true, message)
- case None => SlaveLost(message)
- }
- logInfo("Executor %s removed: %s".format(fullId, message))
- removeExecutor(fullId.split("/")(1), reason)
- }
-
- override def sufficientResourcesRegistered(): Boolean = {
- totalCoreCount.get() >= totalExpectedCores * minRegisteredRatio
- }
-
- override def applicationId(): String =
- Option(appId).getOrElse {
- logWarning("Application ID is not initialized yet.")
- super.applicationId
- }
-
- /**
- * Request executors from the Master by specifying the total number desired,
- * including existing pending and running executors.
- *
- * @return whether the request is acknowledged.
- */
- protected override def doRequestTotalExecutors(requestedTotal: Int): Boolean = {
- Option(client) match {
- case Some(c) => c.requestTotalExecutors(requestedTotal)
- case None =>
- logWarning("Attempted to request executors before driver fully initialized.")
- false
- }
- }
-
- /**
- * Kill the given list of executors through the Master.
- * @return whether the kill request is acknowledged.
- */
- protected override def doKillExecutors(executorIds: Seq[String]): Boolean = {
- Option(client) match {
- case Some(c) => c.killExecutors(executorIds)
- case None =>
- logWarning("Attempted to kill executors before driver fully initialized.")
- false
- }
- }
-
- private def waitForRegistration() = {
- registrationBarrier.acquire()
- }
-
- private def notifyContext() = {
- registrationBarrier.release()
- }
-
- private def stop(finalState: SparkAppHandle.State): Unit = synchronized {
- stopping = true
-
- launcherBackend.setState(finalState)
- launcherBackend.close()
-
- super.stop()
- client.stop()
-
- val callback = shutdownCallback
- if (callback != null) {
- callback(this)
- }
- }
-
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
new file mode 100644
index 000000000000..22ca14fab85b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import java.util.concurrent.Semaphore
+import java.util.concurrent.atomic.AtomicBoolean
+
+import scala.concurrent.Future
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.deploy.{ApplicationDescription, Command}
+import org.apache.spark.deploy.client.{StandaloneAppClient, StandaloneAppClientListener}
+import org.apache.spark.internal.Logging
+import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle}
+import org.apache.spark.rpc.RpcEndpointAddress
+import org.apache.spark.scheduler._
+import org.apache.spark.util.Utils
+
+/**
+ * A [[SchedulerBackend]] implementation for Spark's standalone cluster manager.
+ */
+private[spark] class StandaloneSchedulerBackend(
+ scheduler: TaskSchedulerImpl,
+ sc: SparkContext,
+ masters: Array[String])
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv)
+ with StandaloneAppClientListener
+ with Logging {
+
+ private var client: StandaloneAppClient = null
+ private val stopping = new AtomicBoolean(false)
+ private val launcherBackend = new LauncherBackend() {
+ override protected def onStopRequest(): Unit = stop(SparkAppHandle.State.KILLED)
+ }
+
+ @volatile var shutdownCallback: StandaloneSchedulerBackend => Unit = _
+ @volatile private var appId: String = _
+
+ private val registrationBarrier = new Semaphore(0)
+
+ private val maxCores = conf.getOption("spark.cores.max").map(_.toInt)
+ private val totalExpectedCores = maxCores.getOrElse(0)
+
+ override def start() {
+ super.start()
+
+ // SPARK-21159. The scheduler backend should only try to connect to the launcher when in client
+ // mode. In cluster mode, the code that submits the application to the Master needs to connect
+ // to the launcher instead.
+ if (sc.deployMode == "client") {
+ launcherBackend.connect()
+ }
+
+ // The endpoint for executors to talk to us
+ val driverUrl = RpcEndpointAddress(
+ sc.conf.get("spark.driver.host"),
+ sc.conf.get("spark.driver.port").toInt,
+ CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString
+ val args = Seq(
+ "--driver-url", driverUrl,
+ "--executor-id", "{{EXECUTOR_ID}}",
+ "--hostname", "{{HOSTNAME}}",
+ "--cores", "{{CORES}}",
+ "--app-id", "{{APP_ID}}",
+ "--worker-url", "{{WORKER_URL}}")
+ val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions")
+ .map(Utils.splitCommandString).getOrElse(Seq.empty)
+ val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath")
+ .map(_.split(java.io.File.pathSeparator).toSeq).getOrElse(Nil)
+ val libraryPathEntries = sc.conf.getOption("spark.executor.extraLibraryPath")
+ .map(_.split(java.io.File.pathSeparator).toSeq).getOrElse(Nil)
+
+ // When testing, expose the parent class path to the child. This is processed by
+ // compute-classpath.{cmd,sh} and makes all needed jars available to child processes
+ // when the assembly is built with the "*-provided" profiles enabled.
+ val testingClassPath =
+ if (sys.props.contains("spark.testing")) {
+ sys.props("java.class.path").split(java.io.File.pathSeparator).toSeq
+ } else {
+ Nil
+ }
+
+ // Start executors with a few necessary configs for registering with the scheduler
+ val sparkJavaOpts = Utils.sparkJavaOpts(conf, SparkConf.isExecutorStartupConf)
+ val javaOpts = sparkJavaOpts ++ extraJavaOpts
+ val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend",
+ args, sc.executorEnvs, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts)
+ val webUrl = sc.ui.map(_.webUrl).getOrElse("")
+ val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt)
+ // If we're using dynamic allocation, set our initial executor limit to 0 for now.
+ // ExecutorAllocationManager will send the real initial limit to the Master later.
+ val initialExecutorLimit =
+ if (Utils.isDynamicAllocationEnabled(conf)) {
+ Some(0)
+ } else {
+ None
+ }
+ val appDesc = ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command,
+ webUrl, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor, initialExecutorLimit)
+ client = new StandaloneAppClient(sc.env.rpcEnv, masters, appDesc, this, conf)
+ client.start()
+ launcherBackend.setState(SparkAppHandle.State.SUBMITTED)
+ waitForRegistration()
+ launcherBackend.setState(SparkAppHandle.State.RUNNING)
+ }
+
+ override def stop(): Unit = {
+ stop(SparkAppHandle.State.FINISHED)
+ }
+
+ override def connected(appId: String) {
+ logInfo("Connected to Spark cluster with app ID " + appId)
+ this.appId = appId
+ notifyContext()
+ launcherBackend.setAppId(appId)
+ }
+
+ override def disconnected() {
+ notifyContext()
+ if (!stopping.get) {
+ logWarning("Disconnected from Spark cluster! Waiting for reconnection...")
+ }
+ }
+
+ override def dead(reason: String) {
+ notifyContext()
+ if (!stopping.get) {
+ launcherBackend.setState(SparkAppHandle.State.KILLED)
+ logError("Application has been killed. Reason: " + reason)
+ try {
+ scheduler.error(reason)
+ } finally {
+ // Ensure the application terminates, as we can no longer run jobs.
+ sc.stopInNewThread()
+ }
+ }
+ }
+
+ override def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int,
+ memory: Int) {
+ logInfo("Granted executor ID %s on hostPort %s with %d cores, %s RAM".format(
+ fullId, hostPort, cores, Utils.megabytesToString(memory)))
+ }
+
+ override def executorRemoved(
+ fullId: String, message: String, exitStatus: Option[Int], workerLost: Boolean) {
+ val reason: ExecutorLossReason = exitStatus match {
+ case Some(code) => ExecutorExited(code, exitCausedByApp = true, message)
+ case None => SlaveLost(message, workerLost = workerLost)
+ }
+ logInfo("Executor %s removed: %s".format(fullId, message))
+ removeExecutor(fullId.split("/")(1), reason)
+ }
+
+ override def sufficientResourcesRegistered(): Boolean = {
+ totalCoreCount.get() >= totalExpectedCores * minRegisteredRatio
+ }
+
+ override def applicationId(): String =
+ Option(appId).getOrElse {
+ logWarning("Application ID is not initialized yet.")
+ super.applicationId
+ }
+
+ /**
+ * Request executors from the Master by specifying the total number desired,
+ * including existing pending and running executors.
+ *
+ * @return whether the request is acknowledged.
+ */
+ protected override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = {
+ Option(client) match {
+ case Some(c) => c.requestTotalExecutors(requestedTotal)
+ case None =>
+ logWarning("Attempted to request executors before driver fully initialized.")
+ Future.successful(false)
+ }
+ }
+
+ /**
+ * Kill the given list of executors through the Master.
+ * @return whether the kill request is acknowledged.
+ */
+ protected override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = {
+ Option(client) match {
+ case Some(c) => c.killExecutors(executorIds)
+ case None =>
+ logWarning("Attempted to kill executors before driver fully initialized.")
+ Future.successful(false)
+ }
+ }
+
+ private def waitForRegistration() = {
+ registrationBarrier.acquire()
+ }
+
+ private def notifyContext() = {
+ registrationBarrier.release()
+ }
+
+ private def stop(finalState: SparkAppHandle.State): Unit = {
+ if (stopping.compareAndSet(false, true)) {
+ try {
+ super.stop()
+ client.stop()
+
+ val callback = shutdownCallback
+ if (callback != null) {
+ callback(this)
+ }
+ } finally {
+ launcherBackend.setState(finalState)
+ launcherBackend.close()
+ }
+ }
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
deleted file mode 100644
index 80da37b09b59..000000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ /dev/null
@@ -1,226 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.cluster
-
-import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.{Future, ExecutionContext}
-
-import org.apache.spark.{Logging, SparkContext}
-import org.apache.spark.rpc._
-import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
-import org.apache.spark.scheduler._
-import org.apache.spark.ui.JettyUtils
-import org.apache.spark.util.{ThreadUtils, RpcUtils}
-
-import scala.util.control.NonFatal
-
-/**
- * Abstract Yarn scheduler backend that contains common logic
- * between the client and cluster Yarn scheduler backends.
- */
-private[spark] abstract class YarnSchedulerBackend(
- scheduler: TaskSchedulerImpl,
- sc: SparkContext)
- extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) {
-
- if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) {
- minRegisteredRatio = 0.8
- }
-
- protected var totalExpectedExecutors = 0
-
- private val yarnSchedulerEndpoint = new YarnSchedulerEndpoint(rpcEnv)
-
- private val yarnSchedulerEndpointRef = rpcEnv.setupEndpoint(
- YarnSchedulerBackend.ENDPOINT_NAME, yarnSchedulerEndpoint)
-
- private implicit val askTimeout = RpcUtils.askRpcTimeout(sc.conf)
-
- /**
- * Request executors from the ApplicationMaster by specifying the total number desired.
- * This includes executors already pending or running.
- */
- override def doRequestTotalExecutors(requestedTotal: Int): Boolean = {
- yarnSchedulerEndpointRef.askWithRetry[Boolean](
- RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount))
- }
-
- /**
- * Request that the ApplicationMaster kill the specified executors.
- */
- override def doKillExecutors(executorIds: Seq[String]): Boolean = {
- yarnSchedulerEndpointRef.askWithRetry[Boolean](KillExecutors(executorIds))
- }
-
- override def sufficientResourcesRegistered(): Boolean = {
- totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio
- }
-
- /**
- * Add filters to the SparkUI.
- */
- private def addWebUIFilter(
- filterName: String,
- filterParams: Map[String, String],
- proxyBase: String): Unit = {
- if (proxyBase != null && proxyBase.nonEmpty) {
- System.setProperty("spark.ui.proxyBase", proxyBase)
- }
-
- val hasFilter =
- filterName != null && filterName.nonEmpty &&
- filterParams != null && filterParams.nonEmpty
- if (hasFilter) {
- logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase")
- conf.set("spark.ui.filters", filterName)
- filterParams.foreach { case (k, v) => conf.set(s"spark.$filterName.param.$k", v) }
- scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) }
- }
- }
-
- override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = {
- new YarnDriverEndpoint(rpcEnv, properties)
- }
-
- /**
- * Override the DriverEndpoint to add extra logic for the case when an executor is disconnected.
- * This endpoint communicates with the executors and queries the AM for an executor's exit
- * status when the executor is disconnected.
- */
- private class YarnDriverEndpoint(rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)])
- extends DriverEndpoint(rpcEnv, sparkProperties) {
-
- /**
- * When onDisconnected is received at the driver endpoint, the superclass DriverEndpoint
- * handles it by assuming the Executor was lost for a bad reason and removes the executor
- * immediately.
- *
- * In YARN's case however it is crucial to talk to the application master and ask why the
- * executor had exited. If the executor exited for some reason unrelated to the running tasks
- * (e.g., preemption), according to the application master, then we pass that information down
- * to the TaskSetManager to inform the TaskSetManager that tasks on that lost executor should
- * not count towards a job failure.
- */
- override def onDisconnected(rpcAddress: RpcAddress): Unit = {
- addressToExecutorId.get(rpcAddress).foreach { executorId =>
- if (disableExecutor(executorId)) {
- yarnSchedulerEndpoint.handleExecutorDisconnectedFromDriver(executorId, rpcAddress)
- }
- }
- }
- }
-
- /**
- * An [[RpcEndpoint]] that communicates with the ApplicationMaster.
- */
- private class YarnSchedulerEndpoint(override val rpcEnv: RpcEnv)
- extends ThreadSafeRpcEndpoint with Logging {
- private var amEndpoint: Option[RpcEndpointRef] = None
-
- private val askAmThreadPool =
- ThreadUtils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool")
- implicit val askAmExecutor = ExecutionContext.fromExecutor(askAmThreadPool)
-
- private[YarnSchedulerBackend] def handleExecutorDisconnectedFromDriver(
- executorId: String,
- executorRpcAddress: RpcAddress): Unit = {
- amEndpoint match {
- case Some(am) =>
- val lossReasonRequest = GetExecutorLossReason(executorId)
- val future = am.ask[ExecutorLossReason](lossReasonRequest, askTimeout)
- future onSuccess {
- case reason: ExecutorLossReason => {
- driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason))
- }
- }
- future onFailure {
- case NonFatal(e) => {
- logWarning(s"Attempted to get executor loss reason" +
- s" for executor id ${executorId} at RPC address ${executorRpcAddress}," +
- s" but got no response. Marking as slave lost.", e)
- driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, SlaveLost()))
- }
- case t => throw t
- }
- case None =>
- logWarning("Attempted to check for an executor loss reason" +
- " before the AM has registered!")
- }
- }
-
- override def receive: PartialFunction[Any, Unit] = {
- case RegisterClusterManager(am) =>
- logInfo(s"ApplicationMaster registered as $am")
- amEndpoint = Option(am)
-
- case AddWebUIFilter(filterName, filterParams, proxyBase) =>
- addWebUIFilter(filterName, filterParams, proxyBase)
-
- case RemoveExecutor(executorId, reason) =>
- logWarning(reason.toString)
- removeExecutor(executorId, reason)
- }
-
-
- override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
- case r: RequestExecutors =>
- amEndpoint match {
- case Some(am) =>
- Future {
- context.reply(am.askWithRetry[Boolean](r))
- } onFailure {
- case NonFatal(e) =>
- logError(s"Sending $r to AM was unsuccessful", e)
- context.sendFailure(e)
- }
- case None =>
- logWarning("Attempted to request executors before the AM has registered!")
- context.reply(false)
- }
-
- case k: KillExecutors =>
- amEndpoint match {
- case Some(am) =>
- Future {
- context.reply(am.askWithRetry[Boolean](k))
- } onFailure {
- case NonFatal(e) =>
- logError(s"Sending $k to AM was unsuccessful", e)
- context.sendFailure(e)
- }
- case None =>
- logWarning("Attempted to kill executors before the AM has registered!")
- context.reply(false)
- }
- }
-
- override def onDisconnected(remoteAddress: RpcAddress): Unit = {
- if (amEndpoint.exists(_.address == remoteAddress)) {
- logWarning(s"ApplicationMaster has disassociated: $remoteAddress")
- }
- }
-
- override def onStop(): Unit = {
- askAmThreadPool.shutdownNow()
- }
- }
-}
-
-private[spark] object YarnSchedulerBackend {
- val ENDPOINT_NAME = "YarnScheduler"
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
deleted file mode 100644
index d10a77f8e5c7..000000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ /dev/null
@@ -1,436 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.cluster.mesos
-
-import java.io.File
-import java.util.concurrent.locks.ReentrantLock
-import java.util.{Collections, List => JList}
-
-import scala.collection.JavaConverters._
-import scala.collection.mutable.{HashMap, HashSet}
-
-import com.google.common.collect.HashBiMap
-import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _}
-import org.apache.mesos.{Scheduler => MScheduler, SchedulerDriver}
-
-import org.apache.spark.{SecurityManager, SparkContext, SparkEnv, SparkException, TaskState}
-import org.apache.spark.network.netty.SparkTransportConf
-import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient
-import org.apache.spark.rpc.RpcAddress
-import org.apache.spark.scheduler.{SlaveLost, TaskSchedulerImpl}
-import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
-import org.apache.spark.util.Utils
-
-/**
- * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds
- * onto each Mesos node for the duration of the Spark job instead of relinquishing cores whenever
- * a task is done. It launches Spark tasks within the coarse-grained Mesos tasks using the
- * CoarseGrainedSchedulerBackend mechanism. This class is useful for lower and more predictable
- * latency.
- *
- * Unfortunately this has a bit of duplication from MesosSchedulerBackend, but it seems hard to
- * remove this.
- */
-private[spark] class CoarseMesosSchedulerBackend(
- scheduler: TaskSchedulerImpl,
- sc: SparkContext,
- master: String,
- securityManager: SecurityManager)
- extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv)
- with MScheduler
- with MesosSchedulerUtils {
-
- val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures
-
- // Maximum number of cores to acquire (TODO: we'll need more flexible controls here)
- val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt
-
- // If shuffle service is enabled, the Spark driver will register with the shuffle service.
- // This is for cleaning up shuffle files reliably.
- private val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false)
-
- // Cores we have acquired with each Mesos task ID
- val coresByTaskId = new HashMap[Int, Int]
- var totalCoresAcquired = 0
-
- val slaveIdsWithExecutors = new HashSet[String]
-
- // Maping from slave Id to hostname
- private val slaveIdToHost = new HashMap[String, String]
-
- val taskIdToSlaveId: HashBiMap[Int, String] = HashBiMap.create[Int, String]
- // How many times tasks on each slave failed
- val failuresBySlaveId: HashMap[String, Int] = new HashMap[String, Int]
-
- /**
- * The total number of executors we aim to have. Undefined when not using dynamic allocation
- * and before the ExecutorAllocatorManager calls [[doRequestTotalExecutors]].
- */
- private var executorLimitOption: Option[Int] = None
-
- /**
- * Return the current executor limit, which may be [[Int.MaxValue]]
- * before properly initialized.
- */
- private[mesos] def executorLimit: Int = executorLimitOption.getOrElse(Int.MaxValue)
-
- private val pendingRemovedSlaveIds = new HashSet[String]
-
- // private lock object protecting mutable state above. Using the intrinsic lock
- // may lead to deadlocks since the superclass might also try to lock
- private val stateLock = new ReentrantLock
-
- val extraCoresPerSlave = conf.getInt("spark.mesos.extra.cores", 0)
-
- // Offer constraints
- private val slaveOfferConstraints =
- parseConstraintString(sc.conf.get("spark.mesos.constraints", ""))
-
- // A client for talking to the external shuffle service, if it is a
- private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = {
- if (shuffleServiceEnabled) {
- Some(new MesosExternalShuffleClient(
- SparkTransportConf.fromSparkConf(conf),
- securityManager,
- securityManager.isAuthenticationEnabled(),
- securityManager.isSaslEncryptionEnabled()))
- } else {
- None
- }
- }
-
- var nextMesosTaskId = 0
-
- @volatile var appId: String = _
-
- def newMesosTaskId(): Int = {
- val id = nextMesosTaskId
- nextMesosTaskId += 1
- id
- }
-
- override def start() {
- super.start()
- val driver = createSchedulerDriver(
- master,
- CoarseMesosSchedulerBackend.this,
- sc.sparkUser,
- sc.appName,
- sc.conf,
- sc.ui.map(_.appUIAddress))
- startScheduler(driver)
- }
-
- def createCommand(offer: Offer, numCores: Int, taskId: Int): CommandInfo = {
- val executorSparkHome = conf.getOption("spark.mesos.executor.home")
- .orElse(sc.getSparkHome())
- .getOrElse {
- throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!")
- }
- val environment = Environment.newBuilder()
- val extraClassPath = conf.getOption("spark.executor.extraClassPath")
- extraClassPath.foreach { cp =>
- environment.addVariables(
- Environment.Variable.newBuilder().setName("SPARK_CLASSPATH").setValue(cp).build())
- }
- val extraJavaOpts = conf.get("spark.executor.extraJavaOptions", "")
-
- // Set the environment variable through a command prefix
- // to append to the existing value of the variable
- val prefixEnv = conf.getOption("spark.executor.extraLibraryPath").map { p =>
- Utils.libraryPathEnvPrefix(Seq(p))
- }.getOrElse("")
-
- environment.addVariables(
- Environment.Variable.newBuilder()
- .setName("SPARK_EXECUTOR_OPTS")
- .setValue(extraJavaOpts)
- .build())
-
- sc.executorEnvs.foreach { case (key, value) =>
- environment.addVariables(Environment.Variable.newBuilder()
- .setName(key)
- .setValue(value)
- .build())
- }
- val command = CommandInfo.newBuilder()
- .setEnvironment(environment)
-
- val uri = conf.getOption("spark.executor.uri")
- .orElse(Option(System.getenv("SPARK_EXECUTOR_URI")))
-
- if (uri.isEmpty) {
- val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath
- command.setValue(
- "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend"
- .format(prefixEnv, runScript) +
- s" --driver-url $driverURL" +
- s" --executor-id ${offer.getSlaveId.getValue}" +
- s" --hostname ${offer.getHostname}" +
- s" --cores $numCores" +
- s" --app-id $appId")
- } else {
- // Grab everything to the first '.'. We'll use that and '*' to
- // glob the directory "correctly".
- val basename = uri.get.split('/').last.split('.').head
- val executorId = sparkExecutorId(offer.getSlaveId.getValue, taskId.toString)
- command.setValue(
- s"cd $basename*; $prefixEnv " +
- "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" +
- s" --driver-url $driverURL" +
- s" --executor-id $executorId" +
- s" --hostname ${offer.getHostname}" +
- s" --cores $numCores" +
- s" --app-id $appId")
- command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get))
- }
-
- conf.getOption("spark.mesos.uris").map { uris =>
- setupUris(uris, command)
- }
-
- command.build()
- }
-
- protected def driverURL: String = {
- if (conf.contains("spark.testing")) {
- "driverURL"
- } else {
- sc.env.rpcEnv.uriOf(
- SparkEnv.driverActorSystemName,
- RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt),
- CoarseGrainedSchedulerBackend.ENDPOINT_NAME)
- }
- }
-
- override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
-
- override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) {
- appId = frameworkId.getValue
- mesosExternalShuffleClient.foreach(_.init(appId))
- logInfo("Registered as framework ID " + appId)
- markRegistered()
- }
-
- override def sufficientResourcesRegistered(): Boolean = {
- totalCoresAcquired >= maxCores * minRegisteredRatio
- }
-
- override def disconnected(d: SchedulerDriver) {}
-
- override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {}
-
- /**
- * Method called by Mesos to offer resources on slaves. We respond by launching an executor,
- * unless we've already launched more than we wanted to.
- */
- override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) {
- stateLock.synchronized {
- val filters = Filters.newBuilder().setRefuseSeconds(5).build()
- for (offer <- offers.asScala) {
- val offerAttributes = toAttributeMap(offer.getAttributesList)
- val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes)
- val slaveId = offer.getSlaveId.getValue
- val mem = getResource(offer.getResourcesList, "mem")
- val cpus = getResource(offer.getResourcesList, "cpus").toInt
- val id = offer.getId.getValue
- if (taskIdToSlaveId.size < executorLimit &&
- totalCoresAcquired < maxCores &&
- meetsConstraints &&
- mem >= calculateTotalMemory(sc) &&
- cpus >= 1 &&
- failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES &&
- !slaveIdsWithExecutors.contains(slaveId)) {
- // Launch an executor on the slave
- val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired)
- totalCoresAcquired += cpusToUse
- val taskId = newMesosTaskId()
- taskIdToSlaveId.put(taskId, slaveId)
- slaveIdsWithExecutors += slaveId
- coresByTaskId(taskId) = cpusToUse
- // Gather cpu resources from the available resources and use them in the task.
- val (remainingResources, cpuResourcesToUse) =
- partitionResources(offer.getResourcesList, "cpus", cpusToUse)
- val (_, memResourcesToUse) =
- partitionResources(remainingResources.asJava, "mem", calculateTotalMemory(sc))
- val taskBuilder = MesosTaskInfo.newBuilder()
- .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build())
- .setSlaveId(offer.getSlaveId)
- .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId))
- .setName("Task " + taskId)
- .addAllResources(cpuResourcesToUse.asJava)
- .addAllResources(memResourcesToUse.asJava)
-
- sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image =>
- MesosSchedulerBackendUtil
- .setupContainerBuilderDockerInfo(image, sc.conf, taskBuilder.getContainerBuilder())
- }
-
- // accept the offer and launch the task
- logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus")
- slaveIdToHost(offer.getSlaveId.getValue) = offer.getHostname
- d.launchTasks(
- Collections.singleton(offer.getId),
- Collections.singleton(taskBuilder.build()), filters)
- } else {
- // Decline the offer
- logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus")
- d.declineOffer(offer.getId)
- }
- }
- }
- }
-
-
- override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
- val taskId = status.getTaskId.getValue.toInt
- val state = status.getState
- logInfo(s"Mesos task $taskId is now $state")
- val slaveId: String = status.getSlaveId.getValue
- stateLock.synchronized {
- // If the shuffle service is enabled, have the driver register with each one of the
- // shuffle services. This allows the shuffle services to clean up state associated with
- // this application when the driver exits. There is currently not a great way to detect
- // this through Mesos, since the shuffle services are set up independently.
- if (TaskState.fromMesos(state).equals(TaskState.RUNNING) &&
- slaveIdToHost.contains(slaveId) &&
- shuffleServiceEnabled) {
- assume(mesosExternalShuffleClient.isDefined,
- "External shuffle client was not instantiated even though shuffle service is enabled.")
- // TODO: Remove this and allow the MesosExternalShuffleService to detect
- // framework termination when new Mesos Framework HTTP API is available.
- val externalShufflePort = conf.getInt("spark.shuffle.service.port", 7337)
- val hostname = slaveIdToHost.remove(slaveId).get
- logDebug(s"Connecting to shuffle service on slave $slaveId, " +
- s"host $hostname, port $externalShufflePort for app ${conf.getAppId}")
- mesosExternalShuffleClient.get
- .registerDriverWithShuffleService(hostname, externalShufflePort)
- }
-
- if (TaskState.isFinished(TaskState.fromMesos(state))) {
- val slaveId = taskIdToSlaveId.get(taskId)
- slaveIdsWithExecutors -= slaveId
- taskIdToSlaveId.remove(taskId)
- // Remove the cores we have remembered for this task, if it's in the hashmap
- for (cores <- coresByTaskId.get(taskId)) {
- totalCoresAcquired -= cores
- coresByTaskId -= taskId
- }
- // If it was a failure, mark the slave as failed for blacklisting purposes
- if (TaskState.isFailed(TaskState.fromMesos(state))) {
- failuresBySlaveId(slaveId) = failuresBySlaveId.getOrElse(slaveId, 0) + 1
- if (failuresBySlaveId(slaveId) >= MAX_SLAVE_FAILURES) {
- logInfo(s"Blacklisting Mesos slave $slaveId due to too many failures; " +
- "is Spark installed on it?")
- }
- }
- executorTerminated(d, slaveId, s"Executor finished with state $state")
- // In case we'd rejected everything before but have now lost a node
- d.reviveOffers()
- }
- }
- }
-
- override def error(d: SchedulerDriver, message: String) {
- logError(s"Mesos error: $message")
- scheduler.error(message)
- }
-
- override def stop() {
- super.stop()
- if (mesosDriver != null) {
- mesosDriver.stop()
- }
- }
-
- override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {}
-
- /**
- * Called when a slave is lost or a Mesos task finished. Update local view on
- * what tasks are running and remove the terminated slave from the list of pending
- * slave IDs that we might have asked to be killed. It also notifies the driver
- * that an executor was removed.
- */
- private def executorTerminated(d: SchedulerDriver, slaveId: String, reason: String): Unit = {
- stateLock.synchronized {
- if (slaveIdsWithExecutors.contains(slaveId)) {
- val slaveIdToTaskId = taskIdToSlaveId.inverse()
- if (slaveIdToTaskId.containsKey(slaveId)) {
- val taskId: Int = slaveIdToTaskId.get(slaveId)
- taskIdToSlaveId.remove(taskId)
- removeExecutor(sparkExecutorId(slaveId, taskId.toString), SlaveLost(reason))
- }
- // TODO: This assumes one Spark executor per Mesos slave,
- // which may no longer be true after SPARK-5095
- pendingRemovedSlaveIds -= slaveId
- slaveIdsWithExecutors -= slaveId
- }
- }
- }
-
- private def sparkExecutorId(slaveId: String, taskId: String): String = {
- s"$slaveId/$taskId"
- }
-
- override def slaveLost(d: SchedulerDriver, slaveId: SlaveID): Unit = {
- logInfo(s"Mesos slave lost: ${slaveId.getValue}")
- executorTerminated(d, slaveId.getValue, "Mesos slave lost: " + slaveId.getValue)
- }
-
- override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int): Unit = {
- logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue))
- slaveLost(d, s)
- }
-
- override def applicationId(): String =
- Option(appId).getOrElse {
- logWarning("Application ID is not initialized yet.")
- super.applicationId
- }
-
- override def doRequestTotalExecutors(requestedTotal: Int): Boolean = {
- // We don't truly know if we can fulfill the full amount of executors
- // since at coarse grain it depends on the amount of slaves available.
- logInfo("Capping the total amount of executors to " + requestedTotal)
- executorLimitOption = Some(requestedTotal)
- true
- }
-
- override def doKillExecutors(executorIds: Seq[String]): Boolean = {
- if (mesosDriver == null) {
- logWarning("Asked to kill executors before the Mesos driver was started.")
- return false
- }
-
- val slaveIdToTaskId = taskIdToSlaveId.inverse()
- for (executorId <- executorIds) {
- val slaveId = executorId.split("/")(0)
- if (slaveIdToTaskId.containsKey(slaveId)) {
- mesosDriver.killTask(
- TaskID.newBuilder().setValue(slaveIdToTaskId.get(slaveId).toString).build())
- pendingRemovedSlaveIds += slaveId
- } else {
- logWarning("Unable to find executor Id '" + executorId + "' in Mesos scheduler")
- }
- }
- // no need to adjust `executorLimitOption` since the AllocationManager already communicated
- // the desired limit through a call to `doRequestTotalExecutors`.
- // See [[o.a.s.scheduler.cluster.CoarseGrainedSchedulerBackend.killExecutors]]
- true
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
deleted file mode 100644
index a6d9374eb9e8..000000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
+++ /dev/null
@@ -1,676 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.cluster.mesos
-
-import java.io.File
-import java.util.concurrent.locks.ReentrantLock
-import java.util.{Collections, Date, List => JList}
-
-import scala.collection.JavaConverters._
-import scala.collection.mutable
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.mesos.Protos.Environment.Variable
-import org.apache.mesos.Protos.TaskStatus.Reason
-import org.apache.mesos.Protos.{TaskState => MesosTaskState, _}
-import org.apache.mesos.{Scheduler, SchedulerDriver}
-import org.apache.spark.deploy.mesos.MesosDriverDescription
-import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse}
-import org.apache.spark.metrics.MetricsSystem
-import org.apache.spark.util.Utils
-import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState}
-
-
-/**
- * Tracks the current state of a Mesos Task that runs a Spark driver.
- * @param driverDescription Submitted driver description from
- * [[org.apache.spark.deploy.rest.mesos.MesosRestServer]]
- * @param taskId Mesos TaskID generated for the task
- * @param slaveId Slave ID that the task is assigned to
- * @param mesosTaskStatus The last known task status update.
- * @param startDate The date the task was launched
- */
-private[spark] class MesosClusterSubmissionState(
- val driverDescription: MesosDriverDescription,
- val taskId: TaskID,
- val slaveId: SlaveID,
- var mesosTaskStatus: Option[TaskStatus],
- var startDate: Date,
- var finishDate: Option[Date])
- extends Serializable {
-
- def copy(): MesosClusterSubmissionState = {
- new MesosClusterSubmissionState(
- driverDescription, taskId, slaveId, mesosTaskStatus, startDate, finishDate)
- }
-}
-
-/**
- * Tracks the retry state of a driver, which includes the next time it should be scheduled
- * and necessary information to do exponential backoff.
- * This class is not thread-safe, and we expect the caller to handle synchronizing state.
- * @param lastFailureStatus Last Task status when it failed.
- * @param retries Number of times it has been retried.
- * @param nextRetry Time at which it should be retried next
- * @param waitTime The amount of time driver is scheduled to wait until next retry.
- */
-private[spark] class MesosClusterRetryState(
- val lastFailureStatus: TaskStatus,
- val retries: Int,
- val nextRetry: Date,
- val waitTime: Int) extends Serializable {
- def copy(): MesosClusterRetryState =
- new MesosClusterRetryState(lastFailureStatus, retries, nextRetry, waitTime)
-}
-
-/**
- * The full state of the cluster scheduler, currently being used for displaying
- * information on the UI.
- * @param frameworkId Mesos Framework id for the cluster scheduler.
- * @param masterUrl The Mesos master url
- * @param queuedDrivers All drivers queued to be launched
- * @param launchedDrivers All launched or running drivers
- * @param finishedDrivers All terminated drivers
- * @param pendingRetryDrivers All drivers pending to be retried
- */
-private[spark] class MesosClusterSchedulerState(
- val frameworkId: String,
- val masterUrl: Option[String],
- val queuedDrivers: Iterable[MesosDriverDescription],
- val launchedDrivers: Iterable[MesosClusterSubmissionState],
- val finishedDrivers: Iterable[MesosClusterSubmissionState],
- val pendingRetryDrivers: Iterable[MesosDriverDescription])
-
-/**
- * The full state of a Mesos driver, that is being used to display driver information on the UI.
- */
-private[spark] class MesosDriverState(
- val state: String,
- val description: MesosDriverDescription,
- val submissionState: Option[MesosClusterSubmissionState] = None)
-
-/**
- * A Mesos scheduler that is responsible for launching submitted Spark drivers in cluster mode
- * as Mesos tasks in a Mesos cluster.
- * All drivers are launched asynchronously by the framework, which will eventually be launched
- * by one of the slaves in the cluster. The results of the driver will be stored in slave's task
- * sandbox which is accessible by visiting the Mesos UI.
- * This scheduler supports recovery by persisting all its state and performs task reconciliation
- * on recover, which gets all the latest state for all the drivers from Mesos master.
- */
-private[spark] class MesosClusterScheduler(
- engineFactory: MesosClusterPersistenceEngineFactory,
- conf: SparkConf)
- extends Scheduler with MesosSchedulerUtils {
- var frameworkUrl: String = _
- private val metricsSystem =
- MetricsSystem.createMetricsSystem("mesos_cluster", conf, new SecurityManager(conf))
- private val master = conf.get("spark.master")
- private val appName = conf.get("spark.app.name")
- private val queuedCapacity = conf.getInt("spark.mesos.maxDrivers", 200)
- private val retainedDrivers = conf.getInt("spark.mesos.retainedDrivers", 200)
- private val maxRetryWaitTime = conf.getInt("spark.mesos.cluster.retry.wait.max", 60) // 1 minute
- private val schedulerState = engineFactory.createEngine("scheduler")
- private val stateLock = new ReentrantLock()
- private val finishedDrivers =
- new mutable.ArrayBuffer[MesosClusterSubmissionState](retainedDrivers)
- private var frameworkId: String = null
- // Holds all the launched drivers and current launch state, keyed by driver id.
- private val launchedDrivers = new mutable.HashMap[String, MesosClusterSubmissionState]()
- // Holds a map of driver id to expected slave id that is passed to Mesos for reconciliation.
- // All drivers that are loaded after failover are added here, as we need get the latest
- // state of the tasks from Mesos.
- private val pendingRecover = new mutable.HashMap[String, SlaveID]()
- // Stores all the submitted drivers that hasn't been launched.
- private val queuedDrivers = new ArrayBuffer[MesosDriverDescription]()
- // All supervised drivers that are waiting to retry after termination.
- private val pendingRetryDrivers = new ArrayBuffer[MesosDriverDescription]()
- private val queuedDriversState = engineFactory.createEngine("driverQueue")
- private val launchedDriversState = engineFactory.createEngine("launchedDrivers")
- private val pendingRetryDriversState = engineFactory.createEngine("retryList")
- // Flag to mark if the scheduler is ready to be called, which is until the scheduler
- // is registered with Mesos master.
- @volatile protected var ready = false
- private var masterInfo: Option[MasterInfo] = None
-
- def submitDriver(desc: MesosDriverDescription): CreateSubmissionResponse = {
- val c = new CreateSubmissionResponse
- if (!ready) {
- c.success = false
- c.message = "Scheduler is not ready to take requests"
- return c
- }
-
- stateLock.synchronized {
- if (isQueueFull()) {
- c.success = false
- c.message = "Already reached maximum submission size"
- return c
- }
- c.submissionId = desc.submissionId
- queuedDriversState.persist(desc.submissionId, desc)
- queuedDrivers += desc
- c.success = true
- }
- c
- }
-
- def killDriver(submissionId: String): KillSubmissionResponse = {
- val k = new KillSubmissionResponse
- if (!ready) {
- k.success = false
- k.message = "Scheduler is not ready to take requests"
- return k
- }
- k.submissionId = submissionId
- stateLock.synchronized {
- // We look for the requested driver in the following places:
- // 1. Check if submission is running or launched.
- // 2. Check if it's still queued.
- // 3. Check if it's in the retry list.
- // 4. Check if it has already completed.
- if (launchedDrivers.contains(submissionId)) {
- val task = launchedDrivers(submissionId)
- mesosDriver.killTask(task.taskId)
- k.success = true
- k.message = "Killing running driver"
- } else if (removeFromQueuedDrivers(submissionId)) {
- k.success = true
- k.message = "Removed driver while it's still pending"
- } else if (removeFromPendingRetryDrivers(submissionId)) {
- k.success = true
- k.message = "Removed driver while it's being retried"
- } else if (finishedDrivers.exists(_.driverDescription.submissionId.equals(submissionId))) {
- k.success = false
- k.message = "Driver already terminated"
- } else {
- k.success = false
- k.message = "Cannot find driver"
- }
- }
- k
- }
-
- def getDriverStatus(submissionId: String): SubmissionStatusResponse = {
- val s = new SubmissionStatusResponse
- if (!ready) {
- s.success = false
- s.message = "Scheduler is not ready to take requests"
- return s
- }
- s.submissionId = submissionId
- stateLock.synchronized {
- if (queuedDrivers.exists(_.submissionId.equals(submissionId))) {
- s.success = true
- s.driverState = "QUEUED"
- } else if (launchedDrivers.contains(submissionId)) {
- s.success = true
- s.driverState = "RUNNING"
- launchedDrivers(submissionId).mesosTaskStatus.foreach(state => s.message = state.toString)
- } else if (finishedDrivers.exists(_.driverDescription.submissionId.equals(submissionId))) {
- s.success = true
- s.driverState = "FINISHED"
- finishedDrivers
- .find(d => d.driverDescription.submissionId.equals(submissionId)).get.mesosTaskStatus
- .foreach(state => s.message = state.toString)
- } else if (pendingRetryDrivers.exists(_.submissionId.equals(submissionId))) {
- val status = pendingRetryDrivers.find(_.submissionId.equals(submissionId))
- .get.retryState.get.lastFailureStatus
- s.success = true
- s.driverState = "RETRYING"
- s.message = status.toString
- } else {
- s.success = false
- s.driverState = "NOT_FOUND"
- }
- }
- s
- }
-
- /**
- * Gets the driver state to be displayed on the Web UI.
- */
- def getDriverState(submissionId: String): Option[MesosDriverState] = {
- stateLock.synchronized {
- queuedDrivers.find(_.submissionId.equals(submissionId))
- .map(d => new MesosDriverState("QUEUED", d))
- .orElse(launchedDrivers.get(submissionId)
- .map(d => new MesosDriverState("RUNNING", d.driverDescription, Some(d))))
- .orElse(finishedDrivers.find(_.driverDescription.submissionId.equals(submissionId))
- .map(d => new MesosDriverState("FINISHED", d.driverDescription, Some(d))))
- .orElse(pendingRetryDrivers.find(_.submissionId.equals(submissionId))
- .map(d => new MesosDriverState("RETRYING", d)))
- }
- }
-
- private def isQueueFull(): Boolean = launchedDrivers.size >= queuedCapacity
-
- /**
- * Recover scheduler state that is persisted.
- * We still need to do task reconciliation to be up to date of the latest task states
- * as it might have changed while the scheduler is failing over.
- */
- private def recoverState(): Unit = {
- stateLock.synchronized {
- launchedDriversState.fetchAll[MesosClusterSubmissionState]().foreach { state =>
- launchedDrivers(state.taskId.getValue) = state
- pendingRecover(state.taskId.getValue) = state.slaveId
- }
- queuedDriversState.fetchAll[MesosDriverDescription]().foreach(d => queuedDrivers += d)
- // There is potential timing issue where a queued driver might have been launched
- // but the scheduler shuts down before the queued driver was able to be removed
- // from the queue. We try to mitigate this issue by walking through all queued drivers
- // and remove if they're already launched.
- queuedDrivers
- .filter(d => launchedDrivers.contains(d.submissionId))
- .foreach(d => removeFromQueuedDrivers(d.submissionId))
- pendingRetryDriversState.fetchAll[MesosDriverDescription]()
- .foreach(s => pendingRetryDrivers += s)
- // TODO: Consider storing finished drivers so we can show them on the UI after
- // failover. For now we clear the history on each recovery.
- finishedDrivers.clear()
- }
- }
-
- /**
- * Starts the cluster scheduler and wait until the scheduler is registered.
- * This also marks the scheduler to be ready for requests.
- */
- def start(): Unit = {
- // TODO: Implement leader election to make sure only one framework running in the cluster.
- val fwId = schedulerState.fetch[String]("frameworkId")
- fwId.foreach { id =>
- frameworkId = id
- }
- recoverState()
- metricsSystem.registerSource(new MesosClusterSchedulerSource(this))
- metricsSystem.start()
- val driver = createSchedulerDriver(
- master,
- MesosClusterScheduler.this,
- Utils.getCurrentUserName(),
- appName,
- conf,
- Some(frameworkUrl),
- Some(true),
- Some(Integer.MAX_VALUE),
- fwId)
-
- startScheduler(driver)
- ready = true
- }
-
- def stop(): Unit = {
- ready = false
- metricsSystem.report()
- metricsSystem.stop()
- mesosDriver.stop(true)
- }
-
- override def registered(
- driver: SchedulerDriver,
- newFrameworkId: FrameworkID,
- masterInfo: MasterInfo): Unit = {
- logInfo("Registered as framework ID " + newFrameworkId.getValue)
- if (newFrameworkId.getValue != frameworkId) {
- frameworkId = newFrameworkId.getValue
- schedulerState.persist("frameworkId", frameworkId)
- }
- markRegistered()
-
- stateLock.synchronized {
- this.masterInfo = Some(masterInfo)
- if (!pendingRecover.isEmpty) {
- // Start task reconciliation if we need to recover.
- val statuses = pendingRecover.collect {
- case (taskId, slaveId) =>
- val newStatus = TaskStatus.newBuilder()
- .setTaskId(TaskID.newBuilder().setValue(taskId).build())
- .setSlaveId(slaveId)
- .setState(MesosTaskState.TASK_STAGING)
- .build()
- launchedDrivers.get(taskId).map(_.mesosTaskStatus.getOrElse(newStatus))
- .getOrElse(newStatus)
- }
- // TODO: Page the status updates to avoid trying to reconcile
- // a large amount of tasks at once.
- driver.reconcileTasks(statuses.toSeq.asJava)
- }
- }
- }
-
- private def buildDriverCommand(desc: MesosDriverDescription): CommandInfo = {
- val appJar = CommandInfo.URI.newBuilder()
- .setValue(desc.jarUrl.stripPrefix("file:").stripPrefix("local:")).build()
- val builder = CommandInfo.newBuilder().addUris(appJar)
- val entries =
- (conf.getOption("spark.executor.extraLibraryPath").toList ++
- desc.command.libraryPathEntries)
- val prefixEnv = if (!entries.isEmpty) {
- Utils.libraryPathEnvPrefix(entries)
- } else {
- ""
- }
- val envBuilder = Environment.newBuilder()
- desc.command.environment.foreach { case (k, v) =>
- envBuilder.addVariables(Variable.newBuilder().setName(k).setValue(v).build())
- }
- // Pass all spark properties to executor.
- val executorOpts = desc.schedulerProperties.map { case (k, v) => s"-D$k=$v" }.mkString(" ")
- envBuilder.addVariables(
- Variable.newBuilder().setName("SPARK_EXECUTOR_OPTS").setValue(executorOpts))
- val dockerDefined = desc.schedulerProperties.contains("spark.mesos.executor.docker.image")
- val executorUri = desc.schedulerProperties.get("spark.executor.uri")
- .orElse(desc.command.environment.get("SPARK_EXECUTOR_URI"))
- // Gets the path to run spark-submit, and the path to the Mesos sandbox.
- val (executable, sandboxPath) = if (dockerDefined) {
- // Application jar is automatically downloaded in the mounted sandbox by Mesos,
- // and the path to the mounted volume is stored in $MESOS_SANDBOX env variable.
- ("./bin/spark-submit", "$MESOS_SANDBOX")
- } else if (executorUri.isDefined) {
- builder.addUris(CommandInfo.URI.newBuilder().setValue(executorUri.get).build())
- val folderBasename = executorUri.get.split('/').last.split('.').head
- val cmdExecutable = s"cd $folderBasename*; $prefixEnv bin/spark-submit"
- // Sandbox path points to the parent folder as we chdir into the folderBasename.
- (cmdExecutable, "..")
- } else {
- val executorSparkHome = desc.schedulerProperties.get("spark.mesos.executor.home")
- .orElse(conf.getOption("spark.home"))
- .orElse(Option(System.getenv("SPARK_HOME")))
- .getOrElse {
- throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!")
- }
- val cmdExecutable = new File(executorSparkHome, "./bin/spark-submit").getCanonicalPath
- // Sandbox points to the current directory by default with Mesos.
- (cmdExecutable, ".")
- }
- val primaryResource = new File(sandboxPath, desc.jarUrl.split("/").last).toString()
- val cmdOptions = generateCmdOption(desc, sandboxPath).mkString(" ")
- val appArguments = desc.command.arguments.mkString(" ")
- builder.setValue(s"$executable $cmdOptions $primaryResource $appArguments")
- builder.setEnvironment(envBuilder.build())
- conf.getOption("spark.mesos.uris").map { uris =>
- setupUris(uris, builder)
- }
- desc.schedulerProperties.get("spark.mesos.uris").map { uris =>
- setupUris(uris, builder)
- }
- desc.schedulerProperties.get("spark.submit.pyFiles").map { pyFiles =>
- setupUris(pyFiles, builder)
- }
- builder.build()
- }
-
- private def generateCmdOption(desc: MesosDriverDescription, sandboxPath: String): Seq[String] = {
- var options = Seq(
- "--name", desc.schedulerProperties("spark.app.name"),
- "--master", s"mesos://${conf.get("spark.master")}",
- "--driver-cores", desc.cores.toString,
- "--driver-memory", s"${desc.mem}M")
-
- // Assume empty main class means we're running python
- if (!desc.command.mainClass.equals("")) {
- options ++= Seq("--class", desc.command.mainClass)
- }
-
- desc.schedulerProperties.get("spark.executor.memory").map { v =>
- options ++= Seq("--executor-memory", v)
- }
- desc.schedulerProperties.get("spark.cores.max").map { v =>
- options ++= Seq("--total-executor-cores", v)
- }
- desc.schedulerProperties.get("spark.submit.pyFiles").map { pyFiles =>
- val formattedFiles = pyFiles.split(",")
- .map { path => new File(sandboxPath, path.split("/").last).toString() }
- .mkString(",")
- options ++= Seq("--py-files", formattedFiles)
- }
- options
- }
-
- private class ResourceOffer(val offer: Offer, var cpu: Double, var mem: Double) {
- override def toString(): String = {
- s"Offer id: ${offer.getId.getValue}, cpu: $cpu, mem: $mem"
- }
- }
-
- /**
- * This method takes all the possible candidates and attempt to schedule them with Mesos offers.
- * Every time a new task is scheduled, the afterLaunchCallback is called to perform post scheduled
- * logic on each task.
- */
- private def scheduleTasks(
- candidates: Seq[MesosDriverDescription],
- afterLaunchCallback: (String) => Boolean,
- currentOffers: List[ResourceOffer],
- tasks: mutable.HashMap[OfferID, ArrayBuffer[TaskInfo]]): Unit = {
- for (submission <- candidates) {
- val driverCpu = submission.cores
- val driverMem = submission.mem
- logTrace(s"Finding offer to launch driver with cpu: $driverCpu, mem: $driverMem")
- val offerOption = currentOffers.find { o =>
- o.cpu >= driverCpu && o.mem >= driverMem
- }
- if (offerOption.isEmpty) {
- logDebug(s"Unable to find offer to launch driver id: ${submission.submissionId}, " +
- s"cpu: $driverCpu, mem: $driverMem")
- } else {
- val offer = offerOption.get
- offer.cpu -= driverCpu
- offer.mem -= driverMem
- val taskId = TaskID.newBuilder().setValue(submission.submissionId).build()
- val cpuResource = createResource("cpus", driverCpu)
- val memResource = createResource("mem", driverMem)
- val commandInfo = buildDriverCommand(submission)
- val appName = submission.schedulerProperties("spark.app.name")
- val taskInfo = TaskInfo.newBuilder()
- .setTaskId(taskId)
- .setName(s"Driver for $appName")
- .setSlaveId(offer.offer.getSlaveId)
- .setCommand(commandInfo)
- .addResources(cpuResource)
- .addResources(memResource)
- submission.schedulerProperties.get("spark.mesos.executor.docker.image").foreach { image =>
- val container = taskInfo.getContainerBuilder()
- val volumes = submission.schedulerProperties
- .get("spark.mesos.executor.docker.volumes")
- .map(MesosSchedulerBackendUtil.parseVolumesSpec)
- val portmaps = submission.schedulerProperties
- .get("spark.mesos.executor.docker.portmaps")
- .map(MesosSchedulerBackendUtil.parsePortMappingsSpec)
- MesosSchedulerBackendUtil.addDockerInfo(
- container, image, volumes = volumes, portmaps = portmaps)
- taskInfo.setContainer(container.build())
- }
- val queuedTasks = tasks.getOrElseUpdate(offer.offer.getId, new ArrayBuffer[TaskInfo])
- queuedTasks += taskInfo.build()
- logTrace(s"Using offer ${offer.offer.getId.getValue} to launch driver " +
- submission.submissionId)
- val newState = new MesosClusterSubmissionState(submission, taskId, offer.offer.getSlaveId,
- None, new Date(), None)
- launchedDrivers(submission.submissionId) = newState
- launchedDriversState.persist(submission.submissionId, newState)
- afterLaunchCallback(submission.submissionId)
- }
- }
- }
-
- override def resourceOffers(driver: SchedulerDriver, offers: JList[Offer]): Unit = {
- val currentOffers = offers.asScala.map(o =>
- new ResourceOffer(
- o, getResource(o.getResourcesList, "cpus"), getResource(o.getResourcesList, "mem"))
- ).toList
- logTrace(s"Received offers from Mesos: \n${currentOffers.mkString("\n")}")
- val tasks = new mutable.HashMap[OfferID, ArrayBuffer[TaskInfo]]()
- val currentTime = new Date()
-
- stateLock.synchronized {
- // We first schedule all the supervised drivers that are ready to retry.
- // This list will be empty if none of the drivers are marked as supervise.
- val driversToRetry = pendingRetryDrivers.filter { d =>
- d.retryState.get.nextRetry.before(currentTime)
- }
-
- scheduleTasks(
- copyBuffer(driversToRetry),
- removeFromPendingRetryDrivers,
- currentOffers,
- tasks)
-
- // Then we walk through the queued drivers and try to schedule them.
- scheduleTasks(
- copyBuffer(queuedDrivers),
- removeFromQueuedDrivers,
- currentOffers,
- tasks)
- }
- tasks.foreach { case (offerId, taskInfos) =>
- driver.launchTasks(Collections.singleton(offerId), taskInfos.asJava)
- }
- offers.asScala
- .filter(o => !tasks.keySet.contains(o.getId))
- .foreach(o => driver.declineOffer(o.getId))
- }
-
- private def copyBuffer(
- buffer: ArrayBuffer[MesosDriverDescription]): ArrayBuffer[MesosDriverDescription] = {
- val newBuffer = new ArrayBuffer[MesosDriverDescription](buffer.size)
- buffer.copyToBuffer(newBuffer)
- newBuffer
- }
-
- def getSchedulerState(): MesosClusterSchedulerState = {
- stateLock.synchronized {
- new MesosClusterSchedulerState(
- frameworkId,
- masterInfo.map(m => s"http://${m.getIp}:${m.getPort}"),
- copyBuffer(queuedDrivers),
- launchedDrivers.values.map(_.copy()).toList,
- finishedDrivers.map(_.copy()).toList,
- copyBuffer(pendingRetryDrivers))
- }
- }
-
- override def offerRescinded(driver: SchedulerDriver, offerId: OfferID): Unit = {}
- override def disconnected(driver: SchedulerDriver): Unit = {}
- override def reregistered(driver: SchedulerDriver, masterInfo: MasterInfo): Unit = {
- logInfo(s"Framework re-registered with master ${masterInfo.getId}")
- }
- override def slaveLost(driver: SchedulerDriver, slaveId: SlaveID): Unit = {}
- override def error(driver: SchedulerDriver, error: String): Unit = {
- logError("Error received: " + error)
- }
-
- /**
- * Check if the task state is a recoverable state that we can relaunch the task.
- * Task state like TASK_ERROR are not relaunchable state since it wasn't able
- * to be validated by Mesos.
- */
- private def shouldRelaunch(state: MesosTaskState): Boolean = {
- state == MesosTaskState.TASK_FAILED ||
- state == MesosTaskState.TASK_KILLED ||
- state == MesosTaskState.TASK_LOST
- }
-
- override def statusUpdate(driver: SchedulerDriver, status: TaskStatus): Unit = {
- val taskId = status.getTaskId.getValue
- stateLock.synchronized {
- if (launchedDrivers.contains(taskId)) {
- if (status.getReason == Reason.REASON_RECONCILIATION &&
- !pendingRecover.contains(taskId)) {
- // Task has already received update and no longer requires reconciliation.
- return
- }
- val state = launchedDrivers(taskId)
- // Check if the driver is supervise enabled and can be relaunched.
- if (state.driverDescription.supervise && shouldRelaunch(status.getState)) {
- removeFromLaunchedDrivers(taskId)
- state.finishDate = Some(new Date())
- val retryState: Option[MesosClusterRetryState] = state.driverDescription.retryState
- val (retries, waitTimeSec) = retryState
- .map { rs => (rs.retries + 1, Math.min(maxRetryWaitTime, rs.waitTime * 2)) }
- .getOrElse{ (1, 1) }
- val nextRetry = new Date(new Date().getTime + waitTimeSec * 1000L)
-
- val newDriverDescription = state.driverDescription.copy(
- retryState = Some(new MesosClusterRetryState(status, retries, nextRetry, waitTimeSec)))
- pendingRetryDrivers += newDriverDescription
- pendingRetryDriversState.persist(taskId, newDriverDescription)
- } else if (TaskState.isFinished(TaskState.fromMesos(status.getState))) {
- removeFromLaunchedDrivers(taskId)
- state.finishDate = Some(new Date())
- if (finishedDrivers.size >= retainedDrivers) {
- val toRemove = math.max(retainedDrivers / 10, 1)
- finishedDrivers.trimStart(toRemove)
- }
- finishedDrivers += state
- }
- state.mesosTaskStatus = Option(status)
- } else {
- logError(s"Unable to find driver $taskId in status update")
- }
- }
- }
-
- override def frameworkMessage(
- driver: SchedulerDriver,
- executorId: ExecutorID,
- slaveId: SlaveID,
- message: Array[Byte]): Unit = {}
-
- override def executorLost(
- driver: SchedulerDriver,
- executorId: ExecutorID,
- slaveId: SlaveID,
- status: Int): Unit = {}
-
- private def removeFromQueuedDrivers(id: String): Boolean = {
- val index = queuedDrivers.indexWhere(_.submissionId.equals(id))
- if (index != -1) {
- queuedDrivers.remove(index)
- queuedDriversState.expunge(id)
- true
- } else {
- false
- }
- }
-
- private def removeFromLaunchedDrivers(id: String): Boolean = {
- if (launchedDrivers.remove(id).isDefined) {
- launchedDriversState.expunge(id)
- true
- } else {
- false
- }
- }
-
- private def removeFromPendingRetryDrivers(id: String): Boolean = {
- val index = pendingRetryDrivers.indexWhere(_.submissionId.equals(id))
- if (index != -1) {
- pendingRetryDrivers.remove(index)
- pendingRetryDriversState.expunge(id)
- true
- } else {
- false
- }
- }
-
- def getQueuedDriversSize: Int = queuedDrivers.size
- def getLaunchedDriversSize: Int = launchedDrivers.size
- def getPendingRetryDriversSize: Int = pendingRetryDrivers.size
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
deleted file mode 100644
index aaffac604a88..000000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ /dev/null
@@ -1,416 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.cluster.mesos
-
-import java.io.File
-import java.util.{ArrayList => JArrayList, Collections, List => JList}
-
-import scala.collection.JavaConverters._
-import scala.collection.mutable.{HashMap, HashSet}
-
-import org.apache.mesos.{Scheduler => MScheduler, _}
-import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, _}
-import org.apache.mesos.protobuf.ByteString
-import org.apache.spark.{SparkContext, SparkException, TaskState}
-import org.apache.spark.executor.MesosExecutorBackend
-import org.apache.spark.scheduler._
-import org.apache.spark.scheduler.cluster.ExecutorInfo
-import org.apache.spark.util.Utils
-
-/**
- * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a
- * separate Mesos task, allowing multiple applications to share cluster nodes both in space (tasks
- * from multiple apps can run on different cores) and in time (a core can switch ownership).
- */
-private[spark] class MesosSchedulerBackend(
- scheduler: TaskSchedulerImpl,
- sc: SparkContext,
- master: String)
- extends SchedulerBackend
- with MScheduler
- with MesosSchedulerUtils {
-
- // Stores the slave ids that has launched a Mesos executor.
- val slaveIdToExecutorInfo = new HashMap[String, MesosExecutorInfo]
- val taskIdToSlaveId = new HashMap[Long, String]
-
- // An ExecutorInfo for our tasks
- var execArgs: Array[Byte] = null
-
- var classLoader: ClassLoader = null
-
- // The listener bus to publish executor added/removed events.
- val listenerBus = sc.listenerBus
-
- private[mesos] val mesosExecutorCores = sc.conf.getDouble("spark.mesos.mesosExecutor.cores", 1)
-
- // Offer constraints
- private[this] val slaveOfferConstraints =
- parseConstraintString(sc.conf.get("spark.mesos.constraints", ""))
-
- @volatile var appId: String = _
-
- override def start() {
- classLoader = Thread.currentThread.getContextClassLoader
- val driver = createSchedulerDriver(
- master,
- MesosSchedulerBackend.this,
- sc.sparkUser,
- sc.appName,
- sc.conf,
- sc.ui.map(_.appUIAddress))
- startScheduler(driver)
- }
-
- /**
- * Creates a MesosExecutorInfo that is used to launch a Mesos executor.
- * @param availableResources Available resources that is offered by Mesos
- * @param execId The executor id to assign to this new executor.
- * @return A tuple of the new mesos executor info and the remaining available resources.
- */
- def createExecutorInfo(
- availableResources: JList[Resource],
- execId: String): (MesosExecutorInfo, JList[Resource]) = {
- val executorSparkHome = sc.conf.getOption("spark.mesos.executor.home")
- .orElse(sc.getSparkHome()) // Fall back to driver Spark home for backward compatibility
- .getOrElse {
- throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!")
- }
- val environment = Environment.newBuilder()
- sc.conf.getOption("spark.executor.extraClassPath").foreach { cp =>
- environment.addVariables(
- Environment.Variable.newBuilder().setName("SPARK_CLASSPATH").setValue(cp).build())
- }
- val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").getOrElse("")
-
- val prefixEnv = sc.conf.getOption("spark.executor.extraLibraryPath").map { p =>
- Utils.libraryPathEnvPrefix(Seq(p))
- }.getOrElse("")
-
- environment.addVariables(
- Environment.Variable.newBuilder()
- .setName("SPARK_EXECUTOR_OPTS")
- .setValue(extraJavaOpts)
- .build())
- sc.executorEnvs.foreach { case (key, value) =>
- environment.addVariables(Environment.Variable.newBuilder()
- .setName(key)
- .setValue(value)
- .build())
- }
- val command = CommandInfo.newBuilder()
- .setEnvironment(environment)
- val uri = sc.conf.getOption("spark.executor.uri")
- .orElse(Option(System.getenv("SPARK_EXECUTOR_URI")))
-
- val executorBackendName = classOf[MesosExecutorBackend].getName
- if (uri.isEmpty) {
- val executorPath = new File(executorSparkHome, "/bin/spark-class").getCanonicalPath
- command.setValue(s"$prefixEnv $executorPath $executorBackendName")
- } else {
- // Grab everything to the first '.'. We'll use that and '*' to
- // glob the directory "correctly".
- val basename = uri.get.split('/').last.split('.').head
- command.setValue(s"cd ${basename}*; $prefixEnv ./bin/spark-class $executorBackendName")
- command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get))
- }
- val builder = MesosExecutorInfo.newBuilder()
- val (resourcesAfterCpu, usedCpuResources) =
- partitionResources(availableResources, "cpus", mesosExecutorCores)
- val (resourcesAfterMem, usedMemResources) =
- partitionResources(resourcesAfterCpu.asJava, "mem", calculateTotalMemory(sc))
-
- builder.addAllResources(usedCpuResources.asJava)
- builder.addAllResources(usedMemResources.asJava)
-
- sc.conf.getOption("spark.mesos.uris").foreach(setupUris(_, command))
-
- val executorInfo = builder
- .setExecutorId(ExecutorID.newBuilder().setValue(execId).build())
- .setCommand(command)
- .setData(ByteString.copyFrom(createExecArg()))
-
- sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image =>
- MesosSchedulerBackendUtil
- .setupContainerBuilderDockerInfo(image, sc.conf, executorInfo.getContainerBuilder())
- }
-
- (executorInfo.build(), resourcesAfterMem.asJava)
- }
-
- /**
- * Create and serialize the executor argument to pass to Mesos. Our executor arg is an array
- * containing all the spark.* system properties in the form of (String, String) pairs.
- */
- private def createExecArg(): Array[Byte] = {
- if (execArgs == null) {
- val props = new HashMap[String, String]
- for ((key, value) <- sc.conf.getAll) {
- props(key) = value
- }
- // Serialize the map as an array of (String, String) pairs
- execArgs = Utils.serialize(props.toArray)
- }
- execArgs
- }
-
- override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
-
- override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) {
- inClassLoader() {
- appId = frameworkId.getValue
- logInfo("Registered as framework ID " + appId)
- markRegistered()
- }
- }
-
- private def inClassLoader()(fun: => Unit) = {
- val oldClassLoader = Thread.currentThread.getContextClassLoader
- Thread.currentThread.setContextClassLoader(classLoader)
- try {
- fun
- } finally {
- Thread.currentThread.setContextClassLoader(oldClassLoader)
- }
- }
-
- override def disconnected(d: SchedulerDriver) {}
-
- override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {}
-
- private def getTasksSummary(tasks: JArrayList[MesosTaskInfo]): String = {
- val builder = new StringBuilder
- tasks.asScala.foreach { t =>
- builder.append("Task id: ").append(t.getTaskId.getValue).append("\n")
- .append("Slave id: ").append(t.getSlaveId.getValue).append("\n")
- .append("Task resources: ").append(t.getResourcesList).append("\n")
- .append("Executor resources: ").append(t.getExecutor.getResourcesList)
- .append("---------------------------------------------\n")
- }
- builder.toString()
- }
-
- /**
- * Method called by Mesos to offer resources on slaves. We respond by asking our active task sets
- * for tasks in order of priority. We fill each node with tasks in a round-robin manner so that
- * tasks are balanced across the cluster.
- */
- override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) {
- inClassLoader() {
- // Fail-fast on offers we know will be rejected
- val (usableOffers, unUsableOffers) = offers.asScala.partition { o =>
- val mem = getResource(o.getResourcesList, "mem")
- val cpus = getResource(o.getResourcesList, "cpus")
- val slaveId = o.getSlaveId.getValue
- val offerAttributes = toAttributeMap(o.getAttributesList)
-
- // check if all constraints are satisfield
- // 1. Attribute constraints
- // 2. Memory requirements
- // 3. CPU requirements - need at least 1 for executor, 1 for task
- val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes)
- val meetsMemoryRequirements = mem >= calculateTotalMemory(sc)
- val meetsCPURequirements = cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK)
-
- val meetsRequirements =
- (meetsConstraints && meetsMemoryRequirements && meetsCPURequirements) ||
- (slaveIdToExecutorInfo.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK)
-
- // add some debug messaging
- val debugstr = if (meetsRequirements) "Accepting" else "Declining"
- val id = o.getId.getValue
- logDebug(s"$debugstr offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus")
-
- meetsRequirements
- }
-
- // Decline offers we ruled out immediately
- unUsableOffers.foreach(o => d.declineOffer(o.getId))
-
- val workerOffers = usableOffers.map { o =>
- val cpus = if (slaveIdToExecutorInfo.contains(o.getSlaveId.getValue)) {
- getResource(o.getResourcesList, "cpus").toInt
- } else {
- // If the Mesos executor has not been started on this slave yet, set aside a few
- // cores for the Mesos executor by offering fewer cores to the Spark executor
- (getResource(o.getResourcesList, "cpus") - mesosExecutorCores).toInt
- }
- new WorkerOffer(
- o.getSlaveId.getValue,
- o.getHostname,
- cpus)
- }
-
- val slaveIdToOffer = usableOffers.map(o => o.getSlaveId.getValue -> o).toMap
- val slaveIdToWorkerOffer = workerOffers.map(o => o.executorId -> o).toMap
- val slaveIdToResources = new HashMap[String, JList[Resource]]()
- usableOffers.foreach { o =>
- slaveIdToResources(o.getSlaveId.getValue) = o.getResourcesList
- }
-
- val mesosTasks = new HashMap[String, JArrayList[MesosTaskInfo]]
-
- val slavesIdsOfAcceptedOffers = HashSet[String]()
-
- // Call into the TaskSchedulerImpl
- val acceptedOffers = scheduler.resourceOffers(workerOffers).filter(!_.isEmpty)
- acceptedOffers
- .foreach { offer =>
- offer.foreach { taskDesc =>
- val slaveId = taskDesc.executorId
- slavesIdsOfAcceptedOffers += slaveId
- taskIdToSlaveId(taskDesc.taskId) = slaveId
- val (mesosTask, remainingResources) = createMesosTask(
- taskDesc,
- slaveIdToResources(slaveId),
- slaveId)
- mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo])
- .add(mesosTask)
- slaveIdToResources(slaveId) = remainingResources
- }
- }
-
- // Reply to the offers
- val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout?
-
- mesosTasks.foreach { case (slaveId, tasks) =>
- slaveIdToWorkerOffer.get(slaveId).foreach(o =>
- listenerBus.post(SparkListenerExecutorAdded(System.currentTimeMillis(), slaveId,
- // TODO: Add support for log urls for Mesos
- new ExecutorInfo(o.host, o.cores, Map.empty)))
- )
- logTrace(s"Launching Mesos tasks on slave '$slaveId', tasks:\n${getTasksSummary(tasks)}")
- d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters)
- }
-
- // Decline offers that weren't used
- // NOTE: This logic assumes that we only get a single offer for each host in a given batch
- for (o <- usableOffers if !slavesIdsOfAcceptedOffers.contains(o.getSlaveId.getValue)) {
- d.declineOffer(o.getId)
- }
- }
- }
-
- /** Turn a Spark TaskDescription into a Mesos task and also resources unused by the task */
- def createMesosTask(
- task: TaskDescription,
- resources: JList[Resource],
- slaveId: String): (MesosTaskInfo, JList[Resource]) = {
- val taskId = TaskID.newBuilder().setValue(task.taskId.toString).build()
- val (executorInfo, remainingResources) = if (slaveIdToExecutorInfo.contains(slaveId)) {
- (slaveIdToExecutorInfo(slaveId), resources)
- } else {
- createExecutorInfo(resources, slaveId)
- }
- slaveIdToExecutorInfo(slaveId) = executorInfo
- val (finalResources, cpuResources) =
- partitionResources(remainingResources, "cpus", scheduler.CPUS_PER_TASK)
- val taskInfo = MesosTaskInfo.newBuilder()
- .setTaskId(taskId)
- .setSlaveId(SlaveID.newBuilder().setValue(slaveId).build())
- .setExecutor(executorInfo)
- .setName(task.name)
- .addAllResources(cpuResources.asJava)
- .setData(MesosTaskLaunchData(task.serializedTask, task.attemptNumber).toByteString)
- .build()
- (taskInfo, finalResources.asJava)
- }
-
- override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
- inClassLoader() {
- val tid = status.getTaskId.getValue.toLong
- val state = TaskState.fromMesos(status.getState)
- synchronized {
- if (TaskState.isFailed(TaskState.fromMesos(status.getState))
- && taskIdToSlaveId.contains(tid)) {
- // We lost the executor on this slave, so remember that it's gone
- removeExecutor(taskIdToSlaveId(tid), "Lost executor")
- }
- if (TaskState.isFinished(state)) {
- taskIdToSlaveId.remove(tid)
- }
- }
- scheduler.statusUpdate(tid, state, status.getData.asReadOnlyByteBuffer)
- }
- }
-
- override def error(d: SchedulerDriver, message: String) {
- inClassLoader() {
- logError("Mesos error: " + message)
- scheduler.error(message)
- }
- }
-
- override def stop() {
- if (mesosDriver != null) {
- mesosDriver.stop()
- }
- }
-
- override def reviveOffers() {
- mesosDriver.reviveOffers()
- }
-
- override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {}
-
- /**
- * Remove executor associated with slaveId in a thread safe manner.
- */
- private def removeExecutor(slaveId: String, reason: String) = {
- synchronized {
- listenerBus.post(SparkListenerExecutorRemoved(System.currentTimeMillis(), slaveId, reason))
- slaveIdToExecutorInfo -= slaveId
- }
- }
-
- private def recordSlaveLost(d: SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason) {
- inClassLoader() {
- logInfo("Mesos slave lost: " + slaveId.getValue)
- removeExecutor(slaveId.getValue, reason.toString)
- scheduler.executorLost(slaveId.getValue, reason)
- }
- }
-
- override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) {
- recordSlaveLost(d, slaveId, SlaveLost())
- }
-
- override def executorLost(d: SchedulerDriver, executorId: ExecutorID,
- slaveId: SlaveID, status: Int) {
- logInfo("Executor lost: %s, marking slave %s as lost".format(executorId.getValue,
- slaveId.getValue))
- recordSlaveLost(d, slaveId, ExecutorExited(status, exitCausedByApp = true))
- }
-
- override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = {
- mesosDriver.killTask(
- TaskID.newBuilder()
- .setValue(taskId.toString).build()
- )
- }
-
- // TODO: query Mesos for number of cores
- override def defaultParallelism(): Int = sc.conf.getInt("spark.default.parallelism", 8)
-
- override def applicationId(): String =
- Option(appId).getOrElse {
- logWarning("Application ID is not initialized yet.")
- super.applicationId
- }
-
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala
deleted file mode 100644
index e79c543a9de2..000000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala
+++ /dev/null
@@ -1,142 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.cluster.mesos
-
-import org.apache.mesos.Protos.{ContainerInfo, Volume}
-import org.apache.mesos.Protos.ContainerInfo.DockerInfo
-
-import org.apache.spark.{Logging, SparkConf}
-
-/**
- * A collection of utility functions which can be used by both the
- * MesosSchedulerBackend and the CoarseMesosSchedulerBackend.
- */
-private[mesos] object MesosSchedulerBackendUtil extends Logging {
- /**
- * Parse a comma-delimited list of volume specs, each of which
- * takes the form [host-dir:]container-dir[:rw|:ro].
- */
- def parseVolumesSpec(volumes: String): List[Volume] = {
- volumes.split(",").map(_.split(":")).flatMap { spec =>
- val vol: Volume.Builder = Volume
- .newBuilder()
- .setMode(Volume.Mode.RW)
- spec match {
- case Array(container_path) =>
- Some(vol.setContainerPath(container_path))
- case Array(container_path, "rw") =>
- Some(vol.setContainerPath(container_path))
- case Array(container_path, "ro") =>
- Some(vol.setContainerPath(container_path)
- .setMode(Volume.Mode.RO))
- case Array(host_path, container_path) =>
- Some(vol.setContainerPath(container_path)
- .setHostPath(host_path))
- case Array(host_path, container_path, "rw") =>
- Some(vol.setContainerPath(container_path)
- .setHostPath(host_path))
- case Array(host_path, container_path, "ro") =>
- Some(vol.setContainerPath(container_path)
- .setHostPath(host_path)
- .setMode(Volume.Mode.RO))
- case spec => {
- logWarning(s"Unable to parse volume specs: $volumes. "
- + "Expected form: \"[host-dir:]container-dir[:rw|:ro](, ...)\"")
- None
- }
- }
- }
- .map { _.build() }
- .toList
- }
-
- /**
- * Parse a comma-delimited list of port mapping specs, each of which
- * takes the form host_port:container_port[:udp|:tcp]
- *
- * Note:
- * the docker form is [ip:]host_port:container_port, but the DockerInfo
- * message has no field for 'ip', and instead has a 'protocol' field.
- * Docker itself only appears to support TCP, so this alternative form
- * anticipates the expansion of the docker form to allow for a protocol
- * and leaves open the chance for mesos to begin to accept an 'ip' field
- */
- def parsePortMappingsSpec(portmaps: String): List[DockerInfo.PortMapping] = {
- portmaps.split(",").map(_.split(":")).flatMap { spec: Array[String] =>
- val portmap: DockerInfo.PortMapping.Builder = DockerInfo.PortMapping
- .newBuilder()
- .setProtocol("tcp")
- spec match {
- case Array(host_port, container_port) =>
- Some(portmap.setHostPort(host_port.toInt)
- .setContainerPort(container_port.toInt))
- case Array(host_port, container_port, protocol) =>
- Some(portmap.setHostPort(host_port.toInt)
- .setContainerPort(container_port.toInt)
- .setProtocol(protocol))
- case spec => {
- logWarning(s"Unable to parse port mapping specs: $portmaps. "
- + "Expected form: \"host_port:container_port[:udp|:tcp](, ...)\"")
- None
- }
- }
- }
- .map { _.build() }
- .toList
- }
-
- /**
- * Construct a DockerInfo structure and insert it into a ContainerInfo
- */
- def addDockerInfo(
- container: ContainerInfo.Builder,
- image: String,
- volumes: Option[List[Volume]] = None,
- network: Option[ContainerInfo.DockerInfo.Network] = None,
- portmaps: Option[List[ContainerInfo.DockerInfo.PortMapping]] = None): Unit = {
-
- val docker = ContainerInfo.DockerInfo.newBuilder().setImage(image)
-
- network.foreach(docker.setNetwork)
- portmaps.foreach(_.foreach(docker.addPortMappings))
- container.setType(ContainerInfo.Type.DOCKER)
- container.setDocker(docker.build())
- volumes.foreach(_.foreach(container.addVolumes))
- }
-
- /**
- * Setup a docker containerizer
- */
- def setupContainerBuilderDockerInfo(
- imageName: String,
- conf: SparkConf,
- builder: ContainerInfo.Builder): Unit = {
- val volumes = conf
- .getOption("spark.mesos.executor.docker.volumes")
- .map(parseVolumesSpec)
- val portmaps = conf
- .getOption("spark.mesos.executor.docker.portmaps")
- .map(parsePortMappingsSpec)
- addDockerInfo(
- builder,
- imageName,
- volumes = volumes,
- portmaps = portmaps)
- logDebug("setupContainerDockerInfo: using docker image: " + imageName)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
deleted file mode 100644
index 860c8e097b3b..000000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
+++ /dev/null
@@ -1,339 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.cluster.mesos
-
-import java.util.{List => JList}
-import java.util.concurrent.CountDownLatch
-
-import scala.collection.JavaConverters._
-import scala.collection.mutable.ArrayBuffer
-import scala.util.control.NonFatal
-
-import com.google.common.base.Splitter
-import org.apache.mesos.{MesosSchedulerDriver, SchedulerDriver, Scheduler, Protos}
-import org.apache.mesos.Protos._
-import org.apache.mesos.protobuf.{ByteString, GeneratedMessage}
-import org.apache.spark.{SparkException, SparkConf, Logging, SparkContext}
-import org.apache.spark.util.Utils
-
-
-/**
- * Shared trait for implementing a Mesos Scheduler. This holds common state and helper
- * methods and Mesos scheduler will use.
- */
-private[mesos] trait MesosSchedulerUtils extends Logging {
- // Lock used to wait for scheduler to be registered
- private final val registerLatch = new CountDownLatch(1)
-
- // Driver for talking to Mesos
- protected var mesosDriver: SchedulerDriver = null
-
- /**
- * Creates a new MesosSchedulerDriver that communicates to the Mesos master.
- * @param masterUrl The url to connect to Mesos master
- * @param scheduler the scheduler class to receive scheduler callbacks
- * @param sparkUser User to impersonate with when running tasks
- * @param appName The framework name to display on the Mesos UI
- * @param conf Spark configuration
- * @param webuiUrl The WebUI url to link from Mesos UI
- * @param checkpoint Option to checkpoint tasks for failover
- * @param failoverTimeout Duration Mesos master expect scheduler to reconnect on disconnect
- * @param frameworkId The id of the new framework
- */
- protected def createSchedulerDriver(
- masterUrl: String,
- scheduler: Scheduler,
- sparkUser: String,
- appName: String,
- conf: SparkConf,
- webuiUrl: Option[String] = None,
- checkpoint: Option[Boolean] = None,
- failoverTimeout: Option[Double] = None,
- frameworkId: Option[String] = None): SchedulerDriver = {
- val fwInfoBuilder = FrameworkInfo.newBuilder().setUser(sparkUser).setName(appName)
- val credBuilder = Credential.newBuilder()
- webuiUrl.foreach { url => fwInfoBuilder.setWebuiUrl(url) }
- checkpoint.foreach { checkpoint => fwInfoBuilder.setCheckpoint(checkpoint) }
- failoverTimeout.foreach { timeout => fwInfoBuilder.setFailoverTimeout(timeout) }
- frameworkId.foreach { id =>
- fwInfoBuilder.setId(FrameworkID.newBuilder().setValue(id).build())
- }
- conf.getOption("spark.mesos.principal").foreach { principal =>
- fwInfoBuilder.setPrincipal(principal)
- credBuilder.setPrincipal(principal)
- }
- conf.getOption("spark.mesos.secret").foreach { secret =>
- credBuilder.setSecret(ByteString.copyFromUtf8(secret))
- }
- if (credBuilder.hasSecret && !fwInfoBuilder.hasPrincipal) {
- throw new SparkException(
- "spark.mesos.principal must be configured when spark.mesos.secret is set")
- }
- conf.getOption("spark.mesos.role").foreach { role =>
- fwInfoBuilder.setRole(role)
- }
- if (credBuilder.hasPrincipal) {
- new MesosSchedulerDriver(
- scheduler, fwInfoBuilder.build(), masterUrl, credBuilder.build())
- } else {
- new MesosSchedulerDriver(scheduler, fwInfoBuilder.build(), masterUrl)
- }
- }
-
- /**
- * Starts the MesosSchedulerDriver and stores the current running driver to this new instance.
- * This driver is expected to not be running.
- * This method returns only after the scheduler has registered with Mesos.
- */
- def startScheduler(newDriver: SchedulerDriver): Unit = {
- synchronized {
- if (mesosDriver != null) {
- registerLatch.await()
- return
- }
-
- new Thread(Utils.getFormattedClassName(this) + "-mesos-driver") {
- setDaemon(true)
-
- override def run() {
- mesosDriver = newDriver
- try {
- val ret = mesosDriver.run()
- logInfo("driver.run() returned with code " + ret)
- if (ret != null && ret.equals(Status.DRIVER_ABORTED)) {
- System.exit(1)
- }
- } catch {
- case e: Exception => {
- logError("driver.run() failed", e)
- System.exit(1)
- }
- }
- }
- }.start()
-
- registerLatch.await()
- }
- }
-
- /**
- * Signal that the scheduler has registered with Mesos.
- */
- protected def getResource(res: JList[Resource], name: String): Double = {
- // A resource can have multiple values in the offer since it can either be from
- // a specific role or wildcard.
- res.asScala.filter(_.getName == name).map(_.getScalar.getValue).sum
- }
-
- protected def markRegistered(): Unit = {
- registerLatch.countDown()
- }
-
- def createResource(name: String, amount: Double, role: Option[String] = None): Resource = {
- val builder = Resource.newBuilder()
- .setName(name)
- .setType(Value.Type.SCALAR)
- .setScalar(Value.Scalar.newBuilder().setValue(amount).build())
-
- role.foreach { r => builder.setRole(r) }
-
- builder.build()
- }
-
- /**
- * Partition the existing set of resources into two groups, those remaining to be
- * scheduled and those requested to be used for a new task.
- * @param resources The full list of available resources
- * @param resourceName The name of the resource to take from the available resources
- * @param amountToUse The amount of resources to take from the available resources
- * @return The remaining resources list and the used resources list.
- */
- def partitionResources(
- resources: JList[Resource],
- resourceName: String,
- amountToUse: Double): (List[Resource], List[Resource]) = {
- var remain = amountToUse
- var requestedResources = new ArrayBuffer[Resource]
- val remainingResources = resources.asScala.map {
- case r => {
- if (remain > 0 &&
- r.getType == Value.Type.SCALAR &&
- r.getScalar.getValue > 0.0 &&
- r.getName == resourceName) {
- val usage = Math.min(remain, r.getScalar.getValue)
- requestedResources += createResource(resourceName, usage, Some(r.getRole))
- remain -= usage
- createResource(resourceName, r.getScalar.getValue - usage, Some(r.getRole))
- } else {
- r
- }
- }
- }
-
- // Filter any resource that has depleted.
- val filteredResources =
- remainingResources.filter(r => r.getType != Value.Type.SCALAR || r.getScalar.getValue > 0.0)
-
- (filteredResources.toList, requestedResources.toList)
- }
-
- /** Helper method to get the key,value-set pair for a Mesos Attribute protobuf */
- protected def getAttribute(attr: Attribute): (String, Set[String]) = {
- (attr.getName, attr.getText.getValue.split(',').toSet)
- }
-
-
- /** Build a Mesos resource protobuf object */
- protected def createResource(resourceName: String, quantity: Double): Protos.Resource = {
- Resource.newBuilder()
- .setName(resourceName)
- .setType(Value.Type.SCALAR)
- .setScalar(Value.Scalar.newBuilder().setValue(quantity).build())
- .build()
- }
-
- /**
- * Converts the attributes from the resource offer into a Map of name -> Attribute Value
- * The attribute values are the mesos attribute types and they are
- * @param offerAttributes
- * @return
- */
- protected def toAttributeMap(offerAttributes: JList[Attribute]): Map[String, GeneratedMessage] = {
- offerAttributes.asScala.map(attr => {
- val attrValue = attr.getType match {
- case Value.Type.SCALAR => attr.getScalar
- case Value.Type.RANGES => attr.getRanges
- case Value.Type.SET => attr.getSet
- case Value.Type.TEXT => attr.getText
- }
- (attr.getName, attrValue)
- }).toMap
- }
-
-
- /**
- * Match the requirements (if any) to the offer attributes.
- * if attribute requirements are not specified - return true
- * else if attribute is defined and no values are given, simple attribute presence is performed
- * else if attribute name and value is specified, subset match is performed on slave attributes
- */
- def matchesAttributeRequirements(
- slaveOfferConstraints: Map[String, Set[String]],
- offerAttributes: Map[String, GeneratedMessage]): Boolean = {
- slaveOfferConstraints.forall {
- // offer has the required attribute and subsumes the required values for that attribute
- case (name, requiredValues) =>
- offerAttributes.get(name) match {
- case None => false
- case Some(_) if requiredValues.isEmpty => true // empty value matches presence
- case Some(scalarValue: Value.Scalar) =>
- // check if provided values is less than equal to the offered values
- requiredValues.map(_.toDouble).exists(_ <= scalarValue.getValue)
- case Some(rangeValue: Value.Range) =>
- val offerRange = rangeValue.getBegin to rangeValue.getEnd
- // Check if there is some required value that is between the ranges specified
- // Note: We only support the ability to specify discrete values, in the future
- // we may expand it to subsume ranges specified with a XX..YY value or something
- // similar to that.
- requiredValues.map(_.toLong).exists(offerRange.contains(_))
- case Some(offeredValue: Value.Set) =>
- // check if the specified required values is a subset of offered set
- requiredValues.subsetOf(offeredValue.getItemList.asScala.toSet)
- case Some(textValue: Value.Text) =>
- // check if the specified value is equal, if multiple values are specified
- // we succeed if any of them match.
- requiredValues.contains(textValue.getValue)
- }
- }
- }
-
- /**
- * Parses the attributes constraints provided to spark and build a matching data struct:
- * Map[, Set[values-to-match]]
- * The constraints are specified as ';' separated key-value pairs where keys and values
- * are separated by ':'. The ':' implies equality (for singular values) and "is one of" for
- * multiple values (comma separated). For example:
- * {{{
- * parseConstraintString("tachyon:true;zone:us-east-1a,us-east-1b")
- * // would result in
- *
- * Map(
- * "tachyon" -> Set("true"),
- * "zone": -> Set("us-east-1a", "us-east-1b")
- * )
- * }}}
- *
- * Mesos documentation: http://mesos.apache.org/documentation/attributes-resources/
- * https://github.com/apache/mesos/blob/master/src/common/values.cpp
- * https://github.com/apache/mesos/blob/master/src/common/attributes.cpp
- *
- * @param constraintsVal constaints string consisting of ';' separated key-value pairs (separated
- * by ':')
- * @return Map of constraints to match resources offers.
- */
- def parseConstraintString(constraintsVal: String): Map[String, Set[String]] = {
- /*
- Based on mesos docs:
- attributes : attribute ( ";" attribute )*
- attribute : labelString ":" ( labelString | "," )+
- labelString : [a-zA-Z0-9_/.-]
- */
- val splitter = Splitter.on(';').trimResults().withKeyValueSeparator(':')
- // kv splitter
- if (constraintsVal.isEmpty) {
- Map()
- } else {
- try {
- splitter.split(constraintsVal).asScala.toMap.mapValues(v =>
- if (v == null || v.isEmpty) {
- Set[String]()
- } else {
- v.split(',').toSet
- }
- )
- } catch {
- case NonFatal(e) =>
- throw new IllegalArgumentException(s"Bad constraint string: $constraintsVal", e)
- }
- }
- }
-
- // These defaults copied from YARN
- private val MEMORY_OVERHEAD_FRACTION = 0.10
- private val MEMORY_OVERHEAD_MINIMUM = 384
-
- /**
- * Return the amount of memory to allocate to each executor, taking into account
- * container overheads.
- * @param sc SparkContext to use to get `spark.mesos.executor.memoryOverhead` value
- * @return memory requirement as (0.1 * ) or MEMORY_OVERHEAD_MINIMUM
- * (whichever is larger)
- */
- def calculateTotalMemory(sc: SparkContext): Int = {
- sc.conf.getInt("spark.mesos.executor.memoryOverhead",
- math.max(MEMORY_OVERHEAD_FRACTION * sc.executorMemory, MEMORY_OVERHEAD_MINIMUM).toInt) +
- sc.executorMemory
- }
-
- def setupUris(uris: String, builder: CommandInfo.Builder): Unit = {
- uris.split(",").foreach { uri =>
- builder.addUris(CommandInfo.URI.newBuilder().setValue(uri.trim()))
- }
- }
-
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala
deleted file mode 100644
index 5e7e6567a3e0..000000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala
+++ /dev/null
@@ -1,51 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.cluster.mesos
-
-import java.nio.ByteBuffer
-
-import org.apache.mesos.protobuf.ByteString
-
-import org.apache.spark.Logging
-
-/**
- * Wrapper for serializing the data sent when launching Mesos tasks.
- */
-private[spark] case class MesosTaskLaunchData(
- serializedTask: ByteBuffer,
- attemptNumber: Int) extends Logging {
-
- def toByteString: ByteString = {
- val dataBuffer = ByteBuffer.allocate(4 + serializedTask.limit)
- dataBuffer.putInt(attemptNumber)
- dataBuffer.put(serializedTask)
- dataBuffer.rewind
- logDebug(s"ByteBuffer size: [${dataBuffer.remaining}]")
- ByteString.copyFrom(dataBuffer)
- }
-}
-
-private[spark] object MesosTaskLaunchData extends Logging {
- def fromByteString(byteString: ByteString): MesosTaskLaunchData = {
- val byteBuffer = byteString.asReadOnlyByteBuffer()
- logDebug(s"ByteBuffer size: [${byteBuffer.remaining}]")
- val attemptNumber = byteBuffer.getInt // updates the position by 4 bytes
- val serializedTask = byteBuffer.slice() // subsequence starting at the current position
- MesosTaskLaunchData(serializedTask, attemptNumber)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
deleted file mode 100644
index c633d860ae6e..000000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
+++ /dev/null
@@ -1,165 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.local
-
-import java.io.File
-import java.net.URL
-import java.nio.ByteBuffer
-
-import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState}
-import org.apache.spark.TaskState.TaskState
-import org.apache.spark.executor.{Executor, ExecutorBackend}
-import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle}
-import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
-import org.apache.spark.scheduler._
-import org.apache.spark.scheduler.cluster.ExecutorInfo
-
-private case class ReviveOffers()
-
-private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
-
-private case class KillTask(taskId: Long, interruptThread: Boolean)
-
-private case class StopExecutor()
-
-/**
- * Calls to LocalBackend are all serialized through LocalEndpoint. Using an RpcEndpoint makes the
- * calls on LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend
- * and the TaskSchedulerImpl.
- */
-private[spark] class LocalEndpoint(
- override val rpcEnv: RpcEnv,
- userClassPath: Seq[URL],
- scheduler: TaskSchedulerImpl,
- executorBackend: LocalBackend,
- private val totalCores: Int)
- extends ThreadSafeRpcEndpoint with Logging {
-
- private var freeCores = totalCores
-
- val localExecutorId = SparkContext.DRIVER_IDENTIFIER
- val localExecutorHostname = "localhost"
-
- private val executor = new Executor(
- localExecutorId, localExecutorHostname, SparkEnv.get, userClassPath, isLocal = true)
-
- override def receive: PartialFunction[Any, Unit] = {
- case ReviveOffers =>
- reviveOffers()
-
- case StatusUpdate(taskId, state, serializedData) =>
- scheduler.statusUpdate(taskId, state, serializedData)
- if (TaskState.isFinished(state)) {
- freeCores += scheduler.CPUS_PER_TASK
- reviveOffers()
- }
-
- case KillTask(taskId, interruptThread) =>
- executor.killTask(taskId, interruptThread)
- }
-
- override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
- case StopExecutor =>
- executor.stop()
- context.reply(true)
- }
-
- def reviveOffers() {
- val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores))
- for (task <- scheduler.resourceOffers(offers).flatten) {
- freeCores -= scheduler.CPUS_PER_TASK
- executor.launchTask(executorBackend, taskId = task.taskId, attemptNumber = task.attemptNumber,
- task.name, task.serializedTask)
- }
- }
-}
-
-/**
- * LocalBackend is used when running a local version of Spark where the executor, backend, and
- * master all run in the same JVM. It sits behind a TaskSchedulerImpl and handles launching tasks
- * on a single Executor (created by the LocalBackend) running locally.
- */
-private[spark] class LocalBackend(
- conf: SparkConf,
- scheduler: TaskSchedulerImpl,
- val totalCores: Int)
- extends SchedulerBackend with ExecutorBackend with Logging {
-
- private val appId = "local-" + System.currentTimeMillis
- private var localEndpoint: RpcEndpointRef = null
- private val userClassPath = getUserClasspath(conf)
- private val listenerBus = scheduler.sc.listenerBus
- private val launcherBackend = new LauncherBackend() {
- override def onStopRequest(): Unit = stop(SparkAppHandle.State.KILLED)
- }
-
- /**
- * Returns a list of URLs representing the user classpath.
- *
- * @param conf Spark configuration.
- */
- def getUserClasspath(conf: SparkConf): Seq[URL] = {
- val userClassPathStr = conf.getOption("spark.executor.extraClassPath")
- userClassPathStr.map(_.split(File.pathSeparator)).toSeq.flatten.map(new File(_).toURI.toURL)
- }
-
- launcherBackend.connect()
-
- override def start() {
- val rpcEnv = SparkEnv.get.rpcEnv
- val executorEndpoint = new LocalEndpoint(rpcEnv, userClassPath, scheduler, this, totalCores)
- localEndpoint = rpcEnv.setupEndpoint("LocalBackendEndpoint", executorEndpoint)
- listenerBus.post(SparkListenerExecutorAdded(
- System.currentTimeMillis,
- executorEndpoint.localExecutorId,
- new ExecutorInfo(executorEndpoint.localExecutorHostname, totalCores, Map.empty)))
- launcherBackend.setAppId(appId)
- launcherBackend.setState(SparkAppHandle.State.RUNNING)
- }
-
- override def stop() {
- stop(SparkAppHandle.State.FINISHED)
- }
-
- override def reviveOffers() {
- localEndpoint.send(ReviveOffers)
- }
-
- override def defaultParallelism(): Int =
- scheduler.conf.getInt("spark.default.parallelism", totalCores)
-
- override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) {
- localEndpoint.send(KillTask(taskId, interruptThread))
- }
-
- override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
- localEndpoint.send(StatusUpdate(taskId, state, serializedData))
- }
-
- override def applicationId(): String = appId
-
- private def stop(finalState: SparkAppHandle.State): Unit = {
- localEndpoint.ask(StopExecutor)
- try {
- launcherBackend.setState(finalState)
- } finally {
- launcherBackend.close()
- }
- }
-
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala
new file mode 100644
index 000000000000..35509bc2f85b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.local
+
+import java.io.File
+import java.net.URL
+import java.nio.ByteBuffer
+
+import org.apache.spark.{SparkConf, SparkContext, SparkEnv, TaskState}
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.executor.{Executor, ExecutorBackend}
+import org.apache.spark.internal.Logging
+import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle}
+import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
+import org.apache.spark.scheduler._
+import org.apache.spark.scheduler.cluster.ExecutorInfo
+
+private case class ReviveOffers()
+
+private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
+
+private case class KillTask(taskId: Long, interruptThread: Boolean, reason: String)
+
+private case class StopExecutor()
+
+/**
+ * Calls to [[LocalSchedulerBackend]] are all serialized through LocalEndpoint. Using an
+ * RpcEndpoint makes the calls on [[LocalSchedulerBackend]] asynchronous, which is necessary
+ * to prevent deadlock between [[LocalSchedulerBackend]] and the [[TaskSchedulerImpl]].
+ */
+private[spark] class LocalEndpoint(
+ override val rpcEnv: RpcEnv,
+ userClassPath: Seq[URL],
+ scheduler: TaskSchedulerImpl,
+ executorBackend: LocalSchedulerBackend,
+ private val totalCores: Int)
+ extends ThreadSafeRpcEndpoint with Logging {
+
+ private var freeCores = totalCores
+
+ val localExecutorId = SparkContext.DRIVER_IDENTIFIER
+ val localExecutorHostname = "localhost"
+
+ private val executor = new Executor(
+ localExecutorId, localExecutorHostname, SparkEnv.get, userClassPath, isLocal = true)
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case ReviveOffers =>
+ reviveOffers()
+
+ case StatusUpdate(taskId, state, serializedData) =>
+ scheduler.statusUpdate(taskId, state, serializedData)
+ if (TaskState.isFinished(state)) {
+ freeCores += scheduler.CPUS_PER_TASK
+ reviveOffers()
+ }
+
+ case KillTask(taskId, interruptThread, reason) =>
+ executor.killTask(taskId, interruptThread, reason)
+ }
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case StopExecutor =>
+ executor.stop()
+ context.reply(true)
+ }
+
+ def reviveOffers() {
+ val offers = IndexedSeq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores))
+ for (task <- scheduler.resourceOffers(offers).flatten) {
+ freeCores -= scheduler.CPUS_PER_TASK
+ executor.launchTask(executorBackend, task)
+ }
+ }
+}
+
+/**
+ * Used when running a local version of Spark where the executor, backend, and master all run in
+ * the same JVM. It sits behind a [[TaskSchedulerImpl]] and handles launching tasks on a single
+ * Executor (created by the [[LocalSchedulerBackend]]) running locally.
+ */
+private[spark] class LocalSchedulerBackend(
+ conf: SparkConf,
+ scheduler: TaskSchedulerImpl,
+ val totalCores: Int)
+ extends SchedulerBackend with ExecutorBackend with Logging {
+
+ private val appId = "local-" + System.currentTimeMillis
+ private var localEndpoint: RpcEndpointRef = null
+ private val userClassPath = getUserClasspath(conf)
+ private val listenerBus = scheduler.sc.listenerBus
+ private val launcherBackend = new LauncherBackend() {
+ override def onStopRequest(): Unit = stop(SparkAppHandle.State.KILLED)
+ }
+
+ /**
+ * Returns a list of URLs representing the user classpath.
+ *
+ * @param conf Spark configuration.
+ */
+ def getUserClasspath(conf: SparkConf): Seq[URL] = {
+ val userClassPathStr = conf.getOption("spark.executor.extraClassPath")
+ userClassPathStr.map(_.split(File.pathSeparator)).toSeq.flatten.map(new File(_).toURI.toURL)
+ }
+
+ launcherBackend.connect()
+
+ override def start() {
+ val rpcEnv = SparkEnv.get.rpcEnv
+ val executorEndpoint = new LocalEndpoint(rpcEnv, userClassPath, scheduler, this, totalCores)
+ localEndpoint = rpcEnv.setupEndpoint("LocalSchedulerBackendEndpoint", executorEndpoint)
+ listenerBus.post(SparkListenerExecutorAdded(
+ System.currentTimeMillis,
+ executorEndpoint.localExecutorId,
+ new ExecutorInfo(executorEndpoint.localExecutorHostname, totalCores, Map.empty)))
+ launcherBackend.setAppId(appId)
+ launcherBackend.setState(SparkAppHandle.State.RUNNING)
+ }
+
+ override def stop() {
+ stop(SparkAppHandle.State.FINISHED)
+ }
+
+ override def reviveOffers() {
+ localEndpoint.send(ReviveOffers)
+ }
+
+ override def defaultParallelism(): Int =
+ scheduler.conf.getInt("spark.default.parallelism", totalCores)
+
+ override def killTask(
+ taskId: Long, executorId: String, interruptThread: Boolean, reason: String) {
+ localEndpoint.send(KillTask(taskId, interruptThread, reason))
+ }
+
+ override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
+ localEndpoint.send(StatusUpdate(taskId, state, serializedData))
+ }
+
+ override def applicationId(): String = appId
+
+ private def stop(finalState: SparkAppHandle.State): Unit = {
+ localEndpoint.ask(StopExecutor)
+ try {
+ launcherBackend.setState(finalState)
+ } finally {
+ launcherBackend.close()
+ }
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/package-info.java b/core/src/main/scala/org/apache/spark/scheduler/package-info.java
index 5b4a628d3cee..90fc65251eae 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/package-info.java
+++ b/core/src/main/scala/org/apache/spark/scheduler/package-info.java
@@ -18,4 +18,4 @@
/**
* Spark's DAG scheduler.
*/
-package org.apache.spark.scheduler;
\ No newline at end of file
+package org.apache.spark.scheduler;
diff --git a/core/src/main/scala/org/apache/spark/scheduler/package.scala b/core/src/main/scala/org/apache/spark/scheduler/package.scala
index f0dbfc2ac5f4..4847c41710b2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/package.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/package.scala
@@ -18,7 +18,7 @@
package org.apache.spark
/**
- * Spark's scheduling components. This includes the [[org.apache.spark.scheduler.DAGScheduler]] and
- * lower level [[org.apache.spark.scheduler.TaskScheduler]].
+ * Spark's scheduling components. This includes the `org.apache.spark.scheduler.DAGScheduler` and
+ * lower level `org.apache.spark.scheduler.TaskScheduler`.
*/
package object scheduler
diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
new file mode 100644
index 000000000000..78dabb42ac9d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.security
+
+import java.io.{EOFException, InputStream, OutputStream}
+import java.nio.ByteBuffer
+import java.nio.channels.{ReadableByteChannel, WritableByteChannel}
+import java.util.Properties
+import javax.crypto.KeyGenerator
+import javax.crypto.spec.{IvParameterSpec, SecretKeySpec}
+
+import scala.collection.JavaConverters._
+
+import com.google.common.io.ByteStreams
+import org.apache.commons.crypto.random._
+import org.apache.commons.crypto.stream._
+
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
+import org.apache.spark.network.util.{CryptoUtils, JavaUtils}
+
+/**
+ * A util class for manipulating IO encryption and decryption streams.
+ */
+private[spark] object CryptoStreamUtils extends Logging {
+
+ // The initialization vector length in bytes.
+ val IV_LENGTH_IN_BYTES = 16
+ // The prefix of IO encryption related configurations in Spark configuration.
+ val SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX = "spark.io.encryption.commons.config."
+
+ /**
+ * Helper method to wrap `OutputStream` with `CryptoOutputStream` for encryption.
+ */
+ def createCryptoOutputStream(
+ os: OutputStream,
+ sparkConf: SparkConf,
+ key: Array[Byte]): OutputStream = {
+ val params = new CryptoParams(key, sparkConf)
+ val iv = createInitializationVector(params.conf)
+ os.write(iv)
+ new CryptoOutputStream(params.transformation, params.conf, os, params.keySpec,
+ new IvParameterSpec(iv))
+ }
+
+ /**
+ * Wrap a `WritableByteChannel` for encryption.
+ */
+ def createWritableChannel(
+ channel: WritableByteChannel,
+ sparkConf: SparkConf,
+ key: Array[Byte]): WritableByteChannel = {
+ val params = new CryptoParams(key, sparkConf)
+ val iv = createInitializationVector(params.conf)
+ val helper = new CryptoHelperChannel(channel)
+
+ helper.write(ByteBuffer.wrap(iv))
+ new CryptoOutputStream(params.transformation, params.conf, helper, params.keySpec,
+ new IvParameterSpec(iv))
+ }
+
+ /**
+ * Helper method to wrap `InputStream` with `CryptoInputStream` for decryption.
+ */
+ def createCryptoInputStream(
+ is: InputStream,
+ sparkConf: SparkConf,
+ key: Array[Byte]): InputStream = {
+ val iv = new Array[Byte](IV_LENGTH_IN_BYTES)
+ ByteStreams.readFully(is, iv)
+ val params = new CryptoParams(key, sparkConf)
+ new CryptoInputStream(params.transformation, params.conf, is, params.keySpec,
+ new IvParameterSpec(iv))
+ }
+
+ /**
+ * Wrap a `ReadableByteChannel` for decryption.
+ */
+ def createReadableChannel(
+ channel: ReadableByteChannel,
+ sparkConf: SparkConf,
+ key: Array[Byte]): ReadableByteChannel = {
+ val iv = new Array[Byte](IV_LENGTH_IN_BYTES)
+ val buf = ByteBuffer.wrap(iv)
+ JavaUtils.readFully(channel, buf)
+
+ val params = new CryptoParams(key, sparkConf)
+ new CryptoInputStream(params.transformation, params.conf, channel, params.keySpec,
+ new IvParameterSpec(iv))
+ }
+
+ def toCryptoConf(conf: SparkConf): Properties = {
+ CryptoUtils.toCryptoConf(SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX,
+ conf.getAll.toMap.asJava.entrySet())
+ }
+
+ /**
+ * Creates a new encryption key.
+ */
+ def createKey(conf: SparkConf): Array[Byte] = {
+ val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS)
+ val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM)
+ val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm)
+ keyGen.init(keyLen)
+ keyGen.generateKey().getEncoded()
+ }
+
+ /**
+ * This method to generate an IV (Initialization Vector) using secure random.
+ */
+ private[this] def createInitializationVector(properties: Properties): Array[Byte] = {
+ val iv = new Array[Byte](IV_LENGTH_IN_BYTES)
+ val initialIVStart = System.currentTimeMillis()
+ CryptoRandomFactory.getCryptoRandom(properties).nextBytes(iv)
+ val initialIVFinish = System.currentTimeMillis()
+ val initialIVTime = initialIVFinish - initialIVStart
+ if (initialIVTime > 2000) {
+ logWarning(s"It costs ${initialIVTime} milliseconds to create the Initialization Vector " +
+ s"used by CryptoStream")
+ }
+ iv
+ }
+
+ /**
+ * This class is a workaround for CRYPTO-125, that forces all bytes to be written to the
+ * underlying channel. Since the callers of this API are using blocking I/O, there are no
+ * concerns with regards to CPU usage here.
+ */
+ private class CryptoHelperChannel(sink: WritableByteChannel) extends WritableByteChannel {
+
+ override def write(src: ByteBuffer): Int = {
+ val count = src.remaining()
+ while (src.hasRemaining()) {
+ sink.write(src)
+ }
+ count
+ }
+
+ override def isOpen(): Boolean = sink.isOpen()
+
+ override def close(): Unit = sink.close()
+
+ }
+
+ private class CryptoParams(key: Array[Byte], sparkConf: SparkConf) {
+
+ val keySpec = new SecretKeySpec(key, "AES")
+ val transformation = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
+ val conf = toCryptoConf(sparkConf)
+
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/security/GroupMappingServiceProvider.scala b/core/src/main/scala/org/apache/spark/security/GroupMappingServiceProvider.scala
new file mode 100644
index 000000000000..ea047a4f75d5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/security/GroupMappingServiceProvider.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.security
+
+/**
+ * This Spark trait is used for mapping a given userName to a set of groups which it belongs to.
+ * This is useful for specifying a common group of admins/developers to provide them admin, modify
+ * and/or view access rights. Based on whether access control checks are enabled using
+ * spark.acls.enable, every time a user tries to access or modify the application, the
+ * SecurityManager gets the corresponding groups a user belongs to from the instance of the groups
+ * mapping provider specified by the entry spark.user.groups.mapping.
+ */
+
+trait GroupMappingServiceProvider {
+
+ /**
+ * Get the groups the user belongs to.
+ * @param userName User's Name
+ * @return set of groups that the user belongs to. Empty in case of an invalid user.
+ */
+ def getGroups(userName : String) : Set[String]
+
+}
diff --git a/core/src/main/scala/org/apache/spark/security/ShellBasedGroupsMappingProvider.scala b/core/src/main/scala/org/apache/spark/security/ShellBasedGroupsMappingProvider.scala
new file mode 100644
index 000000000000..f71dd08246b2
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/security/ShellBasedGroupsMappingProvider.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.security
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.Utils
+
+/**
+ * This class is responsible for getting the groups for a particular user in Unix based
+ * environments. This implementation uses the Unix Shell based id command to fetch the user groups
+ * for the specified user. It does not cache the user groups as the invocations are expected
+ * to be infrequent.
+ */
+
+private[spark] class ShellBasedGroupsMappingProvider extends GroupMappingServiceProvider
+ with Logging {
+
+ override def getGroups(username: String): Set[String] = {
+ val userGroups = getUnixGroups(username)
+ logDebug("User: " + username + " Groups: " + userGroups.mkString(","))
+ userGroups
+ }
+
+ // shells out a "bash -c id -Gn username" to get user groups
+ private def getUnixGroups(username: String): Set[String] = {
+ val cmdSeq = Seq("bash", "-c", "id -Gn " + username)
+ // we need to get rid of the trailing "\n" from the result of command execution
+ Utils.executeAndGetOutput(cmdSeq).stripLineEnd.split(" ").toSet
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
new file mode 100644
index 000000000000..d15e7937b052
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.security
+
+import java.io.{DataInputStream, DataOutputStream, InputStream}
+import java.net.Socket
+import java.nio.charset.StandardCharsets.UTF_8
+
+import org.apache.spark.SparkConf
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.util.Utils
+
+/**
+ * A class that can be used to add a simple authentication protocol to socket-based communication.
+ *
+ * The protocol is simple: an auth secret is written to the socket, and the other side checks the
+ * secret and writes either "ok" or "err" to the output. If authentication fails, the socket is
+ * not expected to be valid anymore.
+ *
+ * There's no secrecy, so this relies on the sockets being either local or somehow encrypted.
+ */
+private[spark] class SocketAuthHelper(conf: SparkConf) {
+
+ val secret = Utils.createSecret(conf)
+
+ /**
+ * Read the auth secret from the socket and compare to the expected value. Write the reply back
+ * to the socket.
+ *
+ * If authentication fails, this method will close the socket.
+ *
+ * @param s The client socket.
+ * @throws IllegalArgumentException If authentication fails.
+ */
+ def authClient(s: Socket): Unit = {
+ // Set the socket timeout while checking the auth secret. Reset it before returning.
+ val currentTimeout = s.getSoTimeout()
+ try {
+ s.setSoTimeout(10000)
+ val clientSecret = readUtf8(s)
+ if (secret == clientSecret) {
+ writeUtf8("ok", s)
+ } else {
+ writeUtf8("err", s)
+ JavaUtils.closeQuietly(s)
+ }
+ } finally {
+ s.setSoTimeout(currentTimeout)
+ }
+ }
+
+ /**
+ * Authenticate with a server by writing the auth secret and checking the server's reply.
+ *
+ * If authentication fails, this method will close the socket.
+ *
+ * @param s The socket connected to the server.
+ * @throws IllegalArgumentException If authentication fails.
+ */
+ def authToServer(s: Socket): Unit = {
+ writeUtf8(secret, s)
+
+ val reply = readUtf8(s)
+ if (reply != "ok") {
+ JavaUtils.closeQuietly(s)
+ throw new IllegalArgumentException("Authentication failed.")
+ }
+ }
+
+ protected def readUtf8(s: Socket): String = {
+ val din = new DataInputStream(s.getInputStream())
+ val len = din.readInt()
+ val bytes = new Array[Byte](len)
+ din.readFully(bytes)
+ new String(bytes, UTF_8)
+ }
+
+ protected def writeUtf8(str: String, s: Socket): Unit = {
+ val bytes = str.getBytes(UTF_8)
+ val dout = new DataOutputStream(s.getOutputStream())
+ dout.writeInt(bytes.length)
+ dout.write(bytes, 0, bytes.length)
+ dout.flush()
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala
index 62f8aae7f212..f0ed41f6903f 100644
--- a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala
@@ -19,6 +19,7 @@ package org.apache.spark.serializer
import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.nio.ByteBuffer
+import java.nio.charset.StandardCharsets
import scala.collection.mutable
@@ -29,8 +30,9 @@ import org.apache.avro.generic.{GenericData, GenericRecord}
import org.apache.avro.io._
import org.apache.commons.io.IOUtils
-import org.apache.spark.{SparkException, SparkEnv}
+import org.apache.spark.{SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
+import org.apache.spark.util.Utils
/**
* Custom serializer used for generic Avro records. If the user registers the schemas
@@ -71,8 +73,11 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String])
def compress(schema: Schema): Array[Byte] = compressCache.getOrElseUpdate(schema, {
val bos = new ByteArrayOutputStream()
val out = codec.compressedOutputStream(bos)
- out.write(schema.toString.getBytes("UTF-8"))
- out.close()
+ Utils.tryWithSafeFinally {
+ out.write(schema.toString.getBytes(StandardCharsets.UTF_8))
+ } {
+ out.close()
+ }
bos.toByteArray
})
@@ -81,9 +86,17 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String])
* seen values so to limit the number of times that decompression has to be done.
*/
def decompress(schemaBytes: ByteBuffer): Schema = decompressCache.getOrElseUpdate(schemaBytes, {
- val bis = new ByteArrayInputStream(schemaBytes.array())
- val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis))
- new Schema.Parser().parse(new String(bytes, "UTF-8"))
+ val bis = new ByteArrayInputStream(
+ schemaBytes.array(),
+ schemaBytes.arrayOffset() + schemaBytes.position(),
+ schemaBytes.remaining())
+ val in = codec.compressedInputStream(bis)
+ val bytes = Utils.tryWithSafeFinally {
+ IOUtils.toByteArray(in)
+ } {
+ in.close()
+ }
+ new Schema.Parser().parse(new String(bytes, StandardCharsets.UTF_8))
})
/**
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index b463a71d5bd7..f60dcfddfdc2 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -24,8 +24,7 @@ import scala.reflect.ClassTag
import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.util.ByteBufferInputStream
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils}
private[spark] class JavaSerializationStream(
out: OutputStream, counterReset: Int, extraDebugInfo: Boolean)
@@ -69,7 +68,7 @@ private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoa
// scalastyle:on classforname
} catch {
case e: ClassNotFoundException =>
- JavaDeserializationStream.primitiveMappings.get(desc.getName).getOrElse(throw e)
+ JavaDeserializationStream.primitiveMappings.getOrElse(desc.getName, throw e)
}
}
@@ -96,11 +95,11 @@ private[spark] class JavaSerializerInstance(
extends SerializerInstance {
override def serialize[T: ClassTag](t: T): ByteBuffer = {
- val bos = new ByteArrayOutputStream()
+ val bos = new ByteBufferOutputStream()
val out = serializeStream(bos)
out.writeObject(t)
out.close()
- ByteBuffer.wrap(bos.toByteArray)
+ bos.toByteBuffer
}
override def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
@@ -132,7 +131,7 @@ private[spark] class JavaSerializerInstance(
* :: DeveloperApi ::
* A Spark serializer that uses Java's built-in serialization.
*
- * Note that this serializer is not guaranteed to be wire-compatible across different versions of
+ * @note This serializer is not guaranteed to be wire-compatible across different versions of
* Spark. It is intended to be used to serialize/de-serialize data within a single
* Spark application.
*/
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index bc51d4f2820c..4f03e54e304f 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -17,33 +17,37 @@
package org.apache.spark.serializer
-import java.io.{EOFException, IOException, InputStream, OutputStream}
+import java.io._
import java.nio.ByteBuffer
+import java.util.Locale
import javax.annotation.Nullable
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
-import com.esotericsoftware.kryo.{Kryo, KryoException}
+import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSerializer}
import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
+import com.esotericsoftware.kryo.io.{UnsafeInput => KryoUnsafeInput, UnsafeOutput => KryoUnsafeOutput}
import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer}
import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator}
import org.apache.avro.generic.{GenericData, GenericRecord}
+import org.roaringbitmap.RoaringBitmap
import org.apache.spark._
import org.apache.spark.api.python.PythonBroadcast
-import org.apache.spark.broadcast.HttpBroadcast
+import org.apache.spark.internal.Logging
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus}
import org.apache.spark.storage._
-import org.apache.spark.util.{Utils, BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf}
-import org.apache.spark.util.collection.{BitSet, CompactBuffer}
+import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf, Utils}
+import org.apache.spark.util.collection.CompactBuffer
/**
- * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]].
+ * A Spark serializer that uses the
+ * Kryo serialization library.
*
- * Note that this serializer is not guaranteed to be wire-compatible across different versions of
+ * @note This serializer is not guaranteed to be wire-compatible across different versions of
* Spark. It is intended to be used to serialize/de-serialize data within a single
* Spark application.
*/
@@ -69,14 +73,23 @@ class KryoSerializer(conf: SparkConf)
private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true)
private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false)
- private val userRegistrator = conf.getOption("spark.kryo.registrator")
+ private val userRegistrators = conf.get("spark.kryo.registrator", "")
+ .split(',').map(_.trim)
+ .filter(!_.isEmpty)
private val classesToRegister = conf.get("spark.kryo.classesToRegister", "")
- .split(',')
+ .split(',').map(_.trim)
.filter(!_.isEmpty)
private val avroSchemas = conf.getAvroSchema
+ // whether to use unsafe based IO for serialization
+ private val useUnsafe = conf.getBoolean("spark.kryo.unsafe", false)
- def newKryoOutput(): KryoOutput = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize))
+ def newKryoOutput(): KryoOutput =
+ if (useUnsafe) {
+ new KryoUnsafeOutput(bufferSize, math.max(bufferSize, maxBufferSize))
+ } else {
+ new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize))
+ }
def newKryo(): Kryo = {
val instantiator = new EmptyScalaKryoInstantiator
@@ -93,6 +106,9 @@ class KryoSerializer(conf: SparkConf)
for (cls <- KryoSerializer.toRegister) {
kryo.register(cls)
}
+ for ((cls, ser) <- KryoSerializer.toRegisterSerializer) {
+ kryo.register(cls, ser)
+ }
// For results returned by asJavaIterable. See JavaIterableWrapperSerializer.
kryo.register(JavaIterableWrapperSerializer.wrapperClass, new JavaIterableWrapperSerializer)
@@ -101,7 +117,6 @@ class KryoSerializer(conf: SparkConf)
kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer())
kryo.register(classOf[SerializableConfiguration], new KryoJavaSerializer())
kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer())
- kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer())
kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer())
kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas))
@@ -115,7 +130,7 @@ class KryoSerializer(conf: SparkConf)
classesToRegister
.foreach { className => kryo.register(Class.forName(className, true, classLoader)) }
// Allow the user to register their own classes by setting spark.kryo.registrator.
- userRegistrator
+ userRegistrators
.map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator])
.foreach { reg => reg.registerClasses(kryo) }
// scalastyle:on classforname
@@ -160,6 +175,7 @@ class KryoSerializer(conf: SparkConf)
kryo.register(None.getClass)
kryo.register(Nil.getClass)
kryo.register(Utils.classForName("scala.collection.immutable.$colon$colon"))
+ kryo.register(Utils.classForName("scala.collection.immutable.Map$EmptyMap$"))
kryo.register(classOf[ArrayBuffer[Any]])
kryo.setClassLoader(classLoader)
@@ -167,7 +183,7 @@ class KryoSerializer(conf: SparkConf)
}
override def newInstance(): SerializerInstance = {
- new KryoSerializerInstance(this)
+ new KryoSerializerInstance(this, useUnsafe)
}
private[spark] override lazy val supportsRelocationOfSerializedObjects: Boolean = {
@@ -181,9 +197,12 @@ class KryoSerializer(conf: SparkConf)
private[spark]
class KryoSerializationStream(
serInstance: KryoSerializerInstance,
- outStream: OutputStream) extends SerializationStream {
+ outStream: OutputStream,
+ useUnsafe: Boolean) extends SerializationStream {
+
+ private[this] var output: KryoOutput =
+ if (useUnsafe) new KryoUnsafeOutput(outStream) else new KryoOutput(outStream)
- private[this] var output: KryoOutput = new KryoOutput(outStream)
private[this] var kryo: Kryo = serInstance.borrowKryo()
override def writeObject[T: ClassTag](t: T): SerializationStream = {
@@ -214,9 +233,12 @@ class KryoSerializationStream(
private[spark]
class KryoDeserializationStream(
serInstance: KryoSerializerInstance,
- inStream: InputStream) extends DeserializationStream {
+ inStream: InputStream,
+ useUnsafe: Boolean) extends DeserializationStream {
+
+ private[this] var input: KryoInput =
+ if (useUnsafe) new KryoUnsafeInput(inStream) else new KryoInput(inStream)
- private[this] var input: KryoInput = new KryoInput(inStream)
private[this] var kryo: Kryo = serInstance.borrowKryo()
override def readObject[T: ClassTag](): T = {
@@ -224,7 +246,8 @@ class KryoDeserializationStream(
kryo.readClassAndObject(input).asInstanceOf[T]
} catch {
// DeserializationStream uses the EOF exception to indicate stopping condition.
- case e: KryoException if e.getMessage.toLowerCase.contains("buffer underflow") =>
+ case e: KryoException
+ if e.getMessage.toLowerCase(Locale.ROOT).contains("buffer underflow") =>
throw new EOFException
}
}
@@ -243,8 +266,8 @@ class KryoDeserializationStream(
}
}
-private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
-
+private[spark] class KryoSerializerInstance(ks: KryoSerializer, useUnsafe: Boolean)
+ extends SerializerInstance {
/**
* A re-used [[Kryo]] instance. Methods will borrow this instance by calling `borrowKryo()`, do
* their work, then release the instance by calling `releaseKryo()`. Logically, this is a caching
@@ -283,7 +306,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
// Make these lazy vals to avoid creating a buffer unless we use them.
private lazy val output = ks.newKryoOutput()
- private lazy val input = new KryoInput()
+ private lazy val input = if (useUnsafe) new KryoUnsafeInput() else new KryoInput()
override def serialize[T: ClassTag](t: T): ByteBuffer = {
output.clear()
@@ -293,7 +316,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
} catch {
case e: KryoException if e.getMessage.startsWith("Buffer overflow") =>
throw new SparkException(s"Kryo serialization failed: ${e.getMessage}. To avoid this, " +
- "increase spark.kryoserializer.buffer.max value.")
+ "increase spark.kryoserializer.buffer.max value.", e)
} finally {
releaseKryo(kryo)
}
@@ -303,7 +326,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
override def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
val kryo = borrowKryo()
try {
- input.setBuffer(bytes.array)
+ input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining())
kryo.readClassAndObject(input).asInstanceOf[T]
} finally {
releaseKryo(kryo)
@@ -315,7 +338,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
val oldClassLoader = kryo.getClassLoader
try {
kryo.setClassLoader(loader)
- input.setBuffer(bytes.array)
+ input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining())
kryo.readClassAndObject(input).asInstanceOf[T]
} finally {
kryo.setClassLoader(oldClassLoader)
@@ -324,11 +347,11 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
}
override def serializeStream(s: OutputStream): SerializationStream = {
- new KryoSerializationStream(this, s)
+ new KryoSerializationStream(this, s, useUnsafe)
}
override def deserializeStream(s: InputStream): DeserializationStream = {
- new KryoDeserializationStream(this, s)
+ new KryoDeserializationStream(this, s, useUnsafe)
}
/**
@@ -352,7 +375,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
* serialization.
*/
trait KryoRegistrator {
- def registerClasses(kryo: Kryo)
+ def registerClasses(kryo: Kryo): Unit
}
private[serializer] object KryoSerializer {
@@ -362,15 +385,87 @@ private[serializer] object KryoSerializer {
classOf[StorageLevel],
classOf[CompressedMapStatus],
classOf[HighlyCompressedMapStatus],
- classOf[BitSet],
classOf[CompactBuffer[_]],
classOf[BlockManagerId],
+ classOf[Array[Boolean]],
classOf[Array[Byte]],
classOf[Array[Short]],
+ classOf[Array[Int]],
classOf[Array[Long]],
+ classOf[Array[Float]],
+ classOf[Array[Double]],
+ classOf[Array[Char]],
+ classOf[Array[String]],
+ classOf[Array[Array[String]]],
classOf[BoundedPriorityQueue[_]],
classOf[SparkConf]
)
+
+ private val toRegisterSerializer = Map[Class[_], KryoClassSerializer[_]](
+ classOf[RoaringBitmap] -> new KryoClassSerializer[RoaringBitmap]() {
+ override def write(kryo: Kryo, output: KryoOutput, bitmap: RoaringBitmap): Unit = {
+ bitmap.serialize(new KryoOutputObjectOutputBridge(kryo, output))
+ }
+ override def read(kryo: Kryo, input: KryoInput, cls: Class[RoaringBitmap]): RoaringBitmap = {
+ val ret = new RoaringBitmap
+ ret.deserialize(new KryoInputObjectInputBridge(kryo, input))
+ ret
+ }
+ }
+ )
+}
+
+/**
+ * This is a bridge class to wrap KryoInput as an InputStream and ObjectInput. It forwards all
+ * methods of InputStream and ObjectInput to KryoInput. It's usually helpful when an API expects
+ * an InputStream or ObjectInput but you want to use Kryo.
+ */
+private[spark] class KryoInputObjectInputBridge(
+ kryo: Kryo, input: KryoInput) extends FilterInputStream(input) with ObjectInput {
+ override def readLong(): Long = input.readLong()
+ override def readChar(): Char = input.readChar()
+ override def readFloat(): Float = input.readFloat()
+ override def readByte(): Byte = input.readByte()
+ override def readShort(): Short = input.readShort()
+ override def readUTF(): String = input.readString() // readString in kryo does utf8
+ override def readInt(): Int = input.readInt()
+ override def readUnsignedShort(): Int = input.readShortUnsigned()
+ override def skipBytes(n: Int): Int = {
+ input.skip(n)
+ n
+ }
+ override def readFully(b: Array[Byte]): Unit = input.read(b)
+ override def readFully(b: Array[Byte], off: Int, len: Int): Unit = input.read(b, off, len)
+ override def readLine(): String = throw new UnsupportedOperationException("readLine")
+ override def readBoolean(): Boolean = input.readBoolean()
+ override def readUnsignedByte(): Int = input.readByteUnsigned()
+ override def readDouble(): Double = input.readDouble()
+ override def readObject(): AnyRef = kryo.readClassAndObject(input)
+}
+
+/**
+ * This is a bridge class to wrap KryoOutput as an OutputStream and ObjectOutput. It forwards all
+ * methods of OutputStream and ObjectOutput to KryoOutput. It's usually helpful when an API expects
+ * an OutputStream or ObjectOutput but you want to use Kryo.
+ */
+private[spark] class KryoOutputObjectOutputBridge(
+ kryo: Kryo, output: KryoOutput) extends FilterOutputStream(output) with ObjectOutput {
+ override def writeFloat(v: Float): Unit = output.writeFloat(v)
+ // There is no "readChars" counterpart, except maybe "readLine", which is not supported
+ override def writeChars(s: String): Unit = throw new UnsupportedOperationException("writeChars")
+ override def writeDouble(v: Double): Unit = output.writeDouble(v)
+ override def writeUTF(s: String): Unit = output.writeString(s) // writeString in kryo does UTF8
+ override def writeShort(v: Int): Unit = output.writeShort(v)
+ override def writeInt(v: Int): Unit = output.writeInt(v)
+ override def writeBoolean(v: Boolean): Unit = output.writeBoolean(v)
+ override def write(b: Int): Unit = output.write(b)
+ override def write(b: Array[Byte]): Unit = output.write(b)
+ override def write(b: Array[Byte], off: Int, len: Int): Unit = output.write(b, off, len)
+ override def writeBytes(s: String): Unit = output.writeString(s)
+ override def writeChar(v: Int): Unit = output.writeChar(v.toChar)
+ override def writeLong(v: Long): Unit = output.writeLong(v)
+ override def writeByte(v: Int): Unit = output.writeByte(v)
+ override def writeObject(obj: AnyRef): Unit = kryo.writeClassAndObject(output, obj)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
index a1b1e1631eaf..5e7a98c8aa89 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
@@ -25,7 +25,7 @@ import scala.annotation.tailrec
import scala.collection.mutable
import scala.util.control.NonFatal
-import org.apache.spark.Logging
+import org.apache.spark.internal.Logging
private[spark] object SerializationDebugger extends Logging {
@@ -53,12 +53,13 @@ private[spark] object SerializationDebugger extends Logging {
/**
* Find the path leading to a not serializable object. This method is modeled after OpenJDK's
* serialization mechanism, and handles the following cases:
- * - primitives
- * - arrays of primitives
- * - arrays of non-primitive objects
- * - Serializable objects
- * - Externalizable objects
- * - writeReplace
+ *
+ * - primitives
+ * - arrays of primitives
+ * - arrays of non-primitive objects
+ * - Serializable objects
+ * - Externalizable objects
+ * - writeReplace
*
* It does not yet handle writeObject override, but that shouldn't be too hard to do either.
*/
@@ -154,7 +155,7 @@ private[spark] object SerializationDebugger extends Logging {
// If the object has been replaced using writeReplace(),
// then call visit() on it again to test its type again.
- if (!finalObj.eq(o)) {
+ if (finalObj.getClass != o.getClass) {
return visit(finalObj, s"writeReplace data (class: ${finalObj.getClass.getName})" :: stack)
}
@@ -264,8 +265,13 @@ private[spark] object SerializationDebugger extends Logging {
if (!desc.hasWriteReplaceMethod) {
(o, desc)
} else {
- // write place
- findObjectAndDescriptor(desc.invokeWriteReplace(o))
+ val replaced = desc.invokeWriteReplace(o)
+ // `writeReplace` recursion stops when the returned object has the same class.
+ if (replaced.getClass == o.getClass) {
+ (replaced, desc)
+ } else {
+ findObjectAndDescriptor(replaced)
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index bd2704dc8187..cb8b1cc07763 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -23,9 +23,8 @@ import javax.annotation.concurrent.NotThreadSafe
import scala.reflect.ClassTag
-import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.annotation.{DeveloperApi, Private}
-import org.apache.spark.util.{Utils, ByteBufferInputStream, NextIterator}
+import org.apache.spark.util.NextIterator
/**
* :: DeveloperApi ::
@@ -40,7 +39,7 @@ import org.apache.spark.util.{Utils, ByteBufferInputStream, NextIterator}
*
* 2. Java serialization interface.
*
- * Note that serializers are not required to be wire-compatible across different versions of Spark.
+ * @note Serializers are not required to be wire-compatible across different versions of Spark.
* They are intended to be used to serialize/de-serialize data within a single Spark application.
*/
@DeveloperApi
@@ -78,7 +77,7 @@ abstract class Serializer {
* position = 0
* serOut.write(obj1)
* serOut.flush()
- * position = # of bytes writen to stream so far
+ * position = # of bytes written to stream so far
* obj1Bytes = output[0:position-1]
* serOut.write(obj2)
* serOut.flush()
@@ -100,18 +99,6 @@ abstract class Serializer {
}
-@DeveloperApi
-object Serializer {
- def getSerializer(serializer: Serializer): Serializer = {
- if (serializer == null) SparkEnv.get.serializer else serializer
- }
-
- def getSerializer(serializer: Option[Serializer]): Serializer = {
- serializer.getOrElse(SparkEnv.get.serializer)
- }
-}
-
-
/**
* :: DeveloperApi ::
* An instance of a serializer, for use by one thread at a time.
@@ -138,7 +125,7 @@ abstract class SerializerInstance {
* A stream for writing serialized objects.
*/
@DeveloperApi
-abstract class SerializationStream {
+abstract class SerializationStream extends Closeable {
/** The most general-purpose method to write an object. */
def writeObject[T: ClassTag](t: T): SerializationStream
/** Writes the object representing the key of a key-value pair. */
@@ -146,7 +133,7 @@ abstract class SerializationStream {
/** Writes the object representing the value of a key-value pair. */
def writeValue[T: ClassTag](value: T): SerializationStream = writeObject(value)
def flush(): Unit
- def close(): Unit
+ override def close(): Unit
def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = {
while (iter.hasNext) {
@@ -162,14 +149,14 @@ abstract class SerializationStream {
* A stream for reading serialized objects.
*/
@DeveloperApi
-abstract class DeserializationStream {
+abstract class DeserializationStream extends Closeable {
/** The most general-purpose method to read an object. */
def readObject[T: ClassTag](): T
/** Reads the object representing the key of a key-value pair. */
def readKey[T: ClassTag](): T = readObject[T]()
/** Reads the object representing the value of a key-value pair. */
def readValue[T: ClassTag](): T = readObject[T]()
- def close(): Unit
+ override def close(): Unit
/**
* Read the elements of this stream through an iterator. This can only be called once, as
@@ -200,10 +187,9 @@ abstract class DeserializationStream {
try {
(readKey[Any](), readValue[Any]())
} catch {
- case eof: EOFException => {
+ case eof: EOFException =>
finished = true
null
- }
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
new file mode 100644
index 000000000000..1d4b05caaa14
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer
+
+import java.io.{BufferedInputStream, BufferedOutputStream, InputStream, OutputStream}
+import java.nio.ByteBuffer
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.SparkConf
+import org.apache.spark.io.CompressionCodec
+import org.apache.spark.security.CryptoStreamUtils
+import org.apache.spark.storage._
+import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
+
+/**
+ * Component which configures serialization, compression and encryption for various Spark
+ * components, including automatic selection of which [[Serializer]] to use for shuffles.
+ */
+private[spark] class SerializerManager(
+ defaultSerializer: Serializer,
+ conf: SparkConf,
+ encryptionKey: Option[Array[Byte]]) {
+
+ def this(defaultSerializer: Serializer, conf: SparkConf) = this(defaultSerializer, conf, None)
+
+ private[this] val kryoSerializer = new KryoSerializer(conf)
+
+ def setDefaultClassLoader(classLoader: ClassLoader): Unit = {
+ kryoSerializer.setDefaultClassLoader(classLoader)
+ }
+
+ private[this] val stringClassTag: ClassTag[String] = implicitly[ClassTag[String]]
+ private[this] val primitiveAndPrimitiveArrayClassTags: Set[ClassTag[_]] = {
+ val primitiveClassTags = Set[ClassTag[_]](
+ ClassTag.Boolean,
+ ClassTag.Byte,
+ ClassTag.Char,
+ ClassTag.Double,
+ ClassTag.Float,
+ ClassTag.Int,
+ ClassTag.Long,
+ ClassTag.Null,
+ ClassTag.Short
+ )
+ val arrayClassTags = primitiveClassTags.map(_.wrap)
+ primitiveClassTags ++ arrayClassTags
+ }
+
+ // Whether to compress broadcast variables that are stored
+ private[this] val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true)
+ // Whether to compress shuffle output that are stored
+ private[this] val compressShuffle = conf.getBoolean("spark.shuffle.compress", true)
+ // Whether to compress RDD partitions that are stored serialized
+ private[this] val compressRdds = conf.getBoolean("spark.rdd.compress", false)
+ // Whether to compress shuffle output temporarily spilled to disk
+ private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true)
+
+ /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
+ * the initialization of the compression codec until it is first used. The reason is that a Spark
+ * program could be using a user-defined codec in a third party jar, which is loaded in
+ * Executor.updateDependencies. When the BlockManager is initialized, user level jars hasn't been
+ * loaded yet. */
+ private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf)
+
+ def encryptionEnabled: Boolean = encryptionKey.isDefined
+
+ def canUseKryo(ct: ClassTag[_]): Boolean = {
+ primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag
+ }
+
+ // SPARK-18617: As feature in SPARK-13990 can not be applied to Spark Streaming now. The worst
+ // result is streaming job based on `Receiver` mode can not run on Spark 2.x properly. It may be
+ // a rational choice to close `kryo auto pick` feature for streaming in the first step.
+ def getSerializer(ct: ClassTag[_], autoPick: Boolean): Serializer = {
+ if (autoPick && canUseKryo(ct)) {
+ kryoSerializer
+ } else {
+ defaultSerializer
+ }
+ }
+
+ /**
+ * Pick the best serializer for shuffling an RDD of key-value pairs.
+ */
+ def getSerializer(keyClassTag: ClassTag[_], valueClassTag: ClassTag[_]): Serializer = {
+ if (canUseKryo(keyClassTag) && canUseKryo(valueClassTag)) {
+ kryoSerializer
+ } else {
+ defaultSerializer
+ }
+ }
+
+ private def shouldCompress(blockId: BlockId): Boolean = {
+ blockId match {
+ case _: ShuffleBlockId => compressShuffle
+ case _: BroadcastBlockId => compressBroadcast
+ case _: RDDBlockId => compressRdds
+ case _: TempLocalBlockId => compressShuffleSpill
+ case _: TempShuffleBlockId => compressShuffle
+ case _ => false
+ }
+ }
+
+ /**
+ * Wrap an input stream for encryption and compression
+ */
+ def wrapStream(blockId: BlockId, s: InputStream): InputStream = {
+ wrapForCompression(blockId, wrapForEncryption(s))
+ }
+
+ /**
+ * Wrap an output stream for encryption and compression
+ */
+ def wrapStream(blockId: BlockId, s: OutputStream): OutputStream = {
+ wrapForCompression(blockId, wrapForEncryption(s))
+ }
+
+ /**
+ * Wrap an input stream for encryption if shuffle encryption is enabled
+ */
+ def wrapForEncryption(s: InputStream): InputStream = {
+ encryptionKey
+ .map { key => CryptoStreamUtils.createCryptoInputStream(s, conf, key) }
+ .getOrElse(s)
+ }
+
+ /**
+ * Wrap an output stream for encryption if shuffle encryption is enabled
+ */
+ def wrapForEncryption(s: OutputStream): OutputStream = {
+ encryptionKey
+ .map { key => CryptoStreamUtils.createCryptoOutputStream(s, conf, key) }
+ .getOrElse(s)
+ }
+
+ /**
+ * Wrap an output stream for compression if block compression is enabled for its block type
+ */
+ def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
+ if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s
+ }
+
+ /**
+ * Wrap an input stream for compression if block compression is enabled for its block type
+ */
+ def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
+ if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
+ }
+
+ /** Serializes into a stream. */
+ def dataSerializeStream[T: ClassTag](
+ blockId: BlockId,
+ outputStream: OutputStream,
+ values: Iterator[T]): Unit = {
+ val byteStream = new BufferedOutputStream(outputStream)
+ val autoPick = !blockId.isInstanceOf[StreamBlockId]
+ val ser = getSerializer(implicitly[ClassTag[T]], autoPick).newInstance()
+ ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
+ }
+
+ /** Serializes into a chunked byte buffer. */
+ def dataSerialize[T: ClassTag](
+ blockId: BlockId,
+ values: Iterator[T]): ChunkedByteBuffer = {
+ dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]])
+ }
+
+ /** Serializes into a chunked byte buffer. */
+ def dataSerializeWithExplicitClassTag(
+ blockId: BlockId,
+ values: Iterator[_],
+ classTag: ClassTag[_]): ChunkedByteBuffer = {
+ val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate)
+ val byteStream = new BufferedOutputStream(bbos)
+ val autoPick = !blockId.isInstanceOf[StreamBlockId]
+ val ser = getSerializer(classTag, autoPick).newInstance()
+ ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
+ bbos.toChunkedByteBuffer
+ }
+
+ /**
+ * Deserializes an InputStream into an iterator of values and disposes of it when the end of
+ * the iterator is reached.
+ */
+ def dataDeserializeStream[T](
+ blockId: BlockId,
+ inputStream: InputStream)
+ (classTag: ClassTag[T]): Iterator[T] = {
+ val stream = new BufferedInputStream(inputStream)
+ val autoPick = !blockId.isInstanceOf[StreamBlockId]
+ getSerializer(classTag, autoPick)
+ .newInstance()
+ .deserializeStream(wrapForCompression(blockId, stream))
+ .asIterator.asInstanceOf[Iterator[T]]
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala
index b36c457d6d51..04e4cf88d706 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala
@@ -17,8 +17,7 @@
package org.apache.spark.shuffle
-import org.apache.spark.{ShuffleDependency, Aggregator, Partitioner}
-import org.apache.spark.serializer.Serializer
+import org.apache.spark.ShuffleDependency
/**
* A basic ShuffleHandle implementation that just captures registerShuffle's parameters.
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index b0abda4a81b8..c8d146030093 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -18,7 +18,8 @@
package org.apache.spark.shuffle
import org.apache.spark._
-import org.apache.spark.serializer.Serializer
+import org.apache.spark.internal.{config, Logging}
+import org.apache.spark.serializer.SerializerManager
import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
@@ -32,6 +33,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
startPartition: Int,
endPartition: Int,
context: TaskContext,
+ serializerManager: SerializerManager = SparkEnv.get.serializerManager,
blockManager: BlockManager = SparkEnv.get.blockManager,
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
extends ShuffleReader[K, C] with Logging {
@@ -40,24 +42,23 @@ private[spark] class BlockStoreShuffleReader[K, C](
/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
- val blockFetcherItr = new ShuffleBlockFetcherIterator(
+ val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
+ serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
- SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
+ SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
+ SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
+ SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
+ SparkEnv.get.conf.get(config.REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM),
+ SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
- // Wrap the streams for compression based on configuration
- val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
- blockManager.wrapForCompression(blockId, inputStream)
- }
-
- val ser = Serializer.getSerializer(dep.serializer)
- val serializerInstance = ser.newInstance()
+ val serializerInstance = dep.serializer.newInstance()
// Create a key/value iterator for each stream
- val recordIter = wrappedStreams.flatMap { wrappedStream =>
+ val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
@@ -65,13 +66,13 @@ private[spark] class BlockStoreShuffleReader[K, C](
}
// Update the context task metrics for each record read.
- val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
+ val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
- recordIter.map(record => {
+ recordIter.map { record =>
readMetrics.incRecordsRead(1)
record
- }),
- context.taskMetrics().updateShuffleReadMetrics())
+ },
+ context.taskMetrics().mergeShuffleReadMetrics())
// An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
@@ -96,15 +97,13 @@ private[spark] class BlockStoreShuffleReader[K, C](
// Sort the output if there is a sort ordering defined.
dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
- // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
- // the ExternalSorter won't spill to disk.
+ // Create an ExternalSorter to sort the data.
val sorter =
- new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser))
+ new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
- context.internalMetricsToAccumulators(
- InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
+ context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
index be184464e0ae..265a8acfa8d6 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
@@ -17,8 +17,8 @@
package org.apache.spark.shuffle
+import org.apache.spark.{FetchFailed, TaskContext, TaskFailedReason}
import org.apache.spark.storage.BlockManagerId
-import org.apache.spark.{FetchFailed, TaskEndReason}
import org.apache.spark.util.Utils
/**
@@ -26,6 +26,11 @@ import org.apache.spark.util.Utils
* back to DAGScheduler (through TaskEndReason) so we'd resubmit the previous stage.
*
* Note that bmAddress can be null.
+ *
+ * To prevent user code from hiding this fetch failure, in the constructor we call
+ * [[TaskContext.setFetchFailed()]]. This means that you *must* throw this exception immediately
+ * after creating it -- you cannot create it, check some condition, and then decide to ignore it
+ * (or risk triggering any other exceptions). See SPARK-19276.
*/
private[spark] class FetchFailedException(
bmAddress: BlockManagerId,
@@ -45,7 +50,13 @@ private[spark] class FetchFailedException(
this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause)
}
- def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId,
+ // SPARK-19276. We set the fetch failure in the task context, so that even if there is user-code
+ // which intercepts this exception (possibly wrapping it), the Executor can still tell there was
+ // a fetch failure, and send the correct error msg back to the driver. We wrap with an Option
+ // because the TaskContext is not defined in some test cases.
+ Option(TaskContext.get()).map(_.setFetchFailed(this))
+
+ def toTaskFailedReason: TaskFailedReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId,
Utils.exceptionString(this))
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
deleted file mode 100644
index cd253a78c2b1..000000000000
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
+++ /dev/null
@@ -1,150 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle
-
-import java.util.concurrent.ConcurrentLinkedQueue
-
-import scala.collection.JavaConverters._
-
-import org.apache.spark.{Logging, SparkConf, SparkEnv}
-import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
-import org.apache.spark.network.netty.SparkTransportConf
-import org.apache.spark.serializer.Serializer
-import org.apache.spark.storage._
-import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
-
-/** A group of writers for a ShuffleMapTask, one writer per reducer. */
-private[spark] trait ShuffleWriterGroup {
- val writers: Array[DiskBlockObjectWriter]
-
- /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */
- def releaseWriters(success: Boolean)
-}
-
-/**
- * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one file
- * per reducer.
- */
-// Note: Changes to the format in this file should be kept in sync with
-// org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getHashBasedShuffleBlockData().
-private[spark] class FileShuffleBlockResolver(conf: SparkConf)
- extends ShuffleBlockResolver with Logging {
-
- private val transportConf = SparkTransportConf.fromSparkConf(conf)
-
- private lazy val blockManager = SparkEnv.get.blockManager
-
- // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
- private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
-
- /**
- * Contains all the state related to a particular shuffle.
- */
- private class ShuffleState(val numReducers: Int) {
- /**
- * The mapIds of all map tasks completed on this Executor for this shuffle.
- */
- val completedMapTasks = new ConcurrentLinkedQueue[Int]()
- }
-
- private val shuffleStates = new TimeStampedHashMap[ShuffleId, ShuffleState]
-
- private val metadataCleaner =
- new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup, conf)
-
- /**
- * Get a ShuffleWriterGroup for the given map task, which will register it as complete
- * when the writers are closed successfully
- */
- def forMapTask(shuffleId: Int, mapId: Int, numReducers: Int, serializer: Serializer,
- writeMetrics: ShuffleWriteMetrics): ShuffleWriterGroup = {
- new ShuffleWriterGroup {
- shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numReducers))
- private val shuffleState = shuffleStates(shuffleId)
-
- val openStartTime = System.nanoTime
- val serializerInstance = serializer.newInstance()
- val writers: Array[DiskBlockObjectWriter] = {
- Array.tabulate[DiskBlockObjectWriter](numReducers) { bucketId =>
- val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
- val blockFile = blockManager.diskBlockManager.getFile(blockId)
- // Because of previous failures, the shuffle file may already exist on this machine.
- // If so, remove it.
- if (blockFile.exists) {
- if (blockFile.delete()) {
- logInfo(s"Removed existing shuffle file $blockFile")
- } else {
- logWarning(s"Failed to remove existing shuffle file $blockFile")
- }
- }
- blockManager.getDiskWriter(blockId, blockFile, serializerInstance, bufferSize,
- writeMetrics)
- }
- }
- // Creating the file to write to and creating a disk writer both involve interacting with
- // the disk, so should be included in the shuffle write time.
- writeMetrics.incShuffleWriteTime(System.nanoTime - openStartTime)
-
- override def releaseWriters(success: Boolean) {
- shuffleState.completedMapTasks.add(mapId)
- }
- }
- }
-
- override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
- val file = blockManager.diskBlockManager.getFile(blockId)
- new FileSegmentManagedBuffer(transportConf, file, 0, file.length)
- }
-
- /** Remove all the blocks / files and metadata related to a particular shuffle. */
- def removeShuffle(shuffleId: ShuffleId): Boolean = {
- // Do not change the ordering of this, if shuffleStates should be removed only
- // after the corresponding shuffle blocks have been removed
- val cleaned = removeShuffleBlocks(shuffleId)
- shuffleStates.remove(shuffleId)
- cleaned
- }
-
- /** Remove all the blocks / files related to a particular shuffle. */
- private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = {
- shuffleStates.get(shuffleId) match {
- case Some(state) =>
- for (mapId <- state.completedMapTasks.asScala; reduceId <- 0 until state.numReducers) {
- val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId)
- val file = blockManager.diskBlockManager.getFile(blockId)
- if (!file.delete()) {
- logWarning(s"Error deleting ${file.getPath()}")
- }
- }
- logInfo("Deleted all files for shuffle " + shuffleId)
- true
- case None =>
- logInfo("Could not find files for shuffle " + shuffleId + " for deleting")
- false
- }
- }
-
- private def cleanup(cleanupTime: Long) {
- shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId))
- }
-
- override def stop() {
- metadataCleaner.cancel()
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index 5e4c2b5d0a5c..449f60273b42 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -18,17 +18,18 @@
package org.apache.spark.shuffle
import java.io._
+import java.nio.channels.Channels
+import java.nio.file.Files
-import com.google.common.io.ByteStreams
-
-import org.apache.spark.{SparkConf, SparkEnv, Logging}
+import org.apache.spark.{SparkConf, SparkEnv}
+import org.apache.spark.internal.Logging
+import org.apache.spark.io.NioBufferedFileInputStream
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
import org.apache.spark.storage._
import org.apache.spark.util.Utils
-import IndexShuffleBlockResolver.NOOP_REDUCE_ID
-
/**
* Create and maintain the shuffle blocks' mapping between logic block and physical file location.
* Data of shuffle blocks from the same map task are stored in a single consolidated data file.
@@ -40,12 +41,15 @@ import IndexShuffleBlockResolver.NOOP_REDUCE_ID
*/
// Note: Changes to the format in this file should be kept in sync with
// org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getSortBasedShuffleBlockData().
-private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleBlockResolver
+private[spark] class IndexShuffleBlockResolver(
+ conf: SparkConf,
+ _blockManager: BlockManager = null)
+ extends ShuffleBlockResolver
with Logging {
- private lazy val blockManager = SparkEnv.get.blockManager
+ private lazy val blockManager = Option(_blockManager).getOrElse(SparkEnv.get.blockManager)
- private val transportConf = SparkTransportConf.fromSparkConf(conf)
+ private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle")
def getDataFile(shuffleId: Int, mapId: Int): File = {
blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID))
@@ -57,7 +61,7 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB
/**
* Remove data file and index file that contain the output data from one map.
- * */
+ */
def removeDataByMap(shuffleId: Int, mapId: Int): Unit = {
var file = getDataFile(shuffleId, mapId)
if (file.exists()) {
@@ -74,24 +78,116 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB
}
}
+ /**
+ * Check whether the given index and data files match each other.
+ * If so, return the partition lengths in the data file. Otherwise return null.
+ */
+ private def checkIndexAndDataFile(index: File, data: File, blocks: Int): Array[Long] = {
+ // the index file should have `block + 1` longs as offset.
+ if (index.length() != (blocks + 1) * 8L) {
+ return null
+ }
+ val lengths = new Array[Long](blocks)
+ // Read the lengths of blocks
+ val in = try {
+ new DataInputStream(new NioBufferedFileInputStream(index))
+ } catch {
+ case e: IOException =>
+ return null
+ }
+ try {
+ // Convert the offsets into lengths of each block
+ var offset = in.readLong()
+ if (offset != 0L) {
+ return null
+ }
+ var i = 0
+ while (i < blocks) {
+ val off = in.readLong()
+ lengths(i) = off - offset
+ offset = off
+ i += 1
+ }
+ } catch {
+ case e: IOException =>
+ return null
+ } finally {
+ in.close()
+ }
+
+ // the size of data file should match with index file
+ if (data.length() == lengths.sum) {
+ lengths
+ } else {
+ null
+ }
+ }
+
/**
* Write an index file with the offsets of each block, plus a final offset at the end for the
* end of the output file. This will be used by getBlockData to figure out where each block
* begins and ends.
- * */
- def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = {
+ *
+ * It will commit the data and index file as an atomic operation, use the existing ones, or
+ * replace them with new ones.
+ *
+ * Note: the `lengths` will be updated to match the existing index file if use the existing ones.
+ */
+ def writeIndexFileAndCommit(
+ shuffleId: Int,
+ mapId: Int,
+ lengths: Array[Long],
+ dataTmp: File): Unit = {
val indexFile = getIndexFile(shuffleId, mapId)
- val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile)))
- Utils.tryWithSafeFinally {
- // We take in lengths of each block, need to convert it to offsets.
- var offset = 0L
- out.writeLong(offset)
- for (length <- lengths) {
- offset += length
+ val indexTmp = Utils.tempFileWith(indexFile)
+ try {
+ val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp)))
+ Utils.tryWithSafeFinally {
+ // We take in lengths of each block, need to convert it to offsets.
+ var offset = 0L
out.writeLong(offset)
+ for (length <- lengths) {
+ offset += length
+ out.writeLong(offset)
+ }
+ } {
+ out.close()
+ }
+
+ val dataFile = getDataFile(shuffleId, mapId)
+ // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure
+ // the following check and rename are atomic.
+ synchronized {
+ val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length)
+ if (existingLengths != null) {
+ // Another attempt for the same task has already written our map outputs successfully,
+ // so just use the existing partition lengths and delete our temporary map outputs.
+ System.arraycopy(existingLengths, 0, lengths, 0, lengths.length)
+ if (dataTmp != null && dataTmp.exists()) {
+ dataTmp.delete()
+ }
+ indexTmp.delete()
+ } else {
+ // This is the first successful attempt in writing the map outputs for this task,
+ // so override any existing index and data files with the ones we wrote.
+ if (indexFile.exists()) {
+ indexFile.delete()
+ }
+ if (dataFile.exists()) {
+ dataFile.delete()
+ }
+ if (!indexTmp.renameTo(indexFile)) {
+ throw new IOException("fail to rename file " + indexTmp + " to " + indexFile)
+ }
+ if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) {
+ throw new IOException("fail to rename file " + dataTmp + " to " + dataFile)
+ }
+ }
+ }
+ } finally {
+ if (indexTmp.exists() && !indexTmp.delete()) {
+ logError(s"Failed to delete temporary index file at ${indexTmp.getAbsolutePath}")
}
- } {
- out.close()
}
}
@@ -100,11 +196,24 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB
// find out the consolidated file, then the offset within that from our index
val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)
- val in = new DataInputStream(new FileInputStream(indexFile))
+ // SPARK-22982: if this FileInputStream's position is seeked forward by another piece of code
+ // which is incorrectly using our file descriptor then this code will fetch the wrong offsets
+ // (which may cause a reducer to be sent a different reducer's data). The explicit position
+ // checks added here were a useful debugging aid during SPARK-22982 and may help prevent this
+ // class of issue from re-occurring in the future which is why they are left here even though
+ // SPARK-22982 is fixed.
+ val channel = Files.newByteChannel(indexFile.toPath)
+ channel.position(blockId.reduceId * 8L)
+ val in = new DataInputStream(Channels.newInputStream(channel))
try {
- ByteStreams.skipFully(in, blockId.reduceId * 8)
val offset = in.readLong()
val nextOffset = in.readLong()
+ val actualPosition = channel.position()
+ val expectedPosition = blockId.reduceId * 8L + 16
+ if (actualPosition != expectedPosition) {
+ throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " +
+ s"expected $expectedPosition but actual position was $actualPosition.")
+ }
new FileSegmentManagedBuffer(
transportConf,
getDataFile(blockId.shuffleId, blockId.mapId),
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala
index 4342b0d598b1..d1ecbc1bf017 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala
@@ -17,7 +17,6 @@
package org.apache.spark.shuffle
-import java.nio.ByteBuffer
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.storage.ShuffleBlockId
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
index 978366d1a1d1..4ea8a7120a9c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
@@ -17,7 +17,7 @@
package org.apache.spark.shuffle
-import org.apache.spark.{TaskContext, ShuffleDependency}
+import org.apache.spark.{ShuffleDependency, TaskContext}
/**
* Pluggable interface for shuffle systems. A ShuffleManager is created in SparkEnv on the driver
@@ -28,6 +28,7 @@ import org.apache.spark.{TaskContext, ShuffleDependency}
* boolean isDriver as parameters.
*/
private[spark] trait ShuffleManager {
+
/**
* Register a shuffle with the manager and obtain a handle for it to pass to tasks.
*/
@@ -50,9 +51,9 @@ private[spark] trait ShuffleManager {
context: TaskContext): ShuffleReader[K, C]
/**
- * Remove a shuffle's metadata from the ShuffleManager.
- * @return true if the metadata removed successfully, otherwise false.
- */
+ * Remove a shuffle's metadata from the ShuffleManager.
+ * @return true if the metadata removed successfully, otherwise false.
+ */
def unregisterShuffle(shuffleId: Int): Boolean
/**
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
deleted file mode 100644
index d2e2fc4c110a..000000000000
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.hash
-
-import org.apache.spark._
-import org.apache.spark.shuffle._
-
-/**
- * A ShuffleManager using hashing, that creates one output file per reduce partition on each
- * mapper (possibly reusing these across waves of tasks).
- */
-private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
-
- if (!conf.getBoolean("spark.shuffle.spill", true)) {
- logWarning(
- "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." +
- " Shuffle will continue to spill to disk when necessary.")
- }
-
- private val fileShuffleBlockResolver = new FileShuffleBlockResolver(conf)
-
- /* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */
- override def registerShuffle[K, V, C](
- shuffleId: Int,
- numMaps: Int,
- dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
- new BaseShuffleHandle(shuffleId, numMaps, dependency)
- }
-
- /**
- * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
- * Called on executors by reduce tasks.
- */
- override def getReader[K, C](
- handle: ShuffleHandle,
- startPartition: Int,
- endPartition: Int,
- context: TaskContext): ShuffleReader[K, C] = {
- new BlockStoreShuffleReader(
- handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
- }
-
- /** Get a writer for a given partition. Called on executors by map tasks. */
- override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
- : ShuffleWriter[K, V] = {
- new HashShuffleWriter(
- shuffleBlockResolver, handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context)
- }
-
- /** Remove a shuffle's metadata from the ShuffleManager. */
- override def unregisterShuffle(shuffleId: Int): Boolean = {
- shuffleBlockResolver.removeShuffle(shuffleId)
- }
-
- override def shuffleBlockResolver: FileShuffleBlockResolver = {
- fileShuffleBlockResolver
- }
-
- /** Shut down this ShuffleManager. */
- override def stop(): Unit = {
- shuffleBlockResolver.stop()
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
deleted file mode 100644
index 41df70c602c3..000000000000
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ /dev/null
@@ -1,119 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.hash
-
-import org.apache.spark._
-import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.scheduler.MapStatus
-import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle._
-import org.apache.spark.storage.DiskBlockObjectWriter
-
-private[spark] class HashShuffleWriter[K, V](
- shuffleBlockResolver: FileShuffleBlockResolver,
- handle: BaseShuffleHandle[K, V, _],
- mapId: Int,
- context: TaskContext)
- extends ShuffleWriter[K, V] with Logging {
-
- private val dep = handle.dependency
- private val numOutputSplits = dep.partitioner.numPartitions
- private val metrics = context.taskMetrics
-
- // Are we in the process of stopping? Because map tasks can call stop() with success = true
- // and then call stop() with success = false if they get an exception, we want to make sure
- // we don't try deleting files, etc twice.
- private var stopping = false
-
- private val writeMetrics = new ShuffleWriteMetrics()
- metrics.shuffleWriteMetrics = Some(writeMetrics)
-
- private val blockManager = SparkEnv.get.blockManager
- private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null))
- private val shuffle = shuffleBlockResolver.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser,
- writeMetrics)
-
- /** Write a bunch of records to this task's output */
- override def write(records: Iterator[Product2[K, V]]): Unit = {
- val iter = if (dep.aggregator.isDefined) {
- if (dep.mapSideCombine) {
- dep.aggregator.get.combineValuesByKey(records, context)
- } else {
- records
- }
- } else {
- require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
- records
- }
-
- for (elem <- iter) {
- val bucketId = dep.partitioner.getPartition(elem._1)
- shuffle.writers(bucketId).write(elem._1, elem._2)
- }
- }
-
- /** Close this writer, passing along whether the map completed */
- override def stop(initiallySuccess: Boolean): Option[MapStatus] = {
- var success = initiallySuccess
- try {
- if (stopping) {
- return None
- }
- stopping = true
- if (success) {
- try {
- Some(commitWritesAndBuildStatus())
- } catch {
- case e: Exception =>
- success = false
- revertWrites()
- throw e
- }
- } else {
- revertWrites()
- None
- }
- } finally {
- // Release the writers back to the shuffle block manager.
- if (shuffle != null && shuffle.writers != null) {
- try {
- shuffle.releaseWriters(success)
- } catch {
- case e: Exception => logError("Failed to release shuffle writers", e)
- }
- }
- }
- }
-
- private def commitWritesAndBuildStatus(): MapStatus = {
- // Commit the writes. Get the size of each bucket block (total block size).
- val sizes: Array[Long] = shuffle.writers.map { writer: DiskBlockObjectWriter =>
- writer.commitAndClose()
- writer.fileSegment().length
- }
- MapStatus(blockManager.shuffleServerId, sizes)
- }
-
- private def revertWrites(): Unit = {
- if (shuffle != null && shuffle.writers != null) {
- for (writer <- shuffle.writers) {
- writer.revertPartialWritesAndClose()
- }
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index 66b6bbc61fe8..bfb4dc698e32 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -20,7 +20,7 @@ package org.apache.spark.shuffle.sort
import java.util.concurrent.ConcurrentHashMap
import org.apache.spark._
-import org.apache.spark.serializer.Serializer
+import org.apache.spark.internal.Logging
import org.apache.spark.shuffle._
/**
@@ -82,13 +82,13 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf)
/**
- * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
+ * Obtains a [[ShuffleHandle]] to pass to tasks.
*/
override def registerShuffle[K, V, C](
shuffleId: Int,
numMaps: Int,
dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
- if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf, dependency)) {
+ if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) {
// If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
// need map-side aggregation, then write numPartitions files directly and just concatenate
// them at the end. This avoids doing serialization and deserialization twice to merge
@@ -184,10 +184,9 @@ private[spark] object SortShuffleManager extends Logging {
def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _]): Boolean = {
val shufId = dependency.shuffleId
val numPartitions = dependency.partitioner.numPartitions
- val serializer = Serializer.getSerializer(dependency.serializer)
- if (!serializer.supportsRelocationOfSerializedObjects) {
+ if (!dependency.serializer.supportsRelocationOfSerializedObjects) {
log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " +
- s"${serializer.getClass.getName}, does not support object relocation")
+ s"${dependency.serializer.getClass.getName}, does not support object relocation")
false
} else if (dependency.aggregator.isDefined) {
log.debug(
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 808317b017a0..636b88e792bf 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -18,10 +18,11 @@
package org.apache.spark.shuffle.sort
import org.apache.spark._
-import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.MapStatus
-import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
+import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter}
import org.apache.spark.storage.ShuffleBlockId
+import org.apache.spark.util.Utils
import org.apache.spark.util.collection.ExternalSorter
private[spark] class SortShuffleWriter[K, V, C](
@@ -44,8 +45,7 @@ private[spark] class SortShuffleWriter[K, V, C](
private var mapStatus: MapStatus = null
- private val writeMetrics = new ShuffleWriteMetrics()
- context.taskMetrics.shuffleWriteMetrics = Some(writeMetrics)
+ private val writeMetrics = context.taskMetrics().shuffleWriteMetrics
/** Write a bunch of records to this task's output */
override def write(records: Iterator[Product2[K, V]]): Unit = {
@@ -65,12 +65,18 @@ private[spark] class SortShuffleWriter[K, V, C](
// Don't bother including the time to open the merged output file in the shuffle write time,
// because it just opens a single file, so is typically too fast to measure accurately
// (see SPARK-3570).
- val outputFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
- val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
- val partitionLengths = sorter.writePartitionedFile(blockId, outputFile)
- shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
-
- mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
+ val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
+ val tmp = Utils.tempFileWith(output)
+ try {
+ val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
+ val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
+ shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
+ mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
+ } finally {
+ if (tmp.exists() && !tmp.delete()) {
+ logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
+ }
+ }
}
/** Close this writer, passing along whether the map completed */
@@ -83,8 +89,6 @@ private[spark] class SortShuffleWriter[K, V, C](
if (success) {
return Option(mapStatus)
} else {
- // The map task failed, so delete our output data.
- shuffleBlockResolver.removeDataByMap(dep.shuffleId, mapId)
return None
}
} finally {
@@ -92,8 +96,7 @@ private[spark] class SortShuffleWriter[K, V, C](
if (sorter != null) {
val startTime = System.nanoTime()
sorter.stop()
- context.taskMetrics.shuffleWriteMetrics.foreach(
- _.incShuffleWriteTime(System.nanoTime - startTime))
+ writeMetrics.incWriteTime(System.nanoTime - startTime)
sorter = null
}
}
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllExecutorListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllExecutorListResource.scala
new file mode 100644
index 000000000000..01f2a18122e6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllExecutorListResource.scala
@@ -0,0 +1,41 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+package org.apache.spark.status.api.v1
+
+import javax.ws.rs.{GET, Produces}
+import javax.ws.rs.core.MediaType
+
+import org.apache.spark.ui.SparkUI
+import org.apache.spark.ui.exec.ExecutorsPage
+
+@Produces(Array(MediaType.APPLICATION_JSON))
+private[v1] class AllExecutorListResource(ui: SparkUI) {
+
+ @GET
+ def executorList(): Seq[ExecutorSummary] = {
+ val listener = ui.executorsListener
+ listener.synchronized {
+ // The follow codes should be protected by `listener` to make sure no executors will be
+ // removed before we query their status. See SPARK-12784.
+ (0 until listener.activeStorageStatusList.size).map { statusId =>
+ ExecutorsPage.getExecInfo(listener, statusId, isActive = true)
+ } ++ (0 until listener.deadStorageStatusList.size).map { statusId =>
+ ExecutorsPage.getExecInfo(listener, statusId, isActive = false)
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala
index 5783df5d8220..d0d9ef1165e8 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala
@@ -68,7 +68,12 @@ private[v1] object AllJobsResource {
listener: JobProgressListener,
includeStageDetails: Boolean): JobData = {
listener.synchronized {
- val lastStageInfo = listener.stageIdToInfo.get(job.stageIds.max)
+ val lastStageInfo =
+ if (job.stageIds.isEmpty) {
+ None
+ } else {
+ listener.stageIdToInfo.get(job.stageIds.max)
+ }
val lastStageData = lastStageInfo.flatMap { s =>
listener.stageIdToData.get((s.stageId, s.attemptId))
}
@@ -86,7 +91,7 @@ private[v1] object AllJobsResource {
numTasks = job.numTasks,
numActiveTasks = job.numActiveTasks,
numCompletedTasks = job.numCompletedTasks,
- numSkippedTasks = job.numCompletedTasks,
+ numSkippedTasks = job.numSkippedTasks,
numFailedTasks = job.numFailedTasks,
numActiveStages = job.numActiveStages,
numCompletedStages = job.completedStageIndices.size,
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala
index 645ede26a087..1279b281ad8d 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala
@@ -28,7 +28,7 @@ private[v1] class AllRDDResource(ui: SparkUI) {
@GET
def rddList(): Seq[RDDStorageInfo] = {
- val storageStatusList = ui.storageListener.storageStatusList
+ val storageStatusList = ui.storageListener.activeStorageStatusList
val rddInfos = ui.storageListener.rddInfoList
rddInfos.map{rddInfo =>
AllRDDResource.getRDDStorageInfo(rddInfo.id, rddInfo, storageStatusList,
@@ -44,7 +44,7 @@ private[spark] object AllRDDResource {
rddId: Int,
listener: StorageListener,
includeDetails: Boolean): Option[RDDStorageInfo] = {
- val storageStatusList = listener.storageStatusList
+ val storageStatusList = listener.activeStorageStatusList
listener.rddInfoList.find { _.id == rddId }.map { rddInfo =>
getRDDStorageInfo(rddId, rddInfo, storageStatusList, includeDetails)
}
@@ -61,7 +61,7 @@ private[spark] object AllRDDResource {
.flatMap { _.rddBlocksById(rddId) }
.sortWith { _._1.name < _._1.name }
.map { case (blockId, status) =>
- (blockId, status, blockLocations.get(blockId).getOrElse(Seq[String]("Unknown")))
+ (blockId, status, blockLocations.getOrElse(blockId, Seq[String]("Unknown")))
}
val dataDistribution = if (includeDetails) {
@@ -70,7 +70,13 @@ private[spark] object AllRDDResource {
address = status.blockManagerId.hostPort,
memoryUsed = status.memUsedByRdd(rddId),
memoryRemaining = status.memRemaining,
- diskUsed = status.diskUsedByRdd(rddId)
+ diskUsed = status.diskUsedByRdd(rddId),
+ onHeapMemoryUsed = Some(
+ if (!rddInfo.storageLevel.useOffHeap) status.memUsedByRdd(rddId) else 0L),
+ offHeapMemoryUsed = Some(
+ if (rddInfo.storageLevel.useOffHeap) status.memUsedByRdd(rddId) else 0L),
+ onHeapMemoryRemaining = status.onHeapMemRemaining,
+ offHeapMemoryRemaining = status.offHeapMemRemaining
) } )
} else {
None
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala
index 24a0b5220695..1818935392eb 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala
@@ -17,13 +17,13 @@
package org.apache.spark.status.api.v1
import java.util.{Arrays, Date, List => JList}
-import javax.ws.rs.{GET, PathParam, Produces, QueryParam}
+import javax.ws.rs.{GET, Produces, QueryParam}
import javax.ws.rs.core.MediaType
-import org.apache.spark.executor.{InputMetrics => InternalInputMetrics, OutputMetrics => InternalOutputMetrics, ShuffleReadMetrics => InternalShuffleReadMetrics, ShuffleWriteMetrics => InternalShuffleWriteMetrics, TaskMetrics => InternalTaskMetrics}
import org.apache.spark.scheduler.{AccumulableInfo => InternalAccumulableInfo, StageInfo}
import org.apache.spark.ui.SparkUI
import org.apache.spark.ui.jobs.UIData.{StageUIData, TaskUIData}
+import org.apache.spark.ui.jobs.UIData.{InputMetricsUIData => InternalInputMetrics, OutputMetricsUIData => InternalOutputMetrics, ShuffleReadMetricsUIData => InternalShuffleReadMetrics, ShuffleWriteMetricsUIData => InternalShuffleWriteMetrics, TaskMetricsUIData => InternalTaskMetrics}
import org.apache.spark.util.Distribution
@Produces(Array(MediaType.APPLICATION_JSON))
@@ -59,6 +59,15 @@ private[v1] object AllStagesResource {
stageUiData: StageUIData,
includeDetails: Boolean): StageData = {
+ val taskLaunchTimes = stageUiData.taskData.values.map(_.taskInfo.launchTime).filter(_ > 0)
+
+ val firstTaskLaunchedTime: Option[Date] =
+ if (taskLaunchTimes.nonEmpty) {
+ Some(new Date(taskLaunchTimes.min))
+ } else {
+ None
+ }
+
val taskData = if (includeDetails) {
Some(stageUiData.taskData.map { case (k, v) => k -> convertTaskData(v) } )
} else {
@@ -92,6 +101,10 @@ private[v1] object AllStagesResource {
numCompleteTasks = stageUiData.numCompleteTasks,
numFailedTasks = stageUiData.numFailedTasks,
executorRunTime = stageUiData.executorRunTime,
+ executorCpuTime = stageUiData.executorCpuTime,
+ submissionTime = stageInfo.submissionTime.map(new Date(_)),
+ firstTaskLaunchedTime,
+ completionTime = stageInfo.completionTime.map(new Date(_)),
inputBytes = stageUiData.inputBytes,
inputRecords = stageUiData.inputRecords,
outputBytes = stageUiData.outputBytes,
@@ -129,13 +142,15 @@ private[v1] object AllStagesResource {
index = uiData.taskInfo.index,
attempt = uiData.taskInfo.attemptNumber,
launchTime = new Date(uiData.taskInfo.launchTime),
+ duration = uiData.taskDuration,
executorId = uiData.taskInfo.executorId,
host = uiData.taskInfo.host,
+ status = uiData.taskInfo.status,
taskLocality = uiData.taskInfo.taskLocality.toString(),
speculative = uiData.taskInfo.speculative,
accumulatorUpdates = uiData.taskInfo.accumulables.map { convertAccumulableInfo },
errorMessage = uiData.errorMessage,
- taskMetrics = uiData.taskMetrics.map { convertUiTaskMetrics }
+ taskMetrics = uiData.metrics.map { convertUiTaskMetrics }
)
}
@@ -143,7 +158,7 @@ private[v1] object AllStagesResource {
allTaskData: Iterable[TaskUIData],
quantiles: Array[Double]): TaskMetricDistributions = {
- val rawMetrics = allTaskData.flatMap{_.taskMetrics}.toSeq
+ val rawMetrics = allTaskData.flatMap{_.metrics}.toSeq
def metricQuantiles(f: InternalTaskMetrics => Double): IndexedSeq[Double] =
Distribution(rawMetrics.map { d => f(d) }).get.getQuantiles(quantiles)
@@ -155,35 +170,32 @@ private[v1] object AllStagesResource {
// to make it a little easier to deal w/ all of the nested options. Mostly it lets us just
// implement one "build" method, which just builds the quantiles for each field.
- val inputMetrics: Option[InputMetricDistributions] =
+ val inputMetrics: InputMetricDistributions =
new MetricHelper[InternalInputMetrics, InputMetricDistributions](rawMetrics, quantiles) {
- def getSubmetrics(raw: InternalTaskMetrics): Option[InternalInputMetrics] = {
- raw.inputMetrics
- }
+ def getSubmetrics(raw: InternalTaskMetrics): InternalInputMetrics = raw.inputMetrics
def build: InputMetricDistributions = new InputMetricDistributions(
bytesRead = submetricQuantiles(_.bytesRead),
recordsRead = submetricQuantiles(_.recordsRead)
)
- }.metricOption
+ }.build
- val outputMetrics: Option[OutputMetricDistributions] =
+ val outputMetrics: OutputMetricDistributions =
new MetricHelper[InternalOutputMetrics, OutputMetricDistributions](rawMetrics, quantiles) {
- def getSubmetrics(raw: InternalTaskMetrics): Option[InternalOutputMetrics] = {
- raw.outputMetrics
- }
+ def getSubmetrics(raw: InternalTaskMetrics): InternalOutputMetrics = raw.outputMetrics
+
def build: OutputMetricDistributions = new OutputMetricDistributions(
bytesWritten = submetricQuantiles(_.bytesWritten),
recordsWritten = submetricQuantiles(_.recordsWritten)
)
- }.metricOption
+ }.build
- val shuffleReadMetrics: Option[ShuffleReadMetricDistributions] =
+ val shuffleReadMetrics: ShuffleReadMetricDistributions =
new MetricHelper[InternalShuffleReadMetrics, ShuffleReadMetricDistributions](rawMetrics,
quantiles) {
- def getSubmetrics(raw: InternalTaskMetrics): Option[InternalShuffleReadMetrics] = {
+ def getSubmetrics(raw: InternalTaskMetrics): InternalShuffleReadMetrics =
raw.shuffleReadMetrics
- }
+
def build: ShuffleReadMetricDistributions = new ShuffleReadMetricDistributions(
readBytes = submetricQuantiles(_.totalBytesRead),
readRecords = submetricQuantiles(_.recordsRead),
@@ -193,25 +205,27 @@ private[v1] object AllStagesResource {
totalBlocksFetched = submetricQuantiles(_.totalBlocksFetched),
fetchWaitTime = submetricQuantiles(_.fetchWaitTime)
)
- }.metricOption
+ }.build
- val shuffleWriteMetrics: Option[ShuffleWriteMetricDistributions] =
+ val shuffleWriteMetrics: ShuffleWriteMetricDistributions =
new MetricHelper[InternalShuffleWriteMetrics, ShuffleWriteMetricDistributions](rawMetrics,
quantiles) {
- def getSubmetrics(raw: InternalTaskMetrics): Option[InternalShuffleWriteMetrics] = {
+ def getSubmetrics(raw: InternalTaskMetrics): InternalShuffleWriteMetrics =
raw.shuffleWriteMetrics
- }
+
def build: ShuffleWriteMetricDistributions = new ShuffleWriteMetricDistributions(
- writeBytes = submetricQuantiles(_.shuffleBytesWritten),
- writeRecords = submetricQuantiles(_.shuffleRecordsWritten),
- writeTime = submetricQuantiles(_.shuffleWriteTime)
+ writeBytes = submetricQuantiles(_.bytesWritten),
+ writeRecords = submetricQuantiles(_.recordsWritten),
+ writeTime = submetricQuantiles(_.writeTime)
)
- }.metricOption
+ }.build
new TaskMetricDistributions(
quantiles = quantiles,
executorDeserializeTime = metricQuantiles(_.executorDeserializeTime),
+ executorDeserializeCpuTime = metricQuantiles(_.executorDeserializeCpuTime),
executorRunTime = metricQuantiles(_.executorRunTime),
+ executorCpuTime = metricQuantiles(_.executorCpuTime),
resultSize = metricQuantiles(_.resultSize),
jvmGcTime = metricQuantiles(_.jvmGCTime),
resultSerializationTime = metricQuantiles(_.resultSerializationTime),
@@ -225,22 +239,25 @@ private[v1] object AllStagesResource {
}
def convertAccumulableInfo(acc: InternalAccumulableInfo): AccumulableInfo = {
- new AccumulableInfo(acc.id, acc.name, acc.update, acc.value)
+ new AccumulableInfo(
+ acc.id, acc.name.orNull, acc.update.map(_.toString), acc.value.map(_.toString).orNull)
}
def convertUiTaskMetrics(internal: InternalTaskMetrics): TaskMetrics = {
new TaskMetrics(
executorDeserializeTime = internal.executorDeserializeTime,
+ executorDeserializeCpuTime = internal.executorDeserializeCpuTime,
executorRunTime = internal.executorRunTime,
+ executorCpuTime = internal.executorCpuTime,
resultSize = internal.resultSize,
jvmGcTime = internal.jvmGCTime,
resultSerializationTime = internal.resultSerializationTime,
memoryBytesSpilled = internal.memoryBytesSpilled,
diskBytesSpilled = internal.diskBytesSpilled,
- inputMetrics = internal.inputMetrics.map { convertInputMetrics },
- outputMetrics = Option(internal.outputMetrics).flatten.map { convertOutputMetrics },
- shuffleReadMetrics = internal.shuffleReadMetrics.map { convertShuffleReadMetrics },
- shuffleWriteMetrics = internal.shuffleWriteMetrics.map { convertShuffleWriteMetrics }
+ inputMetrics = convertInputMetrics(internal.inputMetrics),
+ outputMetrics = convertOutputMetrics(internal.outputMetrics),
+ shuffleReadMetrics = convertShuffleReadMetrics(internal.shuffleReadMetrics),
+ shuffleWriteMetrics = convertShuffleWriteMetrics(internal.shuffleWriteMetrics)
)
}
@@ -264,46 +281,35 @@ private[v1] object AllStagesResource {
localBlocksFetched = internal.localBlocksFetched,
fetchWaitTime = internal.fetchWaitTime,
remoteBytesRead = internal.remoteBytesRead,
- totalBlocksFetched = internal.totalBlocksFetched,
+ localBytesRead = internal.localBytesRead,
recordsRead = internal.recordsRead
)
}
def convertShuffleWriteMetrics(internal: InternalShuffleWriteMetrics): ShuffleWriteMetrics = {
new ShuffleWriteMetrics(
- bytesWritten = internal.shuffleBytesWritten,
- writeTime = internal.shuffleWriteTime,
- recordsWritten = internal.shuffleRecordsWritten
+ bytesWritten = internal.bytesWritten,
+ writeTime = internal.writeTime,
+ recordsWritten = internal.recordsWritten
)
}
}
/**
- * Helper for getting distributions from nested metric types. Many of the metrics we want are
- * contained in options inside TaskMetrics (eg., ShuffleWriteMetrics). This makes it easy to handle
- * the options (returning None if the metrics are all empty), and extract the quantiles for each
- * metric. After creating an instance, call metricOption to get the result type.
+ * Helper for getting distributions from nested metric types.
*/
private[v1] abstract class MetricHelper[I, O](
rawMetrics: Seq[InternalTaskMetrics],
quantiles: Array[Double]) {
- def getSubmetrics(raw: InternalTaskMetrics): Option[I]
+ def getSubmetrics(raw: InternalTaskMetrics): I
def build: O
- val data: Seq[I] = rawMetrics.flatMap(getSubmetrics)
+ val data: Seq[I] = rawMetrics.map(getSubmetrics)
/** applies the given function to all input metrics, and returns the quantiles */
def submetricQuantiles(f: I => Double): IndexedSeq[Double] = {
Distribution(data.map { d => f(d) }).get.getQuantiles(quantiles)
}
-
- def metricOption: Option[O] = {
- if (data.isEmpty) {
- None
- } else {
- Some(build)
- }
- }
}
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala
index 50b6ba67e993..f17b63775482 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala
@@ -18,13 +18,14 @@ package org.apache.spark.status.api.v1
import java.util.zip.ZipOutputStream
import javax.servlet.ServletContext
+import javax.servlet.http.HttpServletRequest
import javax.ws.rs._
import javax.ws.rs.core.{Context, Response}
-import com.sun.jersey.api.core.ResourceConfig
-import com.sun.jersey.spi.container.servlet.ServletContainer
import org.eclipse.jetty.server.handler.ContextHandler
import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder}
+import org.glassfish.jersey.server.ServerProperties
+import org.glassfish.jersey.servlet.ServletContainer
import org.apache.spark.SecurityManager
import org.apache.spark.ui.SparkUI
@@ -40,7 +41,7 @@ import org.apache.spark.ui.SparkUI
* HistoryServerSuite.
*/
@Path("/v1")
-private[v1] class ApiRootResource extends UIRootFromServletContext {
+private[v1] class ApiRootResource extends ApiRequestContext {
@Path("applications")
def getApplicationList(): ApplicationListResource = {
@@ -56,21 +57,21 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getJobs(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): AllJobsResource = {
- uiRoot.withSparkUI(appId, Some(attemptId)) { ui =>
+ withSparkUI(appId, Some(attemptId)) { ui =>
new AllJobsResource(ui)
}
}
@Path("applications/{appId}/jobs")
def getJobs(@PathParam("appId") appId: String): AllJobsResource = {
- uiRoot.withSparkUI(appId, None) { ui =>
+ withSparkUI(appId, None) { ui =>
new AllJobsResource(ui)
}
}
@Path("applications/{appId}/jobs/{jobId: \\d+}")
def getJob(@PathParam("appId") appId: String): OneJobResource = {
- uiRoot.withSparkUI(appId, None) { ui =>
+ withSparkUI(appId, None) { ui =>
new OneJobResource(ui)
}
}
@@ -79,31 +80,46 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getJob(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): OneJobResource = {
- uiRoot.withSparkUI(appId, Some(attemptId)) { ui =>
+ withSparkUI(appId, Some(attemptId)) { ui =>
new OneJobResource(ui)
}
}
@Path("applications/{appId}/executors")
def getExecutors(@PathParam("appId") appId: String): ExecutorListResource = {
- uiRoot.withSparkUI(appId, None) { ui =>
+ withSparkUI(appId, None) { ui =>
new ExecutorListResource(ui)
}
}
+ @Path("applications/{appId}/allexecutors")
+ def getAllExecutors(@PathParam("appId") appId: String): AllExecutorListResource = {
+ withSparkUI(appId, None) { ui =>
+ new AllExecutorListResource(ui)
+ }
+ }
+
@Path("applications/{appId}/{attemptId}/executors")
def getExecutors(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): ExecutorListResource = {
- uiRoot.withSparkUI(appId, Some(attemptId)) { ui =>
+ withSparkUI(appId, Some(attemptId)) { ui =>
new ExecutorListResource(ui)
}
}
+ @Path("applications/{appId}/{attemptId}/allexecutors")
+ def getAllExecutors(
+ @PathParam("appId") appId: String,
+ @PathParam("attemptId") attemptId: String): AllExecutorListResource = {
+ withSparkUI(appId, Some(attemptId)) { ui =>
+ new AllExecutorListResource(ui)
+ }
+ }
@Path("applications/{appId}/stages")
def getStages(@PathParam("appId") appId: String): AllStagesResource = {
- uiRoot.withSparkUI(appId, None) { ui =>
+ withSparkUI(appId, None) { ui =>
new AllStagesResource(ui)
}
}
@@ -112,14 +128,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getStages(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): AllStagesResource = {
- uiRoot.withSparkUI(appId, Some(attemptId)) { ui =>
+ withSparkUI(appId, Some(attemptId)) { ui =>
new AllStagesResource(ui)
}
}
@Path("applications/{appId}/stages/{stageId: \\d+}")
def getStage(@PathParam("appId") appId: String): OneStageResource = {
- uiRoot.withSparkUI(appId, None) { ui =>
+ withSparkUI(appId, None) { ui =>
new OneStageResource(ui)
}
}
@@ -128,14 +144,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getStage(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): OneStageResource = {
- uiRoot.withSparkUI(appId, Some(attemptId)) { ui =>
+ withSparkUI(appId, Some(attemptId)) { ui =>
new OneStageResource(ui)
}
}
@Path("applications/{appId}/storage/rdd")
def getRdds(@PathParam("appId") appId: String): AllRDDResource = {
- uiRoot.withSparkUI(appId, None) { ui =>
+ withSparkUI(appId, None) { ui =>
new AllRDDResource(ui)
}
}
@@ -144,14 +160,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getRdds(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): AllRDDResource = {
- uiRoot.withSparkUI(appId, Some(attemptId)) { ui =>
+ withSparkUI(appId, Some(attemptId)) { ui =>
new AllRDDResource(ui)
}
}
@Path("applications/{appId}/storage/rdd/{rddId: \\d+}")
def getRdd(@PathParam("appId") appId: String): OneRDDResource = {
- uiRoot.withSparkUI(appId, None) { ui =>
+ withSparkUI(appId, None) { ui =>
new OneRDDResource(ui)
}
}
@@ -160,7 +176,7 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getRdd(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): OneRDDResource = {
- uiRoot.withSparkUI(appId, Some(attemptId)) { ui =>
+ withSparkUI(appId, Some(attemptId)) { ui =>
new OneRDDResource(ui)
}
}
@@ -168,14 +184,48 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
@Path("applications/{appId}/logs")
def getEventLogs(
@PathParam("appId") appId: String): EventLogDownloadResource = {
- new EventLogDownloadResource(uiRoot, appId, None)
+ try {
+ // withSparkUI will throw NotFoundException if attemptId exists for this application.
+ // So we need to try again with attempt id "1".
+ withSparkUI(appId, None) { _ =>
+ new EventLogDownloadResource(uiRoot, appId, None)
+ }
+ } catch {
+ case _: NotFoundException =>
+ withSparkUI(appId, Some("1")) { _ =>
+ new EventLogDownloadResource(uiRoot, appId, None)
+ }
+ }
}
@Path("applications/{appId}/{attemptId}/logs")
def getEventLogs(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): EventLogDownloadResource = {
- new EventLogDownloadResource(uiRoot, appId, Some(attemptId))
+ withSparkUI(appId, Some(attemptId)) { _ =>
+ new EventLogDownloadResource(uiRoot, appId, Some(attemptId))
+ }
+ }
+
+ @Path("version")
+ def getVersion(): VersionResource = {
+ new VersionResource(uiRoot)
+ }
+
+ @Path("applications/{appId}/environment")
+ def getEnvironment(@PathParam("appId") appId: String): ApplicationEnvironmentResource = {
+ withSparkUI(appId, None) { ui =>
+ new ApplicationEnvironmentResource(ui)
+ }
+ }
+
+ @Path("applications/{appId}/{attemptId}/environment")
+ def getEnvironment(
+ @PathParam("appId") appId: String,
+ @PathParam("attemptId") attemptId: String): ApplicationEnvironmentResource = {
+ withSparkUI(appId, Some(attemptId)) { ui =>
+ new ApplicationEnvironmentResource(ui)
+ }
}
}
@@ -185,12 +235,7 @@ private[spark] object ApiRootResource {
val jerseyContext = new ServletContextHandler(ServletContextHandler.NO_SESSIONS)
jerseyContext.setContextPath("/api")
val holder: ServletHolder = new ServletHolder(classOf[ServletContainer])
- holder.setInitParameter("com.sun.jersey.config.property.resourceConfigClass",
- "com.sun.jersey.api.core.PackagesResourceConfig")
- holder.setInitParameter("com.sun.jersey.config.property.packages",
- "org.apache.spark.status.api.v1")
- holder.setInitParameter(ResourceConfig.PROPERTY_CONTAINER_REQUEST_FILTERS,
- classOf[SecurityFilter].getCanonicalName)
+ holder.setInitParameter(ServerProperties.PROVIDER_PACKAGES, "org.apache.spark.status.api.v1")
UIRootFromServletContext.setUiRoot(jerseyContext, uiRoot)
jerseyContext.addServlet(holder, "/*")
jerseyContext
@@ -199,12 +244,13 @@ private[spark] object ApiRootResource {
/**
* This trait is shared by the all the root containers for application UI information --
- * the HistoryServer, the Master UI, and the application UI. This provides the common
+ * the HistoryServer and the application UI. This provides the common
* interface needed for them all to expose application info as json.
*/
private[spark] trait UIRoot {
def getSparkUI(appKey: String): Option[SparkUI]
def getApplicationInfoList: Iterator[ApplicationInfo]
+ def getApplicationInfo(appId: String): Option[ApplicationInfo]
/**
* Write the event logs for the given app to the [[ZipOutputStream]] instance. If attemptId is
@@ -216,19 +262,6 @@ private[spark] trait UIRoot {
.status(Response.Status.SERVICE_UNAVAILABLE)
.build()
}
-
- /**
- * Get the spark UI with the given appID, and apply a function
- * to it. If there is no such app, throw an appropriate exception
- */
- def withSparkUI[T](appId: String, attemptId: Option[String])(f: SparkUI => T): T = {
- val appKey = attemptId.map(appId + "/" + _).getOrElse(appId)
- getSparkUI(appKey) match {
- case Some(ui) =>
- f(ui)
- case None => throw new NotFoundException("no such app: " + appId)
- }
- }
def securityManager: SecurityManager
}
@@ -245,13 +278,37 @@ private[v1] object UIRootFromServletContext {
}
}
-private[v1] trait UIRootFromServletContext {
+private[v1] trait ApiRequestContext {
@Context
- var servletContext: ServletContext = _
+ protected var servletContext: ServletContext = _
+
+ @Context
+ protected var httpRequest: HttpServletRequest = _
def uiRoot: UIRoot = UIRootFromServletContext.getUiRoot(servletContext)
+
+
+ /**
+ * Get the spark UI with the given appID, and apply a function
+ * to it. If there is no such app, throw an appropriate exception
+ */
+ def withSparkUI[T](appId: String, attemptId: Option[String])(f: SparkUI => T): T = {
+ val appKey = attemptId.map(appId + "/" + _).getOrElse(appId)
+ uiRoot.getSparkUI(appKey) match {
+ case Some(ui) =>
+ val user = httpRequest.getRemoteUser()
+ if (!ui.securityManager.checkUIViewPermissions(user)) {
+ throw new ForbiddenException(raw"""user "$user" is not authorized""")
+ }
+ f(ui)
+ case None => throw new NotFoundException("no such app: " + appId)
+ }
+ }
}
+private[v1] class ForbiddenException(msg: String) extends WebApplicationException(
+ Response.status(Response.Status.FORBIDDEN).entity(msg).build())
+
private[v1] class NotFoundException(msg: String) extends WebApplicationException(
new NoSuchElementException(msg),
Response
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala
new file mode 100644
index 000000000000..739a8aceae86
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.status.api.v1
+
+import javax.ws.rs._
+import javax.ws.rs.core.MediaType
+
+import org.apache.spark.ui.SparkUI
+
+@Produces(Array(MediaType.APPLICATION_JSON))
+private[v1] class ApplicationEnvironmentResource(ui: SparkUI) {
+
+ @GET
+ def getEnvironmentInfo(): ApplicationEnvironmentInfo = {
+ val listener = ui.environmentListener
+ listener.synchronized {
+ val jvmInfo = Map(listener.jvmInformation: _*)
+ val runtime = new RuntimeInfo(
+ jvmInfo("Java Version"),
+ jvmInfo("Java Home"),
+ jvmInfo("Scala Version"))
+
+ new ApplicationEnvironmentInfo(
+ runtime,
+ listener.sparkProperties,
+ listener.systemProperties,
+ listener.classpathEntries)
+ }
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala
index 17b521f3e1d4..a0239266d875 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala
@@ -16,12 +16,11 @@
*/
package org.apache.spark.status.api.v1
-import java.util.{Arrays, Date, List => JList}
+import java.util.{Date, List => JList}
import javax.ws.rs.{DefaultValue, GET, Produces, QueryParam}
import javax.ws.rs.core.MediaType
import org.apache.spark.deploy.history.ApplicationHistoryInfo
-import org.apache.spark.deploy.master.{ApplicationInfo => InternalApplicationInfo}
@Produces(Array(MediaType.APPLICATION_JSON))
private[v1] class ApplicationListResource(uiRoot: UIRoot) {
@@ -30,30 +29,42 @@ private[v1] class ApplicationListResource(uiRoot: UIRoot) {
def appList(
@QueryParam("status") status: JList[ApplicationStatus],
@DefaultValue("2010-01-01") @QueryParam("minDate") minDate: SimpleDateParam,
- @DefaultValue("3000-01-01") @QueryParam("maxDate") maxDate: SimpleDateParam)
+ @DefaultValue("3000-01-01") @QueryParam("maxDate") maxDate: SimpleDateParam,
+ @DefaultValue("2010-01-01") @QueryParam("minEndDate") minEndDate: SimpleDateParam,
+ @DefaultValue("3000-01-01") @QueryParam("maxEndDate") maxEndDate: SimpleDateParam,
+ @QueryParam("limit") limit: Integer)
: Iterator[ApplicationInfo] = {
- val allApps = uiRoot.getApplicationInfoList
- val adjStatus = {
- if (status.isEmpty) {
- Arrays.asList(ApplicationStatus.values(): _*)
- } else {
- status
- }
- }
- val includeCompleted = adjStatus.contains(ApplicationStatus.COMPLETED)
- val includeRunning = adjStatus.contains(ApplicationStatus.RUNNING)
- allApps.filter { app =>
+
+ val numApps = Option(limit).map(_.toInt).getOrElse(Integer.MAX_VALUE)
+ val includeCompleted = status.isEmpty || status.contains(ApplicationStatus.COMPLETED)
+ val includeRunning = status.isEmpty || status.contains(ApplicationStatus.RUNNING)
+
+ uiRoot.getApplicationInfoList.filter { app =>
val anyRunning = app.attempts.exists(!_.completed)
- // if any attempt is still running, we consider the app to also still be running
- val statusOk = (!anyRunning && includeCompleted) ||
- (anyRunning && includeRunning)
+ // if any attempt is still running, we consider the app to also still be running;
// keep the app if *any* attempts fall in the right time window
- val dateOk = app.attempts.exists { attempt =>
- attempt.startTime.getTime >= minDate.timestamp &&
- attempt.startTime.getTime <= maxDate.timestamp
+ ((!anyRunning && includeCompleted) || (anyRunning && includeRunning)) &&
+ app.attempts.exists { attempt =>
+ isAttemptInRange(attempt, minDate, maxDate, minEndDate, maxEndDate, anyRunning)
}
- statusOk && dateOk
- }
+ }.take(numApps)
+ }
+
+ private def isAttemptInRange(
+ attempt: ApplicationAttemptInfo,
+ minStartDate: SimpleDateParam,
+ maxStartDate: SimpleDateParam,
+ minEndDate: SimpleDateParam,
+ maxEndDate: SimpleDateParam,
+ anyRunning: Boolean): Boolean = {
+ val startTimeOk = attempt.startTime.getTime >= minStartDate.timestamp &&
+ attempt.startTime.getTime <= maxStartDate.timestamp
+ // If the maxEndDate is in the past, exclude all running apps.
+ val endTimeOkForRunning = anyRunning && (maxEndDate.timestamp > System.currentTimeMillis())
+ val endTimeOkForCompleted = !anyRunning && (attempt.endTime.getTime >= minEndDate.timestamp &&
+ attempt.endTime.getTime <= maxEndDate.timestamp)
+ val endTimeOk = endTimeOkForRunning || endTimeOkForCompleted
+ startTimeOk && endTimeOk
}
}
@@ -62,33 +73,26 @@ private[spark] object ApplicationsListResource {
new ApplicationInfo(
id = app.id,
name = app.name,
+ coresGranted = None,
+ maxCores = None,
+ coresPerExecutor = None,
+ memoryPerExecutorMB = None,
attempts = app.attempts.map { internalAttemptInfo =>
new ApplicationAttemptInfo(
attemptId = internalAttemptInfo.attemptId,
startTime = new Date(internalAttemptInfo.startTime),
endTime = new Date(internalAttemptInfo.endTime),
+ duration =
+ if (internalAttemptInfo.endTime > 0) {
+ internalAttemptInfo.endTime - internalAttemptInfo.startTime
+ } else {
+ 0
+ },
+ lastUpdated = new Date(internalAttemptInfo.lastUpdated),
sparkUser = internalAttemptInfo.sparkUser,
completed = internalAttemptInfo.completed
)
}
)
}
-
- def convertApplicationInfo(
- internal: InternalApplicationInfo,
- completed: Boolean): ApplicationInfo = {
- // standalone application info always has just one attempt
- new ApplicationInfo(
- id = internal.id,
- name = internal.desc.name,
- attempts = Seq(new ApplicationAttemptInfo(
- attemptId = None,
- startTime = new Date(internal.startTime),
- endTime = new Date(internal.endTime),
- sparkUser = internal.desc.user,
- completed = completed
- ))
- )
- }
-
}
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala
index 22e21f0c62a2..c84022ddfeef 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala
@@ -23,8 +23,9 @@ import javax.ws.rs.core.{MediaType, Response, StreamingOutput}
import scala.util.control.NonFatal
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.SparkConf
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.internal.Logging
@Produces(Array(MediaType.APPLICATION_OCTET_STREAM))
private[v1] class EventLogDownloadResource(
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala
index 8ad4656b4dad..ab5388159418 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala
@@ -16,7 +16,7 @@
*/
package org.apache.spark.status.api.v1
-import javax.ws.rs.{GET, PathParam, Produces}
+import javax.ws.rs.{GET, Produces}
import javax.ws.rs.core.MediaType
import org.apache.spark.ui.SparkUI
@@ -28,9 +28,13 @@ private[v1] class ExecutorListResource(ui: SparkUI) {
@GET
def executorList(): Seq[ExecutorSummary] = {
val listener = ui.executorsListener
- val storageStatusList = listener.storageStatusList
- (0 until storageStatusList.size).map { statusId =>
- ExecutorsPage.getExecInfo(listener, statusId)
+ listener.synchronized {
+ // The follow codes should be protected by `listener` to make sure no executors will be
+ // removed before we query their status. See SPARK-12784.
+ val storageStatusList = listener.activeStorageStatusList
+ (0 until storageStatusList.size).map { statusId =>
+ ExecutorsPage.getExecInfo(listener, statusId, isActive = true)
+ }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala
index 202a5191ad57..76af33c1a18d 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala
@@ -19,8 +19,9 @@ package org.apache.spark.status.api.v1
import java.io.OutputStream
import java.lang.annotation.Annotation
import java.lang.reflect.Type
+import java.nio.charset.StandardCharsets
import java.text.SimpleDateFormat
-import java.util.{Calendar, SimpleTimeZone}
+import java.util.{Calendar, Locale, SimpleTimeZone}
import javax.ws.rs.Produces
import javax.ws.rs.core.{MediaType, MultivaluedMap}
import javax.ws.rs.ext.{MessageBodyWriter, Provider}
@@ -68,7 +69,7 @@ private[v1] class JacksonMessageWriter extends MessageBodyWriter[Object]{
multivaluedMap: MultivaluedMap[String, AnyRef],
outputStream: OutputStream): Unit = {
t match {
- case ErrorWrapper(err) => outputStream.write(err.getBytes("utf-8"))
+ case ErrorWrapper(err) => outputStream.write(err.getBytes(StandardCharsets.UTF_8))
case _ => mapper.writeValue(outputStream, t)
}
}
@@ -85,7 +86,7 @@ private[v1] class JacksonMessageWriter extends MessageBodyWriter[Object]{
private[spark] object JacksonMessageWriter {
def makeISODateFormat: SimpleDateFormat = {
- val iso8601 = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'GMT'")
+ val iso8601 = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'GMT'", Locale.US)
val cal = Calendar.getInstance(new SimpleTimeZone(0, "GMT"))
iso8601.setCalendar(cal)
iso8601
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala
index b5ef72649e29..18c3e2f40736 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala
@@ -16,15 +16,15 @@
*/
package org.apache.spark.status.api.v1
+import javax.ws.rs.{GET, PathParam, Produces}
import javax.ws.rs.core.MediaType
-import javax.ws.rs.{Produces, PathParam, GET}
@Produces(Array(MediaType.APPLICATION_JSON))
private[v1] class OneApplicationResource(uiRoot: UIRoot) {
@GET
def getApp(@PathParam("appId") appId: String): ApplicationInfo = {
- val apps = uiRoot.getApplicationInfoList.find { _.id == appId }
+ val apps = uiRoot.getApplicationInfo(appId)
apps.getOrElse(throw new NotFoundException("unknown app: " + appId))
}
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala
index 6d8a60d480ae..653150385c73 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala
@@ -16,7 +16,7 @@
*/
package org.apache.spark.status.api.v1
-import javax.ws.rs.{PathParam, GET, Produces}
+import javax.ws.rs.{GET, PathParam, Produces}
import javax.ws.rs.core.MediaType
import org.apache.spark.JobExecutionStatus
@@ -30,7 +30,7 @@ private[v1] class OneJobResource(ui: SparkUI) {
def oneJob(@PathParam("jobId") jobId: Int): JobData = {
val statusToJobs: Seq[(JobExecutionStatus, Seq[JobUIData])] =
AllJobsResource.getStatusToJobs(ui)
- val jobOpt = statusToJobs.map {_._2} .flatten.find { jobInfo => jobInfo.jobId == jobId}
+ val jobOpt = statusToJobs.flatMap(_._2).find { jobInfo => jobInfo.jobId == jobId}
jobOpt.map { job =>
AllJobsResource.convertJobData(job, ui.jobProgressListener, false)
}.getOrElse {
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala
index dfdc09c6caf3..237aeac18587 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala
@@ -16,7 +16,7 @@
*/
package org.apache.spark.status.api.v1
-import javax.ws.rs.{PathParam, GET, Produces}
+import javax.ws.rs.{GET, PathParam, Produces}
import javax.ws.rs.core.MediaType
import org.apache.spark.ui.SparkUI
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala
index f9812f06cf52..3e6d2942d0fb 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala
@@ -33,7 +33,7 @@ private[v1] class OneStageResource(ui: SparkUI) {
@GET
@Path("")
def stageData(@PathParam("stageId") stageId: Int): Seq[StageData] = {
- withStage(stageId){ stageAttempts =>
+ withStage(stageId) { stageAttempts =>
stageAttempts.map { stage =>
AllStagesResource.stageUiToStageData(stage.status, stage.info, stage.ui,
includeDetails = true)
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala
index 95fbd96ade5a..1cd37185d660 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala
@@ -16,21 +16,19 @@
*/
package org.apache.spark.status.api.v1
-import javax.ws.rs.WebApplicationException
+import javax.ws.rs.container.{ContainerRequestContext, ContainerRequestFilter}
import javax.ws.rs.core.Response
+import javax.ws.rs.ext.Provider
-import com.sun.jersey.spi.container.{ContainerRequest, ContainerRequestFilter}
-
-private[v1] class SecurityFilter extends ContainerRequestFilter with UIRootFromServletContext {
- def filter(req: ContainerRequest): ContainerRequest = {
- val user = Option(req.getUserPrincipal).map { _.getName }.orNull
- if (uiRoot.securityManager.checkUIViewPermissions(user)) {
- req
- } else {
- throw new WebApplicationException(
+@Provider
+private[v1] class SecurityFilter extends ContainerRequestFilter with ApiRequestContext {
+ override def filter(req: ContainerRequestContext): Unit = {
+ val user = httpRequest.getRemoteUser()
+ if (!uiRoot.securityManager.checkUIViewPermissions(user)) {
+ req.abortWith(
Response
.status(Response.Status.FORBIDDEN)
- .entity(raw"""user "$user"is not authorized""")
+ .entity(raw"""user "$user" is not authorized""")
.build()
)
}
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala
index 0c71cd238222..d8d5e8958b23 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala
@@ -17,7 +17,7 @@
package org.apache.spark.status.api.v1
import java.text.{ParseException, SimpleDateFormat}
-import java.util.TimeZone
+import java.util.{Locale, TimeZone}
import javax.ws.rs.WebApplicationException
import javax.ws.rs.core.Response
import javax.ws.rs.core.Response.Status
@@ -25,12 +25,12 @@ import javax.ws.rs.core.Response.Status
private[v1] class SimpleDateParam(val originalValue: String) {
val timestamp: Long = {
- val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz")
+ val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz", Locale.US)
try {
format.parse(originalValue).getTime()
} catch {
case _: ParseException =>
- val gmtDay = new SimpleDateFormat("yyyy-MM-dd")
+ val gmtDay = new SimpleDateFormat("yyyy-MM-dd", Locale.US)
gmtDay.setTimeZone(TimeZone.getTimeZone("GMT"))
try {
gmtDay.parse(originalValue).getTime()
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/VersionResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/VersionResource.scala
new file mode 100644
index 000000000000..673da1ce36b5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/VersionResource.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.status.api.v1
+
+import javax.ws.rs._
+import javax.ws.rs.core.MediaType
+
+@Produces(Array(MediaType.APPLICATION_JSON))
+private[v1] class VersionResource(ui: UIRoot) {
+
+ @GET
+ def getVersionInfo(): VersionInfo = new VersionInfo(
+ org.apache.spark.SPARK_VERSION
+ )
+
+}
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
index 2bec64f2ef02..56d8e51732ff 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
@@ -25,14 +25,24 @@ import org.apache.spark.JobExecutionStatus
class ApplicationInfo private[spark](
val id: String,
val name: String,
+ val coresGranted: Option[Int],
+ val maxCores: Option[Int],
+ val coresPerExecutor: Option[Int],
+ val memoryPerExecutorMB: Option[Int],
val attempts: Seq[ApplicationAttemptInfo])
class ApplicationAttemptInfo private[spark](
val attemptId: Option[String],
val startTime: Date,
val endTime: Date,
+ val lastUpdated: Date,
+ val duration: Long,
val sparkUser: String,
- val completed: Boolean = false)
+ val completed: Boolean = false) {
+ def getStartTimeEpoch: Long = startTime.getTime
+ def getEndTimeEpoch: Long = endTime.getTime
+ def getLastUpdatedEpoch: Long = lastUpdated.getTime
+}
class ExecutorStageSummary private[spark](
val taskTime : Long,
@@ -48,19 +58,31 @@ class ExecutorStageSummary private[spark](
class ExecutorSummary private[spark](
val id: String,
val hostPort: String,
+ val isActive: Boolean,
val rddBlocks: Int,
val memoryUsed: Long,
val diskUsed: Long,
+ val totalCores: Int,
+ val maxTasks: Int,
val activeTasks: Int,
val failedTasks: Int,
val completedTasks: Int,
val totalTasks: Int,
val totalDuration: Long,
+ val totalGCTime: Long,
val totalInputBytes: Long,
val totalShuffleRead: Long,
val totalShuffleWrite: Long,
+ val isBlacklisted: Boolean,
val maxMemory: Long,
- val executorLogs: Map[String, String])
+ val executorLogs: Map[String, String],
+ val memoryMetrics: Option[MemoryMetrics])
+
+class MemoryMetrics private[spark](
+ val usedOnHeapStorageMemory: Long,
+ val usedOffHeapStorageMemory: Long,
+ val totalOnHeapStorageMemory: Long,
+ val totalOffHeapStorageMemory: Long)
class JobData private[spark](
val jobId: Int,
@@ -81,8 +103,6 @@ class JobData private[spark](
val numSkippedStages: Int,
val numFailedStages: Int)
-// Q: should Tachyon size go in here as well? currently the UI only shows it on the overall storage
-// page ... does anybody pay attention to it?
class RDDStorageInfo private[spark](
val id: Int,
val name: String,
@@ -98,7 +118,11 @@ class RDDDataDistribution private[spark](
val address: String,
val memoryUsed: Long,
val memoryRemaining: Long,
- val diskUsed: Long)
+ val diskUsed: Long,
+ val onHeapMemoryUsed: Option[Long],
+ val offHeapMemoryUsed: Option[Long],
+ val onHeapMemoryRemaining: Option[Long],
+ val offHeapMemoryRemaining: Option[Long])
class RDDPartitionInfo private[spark](
val blockName: String,
@@ -111,11 +135,15 @@ class StageData private[spark](
val status: StageStatus,
val stageId: Int,
val attemptId: Int,
- val numActiveTasks: Int ,
+ val numActiveTasks: Int,
val numCompleteTasks: Int,
val numFailedTasks: Int,
val executorRunTime: Long,
+ val executorCpuTime: Long,
+ val submissionTime: Option[Date],
+ val firstTaskLaunchedTime: Option[Date],
+ val completionTime: Option[Date],
val inputBytes: Long,
val inputRecords: Long,
@@ -141,8 +169,10 @@ class TaskData private[spark](
val index: Int,
val attempt: Int,
val launchTime: Date,
+ val duration: Option[Long] = None,
val executorId: String,
val host: String,
+ val status: String,
val taskLocality: String,
val speculative: Boolean,
val accumulatorUpdates: Seq[AccumulableInfo],
@@ -151,16 +181,18 @@ class TaskData private[spark](
class TaskMetrics private[spark](
val executorDeserializeTime: Long,
+ val executorDeserializeCpuTime: Long,
val executorRunTime: Long,
+ val executorCpuTime: Long,
val resultSize: Long,
val jvmGcTime: Long,
val resultSerializationTime: Long,
val memoryBytesSpilled: Long,
val diskBytesSpilled: Long,
- val inputMetrics: Option[InputMetrics],
- val outputMetrics: Option[OutputMetrics],
- val shuffleReadMetrics: Option[ShuffleReadMetrics],
- val shuffleWriteMetrics: Option[ShuffleWriteMetrics])
+ val inputMetrics: InputMetrics,
+ val outputMetrics: OutputMetrics,
+ val shuffleReadMetrics: ShuffleReadMetrics,
+ val shuffleWriteMetrics: ShuffleWriteMetrics)
class InputMetrics private[spark](
val bytesRead: Long,
@@ -171,11 +203,11 @@ class OutputMetrics private[spark](
val recordsWritten: Long)
class ShuffleReadMetrics private[spark](
- val remoteBlocksFetched: Int,
- val localBlocksFetched: Int,
+ val remoteBlocksFetched: Long,
+ val localBlocksFetched: Long,
val fetchWaitTime: Long,
val remoteBytesRead: Long,
- val totalBlocksFetched: Int,
+ val localBytesRead: Long,
val recordsRead: Long)
class ShuffleWriteMetrics private[spark](
@@ -187,17 +219,19 @@ class TaskMetricDistributions private[spark](
val quantiles: IndexedSeq[Double],
val executorDeserializeTime: IndexedSeq[Double],
+ val executorDeserializeCpuTime: IndexedSeq[Double],
val executorRunTime: IndexedSeq[Double],
+ val executorCpuTime: IndexedSeq[Double],
val resultSize: IndexedSeq[Double],
val jvmGcTime: IndexedSeq[Double],
val resultSerializationTime: IndexedSeq[Double],
val memoryBytesSpilled: IndexedSeq[Double],
val diskBytesSpilled: IndexedSeq[Double],
- val inputMetrics: Option[InputMetricDistributions],
- val outputMetrics: Option[OutputMetricDistributions],
- val shuffleReadMetrics: Option[ShuffleReadMetricDistributions],
- val shuffleWriteMetrics: Option[ShuffleWriteMetricDistributions])
+ val inputMetrics: InputMetricDistributions,
+ val outputMetrics: OutputMetricDistributions,
+ val shuffleReadMetrics: ShuffleReadMetricDistributions,
+ val shuffleWriteMetrics: ShuffleWriteMetricDistributions)
class InputMetricDistributions private[spark](
val bytesRead: IndexedSeq[Double],
@@ -226,3 +260,17 @@ class AccumulableInfo private[spark](
val name: String,
val update: Option[String],
val value: String)
+
+class VersionInfo private[spark](
+ val spark: String)
+
+class ApplicationEnvironmentInfo private[spark] (
+ val runtime: RuntimeInfo,
+ val sparkProperties: Seq[(String, String)],
+ val systemProperties: Seq[(String, String)],
+ val classpathEntries: Seq[(String, String)])
+
+class RuntimeInfo private[spark](
+ val javaVersion: String,
+ val javaHome: String,
+ val scalaVersion: String)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetchException.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetchException.scala
deleted file mode 100644
index f6e46ae9a481..000000000000
--- a/core/src/main/scala/org/apache/spark/storage/BlockFetchException.scala
+++ /dev/null
@@ -1,24 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import org.apache.spark.SparkException
-
-private[spark]
-case class BlockFetchException(messages: String, throwable: Throwable)
- extends SparkException(messages, throwable)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index 524f6970992a..8c1e657ecc8e 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -19,6 +19,7 @@ package org.apache.spark.storage
import java.util.UUID
+import org.apache.spark.SparkException
import org.apache.spark.annotation.DeveloperApi
/**
@@ -100,6 +101,10 @@ private[spark] case class TestBlockId(id: String) extends BlockId {
override def name: String = "test_" + id
}
+@DeveloperApi
+class UnrecognizedBlockId(name: String)
+ extends SparkException(s"Failed to parse $name into a block ID")
+
@DeveloperApi
object BlockId {
val RDD = "rdd_([0-9]+)_([0-9]+)".r
@@ -109,10 +114,11 @@ object BlockId {
val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r
val TASKRESULT = "taskresult_([0-9]+)".r
val STREAM = "input-([0-9]+)-([0-9]+)".r
+ val TEMP_LOCAL = "temp_local_([-A-Fa-f0-9]+)".r
+ val TEMP_SHUFFLE = "temp_shuffle_([-A-Fa-f0-9]+)".r
val TEST = "test_(.*)".r
- /** Converts a BlockId "name" String back into a BlockId. */
- def apply(id: String): BlockId = id match {
+ def apply(name: String): BlockId = name match {
case RDD(rddId, splitIndex) =>
RDDBlockId(rddId.toInt, splitIndex.toInt)
case SHUFFLE(shuffleId, mapId, reduceId) =>
@@ -127,9 +133,13 @@ object BlockId {
TaskResultBlockId(taskId.toLong)
case STREAM(streamId, uniqueId) =>
StreamBlockId(streamId.toInt, uniqueId.toLong)
+ case TEMP_LOCAL(uuid) =>
+ TempLocalBlockId(UUID.fromString(uuid))
+ case TEMP_SHUFFLE(uuid) =>
+ TempShuffleBlockId(UUID.fromString(uuid))
case TEST(value) =>
TestBlockId(value)
case _ =>
- throw new IllegalStateException("Unrecognized BlockId: " + id)
+ throw new UnrecognizedBlockId(name)
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala
deleted file mode 100644
index 22fdf73e9d1f..000000000000
--- a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala
+++ /dev/null
@@ -1,83 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.util.concurrent.ConcurrentHashMap
-
-private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
- // To save space, 'pending' and 'failed' are encoded as special sizes:
- @volatile var size: Long = BlockInfo.BLOCK_PENDING
- private def pending: Boolean = size == BlockInfo.BLOCK_PENDING
- private def failed: Boolean = size == BlockInfo.BLOCK_FAILED
- private def initThread: Thread = BlockInfo.blockInfoInitThreads.get(this)
-
- setInitThread()
-
- private def setInitThread() {
- /* Set current thread as init thread - waitForReady will not block this thread
- * (in case there is non trivial initialization which ends up calling waitForReady
- * as part of initialization itself) */
- BlockInfo.blockInfoInitThreads.put(this, Thread.currentThread())
- }
-
- /**
- * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing).
- * Return true if the block is available, false otherwise.
- */
- def waitForReady(): Boolean = {
- if (pending && initThread != Thread.currentThread()) {
- synchronized {
- while (pending) {
- this.wait()
- }
- }
- }
- !failed
- }
-
- /** Mark this BlockInfo as ready (i.e. block is finished writing) */
- def markReady(sizeInBytes: Long) {
- require(sizeInBytes >= 0, s"sizeInBytes was negative: $sizeInBytes")
- assert(pending)
- size = sizeInBytes
- BlockInfo.blockInfoInitThreads.remove(this)
- synchronized {
- this.notifyAll()
- }
- }
-
- /** Mark this BlockInfo as ready but failed */
- def markFailure() {
- assert(pending)
- size = BlockInfo.BLOCK_FAILED
- BlockInfo.blockInfoInitThreads.remove(this)
- synchronized {
- this.notifyAll()
- }
- }
-}
-
-private object BlockInfo {
- /* initThread is logically a BlockInfo field, but we store it here because
- * it's only needed while this block is in the 'pending' state and we want
- * to minimize BlockInfo's memory footprint. */
- private val blockInfoInitThreads = new ConcurrentHashMap[BlockInfo, Thread]
-
- private val BLOCK_PENDING: Long = -1L
- private val BLOCK_FAILED: Long = -2L
-}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
new file mode 100644
index 000000000000..7064872ec1c7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
@@ -0,0 +1,452 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import javax.annotation.concurrent.GuardedBy
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+import scala.reflect.ClassTag
+
+import com.google.common.collect.{ConcurrentHashMultiset, ImmutableMultiset}
+
+import org.apache.spark.{SparkException, TaskContext}
+import org.apache.spark.internal.Logging
+
+
+/**
+ * Tracks metadata for an individual block.
+ *
+ * Instances of this class are _not_ thread-safe and are protected by locks in the
+ * [[BlockInfoManager]].
+ *
+ * @param level the block's storage level. This is the requested persistence level, not the
+ * effective storage level of the block (i.e. if this is MEMORY_AND_DISK, then this
+ * does not imply that the block is actually resident in memory).
+ * @param classTag the block's [[ClassTag]], used to select the serializer
+ * @param tellMaster whether state changes for this block should be reported to the master. This
+ * is true for most blocks, but is false for broadcast blocks.
+ */
+private[storage] class BlockInfo(
+ val level: StorageLevel,
+ val classTag: ClassTag[_],
+ val tellMaster: Boolean) {
+
+ /**
+ * The size of the block (in bytes)
+ */
+ def size: Long = _size
+ def size_=(s: Long): Unit = {
+ _size = s
+ checkInvariants()
+ }
+ private[this] var _size: Long = 0
+
+ /**
+ * The number of times that this block has been locked for reading.
+ */
+ def readerCount: Int = _readerCount
+ def readerCount_=(c: Int): Unit = {
+ _readerCount = c
+ checkInvariants()
+ }
+ private[this] var _readerCount: Int = 0
+
+ /**
+ * The task attempt id of the task which currently holds the write lock for this block, or
+ * [[BlockInfo.NON_TASK_WRITER]] if the write lock is held by non-task code, or
+ * [[BlockInfo.NO_WRITER]] if this block is not locked for writing.
+ */
+ def writerTask: Long = _writerTask
+ def writerTask_=(t: Long): Unit = {
+ _writerTask = t
+ checkInvariants()
+ }
+ private[this] var _writerTask: Long = BlockInfo.NO_WRITER
+
+ private def checkInvariants(): Unit = {
+ // A block's reader count must be non-negative:
+ assert(_readerCount >= 0)
+ // A block is either locked for reading or for writing, but not for both at the same time:
+ assert(_readerCount == 0 || _writerTask == BlockInfo.NO_WRITER)
+ }
+
+ checkInvariants()
+}
+
+private[storage] object BlockInfo {
+
+ /**
+ * Special task attempt id constant used to mark a block's write lock as being unlocked.
+ */
+ val NO_WRITER: Long = -1
+
+ /**
+ * Special task attempt id constant used to mark a block's write lock as being held by
+ * a non-task thread (e.g. by a driver thread or by unit test code).
+ */
+ val NON_TASK_WRITER: Long = -1024
+}
+
+/**
+ * Component of the [[BlockManager]] which tracks metadata for blocks and manages block locking.
+ *
+ * The locking interface exposed by this class is readers-writer lock. Every lock acquisition is
+ * automatically associated with a running task and locks are automatically released upon task
+ * completion or failure.
+ *
+ * This class is thread-safe.
+ */
+private[storage] class BlockInfoManager extends Logging {
+
+ private type TaskAttemptId = Long
+
+ /**
+ * Used to look up metadata for individual blocks. Entries are added to this map via an atomic
+ * set-if-not-exists operation ([[lockNewBlockForWriting()]]) and are removed
+ * by [[removeBlock()]].
+ */
+ @GuardedBy("this")
+ private[this] val infos = new mutable.HashMap[BlockId, BlockInfo]
+
+ /**
+ * Tracks the set of blocks that each task has locked for writing.
+ */
+ @GuardedBy("this")
+ private[this] val writeLocksByTask =
+ new mutable.HashMap[TaskAttemptId, mutable.Set[BlockId]]
+ with mutable.MultiMap[TaskAttemptId, BlockId]
+
+ /**
+ * Tracks the set of blocks that each task has locked for reading, along with the number of times
+ * that a block has been locked (since our read locks are re-entrant).
+ */
+ @GuardedBy("this")
+ private[this] val readLocksByTask =
+ new mutable.HashMap[TaskAttemptId, ConcurrentHashMultiset[BlockId]]
+
+ // ----------------------------------------------------------------------------------------------
+
+ // Initialization for special task attempt ids:
+ registerTask(BlockInfo.NON_TASK_WRITER)
+
+ // ----------------------------------------------------------------------------------------------
+
+ /**
+ * Called at the start of a task in order to register that task with this [[BlockInfoManager]].
+ * This must be called prior to calling any other BlockInfoManager methods from that task.
+ */
+ def registerTask(taskAttemptId: TaskAttemptId): Unit = synchronized {
+ require(!readLocksByTask.contains(taskAttemptId),
+ s"Task attempt $taskAttemptId is already registered")
+ readLocksByTask(taskAttemptId) = ConcurrentHashMultiset.create()
+ }
+
+ /**
+ * Returns the current task's task attempt id (which uniquely identifies the task), or
+ * [[BlockInfo.NON_TASK_WRITER]] if called by a non-task thread.
+ */
+ private def currentTaskAttemptId: TaskAttemptId = {
+ Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(BlockInfo.NON_TASK_WRITER)
+ }
+
+ /**
+ * Lock a block for reading and return its metadata.
+ *
+ * If another task has already locked this block for reading, then the read lock will be
+ * immediately granted to the calling task and its lock count will be incremented.
+ *
+ * If another task has locked this block for writing, then this call will block until the write
+ * lock is released or will return immediately if `blocking = false`.
+ *
+ * A single task can lock a block multiple times for reading, in which case each lock will need
+ * to be released separately.
+ *
+ * @param blockId the block to lock.
+ * @param blocking if true (default), this call will block until the lock is acquired. If false,
+ * this call will return immediately if the lock acquisition fails.
+ * @return None if the block did not exist or was removed (in which case no lock is held), or
+ * Some(BlockInfo) (in which case the block is locked for reading).
+ */
+ def lockForReading(
+ blockId: BlockId,
+ blocking: Boolean = true): Option[BlockInfo] = synchronized {
+ logTrace(s"Task $currentTaskAttemptId trying to acquire read lock for $blockId")
+ do {
+ infos.get(blockId) match {
+ case None => return None
+ case Some(info) =>
+ if (info.writerTask == BlockInfo.NO_WRITER) {
+ info.readerCount += 1
+ readLocksByTask(currentTaskAttemptId).add(blockId)
+ logTrace(s"Task $currentTaskAttemptId acquired read lock for $blockId")
+ return Some(info)
+ }
+ }
+ if (blocking) {
+ wait()
+ }
+ } while (blocking)
+ None
+ }
+
+ /**
+ * Lock a block for writing and return its metadata.
+ *
+ * If another task has already locked this block for either reading or writing, then this call
+ * will block until the other locks are released or will return immediately if `blocking = false`.
+ *
+ * @param blockId the block to lock.
+ * @param blocking if true (default), this call will block until the lock is acquired. If false,
+ * this call will return immediately if the lock acquisition fails.
+ * @return None if the block did not exist or was removed (in which case no lock is held), or
+ * Some(BlockInfo) (in which case the block is locked for writing).
+ */
+ def lockForWriting(
+ blockId: BlockId,
+ blocking: Boolean = true): Option[BlockInfo] = synchronized {
+ logTrace(s"Task $currentTaskAttemptId trying to acquire write lock for $blockId")
+ do {
+ infos.get(blockId) match {
+ case None => return None
+ case Some(info) =>
+ if (info.writerTask == BlockInfo.NO_WRITER && info.readerCount == 0) {
+ info.writerTask = currentTaskAttemptId
+ writeLocksByTask.addBinding(currentTaskAttemptId, blockId)
+ logTrace(s"Task $currentTaskAttemptId acquired write lock for $blockId")
+ return Some(info)
+ }
+ }
+ if (blocking) {
+ wait()
+ }
+ } while (blocking)
+ None
+ }
+
+ /**
+ * Throws an exception if the current task does not hold a write lock on the given block.
+ * Otherwise, returns the block's BlockInfo.
+ */
+ def assertBlockIsLockedForWriting(blockId: BlockId): BlockInfo = synchronized {
+ infos.get(blockId) match {
+ case Some(info) =>
+ if (info.writerTask != currentTaskAttemptId) {
+ throw new SparkException(
+ s"Task $currentTaskAttemptId has not locked block $blockId for writing")
+ } else {
+ info
+ }
+ case None =>
+ throw new SparkException(s"Block $blockId does not exist")
+ }
+ }
+
+ /**
+ * Get a block's metadata without acquiring any locks. This method is only exposed for use by
+ * [[BlockManager.getStatus()]] and should not be called by other code outside of this class.
+ */
+ private[storage] def get(blockId: BlockId): Option[BlockInfo] = synchronized {
+ infos.get(blockId)
+ }
+
+ /**
+ * Downgrades an exclusive write lock to a shared read lock.
+ */
+ def downgradeLock(blockId: BlockId): Unit = synchronized {
+ logTrace(s"Task $currentTaskAttemptId downgrading write lock for $blockId")
+ val info = get(blockId).get
+ require(info.writerTask == currentTaskAttemptId,
+ s"Task $currentTaskAttemptId tried to downgrade a write lock that it does not hold on" +
+ s" block $blockId")
+ unlock(blockId)
+ val lockOutcome = lockForReading(blockId, blocking = false)
+ assert(lockOutcome.isDefined)
+ }
+
+ /**
+ * Release a lock on the given block.
+ * In case a TaskContext is not propagated properly to all child threads for the task, we fail to
+ * get the TID from TaskContext, so we have to explicitly pass the TID value to release the lock.
+ *
+ * See SPARK-18406 for more discussion of this issue.
+ */
+ def unlock(blockId: BlockId, taskAttemptId: Option[TaskAttemptId] = None): Unit = synchronized {
+ val taskId = taskAttemptId.getOrElse(currentTaskAttemptId)
+ logTrace(s"Task $taskId releasing lock for $blockId")
+ val info = get(blockId).getOrElse {
+ throw new IllegalStateException(s"Block $blockId not found")
+ }
+ if (info.writerTask != BlockInfo.NO_WRITER) {
+ info.writerTask = BlockInfo.NO_WRITER
+ writeLocksByTask.removeBinding(taskId, blockId)
+ } else {
+ assert(info.readerCount > 0, s"Block $blockId is not locked for reading")
+ info.readerCount -= 1
+ val countsForTask = readLocksByTask(taskId)
+ val newPinCountForTask: Int = countsForTask.remove(blockId, 1) - 1
+ assert(newPinCountForTask >= 0,
+ s"Task $taskId release lock on block $blockId more times than it acquired it")
+ }
+ notifyAll()
+ }
+
+ /**
+ * Attempt to acquire the appropriate lock for writing a new block.
+ *
+ * This enforces the first-writer-wins semantics. If we are the first to write the block,
+ * then just go ahead and acquire the write lock. Otherwise, if another thread is already
+ * writing the block, then we wait for the write to finish before acquiring the read lock.
+ *
+ * @return true if the block did not already exist, false otherwise. If this returns false, then
+ * a read lock on the existing block will be held. If this returns true, a write lock on
+ * the new block will be held.
+ */
+ def lockNewBlockForWriting(
+ blockId: BlockId,
+ newBlockInfo: BlockInfo): Boolean = synchronized {
+ logTrace(s"Task $currentTaskAttemptId trying to put $blockId")
+ lockForReading(blockId) match {
+ case Some(info) =>
+ // Block already exists. This could happen if another thread races with us to compute
+ // the same block. In this case, just keep the read lock and return.
+ false
+ case None =>
+ // Block does not yet exist or is removed, so we are free to acquire the write lock
+ infos(blockId) = newBlockInfo
+ lockForWriting(blockId)
+ true
+ }
+ }
+
+ /**
+ * Release all lock held by the given task, clearing that task's pin bookkeeping
+ * structures and updating the global pin counts. This method should be called at the
+ * end of a task (either by a task completion handler or in `TaskRunner.run()`).
+ *
+ * @return the ids of blocks whose pins were released
+ */
+ def releaseAllLocksForTask(taskAttemptId: TaskAttemptId): Seq[BlockId] = {
+ val blocksWithReleasedLocks = mutable.ArrayBuffer[BlockId]()
+
+ val readLocks = synchronized {
+ readLocksByTask.remove(taskAttemptId).getOrElse(ImmutableMultiset.of[BlockId]())
+ }
+ val writeLocks = synchronized {
+ writeLocksByTask.remove(taskAttemptId).getOrElse(Seq.empty)
+ }
+
+ for (blockId <- writeLocks) {
+ infos.get(blockId).foreach { info =>
+ assert(info.writerTask == taskAttemptId)
+ info.writerTask = BlockInfo.NO_WRITER
+ }
+ blocksWithReleasedLocks += blockId
+ }
+ readLocks.entrySet().iterator().asScala.foreach { entry =>
+ val blockId = entry.getElement
+ val lockCount = entry.getCount
+ blocksWithReleasedLocks += blockId
+ synchronized {
+ get(blockId).foreach { info =>
+ info.readerCount -= lockCount
+ assert(info.readerCount >= 0)
+ }
+ }
+ }
+
+ synchronized {
+ notifyAll()
+ }
+ blocksWithReleasedLocks
+ }
+
+ /** Returns the number of locks held by the given task. Used only for testing. */
+ private[storage] def getTaskLockCount(taskAttemptId: TaskAttemptId): Int = {
+ readLocksByTask.get(taskAttemptId).map(_.size()).getOrElse(0) +
+ writeLocksByTask.get(taskAttemptId).map(_.size).getOrElse(0)
+ }
+
+ /**
+ * Returns the number of blocks tracked.
+ */
+ def size: Int = synchronized {
+ infos.size
+ }
+
+ /**
+ * Return the number of map entries in this pin counter's internal data structures.
+ * This is used in unit tests in order to detect memory leaks.
+ */
+ private[storage] def getNumberOfMapEntries: Long = synchronized {
+ size +
+ readLocksByTask.size +
+ readLocksByTask.map(_._2.size()).sum +
+ writeLocksByTask.size +
+ writeLocksByTask.map(_._2.size).sum
+ }
+
+ /**
+ * Returns an iterator over a snapshot of all blocks' metadata. Note that the individual entries
+ * in this iterator are mutable and thus may reflect blocks that are deleted while the iterator
+ * is being traversed.
+ */
+ def entries: Iterator[(BlockId, BlockInfo)] = synchronized {
+ infos.toArray.toIterator
+ }
+
+ /**
+ * Removes the given block and releases the write lock on it.
+ *
+ * This can only be called while holding a write lock on the given block.
+ */
+ def removeBlock(blockId: BlockId): Unit = synchronized {
+ logTrace(s"Task $currentTaskAttemptId trying to remove block $blockId")
+ infos.get(blockId) match {
+ case Some(blockInfo) =>
+ if (blockInfo.writerTask != currentTaskAttemptId) {
+ throw new IllegalStateException(
+ s"Task $currentTaskAttemptId called remove() on block $blockId without a write lock")
+ } else {
+ infos.remove(blockId)
+ blockInfo.readerCount = 0
+ blockInfo.writerTask = BlockInfo.NO_WRITER
+ writeLocksByTask.removeBinding(currentTaskAttemptId, blockId)
+ }
+ case None =>
+ throw new IllegalArgumentException(
+ s"Task $currentTaskAttemptId called remove() on non-existent block $blockId")
+ }
+ notifyAll()
+ }
+
+ /**
+ * Delete all state. Called during shutdown.
+ */
+ def clear(): Unit = synchronized {
+ infos.valuesIterator.foreach { blockInfo =>
+ blockInfo.readerCount = 0
+ blockInfo.writerTask = BlockInfo.NO_WRITER
+ }
+ infos.clear()
+ readLocksByTask.clear()
+ writeLocksByTask.clear()
+ notifyAll()
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index c374b9376622..5f067191070e 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -18,35 +18,33 @@
package org.apache.spark.storage
import java.io._
-import java.nio.{ByteBuffer, MappedByteBuffer}
+import java.nio.ByteBuffer
+import java.nio.channels.Channels
-import scala.collection.mutable.{ArrayBuffer, HashMap}
-import scala.concurrent.{ExecutionContext, Await, Future}
+import scala.collection.mutable
+import scala.collection.mutable.HashMap
+import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration._
-import scala.util.control.NonFatal
+import scala.reflect.ClassTag
import scala.util.Random
-
-import sun.nio.ch.DirectBuffer
+import scala.util.control.NonFatal
import org.apache.spark._
import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics}
-import org.apache.spark.io.CompressionCodec
-import org.apache.spark.memory.MemoryManager
+import org.apache.spark.internal.Logging
+import org.apache.spark.memory.{MemoryManager, MemoryMode}
import org.apache.spark.network._
-import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.ExternalShuffleClient
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
import org.apache.spark.rpc.RpcEnv
-import org.apache.spark.serializer.{SerializerInstance, Serializer}
+import org.apache.spark.serializer.{SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.ShuffleManager
-import org.apache.spark.shuffle.hash.HashShuffleManager
+import org.apache.spark.storage.memory._
+import org.apache.spark.unsafe.Platform
import org.apache.spark.util._
-
-private[spark] sealed trait BlockValues
-private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues
-private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues
-private[spark] case class ArrayValues(buffer: Array[Any]) extends BlockValues
+import org.apache.spark.util.io.ChunkedByteBuffer
/* Class for returning a fetched block and associated metrics. */
private[spark] class BlockResult(
@@ -54,51 +52,103 @@ private[spark] class BlockResult(
val readMethod: DataReadMethod.Value,
val bytes: Long)
+/**
+ * Abstracts away how blocks are stored and provides different ways to read the underlying block
+ * data. Callers should call [[dispose()]] when they're done with the block.
+ */
+private[spark] trait BlockData {
+
+ def toInputStream(): InputStream
+
+ /**
+ * Returns a Netty-friendly wrapper for the block's data.
+ *
+ * Please see `ManagedBuffer.convertToNetty()` for more details.
+ */
+ def toNetty(): Object
+
+ def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer
+
+ def toByteBuffer(): ByteBuffer
+
+ def size: Long
+
+ def dispose(): Unit
+
+}
+
+private[spark] class ByteBufferBlockData(
+ val buffer: ChunkedByteBuffer,
+ val shouldDispose: Boolean) extends BlockData {
+
+ override def toInputStream(): InputStream = buffer.toInputStream(dispose = false)
+
+ override def toNetty(): Object = buffer.toNetty
+
+ override def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer = {
+ buffer.copy(allocator)
+ }
+
+ override def toByteBuffer(): ByteBuffer = buffer.toByteBuffer
+
+ override def size: Long = buffer.size
+
+ override def dispose(): Unit = {
+ if (shouldDispose) {
+ buffer.dispose()
+ }
+ }
+
+}
+
/**
* Manager running on every node (driver and executors) which provides interfaces for putting and
* retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap).
*
- * Note that #initialize() must be called before the BlockManager is usable.
+ * Note that [[initialize()]] must be called before the BlockManager is usable.
*/
private[spark] class BlockManager(
executorId: String,
rpcEnv: RpcEnv,
val master: BlockManagerMaster,
- defaultSerializer: Serializer,
+ val serializerManager: SerializerManager,
val conf: SparkConf,
memoryManager: MemoryManager,
mapOutputTracker: MapOutputTracker,
shuffleManager: ShuffleManager,
- blockTransferService: BlockTransferService,
+ val blockTransferService: BlockTransferService,
securityManager: SecurityManager,
numUsableCores: Int)
- extends BlockDataManager with Logging {
+ extends BlockDataManager with BlockEvictionHandler with Logging {
- val diskBlockManager = new DiskBlockManager(this, conf)
+ private[spark] val externalShuffleServiceEnabled =
+ conf.getBoolean("spark.shuffle.service.enabled", false)
- private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
+ val diskBlockManager = {
+ // Only perform cleanup if an external service is not serving our shuffle files.
+ val deleteFilesOnStop =
+ !externalShuffleServiceEnabled || executorId == SparkContext.DRIVER_IDENTIFIER
+ new DiskBlockManager(conf, deleteFilesOnStop)
+ }
+
+ // Visible for testing
+ private[storage] val blockInfoManager = new BlockInfoManager
private val futureExecutionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonCachedThreadPool("block-manager-future", 128))
// Actual storage of where blocks are kept
- private var externalBlockStoreInitialized = false
- private[spark] val memoryStore = new MemoryStore(this, memoryManager)
- private[spark] val diskStore = new DiskStore(this, diskBlockManager)
- private[spark] lazy val externalBlockStore: ExternalBlockStore = {
- externalBlockStoreInitialized = true
- new ExternalBlockStore(this, executorId)
- }
+ private[spark] val memoryStore =
+ new MemoryStore(conf, blockInfoManager, serializerManager, memoryManager, this)
+ private[spark] val diskStore = new DiskStore(conf, diskBlockManager, securityManager)
memoryManager.setMemoryStore(memoryStore)
- // Note: depending on the memory manager, `maxStorageMemory` may actually vary over time.
+ // Note: depending on the memory manager, `maxMemory` may actually vary over time.
// However, since we use this only for reporting and logging, what we actually want here is
- // the absolute maximum value that `maxStorageMemory` can ever possibly reach. We may need
+ // the absolute maximum value that `maxMemory` can ever possibly reach. We may need
// to revisit whether reporting this value as the "max" is intuitive to the user.
- private val maxMemory = memoryManager.maxStorageMemory
-
- private[spark]
- val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false)
+ private val maxOnHeapMemory = memoryManager.maxOnHeapStorageMemory
+ private val maxOffHeapMemory = memoryManager.maxOffHeapStorageMemory
// Port used by the external shuffle service. In Yarn mode, this may be already be
// set through the Hadoop configuration as the server is launched in the Yarn NM.
@@ -123,21 +173,15 @@ private[spark] class BlockManager(
// Client to read other executors' shuffle files. This is either an external service, or just the
// standard BlockTransferService to directly connect to other Executors.
private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
- val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores)
- new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(),
- securityManager.isSaslEncryptionEnabled())
+ val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores)
+ new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled())
} else {
blockTransferService
}
- // Whether to compress broadcast variables that are stored
- private val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true)
- // Whether to compress shuffle output that are stored
- private val compressShuffle = conf.getBoolean("spark.shuffle.compress", true)
- // Whether to compress RDD partitions that are stored serialized
- private val compressRdds = conf.getBoolean("spark.rdd.compress", false)
- // Whether to compress shuffle output temporarily spilled to disk
- private val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true)
+ // Max number of failures before this block manager refreshes the block locations from the driver
+ private val maxFailuresBeforeLocationRefresh =
+ conf.getInt("spark.block.failures.beforeLocationRefresh", 5)
private val slaveEndpoint = rpcEnv.setupEndpoint(
"BlockManagerEndpoint" + BlockManager.ID_GENERATOR.next,
@@ -148,22 +192,12 @@ private[spark] class BlockManager(
private var asyncReregisterTask: Future[Unit] = null
private val asyncReregisterLock = new Object
- private val metadataCleaner = new MetadataCleaner(
- MetadataCleanerType.BLOCK_MANAGER, this.dropOldNonBroadcastBlocks, conf)
- private val broadcastCleaner = new MetadataCleaner(
- MetadataCleanerType.BROADCAST_VARS, this.dropOldBroadcastBlocks, conf)
-
// Field related to peer block managers that are necessary for block replication
@volatile private var cachedPeers: Seq[BlockManagerId] = _
private val peerFetchLock = new Object
private var lastPeerFetchTime = 0L
- /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
- * the initialization of the compression codec until it is first used. The reason is that a Spark
- * program could be using a user-defined codec in a third party jar, which is loaded in
- * Executor.updateDependencies. When the BlockManager is initialized, user level jars hasn't been
- * loaded yet. */
- private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf)
+ private var blockReplicationPolicy: BlockReplicationPolicy = _
/**
* Initializes the BlockManager with the given appId. This is not performed in the constructor as
@@ -178,8 +212,25 @@ private[spark] class BlockManager(
blockTransferService.init(this)
shuffleClient.init(appId)
- blockManagerId = BlockManagerId(
- executorId, blockTransferService.hostName, blockTransferService.port)
+ blockReplicationPolicy = {
+ val priorityClass = conf.get(
+ "spark.storage.replication.policy", classOf[RandomBlockReplicationPolicy].getName)
+ val clazz = Utils.classForName(priorityClass)
+ val ret = clazz.newInstance.asInstanceOf[BlockReplicationPolicy]
+ logInfo(s"Using $priorityClass for block replication policy")
+ ret
+ }
+
+ val id =
+ BlockManagerId(executorId, blockTransferService.hostName, blockTransferService.port, None)
+
+ val idFromMaster = master.registerBlockManager(
+ id,
+ maxOnHeapMemory,
+ maxOffHeapMemory,
+ slaveEndpoint)
+
+ blockManagerId = if (idFromMaster != null) idFromMaster else id
shuffleServerId = if (externalShuffleServiceEnabled) {
logInfo(s"external shuffle service port = $externalShuffleServicePort")
@@ -188,12 +239,12 @@ private[spark] class BlockManager(
blockManagerId
}
- master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint)
-
// Register Executors' configuration with the local shuffle service, if one should exist.
if (externalShuffleServiceEnabled && !blockManagerId.isDriver) {
registerWithExternalShuffleServer()
}
+
+ logInfo(s"Initialized BlockManager: $blockManagerId")
}
private def registerWithExternalShuffleServer() {
@@ -217,6 +268,9 @@ private[spark] class BlockManager(
logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}"
+ s" more times after waiting $SLEEP_TIME_SECS seconds...", e)
Thread.sleep(SLEEP_TIME_SECS * 1000)
+ case NonFatal(e) =>
+ throw new SparkException("Unable to register with external shuffle server due to : " +
+ e.getMessage, e)
}
}
}
@@ -232,10 +286,10 @@ private[spark] class BlockManager(
* will be made then.
*/
private def reportAllBlocks(): Unit = {
- logInfo(s"Reporting ${blockInfo.size} blocks to the master.")
- for ((blockId, info) <- blockInfo) {
+ logInfo(s"Reporting ${blockInfoManager.size} blocks to the master.")
+ for ((blockId, info) <- blockInfoManager.entries) {
val status = getCurrentBlockStatus(blockId, info)
- if (!tryToReportBlockStatus(blockId, info, status)) {
+ if (info.tellMaster && !tryToReportBlockStatus(blockId, status)) {
logError(s"Failed to report $blockId to master; giving up.")
return
}
@@ -250,8 +304,8 @@ private[spark] class BlockManager(
*/
def reregister(): Unit = {
// TODO: We might need to rate limit re-registering.
- logInfo("BlockManager re-registering with master")
- master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint)
+ logInfo(s"BlockManager $blockManagerId re-registering with master")
+ master.registerBlockManager(blockManagerId, maxOnHeapMemory, maxOffHeapMemory, slaveEndpoint)
reportAllBlocks()
}
@@ -279,7 +333,12 @@ private[spark] class BlockManager(
def waitForAsyncReregister(): Unit = {
val task = asyncReregisterTask
if (task != null) {
- Await.ready(task, Duration.Inf)
+ try {
+ ThreadUtils.awaitReady(task, Duration.Inf)
+ } catch {
+ case NonFatal(t) =>
+ throw new Exception("Error occurred while waiting for async. reregistration", t)
+ }
}
}
@@ -291,34 +350,42 @@ private[spark] class BlockManager(
if (blockId.isShuffle) {
shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
} else {
- val blockBytesOpt = doGetLocal(blockId, asBlockResult = false)
- .asInstanceOf[Option[ByteBuffer]]
- if (blockBytesOpt.isDefined) {
- val buffer = blockBytesOpt.get
- new NioManagedBuffer(buffer)
- } else {
- throw new BlockNotFoundException(blockId.toString)
+ getLocalBytes(blockId) match {
+ case Some(blockData) =>
+ new BlockManagerManagedBuffer(blockInfoManager, blockId, blockData, true)
+ case None =>
+ // If this block manager receives a request for a block that it doesn't have then it's
+ // likely that the master has outdated block statuses for this block. Therefore, we send
+ // an RPC so that this block is marked as being unavailable from this block manager.
+ reportBlockStatus(blockId, BlockStatus.empty)
+ throw new BlockNotFoundException(blockId.toString)
}
}
}
/**
* Put the block locally, using the given storage level.
+ *
+ * '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing
+ * so may corrupt or change the data stored by the `BlockManager`.
*/
- override def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit = {
- putBytes(blockId, data.nioByteBuffer(), level)
+ override def putBlockData(
+ blockId: BlockId,
+ data: ManagedBuffer,
+ level: StorageLevel,
+ classTag: ClassTag[_]): Boolean = {
+ putBytes(blockId, new ChunkedByteBuffer(data.nioByteBuffer()), level)(classTag)
}
/**
* Get the BlockStatus for the block identified by the given ID, if it exists.
- * NOTE: This is mainly for testing, and it doesn't fetch information from external block store.
+ * NOTE: This is mainly for testing.
*/
def getStatus(blockId: BlockId): Option[BlockStatus] = {
- blockInfo.get(blockId).map { info =>
+ blockInfoManager.get(blockId).map { info =>
val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L
val diskSize = if (diskStore.contains(blockId)) diskStore.getSize(blockId) else 0L
- // Assume that block is not in external block store
- BlockStatus(info.level, memSize, diskSize, 0L)
+ BlockStatus(info.level, memSize = memSize, diskSize = diskSize)
}
}
@@ -328,7 +395,12 @@ private[spark] class BlockManager(
* may not know of).
*/
def getMatchingBlockIds(filter: BlockId => Boolean): Seq[BlockId] = {
- (blockInfo.keys ++ diskBlockManager.getAllBlocks()).filter(filter).toSeq
+ // The `toArray` is necessary here in order to force the list to be materialized so that we
+ // don't try to serialize a lazy iterator when responding to client requests.
+ (blockInfoManager.entries.map(_._1) ++ diskBlockManager.getAllBlocks())
+ .filter(filter)
+ .toArray
+ .toSeq
}
/**
@@ -342,10 +414,9 @@ private[spark] class BlockManager(
*/
private def reportBlockStatus(
blockId: BlockId,
- info: BlockInfo,
status: BlockStatus,
droppedMemorySize: Long = 0L): Unit = {
- val needReregister = !tryToReportBlockStatus(blockId, info, status, droppedMemorySize)
+ val needReregister = !tryToReportBlockStatus(blockId, status, droppedMemorySize)
if (needReregister) {
logInfo(s"Got told to re-register updating block $blockId")
// Re-registering will report our new block for free.
@@ -361,19 +432,12 @@ private[spark] class BlockManager(
*/
private def tryToReportBlockStatus(
blockId: BlockId,
- info: BlockInfo,
status: BlockStatus,
droppedMemorySize: Long = 0L): Boolean = {
- if (info.tellMaster) {
- val storageLevel = status.storageLevel
- val inMemSize = Math.max(status.memSize, droppedMemorySize)
- val inExternalBlockStoreSize = status.externalBlockStoreSize
- val onDiskSize = status.diskSize
- master.updateBlockInfo(
- blockManagerId, blockId, storageLevel, inMemSize, onDiskSize, inExternalBlockStoreSize)
- } else {
- true
- }
+ val storageLevel = status.storageLevel
+ val inMemSize = Math.max(status.memSize, droppedMemorySize)
+ val onDiskSize = status.diskSize
+ master.updateBlockInfo(blockManagerId, blockId, storageLevel, inMemSize, onDiskSize)
}
/**
@@ -385,20 +449,21 @@ private[spark] class BlockManager(
info.synchronized {
info.level match {
case null =>
- BlockStatus(StorageLevel.NONE, 0L, 0L, 0L)
+ BlockStatus.empty
case level =>
val inMem = level.useMemory && memoryStore.contains(blockId)
- val inExternalBlockStore = level.useOffHeap && externalBlockStore.contains(blockId)
val onDisk = level.useDisk && diskStore.contains(blockId)
val deserialized = if (inMem) level.deserialized else false
- val replication = if (inMem || inExternalBlockStore || onDisk) level.replication else 1
- val storageLevel =
- StorageLevel(onDisk, inMem, inExternalBlockStore, deserialized, replication)
+ val replication = if (inMem || onDisk) level.replication else 1
+ val storageLevel = StorageLevel(
+ useDisk = onDisk,
+ useMemory = inMem,
+ useOffHeap = level.useOffHeap,
+ deserialized = deserialized,
+ replication = replication)
val memSize = if (inMem) memoryStore.getSize(blockId) else 0L
- val externalBlockStoreSize =
- if (inExternalBlockStore) externalBlockStore.getSize(blockId) else 0L
val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L
- BlockStatus(storageLevel, memSize, diskSize, externalBlockStoreSize)
+ BlockStatus(storageLevel, memSize, diskSize)
}
}
}
@@ -414,17 +479,72 @@ private[spark] class BlockManager(
}
/**
- * Get block from local block manager.
+ * Cleanup code run in response to a failed local read.
+ * Must be called while holding a read lock on the block.
+ */
+ private def handleLocalReadFailure(blockId: BlockId): Nothing = {
+ releaseLock(blockId)
+ // Remove the missing block so that its unavailability is reported to the driver
+ removeBlock(blockId)
+ throw new SparkException(s"Block $blockId was not found even though it's read-locked")
+ }
+
+ /**
+ * Get block from local block manager as an iterator of Java objects.
*/
- def getLocal(blockId: BlockId): Option[BlockResult] = {
+ def getLocalValues(blockId: BlockId): Option[BlockResult] = {
logDebug(s"Getting local block $blockId")
- doGetLocal(blockId, asBlockResult = true).asInstanceOf[Option[BlockResult]]
+ blockInfoManager.lockForReading(blockId) match {
+ case None =>
+ logDebug(s"Block $blockId was not found")
+ None
+ case Some(info) =>
+ val level = info.level
+ logDebug(s"Level for block $blockId is $level")
+ val taskAttemptId = Option(TaskContext.get()).map(_.taskAttemptId())
+ if (level.useMemory && memoryStore.contains(blockId)) {
+ val iter: Iterator[Any] = if (level.deserialized) {
+ memoryStore.getValues(blockId).get
+ } else {
+ serializerManager.dataDeserializeStream(
+ blockId, memoryStore.getBytes(blockId).get.toInputStream())(info.classTag)
+ }
+ // We need to capture the current taskId in case the iterator completion is triggered
+ // from a different thread which does not have TaskContext set; see SPARK-18406 for
+ // discussion.
+ val ci = CompletionIterator[Any, Iterator[Any]](iter, {
+ releaseLock(blockId, taskAttemptId)
+ })
+ Some(new BlockResult(ci, DataReadMethod.Memory, info.size))
+ } else if (level.useDisk && diskStore.contains(blockId)) {
+ val diskData = diskStore.getBytes(blockId)
+ val iterToReturn: Iterator[Any] = {
+ if (level.deserialized) {
+ val diskValues = serializerManager.dataDeserializeStream(
+ blockId,
+ diskData.toInputStream())(info.classTag)
+ maybeCacheDiskValuesInMemory(info, blockId, level, diskValues)
+ } else {
+ val stream = maybeCacheDiskBytesInMemory(info, blockId, level, diskData)
+ .map { _.toInputStream(dispose = false) }
+ .getOrElse { diskData.toInputStream() }
+ serializerManager.dataDeserializeStream(blockId, stream)(info.classTag)
+ }
+ }
+ val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, {
+ releaseLockAndDispose(blockId, diskData, taskAttemptId)
+ })
+ Some(new BlockResult(ci, DataReadMethod.Disk, info.size))
+ } else {
+ handleLocalReadFailure(blockId)
+ }
+ }
}
/**
* Get block from the local block manager as serialized bytes.
*/
- def getLocalBytes(blockId: BlockId): Option[ByteBuffer] = {
+ def getLocalBytes(blockId: BlockId): Option[BlockData] = {
logDebug(s"Getting local block $blockId as bytes")
// As an optimization for map output fetches, if the block is for a shuffle, return it
// without acquiring a lock; the disk store never deletes (recent) items so this should work
@@ -432,186 +552,129 @@ private[spark] class BlockManager(
val shuffleBlockResolver = shuffleManager.shuffleBlockResolver
// TODO: This should gracefully handle case where local block is not available. Currently
// downstream code will throw an exception.
- Option(
+ val buf = new ChunkedByteBuffer(
shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer())
+ Some(new ByteBufferBlockData(buf, true))
} else {
- doGetLocal(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]]
+ blockInfoManager.lockForReading(blockId).map { info => doGetLocalBytes(blockId, info) }
}
}
- private def doGetLocal(blockId: BlockId, asBlockResult: Boolean): Option[Any] = {
- val info = blockInfo.get(blockId).orNull
- if (info != null) {
- info.synchronized {
- // Double check to make sure the block is still there. There is a small chance that the
- // block has been removed by removeBlock (which also synchronizes on the blockInfo object).
- // Note that this only checks metadata tracking. If user intentionally deleted the block
- // on disk or from off heap storage without using removeBlock, this conditional check will
- // still pass but eventually we will get an exception because we can't find the block.
- if (blockInfo.get(blockId).isEmpty) {
- logWarning(s"Block $blockId had been removed")
- return None
- }
-
- // If another thread is writing the block, wait for it to become ready.
- if (!info.waitForReady()) {
- // If we get here, the block write failed.
- logWarning(s"Block $blockId was marked as failure.")
- return None
- }
-
- val level = info.level
- logDebug(s"Level for block $blockId is $level")
-
- // Look for the block in memory
- if (level.useMemory) {
- logDebug(s"Getting block $blockId from memory")
- val result = if (asBlockResult) {
- memoryStore.getValues(blockId).map(new BlockResult(_, DataReadMethod.Memory, info.size))
- } else {
- memoryStore.getBytes(blockId)
- }
- result match {
- case Some(values) =>
- return result
- case None =>
- logDebug(s"Block $blockId not found in memory")
- }
- }
-
- // Look for the block in external block store
- if (level.useOffHeap) {
- logDebug(s"Getting block $blockId from ExternalBlockStore")
- if (externalBlockStore.contains(blockId)) {
- val result = if (asBlockResult) {
- externalBlockStore.getValues(blockId)
- .map(new BlockResult(_, DataReadMethod.Memory, info.size))
- } else {
- externalBlockStore.getBytes(blockId)
- }
- result match {
- case Some(values) =>
- return result
- case None =>
- logDebug(s"Block $blockId not found in ExternalBlockStore")
- }
- }
- }
-
- // Look for block on disk, potentially storing it back in memory if required
- if (level.useDisk) {
- logDebug(s"Getting block $blockId from disk")
- val bytes: ByteBuffer = diskStore.getBytes(blockId) match {
- case Some(b) => b
- case None =>
- throw new BlockException(
- blockId, s"Block $blockId not found on disk, though it should be")
- }
- assert(0 == bytes.position())
-
- if (!level.useMemory) {
- // If the block shouldn't be stored in memory, we can just return it
- if (asBlockResult) {
- return Some(new BlockResult(dataDeserialize(blockId, bytes), DataReadMethod.Disk,
- info.size))
- } else {
- return Some(bytes)
- }
- } else {
- // Otherwise, we also have to store something in the memory store
- if (!level.deserialized || !asBlockResult) {
- /* We'll store the bytes in memory if the block's storage level includes
- * "memory serialized", or if it should be cached as objects in memory
- * but we only requested its serialized bytes. */
- memoryStore.putBytes(blockId, bytes.limit, () => {
- // https://issues.apache.org/jira/browse/SPARK-6076
- // If the file size is bigger than the free memory, OOM will happen. So if we cannot
- // put it into MemoryStore, copyForMemory should not be created. That's why this
- // action is put into a `() => ByteBuffer` and created lazily.
- val copyForMemory = ByteBuffer.allocate(bytes.limit)
- copyForMemory.put(bytes)
- })
- bytes.rewind()
- }
- if (!asBlockResult) {
- return Some(bytes)
- } else {
- val values = dataDeserialize(blockId, bytes)
- if (level.deserialized) {
- // Cache the values before returning them
- val putResult = memoryStore.putIterator(
- blockId, values, level, returnValues = true, allowPersistToDisk = false)
- // The put may or may not have succeeded, depending on whether there was enough
- // space to unroll the block. Either way, the put here should return an iterator.
- putResult.data match {
- case Left(it) =>
- return Some(new BlockResult(it, DataReadMethod.Disk, info.size))
- case _ =>
- // This only happens if we dropped the values back to disk (which is never)
- throw new SparkException("Memory store did not return an iterator!")
- }
- } else {
- return Some(new BlockResult(values, DataReadMethod.Disk, info.size))
- }
- }
- }
- }
+ /**
+ * Get block from the local block manager as serialized bytes.
+ *
+ * Must be called while holding a read lock on the block.
+ * Releases the read lock upon exception; keeps the read lock upon successful return.
+ */
+ private def doGetLocalBytes(blockId: BlockId, info: BlockInfo): BlockData = {
+ val level = info.level
+ logDebug(s"Level for block $blockId is $level")
+ // In order, try to read the serialized bytes from memory, then from disk, then fall back to
+ // serializing in-memory objects, and, finally, throw an exception if the block does not exist.
+ if (level.deserialized) {
+ // Try to avoid expensive serialization by reading a pre-serialized copy from disk:
+ if (level.useDisk && diskStore.contains(blockId)) {
+ // Note: we purposely do not try to put the block back into memory here. Since this branch
+ // handles deserialized blocks, this block may only be cached in memory as objects, not
+ // serialized bytes. Because the caller only requested bytes, it doesn't make sense to
+ // cache the block's deserialized objects since that caching may not have a payoff.
+ diskStore.getBytes(blockId)
+ } else if (level.useMemory && memoryStore.contains(blockId)) {
+ // The block was not found on disk, so serialize an in-memory copy:
+ new ByteBufferBlockData(serializerManager.dataSerializeWithExplicitClassTag(
+ blockId, memoryStore.getValues(blockId).get, info.classTag), true)
+ } else {
+ handleLocalReadFailure(blockId)
+ }
+ } else { // storage level is serialized
+ if (level.useMemory && memoryStore.contains(blockId)) {
+ new ByteBufferBlockData(memoryStore.getBytes(blockId).get, false)
+ } else if (level.useDisk && diskStore.contains(blockId)) {
+ val diskData = diskStore.getBytes(blockId)
+ maybeCacheDiskBytesInMemory(info, blockId, level, diskData)
+ .map(new ByteBufferBlockData(_, false))
+ .getOrElse(diskData)
+ } else {
+ handleLocalReadFailure(blockId)
}
- } else {
- logDebug(s"Block $blockId not registered locally")
}
- None
}
/**
* Get block from remote block managers.
+ *
+ * This does not acquire a lock on this block in this JVM.
*/
- def getRemote(blockId: BlockId): Option[BlockResult] = {
- logDebug(s"Getting remote block $blockId")
- doGetRemote(blockId, asBlockResult = true).asInstanceOf[Option[BlockResult]]
+ private def getRemoteValues[T: ClassTag](blockId: BlockId): Option[BlockResult] = {
+ val ct = implicitly[ClassTag[T]]
+ getRemoteBytes(blockId).map { data =>
+ val values =
+ serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true))(ct)
+ new BlockResult(values, DataReadMethod.Network, data.size)
+ }
}
/**
- * Get block from remote block managers as serialized bytes.
+ * Return a list of locations for the given block, prioritizing the local machine since
+ * multiple block managers can share the same host.
*/
- def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = {
- logDebug(s"Getting remote block $blockId as bytes")
- doGetRemote(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]]
+ private def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
+ val locs = Random.shuffle(master.getLocations(blockId))
+ val (preferredLocs, otherLocs) = locs.partition { loc => blockManagerId.host == loc.host }
+ preferredLocs ++ otherLocs
}
- private def doGetRemote(blockId: BlockId, asBlockResult: Boolean): Option[Any] = {
+ /**
+ * Get block from remote block managers as serialized bytes.
+ */
+ def getRemoteBytes(blockId: BlockId): Option[ChunkedByteBuffer] = {
+ logDebug(s"Getting remote block $blockId")
require(blockId != null, "BlockId is null")
- val locations = Random.shuffle(master.getLocations(blockId))
- var numFetchFailures = 0
- for (loc <- locations) {
+ var runningFailureCount = 0
+ var totalFailureCount = 0
+ val locations = getLocations(blockId)
+ val maxFetchFailures = locations.size
+ var locationIterator = locations.iterator
+ while (locationIterator.hasNext) {
+ val loc = locationIterator.next()
logDebug(s"Getting remote block $blockId from $loc")
val data = try {
blockTransferService.fetchBlockSync(
loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer()
} catch {
case NonFatal(e) =>
- numFetchFailures += 1
- if (numFetchFailures == locations.size) {
- // An exception is thrown while fetching this block from all locations
- throw new BlockFetchException(s"Failed to fetch block from" +
- s" ${locations.size} locations. Most recent failure cause:", e)
- } else {
- // This location failed, so we retry fetch from a different one by returning null here
- logWarning(s"Failed to fetch remote block $blockId " +
- s"from $loc (failed attempt $numFetchFailures)", e)
- null
+ runningFailureCount += 1
+ totalFailureCount += 1
+
+ if (totalFailureCount >= maxFetchFailures) {
+ // Give up trying anymore locations. Either we've tried all of the original locations,
+ // or we've refreshed the list of locations from the master, and have still
+ // hit failures after trying locations from the refreshed list.
+ logWarning(s"Failed to fetch block after $totalFailureCount fetch failures. " +
+ s"Most recent failure cause:", e)
+ return None
+ }
+
+ logWarning(s"Failed to fetch remote block $blockId " +
+ s"from $loc (failed attempt $runningFailureCount)", e)
+
+ // If there is a large number of executors then locations list can contain a
+ // large number of stale entries causing a large number of retries that may
+ // take a significant amount of time. To get rid of these stale entries
+ // we refresh the block locations after a certain number of fetch failures
+ if (runningFailureCount >= maxFailuresBeforeLocationRefresh) {
+ locationIterator = getLocations(blockId).iterator
+ logDebug(s"Refreshed locations from the driver " +
+ s"after ${runningFailureCount} fetch failures.")
+ runningFailureCount = 0
}
+
+ // This location failed, so we retry fetch from a different one by returning null here
+ null
}
if (data != null) {
- if (asBlockResult) {
- return Some(new BlockResult(
- dataDeserialize(blockId, data),
- DataReadMethod.Network,
- data.limit()))
- } else {
- return Some(data)
- }
+ return Some(new ChunkedByteBuffer(data))
}
logDebug(s"The value of block $blockId is null")
}
@@ -621,14 +684,18 @@ private[spark] class BlockManager(
/**
* Get a block from the block manager (either local or remote).
+ *
+ * This acquires a read lock on the block if the block was stored locally and does not acquire
+ * any locks if the block was fetched from a remote block manager. The read lock will
+ * automatically be freed once the result's `data` iterator is fully consumed.
*/
- def get(blockId: BlockId): Option[BlockResult] = {
- val local = getLocal(blockId)
+ def get[T: ClassTag](blockId: BlockId): Option[BlockResult] = {
+ val local = getLocalValues(blockId)
if (local.isDefined) {
logInfo(s"Found block $blockId locally")
return local
}
- val remote = getRemote(blockId)
+ val remote = getRemoteValues[T](blockId)
if (remote.isDefined) {
logInfo(s"Found block $blockId remotely")
return remote
@@ -636,14 +703,101 @@ private[spark] class BlockManager(
None
}
- def putIterator(
+ /**
+ * Downgrades an exclusive write lock to a shared read lock.
+ */
+ def downgradeLock(blockId: BlockId): Unit = {
+ blockInfoManager.downgradeLock(blockId)
+ }
+
+ /**
+ * Release a lock on the given block with explicit TID.
+ * The param `taskAttemptId` should be passed in case we can't get the correct TID from
+ * TaskContext, for example, the input iterator of a cached RDD iterates to the end in a child
+ * thread.
+ */
+ def releaseLock(blockId: BlockId, taskAttemptId: Option[Long] = None): Unit = {
+ blockInfoManager.unlock(blockId, taskAttemptId)
+ }
+
+ /**
+ * Registers a task with the BlockManager in order to initialize per-task bookkeeping structures.
+ */
+ def registerTask(taskAttemptId: Long): Unit = {
+ blockInfoManager.registerTask(taskAttemptId)
+ }
+
+ /**
+ * Release all locks for the given task.
+ *
+ * @return the blocks whose locks were released.
+ */
+ def releaseAllLocksForTask(taskAttemptId: Long): Seq[BlockId] = {
+ blockInfoManager.releaseAllLocksForTask(taskAttemptId)
+ }
+
+ /**
+ * Retrieve the given block if it exists, otherwise call the provided `makeIterator` method
+ * to compute the block, persist it, and return its values.
+ *
+ * @return either a BlockResult if the block was successfully cached, or an iterator if the block
+ * could not be cached.
+ */
+ def getOrElseUpdate[T](
blockId: BlockId,
- values: Iterator[Any],
level: StorageLevel,
- tellMaster: Boolean = true,
- effectiveStorageLevel: Option[StorageLevel] = None): Seq[(BlockId, BlockStatus)] = {
+ classTag: ClassTag[T],
+ makeIterator: () => Iterator[T]): Either[BlockResult, Iterator[T]] = {
+ // Attempt to read the block from local or remote storage. If it's present, then we don't need
+ // to go through the local-get-or-put path.
+ get[T](blockId)(classTag) match {
+ case Some(block) =>
+ return Left(block)
+ case _ =>
+ // Need to compute the block.
+ }
+ // Initially we hold no locks on this block.
+ doPutIterator(blockId, makeIterator, level, classTag, keepReadLock = true) match {
+ case None =>
+ // doPut() didn't hand work back to us, so the block already existed or was successfully
+ // stored. Therefore, we now hold a read lock on the block.
+ val blockResult = getLocalValues(blockId).getOrElse {
+ // Since we held a read lock between the doPut() and get() calls, the block should not
+ // have been evicted, so get() not returning the block indicates some internal error.
+ releaseLock(blockId)
+ throw new SparkException(s"get() failed for block $blockId even though we held a lock")
+ }
+ // We already hold a read lock on the block from the doPut() call and getLocalValues()
+ // acquires the lock again, so we need to call releaseLock() here so that the net number
+ // of lock acquisitions is 1 (since the caller will only call release() once).
+ releaseLock(blockId)
+ Left(blockResult)
+ case Some(iter) =>
+ // The put failed, likely because the data was too large to fit in memory and could not be
+ // dropped to disk. Therefore, we need to pass the input iterator back to the caller so
+ // that they can decide what to do with the values (e.g. process them without caching).
+ Right(iter)
+ }
+ }
+
+ /**
+ * @return true if the block was stored or false if an error occurred.
+ */
+ def putIterator[T: ClassTag](
+ blockId: BlockId,
+ values: Iterator[T],
+ level: StorageLevel,
+ tellMaster: Boolean = true): Boolean = {
require(values != null, "Values is null")
- doPut(blockId, IteratorValues(values), level, tellMaster, effectiveStorageLevel)
+ doPutIterator(blockId, () => values, level, implicitly[ClassTag[T]], tellMaster) match {
+ case None =>
+ true
+ case Some(iter) =>
+ // Caller doesn't care about the iterator values, so we can close the iterator here
+ // to free resources earlier
+ iter.close()
+ false
+ }
}
/**
@@ -657,224 +811,390 @@ private[spark] class BlockManager(
serializerInstance: SerializerInstance,
bufferSize: Int,
writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = {
- val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
- new DiskBlockObjectWriter(file, serializerInstance, bufferSize, compressStream,
- syncWrites, writeMetrics)
+ new DiskBlockObjectWriter(file, serializerManager, serializerInstance, bufferSize,
+ syncWrites, writeMetrics, blockId)
}
/**
- * Put a new block of values to the block manager.
- * Return a list of blocks updated as a result of this put.
+ * Put a new block of serialized bytes to the block manager.
+ *
+ * '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing
+ * so may corrupt or change the data stored by the `BlockManager`.
+ *
+ * @return true if the block was stored or false if an error occurred.
*/
- def putArray(
+ def putBytes[T: ClassTag](
blockId: BlockId,
- values: Array[Any],
+ bytes: ChunkedByteBuffer,
level: StorageLevel,
- tellMaster: Boolean = true,
- effectiveStorageLevel: Option[StorageLevel] = None): Seq[(BlockId, BlockStatus)] = {
- require(values != null, "Values is null")
- doPut(blockId, ArrayValues(values), level, tellMaster, effectiveStorageLevel)
+ tellMaster: Boolean = true): Boolean = {
+ require(bytes != null, "Bytes is null")
+ doPutBytes(blockId, bytes, level, implicitly[ClassTag[T]], tellMaster)
}
/**
- * Put a new block of serialized bytes to the block manager.
- * Return a list of blocks updated as a result of this put.
+ * Put the given bytes according to the given level in one of the block stores, replicating
+ * the values if necessary.
+ *
+ * If the block already exists, this method will not overwrite it.
+ *
+ * '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing
+ * so may corrupt or change the data stored by the `BlockManager`.
+ *
+ * @param keepReadLock if true, this method will hold the read lock when it returns (even if the
+ * block already exists). If false, this method will hold no locks when it
+ * returns.
+ * @return true if the block was already present or if the put succeeded, false otherwise.
*/
- def putBytes(
+ private def doPutBytes[T](
blockId: BlockId,
- bytes: ByteBuffer,
+ bytes: ChunkedByteBuffer,
level: StorageLevel,
+ classTag: ClassTag[T],
tellMaster: Boolean = true,
- effectiveStorageLevel: Option[StorageLevel] = None): Seq[(BlockId, BlockStatus)] = {
- require(bytes != null, "Bytes is null")
- doPut(blockId, ByteBufferValues(bytes), level, tellMaster, effectiveStorageLevel)
+ keepReadLock: Boolean = false): Boolean = {
+ doPut(blockId, level, classTag, tellMaster = tellMaster, keepReadLock = keepReadLock) { info =>
+ val startTimeMs = System.currentTimeMillis
+ // Since we're storing bytes, initiate the replication before storing them locally.
+ // This is faster as data is already serialized and ready to send.
+ val replicationFuture = if (level.replication > 1) {
+ Future {
+ // This is a blocking action and should run in futureExecutionContext which is a cached
+ // thread pool. The ByteBufferBlockData wrapper is not disposed of to avoid releasing
+ // buffers that are owned by the caller.
+ replicate(blockId, new ByteBufferBlockData(bytes, false), level, classTag)
+ }(futureExecutionContext)
+ } else {
+ null
+ }
+
+ val size = bytes.size
+
+ if (level.useMemory) {
+ // Put it in memory first, even if it also has useDisk set to true;
+ // We will drop it to disk later if the memory store can't hold it.
+ val putSucceeded = if (level.deserialized) {
+ val values =
+ serializerManager.dataDeserializeStream(blockId, bytes.toInputStream())(classTag)
+ memoryStore.putIteratorAsValues(blockId, values, classTag) match {
+ case Right(_) => true
+ case Left(iter) =>
+ // If putting deserialized values in memory failed, we will put the bytes directly to
+ // disk, so we don't need this iterator and can close it to free resources earlier.
+ iter.close()
+ false
+ }
+ } else {
+ val memoryMode = level.memoryMode
+ memoryStore.putBytes(blockId, size, memoryMode, () => {
+ if (memoryMode == MemoryMode.OFF_HEAP &&
+ bytes.chunks.exists(buffer => !buffer.isDirect)) {
+ bytes.copy(Platform.allocateDirectBuffer)
+ } else {
+ bytes
+ }
+ })
+ }
+ if (!putSucceeded && level.useDisk) {
+ logWarning(s"Persisting block $blockId to disk instead.")
+ diskStore.putBytes(blockId, bytes)
+ }
+ } else if (level.useDisk) {
+ diskStore.putBytes(blockId, bytes)
+ }
+
+ val putBlockStatus = getCurrentBlockStatus(blockId, info)
+ val blockWasSuccessfullyStored = putBlockStatus.storageLevel.isValid
+ if (blockWasSuccessfullyStored) {
+ // Now that the block is in either the memory or disk store,
+ // tell the master about it.
+ info.size = size
+ if (tellMaster && info.tellMaster) {
+ reportBlockStatus(blockId, putBlockStatus)
+ }
+ addUpdatedBlockStatusToTaskMetrics(blockId, putBlockStatus)
+ }
+ logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs)))
+ if (level.replication > 1) {
+ // Wait for asynchronous replication to finish
+ try {
+ ThreadUtils.awaitReady(replicationFuture, Duration.Inf)
+ } catch {
+ case NonFatal(t) =>
+ throw new Exception("Error occurred while waiting for replication to finish", t)
+ }
+ }
+ if (blockWasSuccessfullyStored) {
+ None
+ } else {
+ Some(bytes)
+ }
+ }.isEmpty
}
/**
- * Put the given block according to the given level in one of the block stores, replicating
- * the values if necessary.
+ * Helper method used to abstract common code from [[doPutBytes()]] and [[doPutIterator()]].
*
- * The effective storage level refers to the level according to which the block will actually be
- * handled. This allows the caller to specify an alternate behavior of doPut while preserving
- * the original level specified by the user.
+ * @param putBody a function which attempts the actual put() and returns None on success
+ * or Some on failure.
*/
- private def doPut(
+ private def doPut[T](
blockId: BlockId,
- data: BlockValues,
level: StorageLevel,
- tellMaster: Boolean = true,
- effectiveStorageLevel: Option[StorageLevel] = None)
- : Seq[(BlockId, BlockStatus)] = {
+ classTag: ClassTag[_],
+ tellMaster: Boolean,
+ keepReadLock: Boolean)(putBody: BlockInfo => Option[T]): Option[T] = {
require(blockId != null, "BlockId is null")
require(level != null && level.isValid, "StorageLevel is null or invalid")
- effectiveStorageLevel.foreach { level =>
- require(level != null && level.isValid, "Effective StorageLevel is null or invalid")
- }
-
- // Return value
- val updatedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
- /* Remember the block's storage level so that we can correctly drop it to disk if it needs
- * to be dropped right after it got put into memory. Note, however, that other threads will
- * not be able to get() this block until we call markReady on its BlockInfo. */
val putBlockInfo = {
- val tinfo = new BlockInfo(level, tellMaster)
- // Do atomically !
- val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo)
- if (oldBlockOpt.isDefined) {
- if (oldBlockOpt.get.waitForReady()) {
- logWarning(s"Block $blockId already exists on this machine; not re-adding it")
- return updatedBlocks
- }
- // TODO: So the block info exists - but previous attempt to load it (?) failed.
- // What do we do now ? Retry on it ?
- oldBlockOpt.get
+ val newInfo = new BlockInfo(level, classTag, tellMaster)
+ if (blockInfoManager.lockNewBlockForWriting(blockId, newInfo)) {
+ newInfo
} else {
- tinfo
+ logWarning(s"Block $blockId already exists on this machine; not re-adding it")
+ if (!keepReadLock) {
+ // lockNewBlockForWriting returned a read lock on the existing block, so we must free it:
+ releaseLock(blockId)
+ }
+ return None
}
}
val startTimeMs = System.currentTimeMillis
-
- /* If we're storing values and we need to replicate the data, we'll want access to the values,
- * but because our put will read the whole iterator, there will be no values left. For the
- * case where the put serializes data, we'll remember the bytes, above; but for the case where
- * it doesn't, such as deserialized storage, let's rely on the put returning an Iterator. */
- var valuesAfterPut: Iterator[Any] = null
-
- // Ditto for the bytes after the put
- var bytesAfterPut: ByteBuffer = null
-
- // Size of the block in bytes
- var size = 0L
-
- // The level we actually use to put the block
- val putLevel = effectiveStorageLevel.getOrElse(level)
-
- // If we're storing bytes, then initiate the replication before storing them locally.
- // This is faster as data is already serialized and ready to send.
- val replicationFuture = data match {
- case b: ByteBufferValues if putLevel.replication > 1 =>
- // Duplicate doesn't copy the bytes, but just creates a wrapper
- val bufferView = b.buffer.duplicate()
- Future {
- // This is a blocking action and should run in futureExecutionContext which is a cached
- // thread pool
- replicate(blockId, bufferView, putLevel)
- }(futureExecutionContext)
- case _ => null
+ var exceptionWasThrown: Boolean = true
+ val result: Option[T] = try {
+ val res = putBody(putBlockInfo)
+ exceptionWasThrown = false
+ if (res.isEmpty) {
+ // the block was successfully stored
+ if (keepReadLock) {
+ blockInfoManager.downgradeLock(blockId)
+ } else {
+ blockInfoManager.unlock(blockId)
+ }
+ } else {
+ removeBlockInternal(blockId, tellMaster = false)
+ logWarning(s"Putting block $blockId failed")
+ }
+ res
+ } finally {
+ // This cleanup is performed in a finally block rather than a `catch` to avoid having to
+ // catch and properly re-throw InterruptedException.
+ if (exceptionWasThrown) {
+ logWarning(s"Putting block $blockId failed due to an exception")
+ // If an exception was thrown then it's possible that the code in `putBody` has already
+ // notified the master about the availability of this block, so we need to send an update
+ // to remove this block location.
+ removeBlockInternal(blockId, tellMaster = tellMaster)
+ // The `putBody` code may have also added a new block status to TaskMetrics, so we need
+ // to cancel that out by overwriting it with an empty block status. We only do this if
+ // the finally block was entered via an exception because doing this unconditionally would
+ // cause us to send empty block statuses for every block that failed to be cached due to
+ // a memory shortage (which is an expected failure, unlike an uncaught exception).
+ addUpdatedBlockStatusToTaskMetrics(blockId, BlockStatus.empty)
+ }
}
-
- putBlockInfo.synchronized {
- logTrace("Put for block %s took %s to get into synchronized block"
+ if (level.replication > 1) {
+ logDebug("Putting block %s with replication took %s"
.format(blockId, Utils.getUsedTimeMs(startTimeMs)))
+ } else {
+ logDebug("Putting block %s without replication took %s"
+ .format(blockId, Utils.getUsedTimeMs(startTimeMs)))
+ }
+ result
+ }
- var marked = false
- try {
- // returnValues - Whether to return the values put
- // blockStore - The type of storage to put these values into
- val (returnValues, blockStore: BlockStore) = {
- if (putLevel.useMemory) {
- // Put it in memory first, even if it also has useDisk set to true;
- // We will drop it to disk later if the memory store can't hold it.
- (true, memoryStore)
- } else if (putLevel.useOffHeap) {
- // Use external block store
- (false, externalBlockStore)
- } else if (putLevel.useDisk) {
- // Don't get back the bytes from put unless we replicate them
- (putLevel.replication > 1, diskStore)
- } else {
- assert(putLevel == StorageLevel.NONE)
- throw new BlockException(
- blockId, s"Attempted to put block $blockId without specifying storage level!")
+ /**
+ * Put the given block according to the given level in one of the block stores, replicating
+ * the values if necessary.
+ *
+ * If the block already exists, this method will not overwrite it.
+ *
+ * @param keepReadLock if true, this method will hold the read lock when it returns (even if the
+ * block already exists). If false, this method will hold no locks when it
+ * returns.
+ * @return None if the block was already present or if the put succeeded, or Some(iterator)
+ * if the put failed.
+ */
+ private def doPutIterator[T](
+ blockId: BlockId,
+ iterator: () => Iterator[T],
+ level: StorageLevel,
+ classTag: ClassTag[T],
+ tellMaster: Boolean = true,
+ keepReadLock: Boolean = false): Option[PartiallyUnrolledIterator[T]] = {
+ doPut(blockId, level, classTag, tellMaster = tellMaster, keepReadLock = keepReadLock) { info =>
+ val startTimeMs = System.currentTimeMillis
+ var iteratorFromFailedMemoryStorePut: Option[PartiallyUnrolledIterator[T]] = None
+ // Size of the block in bytes
+ var size = 0L
+ if (level.useMemory) {
+ // Put it in memory first, even if it also has useDisk set to true;
+ // We will drop it to disk later if the memory store can't hold it.
+ if (level.deserialized) {
+ memoryStore.putIteratorAsValues(blockId, iterator(), classTag) match {
+ case Right(s) =>
+ size = s
+ case Left(iter) =>
+ // Not enough space to unroll this block; drop to disk if applicable
+ if (level.useDisk) {
+ logWarning(s"Persisting block $blockId to disk instead.")
+ diskStore.put(blockId) { channel =>
+ val out = Channels.newOutputStream(channel)
+ serializerManager.dataSerializeStream(blockId, out, iter)(classTag)
+ }
+ size = diskStore.getSize(blockId)
+ } else {
+ iteratorFromFailedMemoryStorePut = Some(iter)
+ }
+ }
+ } else { // !level.deserialized
+ memoryStore.putIteratorAsBytes(blockId, iterator(), classTag, level.memoryMode) match {
+ case Right(s) =>
+ size = s
+ case Left(partiallySerializedValues) =>
+ // Not enough space to unroll this block; drop to disk if applicable
+ if (level.useDisk) {
+ logWarning(s"Persisting block $blockId to disk instead.")
+ diskStore.put(blockId) { channel =>
+ val out = Channels.newOutputStream(channel)
+ partiallySerializedValues.finishWritingToStream(out)
+ }
+ size = diskStore.getSize(blockId)
+ } else {
+ iteratorFromFailedMemoryStorePut = Some(partiallySerializedValues.valuesIterator)
+ }
}
}
- // Actually put the values
- val result = data match {
- case IteratorValues(iterator) =>
- blockStore.putIterator(blockId, iterator, putLevel, returnValues)
- case ArrayValues(array) =>
- blockStore.putArray(blockId, array, putLevel, returnValues)
- case ByteBufferValues(bytes) =>
- bytes.rewind()
- blockStore.putBytes(blockId, bytes, putLevel)
- }
- size = result.size
- result.data match {
- case Left (newIterator) if putLevel.useMemory => valuesAfterPut = newIterator
- case Right (newBytes) => bytesAfterPut = newBytes
- case _ =>
+ } else if (level.useDisk) {
+ diskStore.put(blockId) { channel =>
+ val out = Channels.newOutputStream(channel)
+ serializerManager.dataSerializeStream(blockId, out, iterator())(classTag)
}
+ size = diskStore.getSize(blockId)
+ }
- // Keep track of which blocks are dropped from memory
- if (putLevel.useMemory) {
- result.droppedBlocks.foreach { updatedBlocks += _ }
+ val putBlockStatus = getCurrentBlockStatus(blockId, info)
+ val blockWasSuccessfullyStored = putBlockStatus.storageLevel.isValid
+ if (blockWasSuccessfullyStored) {
+ // Now that the block is in either the memory or disk store, tell the master about it.
+ info.size = size
+ if (tellMaster && info.tellMaster) {
+ reportBlockStatus(blockId, putBlockStatus)
}
-
- val putBlockStatus = getCurrentBlockStatus(blockId, putBlockInfo)
- if (putBlockStatus.storageLevel != StorageLevel.NONE) {
- // Now that the block is in either the memory, externalBlockStore, or disk store,
- // let other threads read it, and tell the master about it.
- marked = true
- putBlockInfo.markReady(size)
- if (tellMaster) {
- reportBlockStatus(blockId, putBlockInfo, putBlockStatus)
+ addUpdatedBlockStatusToTaskMetrics(blockId, putBlockStatus)
+ logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs)))
+ if (level.replication > 1) {
+ val remoteStartTime = System.currentTimeMillis
+ val bytesToReplicate = doGetLocalBytes(blockId, info)
+ // [SPARK-16550] Erase the typed classTag when using default serialization, since
+ // NettyBlockRpcServer crashes when deserializing repl-defined classes.
+ // TODO(ekl) remove this once the classloader issue on the remote end is fixed.
+ val remoteClassTag = if (!serializerManager.canUseKryo(classTag)) {
+ scala.reflect.classTag[Any]
+ } else {
+ classTag
}
- updatedBlocks += ((blockId, putBlockStatus))
- }
- } finally {
- // If we failed in putting the block to memory/disk, notify other possible readers
- // that it has failed, and then remove it from the block info map.
- if (!marked) {
- // Note that the remove must happen before markFailure otherwise another thread
- // could've inserted a new BlockInfo before we remove it.
- blockInfo.remove(blockId)
- putBlockInfo.markFailure()
- logWarning(s"Putting block $blockId failed")
+ try {
+ replicate(blockId, bytesToReplicate, level, remoteClassTag)
+ } finally {
+ bytesToReplicate.dispose()
+ }
+ logDebug("Put block %s remotely took %s"
+ .format(blockId, Utils.getUsedTimeMs(remoteStartTime)))
}
}
+ assert(blockWasSuccessfullyStored == iteratorFromFailedMemoryStorePut.isEmpty)
+ iteratorFromFailedMemoryStorePut
}
- logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs)))
-
- // Either we're storing bytes and we asynchronously started replication, or we're storing
- // values and need to serialize and replicate them now:
- if (putLevel.replication > 1) {
- data match {
- case ByteBufferValues(bytes) =>
- if (replicationFuture != null) {
- Await.ready(replicationFuture, Duration.Inf)
+ }
+
+ /**
+ * Attempts to cache spilled bytes read from disk into the MemoryStore in order to speed up
+ * subsequent reads. This method requires the caller to hold a read lock on the block.
+ *
+ * @return a copy of the bytes from the memory store if the put succeeded, otherwise None.
+ * If this returns bytes from the memory store then the original disk store bytes will
+ * automatically be disposed and the caller should not continue to use them. Otherwise,
+ * if this returns None then the original disk store bytes will be unaffected.
+ */
+ private def maybeCacheDiskBytesInMemory(
+ blockInfo: BlockInfo,
+ blockId: BlockId,
+ level: StorageLevel,
+ diskData: BlockData): Option[ChunkedByteBuffer] = {
+ require(!level.deserialized)
+ if (level.useMemory) {
+ // Synchronize on blockInfo to guard against a race condition where two readers both try to
+ // put values read from disk into the MemoryStore.
+ blockInfo.synchronized {
+ if (memoryStore.contains(blockId)) {
+ diskData.dispose()
+ Some(memoryStore.getBytes(blockId).get)
+ } else {
+ val allocator = level.memoryMode match {
+ case MemoryMode.ON_HEAP => ByteBuffer.allocate _
+ case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _
}
- case _ =>
- val remoteStartTime = System.currentTimeMillis
- // Serialize the block if not already done
- if (bytesAfterPut == null) {
- if (valuesAfterPut == null) {
- throw new SparkException(
- "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.")
- }
- bytesAfterPut = dataSerialize(blockId, valuesAfterPut)
+ val putSucceeded = memoryStore.putBytes(blockId, diskData.size, level.memoryMode, () => {
+ // https://issues.apache.org/jira/browse/SPARK-6076
+ // If the file size is bigger than the free memory, OOM will happen. So if we
+ // cannot put it into MemoryStore, copyForMemory should not be created. That's why
+ // this action is put into a `() => ChunkedByteBuffer` and created lazily.
+ diskData.toChunkedByteBuffer(allocator)
+ })
+ if (putSucceeded) {
+ diskData.dispose()
+ Some(memoryStore.getBytes(blockId).get)
+ } else {
+ None
}
- replicate(blockId, bytesAfterPut, putLevel)
- logDebug("Put block %s remotely took %s"
- .format(blockId, Utils.getUsedTimeMs(remoteStartTime)))
+ }
}
+ } else {
+ None
}
+ }
- BlockManager.dispose(bytesAfterPut)
-
- if (putLevel.replication > 1) {
- logDebug("Putting block %s with replication took %s"
- .format(blockId, Utils.getUsedTimeMs(startTimeMs)))
+ /**
+ * Attempts to cache spilled values read from disk into the MemoryStore in order to speed up
+ * subsequent reads. This method requires the caller to hold a read lock on the block.
+ *
+ * @return a copy of the iterator. The original iterator passed this method should no longer
+ * be used after this method returns.
+ */
+ private def maybeCacheDiskValuesInMemory[T](
+ blockInfo: BlockInfo,
+ blockId: BlockId,
+ level: StorageLevel,
+ diskIterator: Iterator[T]): Iterator[T] = {
+ require(level.deserialized)
+ val classTag = blockInfo.classTag.asInstanceOf[ClassTag[T]]
+ if (level.useMemory) {
+ // Synchronize on blockInfo to guard against a race condition where two readers both try to
+ // put values read from disk into the MemoryStore.
+ blockInfo.synchronized {
+ if (memoryStore.contains(blockId)) {
+ // Note: if we had a means to discard the disk iterator, we would do that here.
+ memoryStore.getValues(blockId).get
+ } else {
+ memoryStore.putIteratorAsValues(blockId, diskIterator, classTag) match {
+ case Left(iter) =>
+ // The memory store put() failed, so it returned the iterator back to us:
+ iter
+ case Right(_) =>
+ // The put() succeeded, so we can read the values back:
+ memoryStore.getValues(blockId).get
+ }
+ }
+ }.asInstanceOf[Iterator[T]]
} else {
- logDebug("Putting block %s without replication took %s"
- .format(blockId, Utils.getUsedTimeMs(startTimeMs)))
+ diskIterator
}
-
- updatedBlocks
}
/**
@@ -894,199 +1214,208 @@ private[spark] class BlockManager(
}
/**
- * Replicate block to another node. Not that this is a blocking call that returns after
+ * Called for pro-active replenishment of blocks lost due to executor failures
+ *
+ * @param blockId blockId being replicate
+ * @param existingReplicas existing block managers that have a replica
+ * @param maxReplicas maximum replicas needed
+ */
+ def replicateBlock(
+ blockId: BlockId,
+ existingReplicas: Set[BlockManagerId],
+ maxReplicas: Int): Unit = {
+ logInfo(s"Using $blockManagerId to pro-actively replicate $blockId")
+ blockInfoManager.lockForReading(blockId).foreach { info =>
+ val data = doGetLocalBytes(blockId, info)
+ val storageLevel = StorageLevel(
+ useDisk = info.level.useDisk,
+ useMemory = info.level.useMemory,
+ useOffHeap = info.level.useOffHeap,
+ deserialized = info.level.deserialized,
+ replication = maxReplicas)
+ // we know we are called as a result of an executor removal, so we refresh peer cache
+ // this way, we won't try to replicate to a missing executor with a stale reference
+ getPeers(forceFetch = true)
+ try {
+ replicate(blockId, data, storageLevel, info.classTag, existingReplicas)
+ } finally {
+ logDebug(s"Releasing lock for $blockId")
+ releaseLockAndDispose(blockId, data)
+ }
+ }
+ }
+
+ /**
+ * Replicate block to another node. Note that this is a blocking call that returns after
* the block has been replicated.
*/
- private def replicate(blockId: BlockId, data: ByteBuffer, level: StorageLevel): Unit = {
+ private def replicate(
+ blockId: BlockId,
+ data: BlockData,
+ level: StorageLevel,
+ classTag: ClassTag[_],
+ existingReplicas: Set[BlockManagerId] = Set.empty): Unit = {
+
val maxReplicationFailures = conf.getInt("spark.storage.maxReplicationFailures", 1)
- val numPeersToReplicateTo = level.replication - 1
- val peersForReplication = new ArrayBuffer[BlockManagerId]
- val peersReplicatedTo = new ArrayBuffer[BlockManagerId]
- val peersFailedToReplicateTo = new ArrayBuffer[BlockManagerId]
val tLevel = StorageLevel(
- level.useDisk, level.useMemory, level.useOffHeap, level.deserialized, 1)
- val startTime = System.currentTimeMillis
- val random = new Random(blockId.hashCode)
-
- var replicationFailed = false
- var failures = 0
- var done = false
-
- // Get cached list of peers
- peersForReplication ++= getPeers(forceFetch = false)
-
- // Get a random peer. Note that this selection of a peer is deterministic on the block id.
- // So assuming the list of peers does not change and no replication failures,
- // if there are multiple attempts in the same node to replicate the same block,
- // the same set of peers will be selected.
- def getRandomPeer(): Option[BlockManagerId] = {
- // If replication had failed, then force update the cached list of peers and remove the peers
- // that have been already used
- if (replicationFailed) {
- peersForReplication.clear()
- peersForReplication ++= getPeers(forceFetch = true)
- peersForReplication --= peersReplicatedTo
- peersForReplication --= peersFailedToReplicateTo
- }
- if (!peersForReplication.isEmpty) {
- Some(peersForReplication(random.nextInt(peersForReplication.size)))
- } else {
- None
- }
- }
+ useDisk = level.useDisk,
+ useMemory = level.useMemory,
+ useOffHeap = level.useOffHeap,
+ deserialized = level.deserialized,
+ replication = 1)
- // One by one choose a random peer and try uploading the block to it
- // If replication fails (e.g., target peer is down), force the list of cached peers
- // to be re-fetched from driver and then pick another random peer for replication. Also
- // temporarily black list the peer for which replication failed.
- //
- // This selection of a peer and replication is continued in a loop until one of the
- // following 3 conditions is fulfilled:
- // (i) specified number of peers have been replicated to
- // (ii) too many failures in replicating to peers
- // (iii) no peer left to replicate to
- //
- while (!done) {
- getRandomPeer() match {
- case Some(peer) =>
- try {
- val onePeerStartTime = System.currentTimeMillis
- data.rewind()
- logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer")
- blockTransferService.uploadBlockSync(
- peer.host, peer.port, peer.executorId, blockId, new NioManagedBuffer(data), tLevel)
- logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %s ms"
- .format(System.currentTimeMillis - onePeerStartTime))
- peersReplicatedTo += peer
- peersForReplication -= peer
- replicationFailed = false
- if (peersReplicatedTo.size == numPeersToReplicateTo) {
- done = true // specified number of peers have been replicated to
- }
- } catch {
- case e: Exception =>
- logWarning(s"Failed to replicate $blockId to $peer, failure #$failures", e)
- failures += 1
- replicationFailed = true
- peersFailedToReplicateTo += peer
- if (failures > maxReplicationFailures) { // too many failures in replcating to peers
- done = true
- }
+ val numPeersToReplicateTo = level.replication - 1
+ val startTime = System.nanoTime
+
+ var peersReplicatedTo = mutable.HashSet.empty ++ existingReplicas
+ var peersFailedToReplicateTo = mutable.HashSet.empty[BlockManagerId]
+ var numFailures = 0
+
+ val initialPeers = getPeers(false).filterNot(existingReplicas.contains(_))
+
+ var peersForReplication = blockReplicationPolicy.prioritize(
+ blockManagerId,
+ initialPeers,
+ peersReplicatedTo,
+ blockId,
+ numPeersToReplicateTo)
+
+ while(numFailures <= maxReplicationFailures &&
+ !peersForReplication.isEmpty &&
+ peersReplicatedTo.size < numPeersToReplicateTo) {
+ val peer = peersForReplication.head
+ try {
+ val onePeerStartTime = System.nanoTime
+ logTrace(s"Trying to replicate $blockId of ${data.size} bytes to $peer")
+ blockTransferService.uploadBlockSync(
+ peer.host,
+ peer.port,
+ peer.executorId,
+ blockId,
+ new BlockManagerManagedBuffer(blockInfoManager, blockId, data, false),
+ tLevel,
+ classTag)
+ logTrace(s"Replicated $blockId of ${data.size} bytes to $peer" +
+ s" in ${(System.nanoTime - onePeerStartTime).toDouble / 1e6} ms")
+ peersForReplication = peersForReplication.tail
+ peersReplicatedTo += peer
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Failed to replicate $blockId to $peer, failure #$numFailures", e)
+ peersFailedToReplicateTo += peer
+ // we have a failed replication, so we get the list of peers again
+ // we don't want peers we have already replicated to and the ones that
+ // have failed previously
+ val filteredPeers = getPeers(true).filter { p =>
+ !peersFailedToReplicateTo.contains(p) && !peersReplicatedTo.contains(p)
}
- case None => // no peer left to replicate to
- done = true
+
+ numFailures += 1
+ peersForReplication = blockReplicationPolicy.prioritize(
+ blockManagerId,
+ filteredPeers,
+ peersReplicatedTo,
+ blockId,
+ numPeersToReplicateTo - peersReplicatedTo.size)
}
}
- val timeTakeMs = (System.currentTimeMillis - startTime)
- logDebug(s"Replicating $blockId of ${data.limit()} bytes to " +
- s"${peersReplicatedTo.size} peer(s) took $timeTakeMs ms")
+ logDebug(s"Replicating $blockId of ${data.size} bytes to " +
+ s"${peersReplicatedTo.size} peer(s) took ${(System.nanoTime - startTime) / 1e6} ms")
if (peersReplicatedTo.size < numPeersToReplicateTo) {
logWarning(s"Block $blockId replicated to only " +
s"${peersReplicatedTo.size} peer(s) instead of $numPeersToReplicateTo peers")
}
+
+ logDebug(s"block $blockId replicated to ${peersReplicatedTo.mkString(", ")}")
}
/**
* Read a block consisting of a single object.
*/
- def getSingle(blockId: BlockId): Option[Any] = {
- get(blockId).map(_.data.next())
+ def getSingle[T: ClassTag](blockId: BlockId): Option[T] = {
+ get[T](blockId).map(_.data.next().asInstanceOf[T])
}
/**
* Write a block consisting of a single object.
+ *
+ * @return true if the block was stored or false if the block was already stored or an
+ * error occurred.
*/
- def putSingle(
+ def putSingle[T: ClassTag](
blockId: BlockId,
- value: Any,
+ value: T,
level: StorageLevel,
- tellMaster: Boolean = true): Seq[(BlockId, BlockStatus)] = {
+ tellMaster: Boolean = true): Boolean = {
putIterator(blockId, Iterator(value), level, tellMaster)
}
- def dropFromMemory(
- blockId: BlockId,
- data: Either[Array[Any], ByteBuffer]): Option[BlockStatus] = {
- dropFromMemory(blockId, () => data)
- }
-
/**
* Drop a block from memory, possibly putting it on disk if applicable. Called when the memory
* store reaches its limit and needs to free up space.
*
* If `data` is not put on disk, it won't be created.
*
- * Return the block status if the given block has been updated, else None.
+ * The caller of this method must hold a write lock on the block before calling this method.
+ * This method does not release the write lock.
+ *
+ * @return the block's new effective StorageLevel.
*/
- def dropFromMemory(
+ private[storage] override def dropFromMemory[T: ClassTag](
blockId: BlockId,
- data: () => Either[Array[Any], ByteBuffer]): Option[BlockStatus] = {
-
+ data: () => Either[Array[T], ChunkedByteBuffer]): StorageLevel = {
logInfo(s"Dropping block $blockId from memory")
- val info = blockInfo.get(blockId).orNull
-
- // If the block has not already been dropped
- if (info != null) {
- info.synchronized {
- // required ? As of now, this will be invoked only for blocks which are ready
- // But in case this changes in future, adding for consistency sake.
- if (!info.waitForReady()) {
- // If we get here, the block write failed.
- logWarning(s"Block $blockId was marked as failure. Nothing to drop")
- return None
- } else if (blockInfo.get(blockId).isEmpty) {
- logWarning(s"Block $blockId was already dropped.")
- return None
- }
- var blockIsUpdated = false
- val level = info.level
-
- // Drop to disk, if storage level requires
- if (level.useDisk && !diskStore.contains(blockId)) {
- logInfo(s"Writing block $blockId to disk")
- data() match {
- case Left(elements) =>
- diskStore.putArray(blockId, elements, level, returnValues = false)
- case Right(bytes) =>
- diskStore.putBytes(blockId, bytes, level)
+ val info = blockInfoManager.assertBlockIsLockedForWriting(blockId)
+ var blockIsUpdated = false
+ val level = info.level
+
+ // Drop to disk, if storage level requires
+ if (level.useDisk && !diskStore.contains(blockId)) {
+ logInfo(s"Writing block $blockId to disk")
+ data() match {
+ case Left(elements) =>
+ diskStore.put(blockId) { channel =>
+ val out = Channels.newOutputStream(channel)
+ serializerManager.dataSerializeStream(
+ blockId,
+ out,
+ elements.toIterator)(info.classTag.asInstanceOf[ClassTag[T]])
}
- blockIsUpdated = true
- }
+ case Right(bytes) =>
+ diskStore.putBytes(blockId, bytes)
+ }
+ blockIsUpdated = true
+ }
- // Actually drop from memory store
- val droppedMemorySize =
- if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L
- val blockIsRemoved = memoryStore.remove(blockId)
- if (blockIsRemoved) {
- blockIsUpdated = true
- } else {
- logWarning(s"Block $blockId could not be dropped from memory as it does not exist")
- }
+ // Actually drop from memory store
+ val droppedMemorySize =
+ if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L
+ val blockIsRemoved = memoryStore.remove(blockId)
+ if (blockIsRemoved) {
+ blockIsUpdated = true
+ } else {
+ logWarning(s"Block $blockId could not be dropped from memory as it does not exist")
+ }
- val status = getCurrentBlockStatus(blockId, info)
- if (info.tellMaster) {
- reportBlockStatus(blockId, info, status, droppedMemorySize)
- }
- if (!level.useDisk) {
- // The block is completely gone from this node; forget it so we can put() it again later.
- blockInfo.remove(blockId)
- }
- if (blockIsUpdated) {
- return Some(status)
- }
- }
+ val status = getCurrentBlockStatus(blockId, info)
+ if (info.tellMaster) {
+ reportBlockStatus(blockId, status, droppedMemorySize)
}
- None
+ if (blockIsUpdated) {
+ addUpdatedBlockStatusToTaskMetrics(blockId, status)
+ }
+ status.storageLevel
}
/**
* Remove all blocks belonging to the given RDD.
+ *
* @return The number of blocks removed.
*/
def removeRdd(rddId: Int): Int = {
// TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks.
logInfo(s"Removing RDD $rddId")
- val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId)
+ val blocksToRemove = blockInfoManager.entries.flatMap(_._1.asRDDId).filter(_.rddId == rddId)
blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) }
blocksToRemove.size
}
@@ -1096,7 +1425,7 @@ private[spark] class BlockManager(
*/
def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = {
logDebug(s"Removing broadcast $broadcastId")
- val blocksToRemove = blockInfo.keys.collect {
+ val blocksToRemove = blockInfoManager.entries.map(_._1).collect {
case bid @ BroadcastBlockId(`broadcastId`, _) => bid
}
blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster) }
@@ -1108,128 +1437,45 @@ private[spark] class BlockManager(
*/
def removeBlock(blockId: BlockId, tellMaster: Boolean = true): Unit = {
logDebug(s"Removing block $blockId")
- val info = blockInfo.get(blockId).orNull
- if (info != null) {
- info.synchronized {
- // Removals are idempotent in disk store and memory store. At worst, we get a warning.
- val removedFromMemory = memoryStore.remove(blockId)
- val removedFromDisk = diskStore.remove(blockId)
- val removedFromExternalBlockStore =
- if (externalBlockStoreInitialized) externalBlockStore.remove(blockId) else false
- if (!removedFromMemory && !removedFromDisk && !removedFromExternalBlockStore) {
- logWarning(s"Block $blockId could not be removed as it was not found in either " +
- "the disk, memory, or external block store")
- }
- blockInfo.remove(blockId)
- if (tellMaster && info.tellMaster) {
- val status = getCurrentBlockStatus(blockId, info)
- reportBlockStatus(blockId, info, status)
- }
- }
- } else {
- // The block has already been removed; do nothing.
- logWarning(s"Asked to remove block $blockId, which does not exist")
+ blockInfoManager.lockForWriting(blockId) match {
+ case None =>
+ // The block has already been removed; do nothing.
+ logWarning(s"Asked to remove block $blockId, which does not exist")
+ case Some(info) =>
+ removeBlockInternal(blockId, tellMaster = tellMaster && info.tellMaster)
+ addUpdatedBlockStatusToTaskMetrics(blockId, BlockStatus.empty)
}
}
- private def dropOldNonBroadcastBlocks(cleanupTime: Long): Unit = {
- logInfo(s"Dropping non broadcast blocks older than $cleanupTime")
- dropOldBlocks(cleanupTime, !_.isBroadcast)
- }
-
- private def dropOldBroadcastBlocks(cleanupTime: Long): Unit = {
- logInfo(s"Dropping broadcast blocks older than $cleanupTime")
- dropOldBlocks(cleanupTime, _.isBroadcast)
- }
-
- private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)): Unit = {
- val iterator = blockInfo.getEntrySet.iterator
- while (iterator.hasNext) {
- val entry = iterator.next()
- val (id, info, time) = (entry.getKey, entry.getValue.value, entry.getValue.timestamp)
- if (time < cleanupTime && shouldDrop(id)) {
- info.synchronized {
- val level = info.level
- if (level.useMemory) { memoryStore.remove(id) }
- if (level.useDisk) { diskStore.remove(id) }
- if (level.useOffHeap) { externalBlockStore.remove(id) }
- iterator.remove()
- logInfo(s"Dropped block $id")
- }
- val status = getCurrentBlockStatus(id, info)
- reportBlockStatus(id, info, status)
- }
- }
- }
-
- private def shouldCompress(blockId: BlockId): Boolean = {
- blockId match {
- case _: ShuffleBlockId => compressShuffle
- case _: BroadcastBlockId => compressBroadcast
- case _: RDDBlockId => compressRdds
- case _: TempLocalBlockId => compressShuffleSpill
- case _: TempShuffleBlockId => compressShuffle
- case _ => false
- }
- }
-
- /**
- * Wrap an output stream for compression if block compression is enabled for its block type
- */
- def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
- if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s
- }
-
/**
- * Wrap an input stream for compression if block compression is enabled for its block type
+ * Internal version of [[removeBlock()]] which assumes that the caller already holds a write
+ * lock on the block.
*/
- def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
- if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
- }
-
- /** Serializes into a stream. */
- def dataSerializeStream(
- blockId: BlockId,
- outputStream: OutputStream,
- values: Iterator[Any],
- serializer: Serializer = defaultSerializer): Unit = {
- val byteStream = new BufferedOutputStream(outputStream)
- val ser = serializer.newInstance()
- ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
- }
-
- /** Serializes into a byte buffer. */
- def dataSerialize(
- blockId: BlockId,
- values: Iterator[Any],
- serializer: Serializer = defaultSerializer): ByteBuffer = {
- val byteStream = new ByteArrayOutputStream(4096)
- dataSerializeStream(blockId, byteStream, values, serializer)
- ByteBuffer.wrap(byteStream.toByteArray)
+ private def removeBlockInternal(blockId: BlockId, tellMaster: Boolean): Unit = {
+ // Removals are idempotent in disk store and memory store. At worst, we get a warning.
+ val removedFromMemory = memoryStore.remove(blockId)
+ val removedFromDisk = diskStore.remove(blockId)
+ if (!removedFromMemory && !removedFromDisk) {
+ logWarning(s"Block $blockId could not be removed as it was not found on disk or in memory")
+ }
+ blockInfoManager.removeBlock(blockId)
+ if (tellMaster) {
+ reportBlockStatus(blockId, BlockStatus.empty)
+ }
}
- /**
- * Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of
- * the iterator is reached.
- */
- def dataDeserialize(
- blockId: BlockId,
- bytes: ByteBuffer,
- serializer: Serializer = defaultSerializer): Iterator[Any] = {
- bytes.rewind()
- dataDeserializeStream(blockId, new ByteBufferInputStream(bytes, true), serializer)
+ private def addUpdatedBlockStatusToTaskMetrics(blockId: BlockId, status: BlockStatus): Unit = {
+ Option(TaskContext.get()).foreach { c =>
+ c.taskMetrics().incUpdatedBlockStatuses(blockId -> status)
+ }
}
- /**
- * Deserializes a InputStream into an iterator of values and disposes of it when the end of
- * the iterator is reached.
- */
- def dataDeserializeStream(
+ def releaseLockAndDispose(
blockId: BlockId,
- inputStream: InputStream,
- serializer: Serializer = defaultSerializer): Iterator[Any] = {
- val stream = new BufferedInputStream(inputStream)
- serializer.newInstance().deserializeStream(wrapForCompression(blockId, stream)).asIterator
+ data: BlockData,
+ taskAttemptId: Option[Long] = None): Unit = {
+ releaseLock(blockId, taskAttemptId)
+ data.dispose()
}
def stop(): Unit = {
@@ -1240,38 +1486,17 @@ private[spark] class BlockManager(
}
diskBlockManager.stop()
rpcEnv.stop(slaveEndpoint)
- blockInfo.clear()
+ blockInfoManager.clear()
memoryStore.clear()
- diskStore.clear()
- if (externalBlockStoreInitialized) {
- externalBlockStore.clear()
- }
- metadataCleaner.cancel()
- broadcastCleaner.cancel()
futureExecutionContext.shutdownNow()
logInfo("BlockManager stopped")
}
}
-private[spark] object BlockManager extends Logging {
+private[spark] object BlockManager {
private val ID_GENERATOR = new IdGenerator
- /**
- * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that
- * might cause errors if one attempts to read from the unmapped buffer, but it's better than
- * waiting for the GC to find it because that could lead to huge numbers of open files. There's
- * unfortunately no standard API to do this.
- */
- def dispose(buffer: ByteBuffer): Unit = {
- if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) {
- logTrace(s"Unmapping $buffer")
- if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) {
- buffer.asInstanceOf[DirectBuffer].cleaner().clean()
- }
- }
- }
-
def blockIdsToHosts(
blockIds: Array[BlockId],
env: SparkEnv,
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
index 69ac37511e73..a416f08b5b19 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
@@ -18,7 +18,8 @@
package org.apache.spark.storage
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
-import java.util.concurrent.ConcurrentHashMap
+
+import com.google.common.cache.{CacheBuilder, CacheLoader}
import org.apache.spark.SparkContext
import org.apache.spark.annotation.DeveloperApi
@@ -28,7 +29,7 @@ import org.apache.spark.util.Utils
* :: DeveloperApi ::
* This class represent an unique identifier for a BlockManager.
*
- * The first 2 constructors of this class is made private to ensure that BlockManagerId objects
+ * The first 2 constructors of this class are made private to ensure that BlockManagerId objects
* can be created only using the apply method in the companion object. This allows de-duplication
* of ID objects. Also, constructor parameters are private to ensure that parameters cannot be
* modified from outside this class.
@@ -37,14 +38,15 @@ import org.apache.spark.util.Utils
class BlockManagerId private (
private var executorId_ : String,
private var host_ : String,
- private var port_ : Int)
+ private var port_ : Int,
+ private var topologyInfo_ : Option[String])
extends Externalizable {
- private def this() = this(null, null, 0) // For deserialization only
+ private def this() = this(null, null, 0, None) // For deserialization only
def executorId: String = executorId_
- if (null != host_){
+ if (null != host_) {
Utils.checkHost(host_, "Expected hostname")
assert (port_ > 0)
}
@@ -60,6 +62,8 @@ class BlockManagerId private (
def port: Int = port_
+ def topologyInfo: Option[String] = topologyInfo_
+
def isDriver: Boolean = {
executorId == SparkContext.DRIVER_IDENTIFIER ||
executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER
@@ -69,24 +73,33 @@ class BlockManagerId private (
out.writeUTF(executorId_)
out.writeUTF(host_)
out.writeInt(port_)
+ out.writeBoolean(topologyInfo_.isDefined)
+ // we only write topologyInfo if we have it
+ topologyInfo.foreach(out.writeUTF(_))
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
executorId_ = in.readUTF()
host_ = in.readUTF()
port_ = in.readInt()
+ val isTopologyInfoAvailable = in.readBoolean()
+ topologyInfo_ = if (isTopologyInfoAvailable) Option(in.readUTF()) else None
}
@throws(classOf[IOException])
private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this)
- override def toString: String = s"BlockManagerId($executorId, $host, $port)"
+ override def toString: String = s"BlockManagerId($executorId, $host, $port, $topologyInfo)"
- override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port
+ override def hashCode: Int =
+ ((executorId.hashCode * 41 + host.hashCode) * 41 + port) * 41 + topologyInfo.hashCode
override def equals(that: Any): Boolean = that match {
case id: BlockManagerId =>
- executorId == id.executorId && port == id.port && host == id.host
+ executorId == id.executorId &&
+ port == id.port &&
+ host == id.host &&
+ topologyInfo == id.topologyInfo
case _ =>
false
}
@@ -101,10 +114,18 @@ private[spark] object BlockManagerId {
* @param execId ID of the executor.
* @param host Host name of the block manager.
* @param port Port of the block manager.
+ * @param topologyInfo topology information for the blockmanager, if available
+ * This can be network topology information for use while choosing peers
+ * while replicating data blocks. More information available here:
+ * [[org.apache.spark.storage.TopologyMapper]]
* @return A new [[org.apache.spark.storage.BlockManagerId]].
*/
- def apply(execId: String, host: String, port: Int): BlockManagerId =
- getCachedBlockManagerId(new BlockManagerId(execId, host, port))
+ def apply(
+ execId: String,
+ host: String,
+ port: Int,
+ topologyInfo: Option[String] = None): BlockManagerId =
+ getCachedBlockManagerId(new BlockManagerId(execId, host, port, topologyInfo))
def apply(in: ObjectInput): BlockManagerId = {
val obj = new BlockManagerId()
@@ -112,10 +133,17 @@ private[spark] object BlockManagerId {
getCachedBlockManagerId(obj)
}
- val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]()
+ /**
+ * The max cache size is hardcoded to 10000, since the size of a BlockManagerId
+ * object is about 48B, the total memory cost should be below 1MB which is feasible.
+ */
+ val blockManagerIdCache = CacheBuilder.newBuilder()
+ .maximumSize(10000)
+ .build(new CacheLoader[BlockManagerId, BlockManagerId]() {
+ override def load(id: BlockManagerId) = id
+ })
def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = {
- blockManagerIdCache.putIfAbsent(id, id)
blockManagerIdCache.get(id)
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala
new file mode 100644
index 000000000000..1ea0d378cbe8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.InputStream
+import java.nio.ByteBuffer
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.util.io.ChunkedByteBuffer
+
+/**
+ * This [[ManagedBuffer]] wraps a [[BlockData]] instance retrieved from the [[BlockManager]]
+ * so that the corresponding block's read lock can be released once this buffer's references
+ * are released.
+ *
+ * If `dispose` is set to true, the [[BlockData]]will be disposed when the buffer's reference
+ * count drops to zero.
+ *
+ * This is effectively a wrapper / bridge to connect the BlockManager's notion of read locks
+ * to the network layer's notion of retain / release counts.
+ */
+private[storage] class BlockManagerManagedBuffer(
+ blockInfoManager: BlockInfoManager,
+ blockId: BlockId,
+ data: BlockData,
+ dispose: Boolean) extends ManagedBuffer {
+
+ private val refCount = new AtomicInteger(1)
+
+ override def size(): Long = data.size
+
+ override def nioByteBuffer(): ByteBuffer = data.toByteBuffer()
+
+ override def createInputStream(): InputStream = data.toInputStream()
+
+ override def convertToNetty(): Object = data.toNetty()
+
+ override def retain(): ManagedBuffer = {
+ refCount.incrementAndGet()
+ val locked = blockInfoManager.lockForReading(blockId, blocking = false)
+ assert(locked.isDefined)
+ this
+ }
+
+ override def release(): ManagedBuffer = {
+ blockInfoManager.unlock(blockId)
+ if (refCount.decrementAndGet() == 0 && dispose) {
+ data.dispose()
+ }
+ this
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index f45bff34d4db..ea5d8423a588 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -19,12 +19,13 @@ package org.apache.spark.storage
import scala.collection.Iterable
import scala.collection.generic.CanBuildFrom
-import scala.concurrent.{Await, Future}
+import scala.concurrent.Future
+import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.internal.Logging
import org.apache.spark.rpc.RpcEndpointRef
-import org.apache.spark.{Logging, SparkConf, SparkException}
import org.apache.spark.storage.BlockManagerMessages._
-import org.apache.spark.util.{ThreadUtils, RpcUtils}
+import org.apache.spark.util.{RpcUtils, ThreadUtils}
private[spark]
class BlockManagerMaster(
@@ -41,12 +42,29 @@ class BlockManagerMaster(
logInfo("Removed " + execId + " successfully in removeExecutor")
}
- /** Register the BlockManager's id with the driver. */
+ /** Request removal of a dead executor from the driver endpoint.
+ * This is only called on the driver side. Non-blocking
+ */
+ def removeExecutorAsync(execId: String) {
+ driverEndpoint.ask[Boolean](RemoveExecutor(execId))
+ logInfo("Removal of executor " + execId + " requested")
+ }
+
+ /**
+ * Register the BlockManager's id with the driver. The input BlockManagerId does not contain
+ * topology information. This information is obtained from the master and we respond with an
+ * updated BlockManagerId fleshed out with this information.
+ */
def registerBlockManager(
- blockManagerId: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef): Unit = {
- logInfo("Trying to register BlockManager")
- tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint))
- logInfo("Registered BlockManager")
+ blockManagerId: BlockManagerId,
+ maxOnHeapMemSize: Long,
+ maxOffHeapMemSize: Long,
+ slaveEndpoint: RpcEndpointRef): BlockManagerId = {
+ logInfo(s"Registering BlockManager $blockManagerId")
+ val updatedId = driverEndpoint.askSync[BlockManagerId](
+ RegisterBlockManager(blockManagerId, maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint))
+ logInfo(s"Registered BlockManager $updatedId")
+ updatedId
}
def updateBlockInfo(
@@ -54,23 +72,21 @@ class BlockManagerMaster(
blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
- diskSize: Long,
- externalBlockStoreSize: Long): Boolean = {
- val res = driverEndpoint.askWithRetry[Boolean](
- UpdateBlockInfo(blockManagerId, blockId, storageLevel,
- memSize, diskSize, externalBlockStoreSize))
+ diskSize: Long): Boolean = {
+ val res = driverEndpoint.askSync[Boolean](
+ UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize))
logDebug(s"Updated info of block $blockId")
res
}
/** Get locations of the blockId from the driver */
def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
- driverEndpoint.askWithRetry[Seq[BlockManagerId]](GetLocations(blockId))
+ driverEndpoint.askSync[Seq[BlockManagerId]](GetLocations(blockId))
}
/** Get locations of multiple blockIds from the driver */
def getLocations(blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = {
- driverEndpoint.askWithRetry[IndexedSeq[Seq[BlockManagerId]]](
+ driverEndpoint.askSync[IndexedSeq[Seq[BlockManagerId]]](
GetLocationsMultipleBlockIds(blockIds))
}
@@ -84,11 +100,11 @@ class BlockManagerMaster(
/** Get ids of other nodes in the cluster from the driver */
def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = {
- driverEndpoint.askWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId))
+ driverEndpoint.askSync[Seq[BlockManagerId]](GetPeers(blockManagerId))
}
- def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = {
- driverEndpoint.askWithRetry[Option[(String, Int)]](GetRpcHostPortForExecutor(executorId))
+ def getExecutorEndpointRef(executorId: String): Option[RpcEndpointRef] = {
+ driverEndpoint.askSync[Option[RpcEndpointRef]](GetExecutorEndpointRef(executorId))
}
/**
@@ -96,12 +112,12 @@ class BlockManagerMaster(
* blocks that the driver knows about.
*/
def removeBlock(blockId: BlockId) {
- driverEndpoint.askWithRetry[Boolean](RemoveBlock(blockId))
+ driverEndpoint.askSync[Boolean](RemoveBlock(blockId))
}
/** Remove all blocks belonging to the given RDD. */
def removeRdd(rddId: Int, blocking: Boolean) {
- val future = driverEndpoint.askWithRetry[Future[Seq[Int]]](RemoveRdd(rddId))
+ val future = driverEndpoint.askSync[Future[Seq[Int]]](RemoveRdd(rddId))
future.onFailure {
case e: Exception =>
logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}", e)
@@ -113,7 +129,7 @@ class BlockManagerMaster(
/** Remove all blocks belonging to the given shuffle. */
def removeShuffle(shuffleId: Int, blocking: Boolean) {
- val future = driverEndpoint.askWithRetry[Future[Seq[Boolean]]](RemoveShuffle(shuffleId))
+ val future = driverEndpoint.askSync[Future[Seq[Boolean]]](RemoveShuffle(shuffleId))
future.onFailure {
case e: Exception =>
logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}", e)
@@ -125,7 +141,7 @@ class BlockManagerMaster(
/** Remove all blocks belonging to the given broadcast. */
def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) {
- val future = driverEndpoint.askWithRetry[Future[Seq[Int]]](
+ val future = driverEndpoint.askSync[Future[Seq[Int]]](
RemoveBroadcast(broadcastId, removeFromMaster))
future.onFailure {
case e: Exception =>
@@ -144,11 +160,11 @@ class BlockManagerMaster(
* amount of remaining memory.
*/
def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = {
- driverEndpoint.askWithRetry[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus)
+ driverEndpoint.askSync[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus)
}
def getStorageStatus: Array[StorageStatus] = {
- driverEndpoint.askWithRetry[Array[StorageStatus]](GetStorageStatus)
+ driverEndpoint.askSync[Array[StorageStatus]](GetStorageStatus)
}
/**
@@ -169,7 +185,7 @@ class BlockManagerMaster(
* master endpoint for a response to a prior message.
*/
val response = driverEndpoint.
- askWithRetry[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg)
+ askSync[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg)
val (blockManagerIds, futures) = response.unzip
implicit val sameThread = ThreadUtils.sameThread
val cbf =
@@ -199,7 +215,7 @@ class BlockManagerMaster(
filter: BlockId => Boolean,
askSlaves: Boolean): Seq[BlockId] = {
val msg = GetMatchingBlockIds(filter, askSlaves)
- val future = driverEndpoint.askWithRetry[Future[Seq[BlockId]]](msg)
+ val future = driverEndpoint.askSync[Future[Seq[BlockId]]](msg)
timeout.awaitResult(future)
}
@@ -208,7 +224,7 @@ class BlockManagerMaster(
* since they are not reported the master.
*/
def hasCachedBlocks(executorId: String): Boolean = {
- driverEndpoint.askWithRetry[Boolean](HasCachedBlocks(executorId))
+ driverEndpoint.askSync[Boolean](HasCachedBlocks(executorId))
}
/** Stop the driver endpoint, called only on the Spark driver node */
@@ -222,7 +238,7 @@ class BlockManagerMaster(
/** Send a one-way message to the master endpoint, to which we expect it to reply with true. */
private def tell(message: Any) {
- if (!driverEndpoint.askWithRetry[Boolean](message)) {
+ if (!driverEndpoint.askSync[Boolean](message)) {
throw new SparkException("BlockManagerMasterEndpoint returned false, expected true.")
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
index 7db6035553ae..6f85b9e4d6c7 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
@@ -19,14 +19,15 @@ package org.apache.spark.storage
import java.util.{HashMap => JHashMap}
-import scala.collection.immutable.HashSet
import scala.collection.mutable
import scala.collection.JavaConverters._
import scala.concurrent.{ExecutionContext, Future}
+import scala.util.Random
-import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, ThreadSafeRpcEndpoint}
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.scheduler._
import org.apache.spark.storage.BlockManagerMessages._
import org.apache.spark.util.{ThreadUtils, Utils}
@@ -55,15 +56,27 @@ class BlockManagerMasterEndpoint(
private val askThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-ask-thread-pool")
private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool)
+ private val topologyMapper = {
+ val topologyMapperClassName = conf.get(
+ "spark.storage.replication.topologyMapper", classOf[DefaultTopologyMapper].getName)
+ val clazz = Utils.classForName(topologyMapperClassName)
+ val mapper =
+ clazz.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[TopologyMapper]
+ logInfo(s"Using $topologyMapperClassName for getting topology information")
+ mapper
+ }
+
+ val proactivelyReplicate = conf.get("spark.storage.replication.proactive", "false").toBoolean
+
+ logInfo("BlockManagerMasterEndpoint up")
+
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
- case RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint) =>
- register(blockManagerId, maxMemSize, slaveEndpoint)
- context.reply(true)
+ case RegisterBlockManager(blockManagerId, maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint) =>
+ context.reply(register(blockManagerId, maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint))
- case _updateBlockInfo @ UpdateBlockInfo(
- blockManagerId, blockId, storageLevel, deserializedSize, size, externalBlockStoreSize) =>
- context.reply(updateBlockInfo(
- blockManagerId, blockId, storageLevel, deserializedSize, size, externalBlockStoreSize))
+ case _updateBlockInfo @
+ UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) =>
+ context.reply(updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size))
listenerBus.post(SparkListenerBlockUpdated(BlockUpdatedInfo(_updateBlockInfo)))
case GetLocations(blockId) =>
@@ -75,8 +88,8 @@ class BlockManagerMasterEndpoint(
case GetPeers(blockManagerId) =>
context.reply(getPeers(blockManagerId))
- case GetRpcHostPortForExecutor(executorId) =>
- context.reply(getRpcHostPortForExecutor(executorId))
+ case GetExecutorEndpointRef(executorId) =>
+ context.reply(getExecutorEndpointRef(executorId))
case GetMemoryStatus =>
context.reply(memoryStatus)
@@ -185,17 +198,38 @@ class BlockManagerMasterEndpoint(
// Remove it from blockManagerInfo and remove all the blocks.
blockManagerInfo.remove(blockManagerId)
+
val iterator = info.blocks.keySet.iterator
while (iterator.hasNext) {
val blockId = iterator.next
val locations = blockLocations.get(blockId)
locations -= blockManagerId
+ // De-register the block if none of the block managers have it. Otherwise, if pro-active
+ // replication is enabled, and a block is either an RDD or a test block (the latter is used
+ // for unit testing), we send a message to a randomly chosen executor location to replicate
+ // the given block. Note that we ignore other block types (such as broadcast/shuffle blocks
+ // etc.) as replication doesn't make much sense in that context.
if (locations.size == 0) {
blockLocations.remove(blockId)
+ logWarning(s"No more replicas available for $blockId !")
+ } else if (proactivelyReplicate && (blockId.isRDD || blockId.isInstanceOf[TestBlockId])) {
+ // As a heursitic, assume single executor failure to find out the number of replicas that
+ // existed before failure
+ val maxReplicas = locations.size + 1
+ val i = (new Random(blockId.hashCode)).nextInt(locations.size)
+ val blockLocations = locations.toSeq
+ val candidateBMId = blockLocations(i)
+ blockManagerInfo.get(candidateBMId).foreach { bm =>
+ val remainingLocations = locations.toSeq.filter(bm => bm != candidateBMId)
+ val replicateMsg = ReplicateBlock(blockId, remainingLocations, maxReplicas)
+ bm.slaveEndpoint.ask[Boolean](replicateMsg)
+ }
}
}
+
listenerBus.post(SparkListenerBlockManagerRemoved(System.currentTimeMillis(), blockManagerId))
logInfo(s"Removing block manager $blockManagerId")
+
}
private def removeExecutor(execId: String) {
@@ -242,7 +276,8 @@ class BlockManagerMasterEndpoint(
private def storageStatus: Array[StorageStatus] = {
blockManagerInfo.map { case (blockManagerId, info) =>
- new StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala)
+ new StorageStatus(blockManagerId, info.maxMem, Some(info.maxOnHeapMem),
+ Some(info.maxOffHeapMem), info.blocks.asScala)
}.toArray
}
@@ -299,7 +334,22 @@ class BlockManagerMasterEndpoint(
).map(_.flatten.toSeq)
}
- private def register(id: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef) {
+ /**
+ * Returns the BlockManagerId with topology information populated, if available.
+ */
+ private def register(
+ idWithoutTopologyInfo: BlockManagerId,
+ maxOnHeapMemSize: Long,
+ maxOffHeapMemSize: Long,
+ slaveEndpoint: RpcEndpointRef): BlockManagerId = {
+ // the dummy id is not expected to contain the topology information.
+ // we get that info here and respond back with a more fleshed out block manager id
+ val id = BlockManagerId(
+ idWithoutTopologyInfo.executorId,
+ idWithoutTopologyInfo.host,
+ idWithoutTopologyInfo.port,
+ topologyMapper.getTopologyForHost(idWithoutTopologyInfo.host))
+
val time = System.currentTimeMillis()
if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
@@ -311,14 +361,16 @@ class BlockManagerMasterEndpoint(
case None =>
}
logInfo("Registering block manager %s with %s RAM, %s".format(
- id.hostPort, Utils.bytesToString(maxMemSize), id))
+ id.hostPort, Utils.bytesToString(maxOnHeapMemSize + maxOffHeapMemSize), id))
blockManagerIdByExecutor(id.executorId) = id
blockManagerInfo(id) = new BlockManagerInfo(
- id, System.currentTimeMillis(), maxMemSize, slaveEndpoint)
+ id, System.currentTimeMillis(), maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint)
}
- listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize))
+ listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxOnHeapMemSize + maxOffHeapMemSize,
+ Some(maxOnHeapMemSize), Some(maxOffHeapMemSize)))
+ id
}
private def updateBlockInfo(
@@ -326,8 +378,7 @@ class BlockManagerMasterEndpoint(
blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
- diskSize: Long,
- externalBlockStoreSize: Long): Boolean = {
+ diskSize: Long): Boolean = {
if (!blockManagerInfo.contains(blockManagerId)) {
if (blockManagerId.isDriver && !isLocal) {
@@ -344,8 +395,7 @@ class BlockManagerMasterEndpoint(
return true
}
- blockManagerInfo(blockManagerId).updateBlockInfo(
- blockId, storageLevel, memSize, diskSize, externalBlockStoreSize)
+ blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize)
var locations: mutable.HashSet[BlockManagerId] = null
if (blockLocations.containsKey(blockId)) {
@@ -388,15 +438,14 @@ class BlockManagerMasterEndpoint(
}
/**
- * Returns the hostname and port of an executor, based on the [[RpcEnv]] address of its
- * [[BlockManagerSlaveEndpoint]].
+ * Returns an [[RpcEndpointRef]] of the [[BlockManagerSlaveEndpoint]] for sending RPC messages.
*/
- private def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = {
+ private def getExecutorEndpointRef(executorId: String): Option[RpcEndpointRef] = {
for (
blockManagerId <- blockManagerIdByExecutor.get(executorId);
info <- blockManagerInfo.get(blockManagerId)
) yield {
- (info.slaveEndpoint.address.host, info.slaveEndpoint.address.port)
+ info.slaveEndpoint
}
}
@@ -406,26 +455,25 @@ class BlockManagerMasterEndpoint(
}
@DeveloperApi
-case class BlockStatus(
- storageLevel: StorageLevel,
- memSize: Long,
- diskSize: Long,
- externalBlockStoreSize: Long) {
- def isCached: Boolean = memSize + diskSize + externalBlockStoreSize > 0
+case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) {
+ def isCached: Boolean = memSize + diskSize > 0
}
@DeveloperApi
object BlockStatus {
- def empty: BlockStatus = BlockStatus(StorageLevel.NONE, 0L, 0L, 0L)
+ def empty: BlockStatus = BlockStatus(StorageLevel.NONE, memSize = 0L, diskSize = 0L)
}
private[spark] class BlockManagerInfo(
val blockManagerId: BlockManagerId,
timeMs: Long,
- val maxMem: Long,
+ val maxOnHeapMem: Long,
+ val maxOffHeapMem: Long,
val slaveEndpoint: RpcEndpointRef)
extends Logging {
+ val maxMem = maxOnHeapMem + maxOffHeapMem
+
private var _lastSeenMs: Long = timeMs
private var _remainingMem: Long = maxMem
@@ -445,16 +493,21 @@ private[spark] class BlockManagerInfo(
blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
- diskSize: Long,
- externalBlockStoreSize: Long) {
+ diskSize: Long) {
updateLastSeenMs()
- if (_blocks.containsKey(blockId)) {
+ val blockExists = _blocks.containsKey(blockId)
+ var originalMemSize: Long = 0
+ var originalDiskSize: Long = 0
+ var originalLevel: StorageLevel = StorageLevel.NONE
+
+ if (blockExists) {
// The block exists on the slave already.
val blockStatus: BlockStatus = _blocks.get(blockId)
- val originalLevel: StorageLevel = blockStatus.storageLevel
- val originalMemSize: Long = blockStatus.memSize
+ originalLevel = blockStatus.storageLevel
+ originalMemSize = blockStatus.memSize
+ originalDiskSize = blockStatus.diskSize
if (originalLevel.useMemory) {
_remainingMem += originalMemSize
@@ -462,7 +515,7 @@ private[spark] class BlockManagerInfo(
}
if (storageLevel.isValid) {
- /* isValid means it is either stored in-memory, on-disk or on-externalBlockStore.
+ /* isValid means it is either stored in-memory or on-disk.
* The memSize here indicates the data size in or dropped from memory,
* externalBlockStoreSize here indicates the data size in or dropped from externalBlockStore,
* and the diskSize here indicates the data size in or dropped to disk.
@@ -470,46 +523,47 @@ private[spark] class BlockManagerInfo(
* Therefore, a safe way to set BlockStatus is to set its info in accurate modes. */
var blockStatus: BlockStatus = null
if (storageLevel.useMemory) {
- blockStatus = BlockStatus(storageLevel, memSize, 0, 0)
+ blockStatus = BlockStatus(storageLevel, memSize = memSize, diskSize = 0)
_blocks.put(blockId, blockStatus)
_remainingMem -= memSize
- logInfo("Added %s in memory on %s (size: %s, free: %s)".format(
- blockId, blockManagerId.hostPort, Utils.bytesToString(memSize),
- Utils.bytesToString(_remainingMem)))
+ if (blockExists) {
+ logInfo(s"Updated $blockId in memory on ${blockManagerId.hostPort}" +
+ s" (current size: ${Utils.bytesToString(memSize)}," +
+ s" original size: ${Utils.bytesToString(originalMemSize)}," +
+ s" free: ${Utils.bytesToString(_remainingMem)})")
+ } else {
+ logInfo(s"Added $blockId in memory on ${blockManagerId.hostPort}" +
+ s" (size: ${Utils.bytesToString(memSize)}," +
+ s" free: ${Utils.bytesToString(_remainingMem)})")
+ }
}
if (storageLevel.useDisk) {
- blockStatus = BlockStatus(storageLevel, 0, diskSize, 0)
- _blocks.put(blockId, blockStatus)
- logInfo("Added %s on disk on %s (size: %s)".format(
- blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize)))
- }
- if (storageLevel.useOffHeap) {
- blockStatus = BlockStatus(storageLevel, 0, 0, externalBlockStoreSize)
+ blockStatus = BlockStatus(storageLevel, memSize = 0, diskSize = diskSize)
_blocks.put(blockId, blockStatus)
- logInfo("Added %s on ExternalBlockStore on %s (size: %s)".format(
- blockId, blockManagerId.hostPort, Utils.bytesToString(externalBlockStoreSize)))
+ if (blockExists) {
+ logInfo(s"Updated $blockId on disk on ${blockManagerId.hostPort}" +
+ s" (current size: ${Utils.bytesToString(diskSize)}," +
+ s" original size: ${Utils.bytesToString(originalDiskSize)})")
+ } else {
+ logInfo(s"Added $blockId on disk on ${blockManagerId.hostPort}" +
+ s" (size: ${Utils.bytesToString(diskSize)})")
+ }
}
if (!blockId.isBroadcast && blockStatus.isCached) {
_cachedBlocks += blockId
}
- } else if (_blocks.containsKey(blockId)) {
+ } else if (blockExists) {
// If isValid is not true, drop the block.
- val blockStatus: BlockStatus = _blocks.get(blockId)
_blocks.remove(blockId)
_cachedBlocks -= blockId
- if (blockStatus.storageLevel.useMemory) {
- logInfo("Removed %s on %s in memory (size: %s, free: %s)".format(
- blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize),
- Utils.bytesToString(_remainingMem)))
- }
- if (blockStatus.storageLevel.useDisk) {
- logInfo("Removed %s on %s on disk (size: %s)".format(
- blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize)))
+ if (originalLevel.useMemory) {
+ logInfo(s"Removed $blockId on ${blockManagerId.hostPort} in memory" +
+ s" (size: ${Utils.bytesToString(originalMemSize)}," +
+ s" free: ${Utils.bytesToString(_remainingMem)})")
}
- if (blockStatus.storageLevel.useOffHeap) {
- logInfo("Removed %s on %s on externalBlockStore (size: %s)".format(
- blockId, blockManagerId.hostPort,
- Utils.bytesToString(blockStatus.externalBlockStoreSize)))
+ if (originalLevel.useDisk) {
+ logInfo(s"Removed $blockId on ${blockManagerId.hostPort} on disk" +
+ s" (size: ${Utils.bytesToString(originalDiskSize)})")
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index 376e9eb48843..0c0ff144596a 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -32,6 +32,10 @@ private[spark] object BlockManagerMessages {
// blocks that the master knows about.
case class RemoveBlock(blockId: BlockId) extends ToBlockManagerSlave
+ // Replicate blocks that were lost due to executor failure
+ case class ReplicateBlock(blockId: BlockId, replicas: Seq[BlockManagerId], maxReplicas: Int)
+ extends ToBlockManagerSlave
+
// Remove all blocks belonging to a specific RDD.
case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave
@@ -42,6 +46,11 @@ private[spark] object BlockManagerMessages {
case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true)
extends ToBlockManagerSlave
+ /**
+ * Driver to Executor message to trigger a thread dump.
+ */
+ case object TriggerThreadDump extends ToBlockManagerSlave
+
//////////////////////////////////////////////////////////////////////////////////
// Messages from slaves to the master.
//////////////////////////////////////////////////////////////////////////////////
@@ -49,7 +58,8 @@ private[spark] object BlockManagerMessages {
case class RegisterBlockManager(
blockManagerId: BlockManagerId,
- maxMemSize: Long,
+ maxOnHeapMemSize: Long,
+ maxOffHeapMemSize: Long,
sender: RpcEndpointRef)
extends ToBlockManagerMaster
@@ -58,12 +68,11 @@ private[spark] object BlockManagerMessages {
var blockId: BlockId,
var storageLevel: StorageLevel,
var memSize: Long,
- var diskSize: Long,
- var externalBlockStoreSize: Long)
+ var diskSize: Long)
extends ToBlockManagerMaster
with Externalizable {
- def this() = this(null, null, null, 0, 0, 0) // For deserialization only
+ def this() = this(null, null, null, 0, 0) // For deserialization only
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
blockManagerId.writeExternal(out)
@@ -71,7 +80,6 @@ private[spark] object BlockManagerMessages {
storageLevel.writeExternal(out)
out.writeLong(memSize)
out.writeLong(diskSize)
- out.writeLong(externalBlockStoreSize)
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
@@ -80,7 +88,6 @@ private[spark] object BlockManagerMessages {
storageLevel = StorageLevel(in)
memSize = in.readLong()
diskSize = in.readLong()
- externalBlockStoreSize = in.readLong()
}
}
@@ -90,7 +97,7 @@ private[spark] object BlockManagerMessages {
case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
- case class GetRpcHostPortForExecutor(executorId: String) extends ToBlockManagerMaster
+ case class GetExecutorEndpointRef(executorId: String) extends ToBlockManagerMaster
case class RemoveExecutor(execId: String) extends ToBlockManagerMaster
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
index e749631bf6f1..1aaa42459df6 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
@@ -19,10 +19,11 @@ package org.apache.spark.storage
import scala.concurrent.{ExecutionContext, Future}
-import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext, RpcEndpoint}
-import org.apache.spark.util.ThreadUtils
-import org.apache.spark.{Logging, MapOutputTracker, SparkEnv}
+import org.apache.spark.{MapOutputTracker, SparkEnv}
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.storage.BlockManagerMessages._
+import org.apache.spark.util.{ThreadUtils, Utils}
/**
* An RpcEndpoint to take commands from the master to execute options. For example,
@@ -70,6 +71,13 @@ class BlockManagerSlaveEndpoint(
case GetMatchingBlockIds(filter, _) =>
context.reply(blockManager.getMatchingBlockIds(filter))
+
+ case TriggerThreadDump =>
+ context.reply(Utils.getThreadDump())
+
+ case ReplicateBlock(blockId, replicas, maxReplicas) =>
+ context.reply(blockManager.replicateBlock(blockId, replicas.toSet, maxReplicas))
+
}
private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala
index c5ba9af3e265..197a01762c0c 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala
@@ -26,35 +26,39 @@ private[spark] class BlockManagerSource(val blockManager: BlockManager)
override val metricRegistry = new MetricRegistry()
override val sourceName = "BlockManager"
- metricRegistry.register(MetricRegistry.name("memory", "maxMem_MB"), new Gauge[Long] {
- override def getValue: Long = {
- val storageStatusList = blockManager.master.getStorageStatus
- val maxMem = storageStatusList.map(_.maxMem).sum
- maxMem / 1024 / 1024
- }
- })
-
- metricRegistry.register(MetricRegistry.name("memory", "remainingMem_MB"), new Gauge[Long] {
- override def getValue: Long = {
- val storageStatusList = blockManager.master.getStorageStatus
- val remainingMem = storageStatusList.map(_.memRemaining).sum
- remainingMem / 1024 / 1024
- }
- })
-
- metricRegistry.register(MetricRegistry.name("memory", "memUsed_MB"), new Gauge[Long] {
- override def getValue: Long = {
- val storageStatusList = blockManager.master.getStorageStatus
- val memUsed = storageStatusList.map(_.memUsed).sum
- memUsed / 1024 / 1024
- }
- })
-
- metricRegistry.register(MetricRegistry.name("disk", "diskSpaceUsed_MB"), new Gauge[Long] {
- override def getValue: Long = {
- val storageStatusList = blockManager.master.getStorageStatus
- val diskSpaceUsed = storageStatusList.map(_.diskUsed).sum
- diskSpaceUsed / 1024 / 1024
- }
- })
+ private def registerGauge(name: String, func: BlockManagerMaster => Long): Unit = {
+ metricRegistry.register(name, new Gauge[Long] {
+ override def getValue: Long = func(blockManager.master) / 1024 / 1024
+ })
+ }
+
+ registerGauge(MetricRegistry.name("memory", "maxMem_MB"),
+ _.getStorageStatus.map(_.maxMem).sum)
+
+ registerGauge(MetricRegistry.name("memory", "maxOnHeapMem_MB"),
+ _.getStorageStatus.map(_.maxOnHeapMem.getOrElse(0L)).sum)
+
+ registerGauge(MetricRegistry.name("memory", "maxOffHeapMem_MB"),
+ _.getStorageStatus.map(_.maxOffHeapMem.getOrElse(0L)).sum)
+
+ registerGauge(MetricRegistry.name("memory", "remainingMem_MB"),
+ _.getStorageStatus.map(_.memRemaining).sum)
+
+ registerGauge(MetricRegistry.name("memory", "remainingOnHeapMem_MB"),
+ _.getStorageStatus.map(_.onHeapMemRemaining.getOrElse(0L)).sum)
+
+ registerGauge(MetricRegistry.name("memory", "remainingOffHeapMem_MB"),
+ _.getStorageStatus.map(_.offHeapMemRemaining.getOrElse(0L)).sum)
+
+ registerGauge(MetricRegistry.name("memory", "memUsed_MB"),
+ _.getStorageStatus.map(_.memUsed).sum)
+
+ registerGauge(MetricRegistry.name("memory", "onHeapMemUsed_MB"),
+ _.getStorageStatus.map(_.onHeapMemUsed.getOrElse(0L)).sum)
+
+ registerGauge(MetricRegistry.name("memory", "offHeapMemUsed_MB"),
+ _.getStorageStatus.map(_.offHeapMemUsed.getOrElse(0L)).sum)
+
+ registerGauge(MetricRegistry.name("disk", "diskSpaceUsed_MB"),
+ _.getStorageStatus.map(_.diskUsed).sum)
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala
new file mode 100644
index 000000000000..353eac60df17
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import scala.collection.mutable
+import scala.util.Random
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.internal.Logging
+
+/**
+ * ::DeveloperApi::
+ * BlockReplicationPrioritization provides logic for prioritizing a sequence of peers for
+ * replicating blocks. BlockManager will replicate to each peer returned in order until the
+ * desired replication order is reached. If a replication fails, prioritize() will be called
+ * again to get a fresh prioritization.
+ */
+@DeveloperApi
+trait BlockReplicationPolicy {
+
+ /**
+ * Method to prioritize a bunch of candidate peers of a block
+ *
+ * @param blockManagerId Id of the current BlockManager for self identification
+ * @param peers A list of peers of a BlockManager
+ * @param peersReplicatedTo Set of peers already replicated to
+ * @param blockId BlockId of the block being replicated. This can be used as a source of
+ * randomness if needed.
+ * @param numReplicas Number of peers we need to replicate to
+ * @return A prioritized list of peers. Lower the index of a peer, higher its priority.
+ * This returns a list of size at most `numPeersToReplicateTo`.
+ */
+ def prioritize(
+ blockManagerId: BlockManagerId,
+ peers: Seq[BlockManagerId],
+ peersReplicatedTo: mutable.HashSet[BlockManagerId],
+ blockId: BlockId,
+ numReplicas: Int): List[BlockManagerId]
+}
+
+object BlockReplicationUtils {
+ // scalastyle:off line.size.limit
+ /**
+ * Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while
+ * minimizing space usage. Please see
+ * here.
+ *
+ * @param n total number of indices
+ * @param m number of samples needed
+ * @param r random number generator
+ * @return list of m random unique indices
+ */
+ // scalastyle:on line.size.limit
+ private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = {
+ val indices = (n - m + 1 to n).foldLeft(mutable.LinkedHashSet.empty[Int]) {case (set, i) =>
+ val t = r.nextInt(i) + 1
+ if (set.contains(t)) set + i else set + t
+ }
+ indices.map(_ - 1).toList
+ }
+
+ /**
+ * Get a random sample of size m from the elems
+ *
+ * @param elems
+ * @param m number of samples needed
+ * @param r random number generator
+ * @tparam T
+ * @return a random list of size m. If there are fewer than m elements in elems, we just
+ * randomly shuffle elems
+ */
+ def getRandomSample[T](elems: Seq[T], m: Int, r: Random): List[T] = {
+ if (elems.size > m) {
+ getSampleIds(elems.size, m, r).map(elems(_))
+ } else {
+ r.shuffle(elems).toList
+ }
+ }
+}
+
+@DeveloperApi
+class RandomBlockReplicationPolicy
+ extends BlockReplicationPolicy
+ with Logging {
+
+ /**
+ * Method to prioritize a bunch of candidate peers of a block. This is a basic implementation,
+ * that just makes sure we put blocks on different hosts, if possible
+ *
+ * @param blockManagerId Id of the current BlockManager for self identification
+ * @param peers A list of peers of a BlockManager
+ * @param peersReplicatedTo Set of peers already replicated to
+ * @param blockId BlockId of the block being replicated. This can be used as a source of
+ * randomness if needed.
+ * @param numReplicas Number of peers we need to replicate to
+ * @return A prioritized list of peers. Lower the index of a peer, higher its priority
+ */
+ override def prioritize(
+ blockManagerId: BlockManagerId,
+ peers: Seq[BlockManagerId],
+ peersReplicatedTo: mutable.HashSet[BlockManagerId],
+ blockId: BlockId,
+ numReplicas: Int): List[BlockManagerId] = {
+ val random = new Random(blockId.hashCode)
+ logDebug(s"Input peers : ${peers.mkString(", ")}")
+ val prioritizedPeers = if (peers.size > numReplicas) {
+ BlockReplicationUtils.getRandomSample(peers, numReplicas, random)
+ } else {
+ if (peers.size < numReplicas) {
+ logWarning(s"Expecting ${numReplicas} replicas with only ${peers.size} peer/s.")
+ }
+ random.shuffle(peers).toList
+ }
+ logDebug(s"Prioritized peers : ${prioritizedPeers.mkString(", ")}")
+ prioritizedPeers
+ }
+}
+
+@DeveloperApi
+class BasicBlockReplicationPolicy
+ extends BlockReplicationPolicy
+ with Logging {
+
+ /**
+ * Method to prioritize a bunch of candidate peers of a block manager. This implementation
+ * replicates the behavior of block replication in HDFS. For a given number of replicas needed,
+ * we choose a peer within the rack, one outside and remaining blockmanagers are chosen at
+ * random, in that order till we meet the number of replicas needed.
+ * This works best with a total replication factor of 3, like HDFS.
+ *
+ * @param blockManagerId Id of the current BlockManager for self identification
+ * @param peers A list of peers of a BlockManager
+ * @param peersReplicatedTo Set of peers already replicated to
+ * @param blockId BlockId of the block being replicated. This can be used as a source of
+ * randomness if needed.
+ * @param numReplicas Number of peers we need to replicate to
+ * @return A prioritized list of peers. Lower the index of a peer, higher its priority
+ */
+ override def prioritize(
+ blockManagerId: BlockManagerId,
+ peers: Seq[BlockManagerId],
+ peersReplicatedTo: mutable.HashSet[BlockManagerId],
+ blockId: BlockId,
+ numReplicas: Int): List[BlockManagerId] = {
+
+ logDebug(s"Input peers : $peers")
+ logDebug(s"BlockManagerId : $blockManagerId")
+
+ val random = new Random(blockId.hashCode)
+
+ // if block doesn't have topology info, we can't do much, so we randomly shuffle
+ // if there is, we see what's needed from peersReplicatedTo and based on numReplicas,
+ // we choose whats needed
+ if (blockManagerId.topologyInfo.isEmpty || numReplicas == 0) {
+ // no topology info for the block. The best we can do is randomly choose peers
+ BlockReplicationUtils.getRandomSample(peers, numReplicas, random)
+ } else {
+ // we have topology information, we see what is left to be done from peersReplicatedTo
+ val doneWithinRack = peersReplicatedTo.exists(_.topologyInfo == blockManagerId.topologyInfo)
+ val doneOutsideRack = peersReplicatedTo.exists { p =>
+ p.topologyInfo.isDefined && p.topologyInfo != blockManagerId.topologyInfo
+ }
+
+ if (doneOutsideRack && doneWithinRack) {
+ // we are done, we just return a random sample
+ BlockReplicationUtils.getRandomSample(peers, numReplicas, random)
+ } else {
+ // we separate peers within and outside rack
+ val (inRackPeers, outOfRackPeers) = peers
+ .filter(_.host != blockManagerId.host)
+ .partition(_.topologyInfo == blockManagerId.topologyInfo)
+
+ val peerWithinRack = if (doneWithinRack) {
+ // we are done with in-rack replication, so don't need anymore peers
+ Seq.empty
+ } else {
+ if (inRackPeers.isEmpty) {
+ Seq.empty
+ } else {
+ Seq(inRackPeers(random.nextInt(inRackPeers.size)))
+ }
+ }
+
+ val peerOutsideRack = if (doneOutsideRack || numReplicas - peerWithinRack.size <= 0) {
+ Seq.empty
+ } else {
+ if (outOfRackPeers.isEmpty) {
+ Seq.empty
+ } else {
+ Seq(outOfRackPeers(random.nextInt(outOfRackPeers.size)))
+ }
+ }
+
+ val priorityPeers = peerWithinRack ++ peerOutsideRack
+ val numRemainingPeers = numReplicas - priorityPeers.size
+ val remainingPeers = if (numRemainingPeers > 0) {
+ val rPeers = peers.filter(p => !priorityPeers.contains(p))
+ BlockReplicationUtils.getRandomSample(rPeers, numRemainingPeers, random)
+ } else {
+ Seq.empty
+ }
+
+ (priorityPeers ++ remainingPeers).toList
+ }
+
+ }
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala
index 2789e25b8d3a..0a14fcadf53e 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala
@@ -26,8 +26,7 @@ private[spark] case class BlockUIData(
location: String,
storageLevel: StorageLevel,
memSize: Long,
- diskSize: Long,
- externalBlockStoreSize: Long)
+ diskSize: Long)
/**
* The aggregated status of stream blocks in an executor
@@ -41,8 +40,6 @@ private[spark] case class ExecutorStreamBlockStatus(
def totalDiskSize: Long = blocks.map(_.diskSize).sum
- def totalExternalBlockStoreSize: Long = blocks.map(_.externalBlockStoreSize).sum
-
def numStreamBlocks: Int = blocks.size
}
@@ -62,7 +59,6 @@ private[spark] class BlockStatusListener extends SparkListener {
val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel
val memSize = blockUpdated.blockUpdatedInfo.memSize
val diskSize = blockUpdated.blockUpdatedInfo.diskSize
- val externalBlockStoreSize = blockUpdated.blockUpdatedInfo.externalBlockStoreSize
synchronized {
// Drop the update info if the block manager is not registered
@@ -74,8 +70,7 @@ private[spark] class BlockStatusListener extends SparkListener {
blockManagerId.hostPort,
storageLevel,
memSize,
- diskSize,
- externalBlockStoreSize)
+ diskSize)
)
} else {
// If isValid is not true, it means we should drop the block.
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
deleted file mode 100644
index 69985c9759e2..000000000000
--- a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
+++ /dev/null
@@ -1,71 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.nio.ByteBuffer
-
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.Logging
-
-/**
- * Abstract class to store blocks.
- */
-private[spark] abstract class BlockStore(val blockManager: BlockManager) extends Logging {
-
- def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel): PutResult
-
- /**
- * Put in a block and, possibly, also return its content as either bytes or another Iterator.
- * This is used to efficiently write the values to multiple locations (e.g. for replication).
- *
- * @return a PutResult that contains the size of the data, as well as the values put if
- * returnValues is true (if not, the result's data field can be null)
- */
- def putIterator(
- blockId: BlockId,
- values: Iterator[Any],
- level: StorageLevel,
- returnValues: Boolean): PutResult
-
- def putArray(
- blockId: BlockId,
- values: Array[Any],
- level: StorageLevel,
- returnValues: Boolean): PutResult
-
- /**
- * Return the size of a block in bytes.
- */
- def getSize(blockId: BlockId): Long
-
- def getBytes(blockId: BlockId): Option[ByteBuffer]
-
- def getValues(blockId: BlockId): Option[Iterator[Any]]
-
- /**
- * Remove a block, if it exists.
- * @param blockId the block to remove.
- * @return True if the block was found and removed, False otherwise.
- */
- def remove(blockId: BlockId): Boolean
-
- def contains(blockId: BlockId): Boolean
-
- def clear() { }
-}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala
index a5790e4454a8..e070bf658acb 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala
@@ -30,8 +30,7 @@ case class BlockUpdatedInfo(
blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
- diskSize: Long,
- externalBlockStoreSize: Long)
+ diskSize: Long)
private[spark] object BlockUpdatedInfo {
@@ -41,7 +40,6 @@ private[spark] object BlockUpdatedInfo {
updateBlockInfo.blockId,
updateBlockInfo.storageLevel,
updateBlockInfo.memSize,
- updateBlockInfo.diskSize,
- updateBlockInfo.externalBlockStoreSize)
+ updateBlockInfo.diskSize)
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index f7e84a2c2e14..a69bcc925999 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -17,27 +17,24 @@
package org.apache.spark.storage
+import java.io.{File, IOException}
import java.util.UUID
-import java.io.{IOException, File}
-import org.apache.spark.{SparkConf, Logging}
+import org.apache.spark.SparkConf
import org.apache.spark.executor.ExecutorExitCode
+import org.apache.spark.internal.Logging
import org.apache.spark.util.{ShutdownHookManager, Utils}
/**
* Creates and maintains the logical mapping between logical blocks and physical on-disk
- * locations. By default, one block is mapped to one file with a name given by its BlockId.
- * However, it is also possible to have a block map to only a segment of a file, by calling
- * mapBlockToFileSegment().
+ * locations. One block is mapped to one file with a name given by its BlockId.
*
* Block files are hashed among the directories listed in spark.local.dir (or in
* SPARK_LOCAL_DIRS, if it's set).
*/
-private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkConf)
- extends Logging {
+private[spark] class DiskBlockManager(conf: SparkConf, deleteFilesOnStop: Boolean) extends Logging {
- private[spark]
- val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64)
+ private[spark] val subDirsPerLocalDir = conf.getInt("spark.diskStore.subDirectories", 64)
/* Create one local directory for each path mentioned in spark.local.dir; then, inside this
* directory, create multiple subdirectories that we will hash files into, in order to avoid
@@ -103,7 +100,16 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
/** List all the blocks currently stored on disk by the disk manager. */
def getAllBlocks(): Seq[BlockId] = {
- getAllFiles().map(f => BlockId(f.getName))
+ getAllFiles().flatMap { f =>
+ try {
+ Some(BlockId(f.getName))
+ } catch {
+ case _: UnrecognizedBlockId =>
+ // Skip files which do not correspond to blocks, for example temporary
+ // files created by [[SortShuffleWriter]].
+ None
+ }
+ }
}
/** Produces a unique block id and File suitable for storing local intermediate results. */
@@ -144,6 +150,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
}
private def addShutdownHook(): AnyRef = {
+ logDebug("Adding shutdown hook") // force eager creation of logger
ShutdownHookManager.addShutdownHook(ShutdownHookManager.TEMP_DIR_SHUTDOWN_PRIORITY + 1) { () =>
logInfo("Shutdown hook called")
DiskBlockManager.this.doStop()
@@ -163,10 +170,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
}
private def doStop(): Unit = {
- // Only perform cleanup if an external service is not serving our shuffle files.
- // Also blockManagerId could be null if block manager is not initialized properly.
- if (!blockManager.externalShuffleServiceEnabled ||
- (blockManager.blockManagerId != null && blockManager.blockManagerId.isDriver)) {
+ if (deleteFilesOnStop) {
localDirs.foreach { localDir =>
if (localDir.isDirectory() && localDir.exists()) {
try {
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index 80d426fadc65..a024c83d8d8b 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -17,130 +17,185 @@
package org.apache.spark.storage
-import java.io.{BufferedOutputStream, FileOutputStream, File, OutputStream}
+import java.io.{BufferedOutputStream, File, FileOutputStream, OutputStream}
import java.nio.channels.FileChannel
-import org.apache.spark.Logging
-import org.apache.spark.serializer.{SerializerInstance, SerializationStream}
import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.internal.Logging
+import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
import org.apache.spark.util.Utils
/**
* A class for writing JVM objects directly to a file on disk. This class allows data to be appended
- * to an existing block and can guarantee atomicity in the case of faults as it allows the caller to
- * revert partial writes.
+ * to an existing block. For efficiency, it retains the underlying file channel across
+ * multiple commits. This channel is kept open until close() is called. In case of faults,
+ * callers should instead close with revertPartialWritesAndClose() to atomically revert the
+ * uncommitted partial writes.
*
* This class does not support concurrent writes. Also, once the writer has been opened it cannot be
* reopened again.
*/
private[spark] class DiskBlockObjectWriter(
- file: File,
+ val file: File,
+ serializerManager: SerializerManager,
serializerInstance: SerializerInstance,
bufferSize: Int,
- compressStream: OutputStream => OutputStream,
syncWrites: Boolean,
// These write metrics concurrently shared with other active DiskBlockObjectWriters who
// are themselves performing writes. All updates must be relative.
- writeMetrics: ShuffleWriteMetrics)
+ writeMetrics: ShuffleWriteMetrics,
+ val blockId: BlockId = null)
extends OutputStream
with Logging {
+ /**
+ * Guards against close calls, e.g. from a wrapping stream.
+ * Call manualClose to close the stream that was extended by this trait.
+ * Commit uses this trait to close object streams without paying the
+ * cost of closing and opening the underlying file.
+ */
+ private trait ManualCloseOutputStream extends OutputStream {
+ abstract override def close(): Unit = {
+ flush()
+ }
+
+ def manualClose(): Unit = {
+ super.close()
+ }
+ }
+
/** The file channel, used for repositioning / truncating the file. */
private var channel: FileChannel = null
+ private var mcs: ManualCloseOutputStream = null
private var bs: OutputStream = null
private var fos: FileOutputStream = null
private var ts: TimeTrackingOutputStream = null
private var objOut: SerializationStream = null
private var initialized = false
+ private var streamOpen = false
private var hasBeenClosed = false
- private var commitAndCloseHasBeenCalled = false
/**
* Cursors used to represent positions in the file.
*
- * xxxxxxxx|--------|--- |
- * ^ ^ ^
- * | | finalPosition
- * | reportedPosition
- * initialPosition
+ * xxxxxxxxxx|----------|-----|
+ * ^ ^ ^
+ * | | channel.position()
+ * | reportedPosition
+ * committedPosition
*
- * initialPosition: Offset in the file where we start writing. Immutable.
* reportedPosition: Position at the time of the last update to the write metrics.
- * finalPosition: Offset where we stopped writing. Set on closeAndCommit() then never changed.
+ * committedPosition: Offset after last committed write.
* -----: Current writes to the underlying file.
- * xxxxx: Existing contents of the file.
+ * xxxxx: Committed contents of the file.
*/
- private val initialPosition = file.length()
- private var finalPosition: Long = -1
- private var reportedPosition = initialPosition
+ private var committedPosition = file.length()
+ private var reportedPosition = committedPosition
/**
* Keep track of number of records written and also use this to periodically
* output bytes written since the latter is expensive to do for each record.
+ * And we reset it after every commitAndGet called.
*/
private var numRecordsWritten = 0
+ private def initialize(): Unit = {
+ fos = new FileOutputStream(file, true)
+ channel = fos.getChannel()
+ ts = new TimeTrackingOutputStream(writeMetrics, fos)
+ class ManualCloseBufferedOutputStream
+ extends BufferedOutputStream(ts, bufferSize) with ManualCloseOutputStream
+ mcs = new ManualCloseBufferedOutputStream
+ }
+
def open(): DiskBlockObjectWriter = {
if (hasBeenClosed) {
throw new IllegalStateException("Writer already closed. Cannot be reopened.")
}
- fos = new FileOutputStream(file, true)
- ts = new TimeTrackingOutputStream(writeMetrics, fos)
- channel = fos.getChannel()
- bs = compressStream(new BufferedOutputStream(ts, bufferSize))
+ if (!initialized) {
+ initialize()
+ initialized = true
+ }
+
+ bs = serializerManager.wrapStream(blockId, mcs)
objOut = serializerInstance.serializeStream(bs)
- initialized = true
+ streamOpen = true
this
}
- override def close() {
+ /**
+ * Close and cleanup all resources.
+ * Should call after committing or reverting partial writes.
+ */
+ private def closeResources(): Unit = {
if (initialized) {
Utils.tryWithSafeFinally {
- if (syncWrites) {
- // Force outstanding writes to disk and track how long it takes
- objOut.flush()
- val start = System.nanoTime()
- fos.getFD.sync()
- writeMetrics.incShuffleWriteTime(System.nanoTime() - start)
- }
+ mcs.manualClose()
} {
- objOut.close()
+ channel = null
+ mcs = null
+ bs = null
+ fos = null
+ ts = null
+ objOut = null
+ initialized = false
+ streamOpen = false
+ hasBeenClosed = true
}
-
- channel = null
- bs = null
- fos = null
- ts = null
- objOut = null
- initialized = false
- hasBeenClosed = true
}
}
- def isOpen: Boolean = objOut != null
+ /**
+ * Commits any remaining partial writes and closes resources.
+ */
+ override def close() {
+ if (initialized) {
+ Utils.tryWithSafeFinally {
+ commitAndGet()
+ } {
+ closeResources()
+ }
+ }
+ }
/**
* Flush the partial writes and commit them as a single atomic block.
+ * A commit may write additional bytes to frame the atomic block.
+ *
+ * @return file segment with previous offset and length committed on this call.
*/
- def commitAndClose(): Unit = {
- if (initialized) {
+ def commitAndGet(): FileSegment = {
+ if (streamOpen) {
// NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the
// serializer stream and the lower level stream.
objOut.flush()
bs.flush()
- close()
- finalPosition = file.length()
- // In certain compression codecs, more bytes are written after close() is called
- writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition)
+ objOut.close()
+ streamOpen = false
+
+ if (syncWrites) {
+ // Force outstanding writes to disk and track how long it takes
+ val start = System.nanoTime()
+ fos.getFD.sync()
+ writeMetrics.incWriteTime(System.nanoTime() - start)
+ }
+
+ val pos = channel.position()
+ val fileSegment = new FileSegment(file, committedPosition, pos - committedPosition)
+ committedPosition = pos
+ // In certain compression codecs, more bytes are written after streams are closed
+ writeMetrics.incBytesWritten(committedPosition - reportedPosition)
+ reportedPosition = committedPosition
+ numRecordsWritten = 0
+ fileSegment
} else {
- finalPosition = file.length()
+ new FileSegment(file, committedPosition, 0)
}
- commitAndCloseHasBeenCalled = true
}
/**
- * Reverts writes that haven't been flushed yet. Callers should invoke this function
+ * Reverts writes that haven't been committed yet. Callers should invoke this function
* when there are runtime exceptions. This method will not throw, though it may be
* unsuccessful in truncating written data.
*
@@ -149,34 +204,36 @@ private[spark] class DiskBlockObjectWriter(
def revertPartialWritesAndClose(): File = {
// Discard current writes. We do this by flushing the outstanding writes and then
// truncating the file to its initial position.
- try {
+ Utils.tryWithSafeFinally {
if (initialized) {
- writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition)
- writeMetrics.decShuffleRecordsWritten(numRecordsWritten)
- objOut.flush()
- bs.flush()
- close()
+ writeMetrics.decBytesWritten(reportedPosition - committedPosition)
+ writeMetrics.decRecordsWritten(numRecordsWritten)
+ streamOpen = false
+ closeResources()
}
-
- val truncateStream = new FileOutputStream(file, true)
+ } {
+ var truncateStream: FileOutputStream = null
try {
- truncateStream.getChannel.truncate(initialPosition)
- file
+ truncateStream = new FileOutputStream(file, true)
+ truncateStream.getChannel.truncate(committedPosition)
+ } catch {
+ case e: Exception =>
+ logError("Uncaught exception while reverting partial writes to file " + file, e)
} finally {
- truncateStream.close()
+ if (truncateStream != null) {
+ truncateStream.close()
+ truncateStream = null
+ }
}
- } catch {
- case e: Exception =>
- logError("Uncaught exception while reverting partial writes to file " + file, e)
- file
}
+ file
}
/**
* Writes a key-value pair.
*/
def write(key: Any, value: Any) {
- if (!initialized) {
+ if (!streamOpen) {
open()
}
@@ -188,7 +245,7 @@ private[spark] class DiskBlockObjectWriter(
override def write(b: Int): Unit = throw new UnsupportedOperationException()
override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = {
- if (!initialized) {
+ if (!streamOpen) {
open()
}
@@ -200,32 +257,20 @@ private[spark] class DiskBlockObjectWriter(
*/
def recordWritten(): Unit = {
numRecordsWritten += 1
- writeMetrics.incShuffleRecordsWritten(1)
+ writeMetrics.incRecordsWritten(1)
- if (numRecordsWritten % 32 == 0) {
+ if (numRecordsWritten % 16384 == 0) {
updateBytesWritten()
}
}
- /**
- * Returns the file segment of committed data that this Writer has written.
- * This is only valid after commitAndClose() has been called.
- */
- def fileSegment(): FileSegment = {
- if (!commitAndCloseHasBeenCalled) {
- throw new IllegalStateException(
- "fileSegment() is only valid after commitAndClose() has been called")
- }
- new FileSegment(file, initialPosition, finalPosition - initialPosition)
- }
-
/**
* Report the number of bytes written in this writer's shuffle write metrics.
* Note that this is only valid before the underlying streams are closed.
*/
private def updateBytesWritten() {
val pos = channel.position()
- writeMetrics.incShuffleBytesWritten(pos - reportedPosition)
+ writeMetrics.incBytesWritten(pos - reportedPosition)
reportedPosition = pos
}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index c008b9dc1632..c6656341fcd1 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -17,158 +17,273 @@
package org.apache.spark.storage
-import java.io.{IOException, File, FileOutputStream, RandomAccessFile}
+import java.io._
import java.nio.ByteBuffer
+import java.nio.channels.{Channels, ReadableByteChannel, WritableByteChannel}
import java.nio.channels.FileChannel.MapMode
+import java.nio.charset.StandardCharsets.UTF_8
+import java.util.concurrent.ConcurrentHashMap
-import org.apache.spark.Logging
-import org.apache.spark.serializer.Serializer
-import org.apache.spark.util.Utils
+import scala.collection.mutable.ListBuffer
+
+import com.google.common.io.{ByteStreams, Closeables, Files}
+import io.netty.channel.FileRegion
+import io.netty.util.AbstractReferenceCounted
+
+import org.apache.spark.{SecurityManager, SparkConf}
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.security.CryptoStreamUtils
+import org.apache.spark.util.{ByteBufferInputStream, Utils}
+import org.apache.spark.util.io.ChunkedByteBuffer
/**
* Stores BlockManager blocks on disk.
*/
-private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager)
- extends BlockStore(blockManager) with Logging {
+private[spark] class DiskStore(
+ conf: SparkConf,
+ diskManager: DiskBlockManager,
+ securityManager: SecurityManager) extends Logging {
- val minMemoryMapBytes = blockManager.conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m")
+ private val minMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m")
+ private val blockSizes = new ConcurrentHashMap[String, Long]()
- override def getSize(blockId: BlockId): Long = {
- diskManager.getFile(blockId.name).length
- }
+ def getSize(blockId: BlockId): Long = blockSizes.get(blockId.name)
- override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel): PutResult = {
- // So that we do not modify the input offsets !
- // duplicate does not copy buffer, so inexpensive
- val bytes = _bytes.duplicate()
+ /**
+ * Invokes the provided callback function to write the specific block.
+ *
+ * @throws IllegalStateException if the block already exists in the disk store.
+ */
+ def put(blockId: BlockId)(writeFunc: WritableByteChannel => Unit): Unit = {
+ if (contains(blockId)) {
+ throw new IllegalStateException(s"Block $blockId is already present in the disk store")
+ }
logDebug(s"Attempting to put block $blockId")
val startTime = System.currentTimeMillis
val file = diskManager.getFile(blockId)
- val channel = new FileOutputStream(file).getChannel
- Utils.tryWithSafeFinally {
- while (bytes.remaining > 0) {
- channel.write(bytes)
+ val out = new CountingWritableChannel(openForWrite(file))
+ var threwException: Boolean = true
+ try {
+ writeFunc(out)
+ blockSizes.put(blockId.name, out.getCount)
+ threwException = false
+ } finally {
+ try {
+ out.close()
+ } catch {
+ case ioe: IOException =>
+ if (!threwException) {
+ threwException = true
+ throw ioe
+ }
+ } finally {
+ if (threwException) {
+ remove(blockId)
+ }
}
- } {
- channel.close()
}
val finishTime = System.currentTimeMillis
logDebug("Block %s stored as %s file on disk in %d ms".format(
- file.getName, Utils.bytesToString(bytes.limit), finishTime - startTime))
- PutResult(bytes.limit(), Right(bytes.duplicate()))
+ file.getName,
+ Utils.bytesToString(file.length()),
+ finishTime - startTime))
}
- override def putArray(
- blockId: BlockId,
- values: Array[Any],
- level: StorageLevel,
- returnValues: Boolean): PutResult = {
- putIterator(blockId, values.toIterator, level, returnValues)
+ def putBytes(blockId: BlockId, bytes: ChunkedByteBuffer): Unit = {
+ put(blockId) { channel =>
+ bytes.writeFully(channel)
+ }
}
- override def putIterator(
- blockId: BlockId,
- values: Iterator[Any],
- level: StorageLevel,
- returnValues: Boolean): PutResult = {
+ def getBytes(blockId: BlockId): BlockData = {
+ val file = diskManager.getFile(blockId.name)
+ val blockSize = getSize(blockId)
- logDebug(s"Attempting to write values for block $blockId")
- val startTime = System.currentTimeMillis
- val file = diskManager.getFile(blockId)
- val outputStream = new FileOutputStream(file)
- try {
- Utils.tryWithSafeFinally {
- blockManager.dataSerializeStream(blockId, outputStream, values)
- } {
- // Close outputStream here because it should be closed before file is deleted.
- outputStream.close()
- }
- } catch {
- case e: Throwable =>
- if (file.exists()) {
- if (!file.delete()) {
- logWarning(s"Error deleting ${file}")
+ securityManager.getIOEncryptionKey() match {
+ case Some(key) =>
+ // Encrypted blocks cannot be memory mapped; return a special object that does decryption
+ // and provides InputStream / FileRegion implementations for reading the data.
+ new EncryptedBlockData(file, blockSize, conf, key)
+
+ case _ =>
+ val channel = new FileInputStream(file).getChannel()
+ if (blockSize < minMemoryMapBytes) {
+ // For small files, directly read rather than memory map.
+ Utils.tryWithSafeFinally {
+ val buf = ByteBuffer.allocate(blockSize.toInt)
+ JavaUtils.readFully(channel, buf)
+ buf.flip()
+ new ByteBufferBlockData(new ChunkedByteBuffer(buf), true)
+ } {
+ channel.close()
+ }
+ } else {
+ Utils.tryWithSafeFinally {
+ new ByteBufferBlockData(
+ new ChunkedByteBuffer(channel.map(MapMode.READ_ONLY, 0, file.length)), true)
+ } {
+ channel.close()
}
}
- throw e
}
+ }
- val length = file.length
+ def remove(blockId: BlockId): Boolean = {
+ blockSizes.remove(blockId.name)
+ val file = diskManager.getFile(blockId.name)
+ if (file.exists()) {
+ val ret = file.delete()
+ if (!ret) {
+ logWarning(s"Error deleting ${file.getPath()}")
+ }
+ ret
+ } else {
+ false
+ }
+ }
- val timeTaken = System.currentTimeMillis - startTime
- logDebug("Block %s stored as %s file on disk in %d ms".format(
- file.getName, Utils.bytesToString(length), timeTaken))
+ def contains(blockId: BlockId): Boolean = {
+ val file = diskManager.getFile(blockId.name)
+ file.exists()
+ }
- if (returnValues) {
- // Return a byte buffer for the contents of the file
- val buffer = getBytes(blockId).get
- PutResult(length, Right(buffer))
- } else {
- PutResult(length, null)
+ private def openForWrite(file: File): WritableByteChannel = {
+ val out = new FileOutputStream(file).getChannel()
+ try {
+ securityManager.getIOEncryptionKey().map { key =>
+ CryptoStreamUtils.createWritableChannel(out, conf, key)
+ }.getOrElse(out)
+ } catch {
+ case e: Exception =>
+ Closeables.close(out, true)
+ file.delete()
+ throw e
}
}
- private def getBytes(file: File, offset: Long, length: Long): Option[ByteBuffer] = {
- val channel = new RandomAccessFile(file, "r").getChannel
- Utils.tryWithSafeFinally {
- // For small files, directly read rather than memory map
- if (length < minMemoryMapBytes) {
- val buf = ByteBuffer.allocate(length.toInt)
- channel.position(offset)
- while (buf.remaining() != 0) {
- if (channel.read(buf) == -1) {
- throw new IOException("Reached EOF before filling buffer\n" +
- s"offset=$offset\nfile=${file.getAbsolutePath}\nbuf.remaining=${buf.remaining}")
- }
- }
- buf.flip()
- Some(buf)
- } else {
- Some(channel.map(MapMode.READ_ONLY, offset, length))
+}
+
+private class EncryptedBlockData(
+ file: File,
+ blockSize: Long,
+ conf: SparkConf,
+ key: Array[Byte]) extends BlockData {
+
+ override def toInputStream(): InputStream = Channels.newInputStream(open())
+
+ override def toNetty(): Object = new ReadableChannelFileRegion(open(), blockSize)
+
+ override def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer = {
+ val source = open()
+ try {
+ var remaining = blockSize
+ val chunks = new ListBuffer[ByteBuffer]()
+ while (remaining > 0) {
+ val chunkSize = math.min(remaining, Int.MaxValue)
+ val chunk = allocator(chunkSize.toInt)
+ remaining -= chunkSize
+ JavaUtils.readFully(source, chunk)
+ chunk.flip()
+ chunks += chunk
}
- } {
- channel.close()
+
+ new ChunkedByteBuffer(chunks.toArray)
+ } finally {
+ source.close()
}
}
- override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
- val file = diskManager.getFile(blockId.name)
- getBytes(file, 0, file.length)
+ override def toByteBuffer(): ByteBuffer = {
+ // This is used by the block transfer service to replicate blocks. The upload code reads
+ // all bytes into memory to send the block to the remote executor, so it's ok to do this
+ // as long as the block fits in a Java array.
+ assert(blockSize <= Int.MaxValue, "Block is too large to be wrapped in a byte buffer.")
+ val dst = ByteBuffer.allocate(blockSize.toInt)
+ val in = open()
+ try {
+ JavaUtils.readFully(in, dst)
+ dst.flip()
+ dst
+ } finally {
+ Closeables.close(in, true)
+ }
}
- def getBytes(segment: FileSegment): Option[ByteBuffer] = {
- getBytes(segment.file, segment.offset, segment.length)
- }
+ override def size: Long = blockSize
- override def getValues(blockId: BlockId): Option[Iterator[Any]] = {
- getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer))
- }
+ override def dispose(): Unit = { }
- /**
- * A version of getValues that allows a custom serializer. This is used as part of the
- * shuffle short-circuit code.
- */
- def getValues(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
- // TODO: Should bypass getBytes and use a stream based implementation, so that
- // we won't use a lot of memory during e.g. external sort merge.
- getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer))
+ private def open(): ReadableByteChannel = {
+ val channel = new FileInputStream(file).getChannel()
+ try {
+ CryptoStreamUtils.createReadableChannel(channel, conf, key)
+ } catch {
+ case e: Exception =>
+ Closeables.close(channel, true)
+ throw e
+ }
}
- override def remove(blockId: BlockId): Boolean = {
- val file = diskManager.getFile(blockId.name)
- if (file.exists()) {
- val ret = file.delete()
- if (!ret) {
- logWarning(s"Error deleting ${file.getPath()}")
+}
+
+private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: Long)
+ extends AbstractReferenceCounted with FileRegion {
+
+ private var _transferred = 0L
+
+ private val buffer = ByteBuffer.allocateDirect(64 * 1024)
+ buffer.flip()
+
+ override def count(): Long = blockSize
+
+ override def position(): Long = 0
+
+ override def transfered(): Long = _transferred
+
+ override def transferTo(target: WritableByteChannel, pos: Long): Long = {
+ assert(pos == transfered(), "Invalid position.")
+
+ var written = 0L
+ var lastWrite = -1L
+ while (lastWrite != 0) {
+ if (!buffer.hasRemaining()) {
+ buffer.clear()
+ source.read(buffer)
+ buffer.flip()
+ }
+ if (buffer.hasRemaining()) {
+ lastWrite = target.write(buffer)
+ written += lastWrite
+ } else {
+ lastWrite = 0
}
- ret
- } else {
- false
}
+
+ _transferred += written
+ written
}
- override def contains(blockId: BlockId): Boolean = {
- val file = diskManager.getFile(blockId.name)
- file.exists()
+ override def deallocate(): Unit = source.close()
+}
+
+private class CountingWritableChannel(sink: WritableByteChannel) extends WritableByteChannel {
+
+ private var count = 0L
+
+ def getCount: Long = count
+
+ override def write(src: ByteBuffer): Int = {
+ val written = sink.write(src)
+ if (written > 0) {
+ count += written
+ }
+ written
}
+
+ override def isOpen(): Boolean = sink.isOpen()
+
+ override def close(): Unit = sink.close()
+
}
diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala
deleted file mode 100644
index f39325a12d24..000000000000
--- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala
+++ /dev/null
@@ -1,122 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.nio.ByteBuffer
-
-/**
- * An abstract class that the concrete external block manager has to inherit.
- * The class has to have a no-argument constructor, and will be initialized by init,
- * which is invoked by ExternalBlockStore. The main input parameter is blockId for all
- * the methods, which is the unique identifier for Block in one Spark application.
- *
- * The underlying external block manager should avoid any name space conflicts among multiple
- * Spark applications. For example, creating different directory for different applications
- * by randomUUID
- *
- */
-private[spark] abstract class ExternalBlockManager {
-
- protected var blockManager: BlockManager = _
-
- override def toString: String = {"External Block Store"}
-
- /**
- * Initialize a concrete block manager implementation. Subclass should initialize its internal
- * data structure, e.g, file system, in this function, which is invoked by ExternalBlockStore
- * right after the class is constructed. The function should throw IOException on failure
- *
- * @throws java.io.IOException if there is any file system failure during the initialization.
- */
- def init(blockManager: BlockManager, executorId: String): Unit = {
- this.blockManager = blockManager
- }
-
- /**
- * Drop the block from underlying external block store, if it exists..
- * @return true on successfully removing the block
- * false if the block could not be removed as it was not found
- *
- * @throws java.io.IOException if there is any file system failure in removing the block.
- */
- def removeBlock(blockId: BlockId): Boolean
-
- /**
- * Used by BlockManager to check the existence of the block in the underlying external
- * block store.
- * @return true if the block exists.
- * false if the block does not exists.
- *
- * @throws java.io.IOException if there is any file system failure in checking
- * the block existence.
- */
- def blockExists(blockId: BlockId): Boolean
-
- /**
- * Put the given block to the underlying external block store. Note that in normal case,
- * putting a block should never fail unless something wrong happens to the underlying
- * external block store, e.g., file system failure, etc. In this case, IOException
- * should be thrown.
- *
- * @throws java.io.IOException if there is any file system failure in putting the block.
- */
- def putBytes(blockId: BlockId, bytes: ByteBuffer): Unit
-
- def putValues(blockId: BlockId, values: Iterator[_]): Unit = {
- val bytes = blockManager.dataSerialize(blockId, values)
- putBytes(blockId, bytes)
- }
-
- /**
- * Retrieve the block bytes.
- * @return Some(ByteBuffer) if the block bytes is successfully retrieved
- * None if the block does not exist in the external block store.
- *
- * @throws java.io.IOException if there is any file system failure in getting the block.
- */
- def getBytes(blockId: BlockId): Option[ByteBuffer]
-
- /**
- * Retrieve the block data.
- * @return Some(Iterator[Any]) if the block data is successfully retrieved
- * None if the block does not exist in the external block store.
- *
- * @throws java.io.IOException if there is any file system failure in getting the block.
- */
- def getValues(blockId: BlockId): Option[Iterator[_]] = {
- getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer))
- }
-
- /**
- * Get the size of the block saved in the underlying external block store,
- * which is saved before by putBytes.
- * @return size of the block
- * 0 if the block does not exist
- *
- * @throws java.io.IOException if there is any file system failure in getting the block size.
- */
- def getSize(blockId: BlockId): Long
-
- /**
- * Clean up any information persisted in the underlying external block store,
- * e.g., the directory, files, etc,which is invoked by the shutdown hook of ExternalBlockStore
- * during system shutdown.
- *
- */
- def shutdown()
-}
diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala
deleted file mode 100644
index db965d54bafd..000000000000
--- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala
+++ /dev/null
@@ -1,217 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.nio.ByteBuffer
-
-import scala.util.control.NonFatal
-
-import org.apache.spark.Logging
-import org.apache.spark.util.Utils
-
-
-/**
- * Stores BlockManager blocks on ExternalBlockStore.
- * We capture any potential exception from underlying implementation
- * and return with the expected failure value
- */
-private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: String)
- extends BlockStore(blockManager: BlockManager) with Logging {
-
- lazy val externalBlockManager: Option[ExternalBlockManager] = createBlkManager()
-
- logInfo("ExternalBlockStore started")
-
- override def getSize(blockId: BlockId): Long = {
- try {
- externalBlockManager.map(_.getSize(blockId)).getOrElse(0)
- } catch {
- case NonFatal(t) =>
- logError(s"Error in getSize($blockId)", t)
- 0L
- }
- }
-
- override def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel): PutResult = {
- putIntoExternalBlockStore(blockId, bytes, returnValues = true)
- }
-
- override def putArray(
- blockId: BlockId,
- values: Array[Any],
- level: StorageLevel,
- returnValues: Boolean): PutResult = {
- putIntoExternalBlockStore(blockId, values.toIterator, returnValues)
- }
-
- override def putIterator(
- blockId: BlockId,
- values: Iterator[Any],
- level: StorageLevel,
- returnValues: Boolean): PutResult = {
- putIntoExternalBlockStore(blockId, values, returnValues)
- }
-
- private def putIntoExternalBlockStore(
- blockId: BlockId,
- values: Iterator[_],
- returnValues: Boolean): PutResult = {
- logTrace(s"Attempting to put block $blockId into ExternalBlockStore")
- // we should never hit here if externalBlockManager is None. Handle it anyway for safety.
- try {
- val startTime = System.currentTimeMillis
- if (externalBlockManager.isDefined) {
- externalBlockManager.get.putValues(blockId, values)
- val size = getSize(blockId)
- val data = if (returnValues) {
- Left(getValues(blockId).get)
- } else {
- null
- }
- val finishTime = System.currentTimeMillis
- logDebug("Block %s stored as %s file in ExternalBlockStore in %d ms".format(
- blockId, Utils.bytesToString(size), finishTime - startTime))
- PutResult(size, data)
- } else {
- logError(s"Error in putValues($blockId): no ExternalBlockManager has been configured")
- PutResult(-1, null, Seq((blockId, BlockStatus.empty)))
- }
- } catch {
- case NonFatal(t) =>
- logError(s"Error in putValues($blockId)", t)
- PutResult(-1, null, Seq((blockId, BlockStatus.empty)))
- }
- }
-
- private def putIntoExternalBlockStore(
- blockId: BlockId,
- bytes: ByteBuffer,
- returnValues: Boolean): PutResult = {
- logTrace(s"Attempting to put block $blockId into ExternalBlockStore")
- // we should never hit here if externalBlockManager is None. Handle it anyway for safety.
- try {
- val startTime = System.currentTimeMillis
- if (externalBlockManager.isDefined) {
- val byteBuffer = bytes.duplicate()
- byteBuffer.rewind()
- externalBlockManager.get.putBytes(blockId, byteBuffer)
- val size = bytes.limit()
- val data = if (returnValues) {
- Right(bytes)
- } else {
- null
- }
- val finishTime = System.currentTimeMillis
- logDebug("Block %s stored as %s file in ExternalBlockStore in %d ms".format(
- blockId, Utils.bytesToString(size), finishTime - startTime))
- PutResult(size, data)
- } else {
- logError(s"Error in putBytes($blockId): no ExternalBlockManager has been configured")
- PutResult(-1, null, Seq((blockId, BlockStatus.empty)))
- }
- } catch {
- case NonFatal(t) =>
- logError(s"Error in putBytes($blockId)", t)
- PutResult(-1, null, Seq((blockId, BlockStatus.empty)))
- }
- }
-
- // We assume the block is removed even if exception thrown
- override def remove(blockId: BlockId): Boolean = {
- try {
- externalBlockManager.map(_.removeBlock(blockId)).getOrElse(true)
- } catch {
- case NonFatal(t) =>
- logError(s"Error in removeBlock($blockId)", t)
- true
- }
- }
-
- override def getValues(blockId: BlockId): Option[Iterator[Any]] = {
- try {
- externalBlockManager.flatMap(_.getValues(blockId))
- } catch {
- case NonFatal(t) =>
- logError(s"Error in getValues($blockId)", t)
- None
- }
- }
-
- override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
- try {
- externalBlockManager.flatMap(_.getBytes(blockId))
- } catch {
- case NonFatal(t) =>
- logError(s"Error in getBytes($blockId)", t)
- None
- }
- }
-
- override def contains(blockId: BlockId): Boolean = {
- try {
- val ret = externalBlockManager.map(_.blockExists(blockId)).getOrElse(false)
- if (!ret) {
- logInfo(s"Remove block $blockId")
- blockManager.removeBlock(blockId, true)
- }
- ret
- } catch {
- case NonFatal(t) =>
- logError(s"Error in getBytes($blockId)", t)
- false
- }
- }
-
- private def addShutdownHook() {
- Runtime.getRuntime.addShutdownHook(new Thread("ExternalBlockStore shutdown hook") {
- override def run(): Unit = Utils.logUncaughtExceptions {
- logDebug("Shutdown hook called")
- externalBlockManager.map(_.shutdown())
- }
- })
- }
-
- // Create concrete block manager and fall back to Tachyon by default for backward compatibility.
- private def createBlkManager(): Option[ExternalBlockManager] = {
- val clsName = blockManager.conf.getOption(ExternalBlockStore.BLOCK_MANAGER_NAME)
- .getOrElse(ExternalBlockStore.DEFAULT_BLOCK_MANAGER_NAME)
-
- try {
- val instance = Utils.classForName(clsName)
- .newInstance()
- .asInstanceOf[ExternalBlockManager]
- instance.init(blockManager, executorId)
- addShutdownHook();
- Some(instance)
- } catch {
- case NonFatal(t) =>
- logError("Cannot initialize external block store", t)
- None
- }
- }
-}
-
-private[spark] object ExternalBlockStore extends Logging {
- val MAX_DIR_CREATION_ATTEMPTS = 10
- val SUB_DIRS_PER_DIR = "64"
- val BASE_DIR = "spark.externalBlockStore.baseDir"
- val FOLD_NAME = "spark.externalBlockStore.folderName"
- val MASTER_URL = "spark.externalBlockStore.url"
- val BLOCK_MANAGER_NAME = "spark.externalBlockStore.blockManager"
- val DEFAULT_BLOCK_MANAGER_NAME = "org.apache.spark.storage.TachyonBlockManager"
-}
diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
deleted file mode 100644
index 4dbac388e098..000000000000
--- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
+++ /dev/null
@@ -1,625 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.nio.ByteBuffer
-import java.util.LinkedHashMap
-
-import scala.collection.mutable
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.TaskContext
-import org.apache.spark.memory.MemoryManager
-import org.apache.spark.util.{SizeEstimator, Utils}
-import org.apache.spark.util.collection.SizeTrackingVector
-
-private case class MemoryEntry(value: Any, size: Long, deserialized: Boolean)
-
-/**
- * Stores blocks in memory, either as Arrays of deserialized Java objects or as
- * serialized ByteBuffers.
- */
-private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: MemoryManager)
- extends BlockStore(blockManager) {
-
- // Note: all changes to memory allocations, notably putting blocks, evicting blocks, and
- // acquiring or releasing unroll memory, must be synchronized on `memoryManager`!
-
- private val conf = blockManager.conf
- private val entries = new LinkedHashMap[BlockId, MemoryEntry](32, 0.75f, true)
-
- // A mapping from taskAttemptId to amount of memory used for unrolling a block (in bytes)
- // All accesses of this map are assumed to have manually synchronized on `memoryManager`
- private val unrollMemoryMap = mutable.HashMap[Long, Long]()
- // Same as `unrollMemoryMap`, but for pending unroll memory as defined below.
- // Pending unroll memory refers to the intermediate memory occupied by a task
- // after the unroll but before the actual putting of the block in the cache.
- // This chunk of memory is expected to be released *as soon as* we finish
- // caching the corresponding block as opposed to until after the task finishes.
- // This is only used if a block is successfully unrolled in its entirety in
- // memory (SPARK-4777).
- private val pendingUnrollMemoryMap = mutable.HashMap[Long, Long]()
-
- // Initial memory to request before unrolling any block
- private val unrollMemoryThreshold: Long =
- conf.getLong("spark.storage.unrollMemoryThreshold", 1024 * 1024)
-
- /** Total amount of memory available for storage, in bytes. */
- private def maxMemory: Long = memoryManager.maxStorageMemory
-
- if (maxMemory < unrollMemoryThreshold) {
- logWarning(s"Max memory ${Utils.bytesToString(maxMemory)} is less than the initial memory " +
- s"threshold ${Utils.bytesToString(unrollMemoryThreshold)} needed to store a block in " +
- s"memory. Please configure Spark with more memory.")
- }
-
- logInfo("MemoryStore started with capacity %s".format(Utils.bytesToString(maxMemory)))
-
- /** Total storage memory used including unroll memory, in bytes. */
- private def memoryUsed: Long = memoryManager.storageMemoryUsed
-
- /**
- * Amount of storage memory, in bytes, used for caching blocks.
- * This does not include memory used for unrolling.
- */
- private def blocksMemoryUsed: Long = memoryManager.synchronized {
- memoryUsed - currentUnrollMemory
- }
-
- override def getSize(blockId: BlockId): Long = {
- entries.synchronized {
- entries.get(blockId).size
- }
- }
-
- override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel): PutResult = {
- // Work on a duplicate - since the original input might be used elsewhere.
- val bytes = _bytes.duplicate()
- bytes.rewind()
- if (level.deserialized) {
- val values = blockManager.dataDeserialize(blockId, bytes)
- putIterator(blockId, values, level, returnValues = true)
- } else {
- val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
- tryToPut(blockId, bytes, bytes.limit, deserialized = false, droppedBlocks)
- PutResult(bytes.limit(), Right(bytes.duplicate()), droppedBlocks)
- }
- }
-
- /**
- * Use `size` to test if there is enough space in MemoryStore. If so, create the ByteBuffer and
- * put it into MemoryStore. Otherwise, the ByteBuffer won't be created.
- *
- * The caller should guarantee that `size` is correct.
- */
- def putBytes(blockId: BlockId, size: Long, _bytes: () => ByteBuffer): PutResult = {
- // Work on a duplicate - since the original input might be used elsewhere.
- lazy val bytes = _bytes().duplicate().rewind().asInstanceOf[ByteBuffer]
- val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
- val putSuccess = tryToPut(blockId, () => bytes, size, deserialized = false, droppedBlocks)
- val data =
- if (putSuccess) {
- assert(bytes.limit == size)
- Right(bytes.duplicate())
- } else {
- null
- }
- PutResult(size, data, droppedBlocks)
- }
-
- override def putArray(
- blockId: BlockId,
- values: Array[Any],
- level: StorageLevel,
- returnValues: Boolean): PutResult = {
- val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
- if (level.deserialized) {
- val sizeEstimate = SizeEstimator.estimate(values.asInstanceOf[AnyRef])
- tryToPut(blockId, values, sizeEstimate, deserialized = true, droppedBlocks)
- PutResult(sizeEstimate, Left(values.iterator), droppedBlocks)
- } else {
- val bytes = blockManager.dataSerialize(blockId, values.iterator)
- tryToPut(blockId, bytes, bytes.limit, deserialized = false, droppedBlocks)
- PutResult(bytes.limit(), Right(bytes.duplicate()), droppedBlocks)
- }
- }
-
- override def putIterator(
- blockId: BlockId,
- values: Iterator[Any],
- level: StorageLevel,
- returnValues: Boolean): PutResult = {
- putIterator(blockId, values, level, returnValues, allowPersistToDisk = true)
- }
-
- /**
- * Attempt to put the given block in memory store.
- *
- * There may not be enough space to fully unroll the iterator in memory, in which case we
- * optionally drop the values to disk if
- * (1) the block's storage level specifies useDisk, and
- * (2) `allowPersistToDisk` is true.
- *
- * One scenario in which `allowPersistToDisk` is false is when the BlockManager reads a block
- * back from disk and attempts to cache it in memory. In this case, we should not persist the
- * block back on disk again, as it is already in disk store.
- */
- private[storage] def putIterator(
- blockId: BlockId,
- values: Iterator[Any],
- level: StorageLevel,
- returnValues: Boolean,
- allowPersistToDisk: Boolean): PutResult = {
- val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
- val unrolledValues = unrollSafely(blockId, values, droppedBlocks)
- unrolledValues match {
- case Left(arrayValues) =>
- // Values are fully unrolled in memory, so store them as an array
- val res = putArray(blockId, arrayValues, level, returnValues)
- droppedBlocks ++= res.droppedBlocks
- PutResult(res.size, res.data, droppedBlocks)
- case Right(iteratorValues) =>
- // Not enough space to unroll this block; drop to disk if applicable
- if (level.useDisk && allowPersistToDisk) {
- logWarning(s"Persisting block $blockId to disk instead.")
- val res = blockManager.diskStore.putIterator(blockId, iteratorValues, level, returnValues)
- PutResult(res.size, res.data, droppedBlocks)
- } else {
- PutResult(0, Left(iteratorValues), droppedBlocks)
- }
- }
- }
-
- override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
- val entry = entries.synchronized {
- entries.get(blockId)
- }
- if (entry == null) {
- None
- } else if (entry.deserialized) {
- Some(blockManager.dataSerialize(blockId, entry.value.asInstanceOf[Array[Any]].iterator))
- } else {
- Some(entry.value.asInstanceOf[ByteBuffer].duplicate()) // Doesn't actually copy the data
- }
- }
-
- override def getValues(blockId: BlockId): Option[Iterator[Any]] = {
- val entry = entries.synchronized {
- entries.get(blockId)
- }
- if (entry == null) {
- None
- } else if (entry.deserialized) {
- Some(entry.value.asInstanceOf[Array[Any]].iterator)
- } else {
- val buffer = entry.value.asInstanceOf[ByteBuffer].duplicate() // Doesn't actually copy data
- Some(blockManager.dataDeserialize(blockId, buffer))
- }
- }
-
- override def remove(blockId: BlockId): Boolean = memoryManager.synchronized {
- val entry = entries.synchronized { entries.remove(blockId) }
- if (entry != null) {
- memoryManager.releaseStorageMemory(entry.size)
- logDebug(s"Block $blockId of size ${entry.size} dropped " +
- s"from memory (free ${maxMemory - blocksMemoryUsed})")
- true
- } else {
- false
- }
- }
-
- override def clear(): Unit = memoryManager.synchronized {
- entries.synchronized {
- entries.clear()
- }
- unrollMemoryMap.clear()
- pendingUnrollMemoryMap.clear()
- memoryManager.releaseAllStorageMemory()
- logInfo("MemoryStore cleared")
- }
-
- /**
- * Unroll the given block in memory safely.
- *
- * The safety of this operation refers to avoiding potential OOM exceptions caused by
- * unrolling the entirety of the block in memory at once. This is achieved by periodically
- * checking whether the memory restrictions for unrolling blocks are still satisfied,
- * stopping immediately if not. This check is a safeguard against the scenario in which
- * there is not enough free memory to accommodate the entirety of a single block.
- *
- * This method returns either an array with the contents of the entire block or an iterator
- * containing the values of the block (if the array would have exceeded available memory).
- */
- def unrollSafely(
- blockId: BlockId,
- values: Iterator[Any],
- droppedBlocks: ArrayBuffer[(BlockId, BlockStatus)])
- : Either[Array[Any], Iterator[Any]] = {
-
- // Number of elements unrolled so far
- var elementsUnrolled = 0
- // Whether there is still enough memory for us to continue unrolling this block
- var keepUnrolling = true
- // Initial per-task memory to request for unrolling blocks (bytes). Exposed for testing.
- val initialMemoryThreshold = unrollMemoryThreshold
- // How often to check whether we need to request more memory
- val memoryCheckPeriod = 16
- // Memory currently reserved by this task for this particular unrolling operation
- var memoryThreshold = initialMemoryThreshold
- // Memory to request as a multiple of current vector size
- val memoryGrowthFactor = 1.5
- // Previous unroll memory held by this task, for releasing later (only at the very end)
- val previousMemoryReserved = currentUnrollMemoryForThisTask
- // Underlying vector for unrolling the block
- var vector = new SizeTrackingVector[Any]
-
- // Request enough memory to begin unrolling
- keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, droppedBlocks)
-
- if (!keepUnrolling) {
- logWarning(s"Failed to reserve initial memory threshold of " +
- s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.")
- }
-
- // Unroll this block safely, checking whether we have exceeded our threshold periodically
- try {
- while (values.hasNext && keepUnrolling) {
- vector += values.next()
- if (elementsUnrolled % memoryCheckPeriod == 0) {
- // If our vector's size has exceeded the threshold, request more memory
- val currentSize = vector.estimateSize()
- if (currentSize >= memoryThreshold) {
- val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong
- keepUnrolling = reserveUnrollMemoryForThisTask(
- blockId, amountToRequest, droppedBlocks)
- // New threshold is currentSize * memoryGrowthFactor
- memoryThreshold += amountToRequest
- }
- }
- elementsUnrolled += 1
- }
-
- if (keepUnrolling) {
- // We successfully unrolled the entirety of this block
- Left(vector.toArray)
- } else {
- // We ran out of space while unrolling the values for this block
- logUnrollFailureMessage(blockId, vector.estimateSize())
- Right(vector.iterator ++ values)
- }
-
- } finally {
- // If we return an array, the values returned here will be cached in `tryToPut` later.
- // In this case, we should release the memory only after we cache the block there.
- if (keepUnrolling) {
- val taskAttemptId = currentTaskAttemptId()
- memoryManager.synchronized {
- // Since we continue to hold onto the array until we actually cache it, we cannot
- // release the unroll memory yet. Instead, we transfer it to pending unroll memory
- // so `tryToPut` can further transfer it to normal storage memory later.
- // TODO: we can probably express this without pending unroll memory (SPARK-10907)
- val amountToTransferToPending = currentUnrollMemoryForThisTask - previousMemoryReserved
- unrollMemoryMap(taskAttemptId) -= amountToTransferToPending
- pendingUnrollMemoryMap(taskAttemptId) =
- pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + amountToTransferToPending
- }
- } else {
- // Otherwise, if we return an iterator, we can only release the unroll memory when
- // the task finishes since we don't know when the iterator will be consumed.
- }
- }
- }
-
- /**
- * Return the RDD ID that a given block ID is from, or None if it is not an RDD block.
- */
- private def getRddId(blockId: BlockId): Option[Int] = {
- blockId.asRDDId.map(_.rddId)
- }
-
- private def tryToPut(
- blockId: BlockId,
- value: Any,
- size: Long,
- deserialized: Boolean,
- droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = {
- tryToPut(blockId, () => value, size, deserialized, droppedBlocks)
- }
-
- /**
- * Try to put in a set of values, if we can free up enough space. The value should either be
- * an Array if deserialized is true or a ByteBuffer otherwise. Its (possibly estimated) size
- * must also be passed by the caller.
- *
- * `value` will be lazily created. If it cannot be put into MemoryStore or disk, `value` won't be
- * created to avoid OOM since it may be a big ByteBuffer.
- *
- * Synchronize on `memoryManager` to ensure that all the put requests and its associated block
- * dropping is done by only on thread at a time. Otherwise while one thread is dropping
- * blocks to free memory for one block, another thread may use up the freed space for
- * another block.
- *
- * All blocks evicted in the process, if any, will be added to `droppedBlocks`.
- *
- * @return whether put was successful.
- */
- private def tryToPut(
- blockId: BlockId,
- value: () => Any,
- size: Long,
- deserialized: Boolean,
- droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = {
-
- /* TODO: Its possible to optimize the locking by locking entries only when selecting blocks
- * to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has
- * been released, it must be ensured that those to-be-dropped blocks are not double counted
- * for freeing up more space for another block that needs to be put. Only then the actually
- * dropping of blocks (and writing to disk if necessary) can proceed in parallel. */
-
- memoryManager.synchronized {
- // Note: if we have previously unrolled this block successfully, then pending unroll
- // memory should be non-zero. This is the amount that we already reserved during the
- // unrolling process. In this case, we can just reuse this space to cache our block.
- // The synchronization on `memoryManager` here guarantees that the release and acquire
- // happen atomically. This relies on the assumption that all memory acquisitions are
- // synchronized on the same lock.
- releasePendingUnrollMemoryForThisTask()
- val enoughMemory = memoryManager.acquireStorageMemory(blockId, size, droppedBlocks)
- if (enoughMemory) {
- // We acquired enough memory for the block, so go ahead and put it
- val entry = new MemoryEntry(value(), size, deserialized)
- entries.synchronized {
- entries.put(blockId, entry)
- }
- val valuesOrBytes = if (deserialized) "values" else "bytes"
- logInfo("Block %s stored as %s in memory (estimated size %s, free %s)".format(
- blockId, valuesOrBytes, Utils.bytesToString(size), Utils.bytesToString(blocksMemoryUsed)))
- } else {
- // Tell the block manager that we couldn't put it in memory so that it can drop it to
- // disk if the block allows disk storage.
- lazy val data = if (deserialized) {
- Left(value().asInstanceOf[Array[Any]])
- } else {
- Right(value().asInstanceOf[ByteBuffer].duplicate())
- }
- val droppedBlockStatus = blockManager.dropFromMemory(blockId, () => data)
- droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) }
- }
- enoughMemory
- }
- }
-
- /**
- * Try to free up a given amount of space by evicting existing blocks.
- *
- * @param space the amount of memory to free, in bytes
- * @param droppedBlocks a holder for blocks evicted in the process
- * @return whether the requested free space is freed.
- */
- private[spark] def ensureFreeSpace(
- space: Long,
- droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = {
- ensureFreeSpace(None, space, droppedBlocks)
- }
-
- /**
- * Try to free up a given amount of space to store a block by evicting existing ones.
- *
- * @param space the amount of memory to free, in bytes
- * @param droppedBlocks a holder for blocks evicted in the process
- * @return whether the requested free space is freed.
- */
- private[spark] def ensureFreeSpace(
- blockId: BlockId,
- space: Long,
- droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = {
- ensureFreeSpace(Some(blockId), space, droppedBlocks)
- }
-
- /**
- * Try to free up a given amount of space to store a particular block, but can fail if
- * either the block is bigger than our memory or it would require replacing another block
- * from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that
- * don't fit into memory that we want to avoid).
- *
- * @param blockId the ID of the block we are freeing space for, if any
- * @param space the size of this block
- * @param droppedBlocks a holder for blocks evicted in the process
- * @return whether the requested free space is freed.
- */
- private def ensureFreeSpace(
- blockId: Option[BlockId],
- space: Long,
- droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = {
- memoryManager.synchronized {
- val freeMemory = maxMemory - memoryUsed
- val rddToAdd = blockId.flatMap(getRddId)
- val selectedBlocks = new ArrayBuffer[BlockId]
- var selectedMemory = 0L
-
- logInfo(s"Ensuring $space bytes of free space " +
- blockId.map { id => s"for block $id" }.getOrElse("") +
- s"(free: $freeMemory, max: $maxMemory)")
-
- // Fail fast if the block simply won't fit
- if (space > maxMemory) {
- logInfo("Will not " + blockId.map { id => s"store $id" }.getOrElse("free memory") +
- s" as the required space ($space bytes) exceeds our memory limit ($maxMemory bytes)")
- return false
- }
-
- // No need to evict anything if there is already enough free space
- if (freeMemory >= space) {
- return true
- }
-
- // This is synchronized to ensure that the set of entries is not changed
- // (because of getValue or getBytes) while traversing the iterator, as that
- // can lead to exceptions.
- entries.synchronized {
- val iterator = entries.entrySet().iterator()
- while (freeMemory + selectedMemory < space && iterator.hasNext) {
- val pair = iterator.next()
- val blockId = pair.getKey
- if (rddToAdd.isEmpty || rddToAdd != getRddId(blockId)) {
- selectedBlocks += blockId
- selectedMemory += pair.getValue.size
- }
- }
- }
-
- if (freeMemory + selectedMemory >= space) {
- logInfo(s"${selectedBlocks.size} blocks selected for dropping")
- for (blockId <- selectedBlocks) {
- val entry = entries.synchronized { entries.get(blockId) }
- // This should never be null as only one task should be dropping
- // blocks and removing entries. However the check is still here for
- // future safety.
- if (entry != null) {
- val data = if (entry.deserialized) {
- Left(entry.value.asInstanceOf[Array[Any]])
- } else {
- Right(entry.value.asInstanceOf[ByteBuffer].duplicate())
- }
- val droppedBlockStatus = blockManager.dropFromMemory(blockId, data)
- droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) }
- }
- }
- true
- } else {
- blockId.foreach { id =>
- logInfo(s"Will not store $id as it would require dropping another block " +
- "from the same RDD")
- }
- false
- }
- }
- }
-
- override def contains(blockId: BlockId): Boolean = {
- entries.synchronized { entries.containsKey(blockId) }
- }
-
- private def currentTaskAttemptId(): Long = {
- // In case this is called on the driver, return an invalid task attempt id.
- Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L)
- }
-
- /**
- * Reserve memory for unrolling the given block for this task.
- * @return whether the request is granted.
- */
- def reserveUnrollMemoryForThisTask(
- blockId: BlockId,
- memory: Long,
- droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = {
- memoryManager.synchronized {
- val success = memoryManager.acquireUnrollMemory(blockId, memory, droppedBlocks)
- if (success) {
- val taskAttemptId = currentTaskAttemptId()
- unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory
- }
- success
- }
- }
-
- /**
- * Release memory used by this task for unrolling blocks.
- * If the amount is not specified, remove the current task's allocation altogether.
- */
- def releaseUnrollMemoryForThisTask(memory: Long = Long.MaxValue): Unit = {
- val taskAttemptId = currentTaskAttemptId()
- memoryManager.synchronized {
- if (unrollMemoryMap.contains(taskAttemptId)) {
- val memoryToRelease = math.min(memory, unrollMemoryMap(taskAttemptId))
- if (memoryToRelease > 0) {
- unrollMemoryMap(taskAttemptId) -= memoryToRelease
- if (unrollMemoryMap(taskAttemptId) == 0) {
- unrollMemoryMap.remove(taskAttemptId)
- }
- memoryManager.releaseUnrollMemory(memoryToRelease)
- }
- }
- }
- }
-
- /**
- * Release pending unroll memory of current unroll successful block used by this task
- */
- def releasePendingUnrollMemoryForThisTask(memory: Long = Long.MaxValue): Unit = {
- val taskAttemptId = currentTaskAttemptId()
- memoryManager.synchronized {
- if (pendingUnrollMemoryMap.contains(taskAttemptId)) {
- val memoryToRelease = math.min(memory, pendingUnrollMemoryMap(taskAttemptId))
- if (memoryToRelease > 0) {
- pendingUnrollMemoryMap(taskAttemptId) -= memoryToRelease
- if (pendingUnrollMemoryMap(taskAttemptId) == 0) {
- pendingUnrollMemoryMap.remove(taskAttemptId)
- }
- memoryManager.releaseUnrollMemory(memoryToRelease)
- }
- }
- }
- }
-
- /**
- * Return the amount of memory currently occupied for unrolling blocks across all tasks.
- */
- def currentUnrollMemory: Long = memoryManager.synchronized {
- unrollMemoryMap.values.sum + pendingUnrollMemoryMap.values.sum
- }
-
- /**
- * Return the amount of memory currently occupied for unrolling blocks by this task.
- */
- def currentUnrollMemoryForThisTask: Long = memoryManager.synchronized {
- unrollMemoryMap.getOrElse(currentTaskAttemptId(), 0L)
- }
-
- /**
- * Return the number of tasks currently unrolling blocks.
- */
- private def numTasksUnrolling: Int = memoryManager.synchronized { unrollMemoryMap.keys.size }
-
- /**
- * Log information about current memory usage.
- */
- private def logMemoryUsage(): Unit = {
- logInfo(
- s"Memory use = ${Utils.bytesToString(blocksMemoryUsed)} (blocks) + " +
- s"${Utils.bytesToString(currentUnrollMemory)} (scratch space shared across " +
- s"$numTasksUnrolling tasks(s)) = ${Utils.bytesToString(memoryUsed)}. " +
- s"Storage limit = ${Utils.bytesToString(maxMemory)}."
- )
- }
-
- /**
- * Log a warning for failing to unroll a block.
- *
- * @param blockId ID of the block we are trying to unroll.
- * @param finalVectorSize Final size of the vector before unrolling failed.
- */
- private def logUnrollFailureMessage(blockId: BlockId, finalVectorSize: Long): Unit = {
- logWarning(
- s"Not enough space to cache $blockId in memory! " +
- s"(computed ${Utils.bytesToString(finalVectorSize)} so far)"
- )
- logMemoryUsage()
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/storage/PutResult.scala b/core/src/main/scala/org/apache/spark/storage/PutResult.scala
deleted file mode 100644
index f0eac7594ecf..000000000000
--- a/core/src/main/scala/org/apache/spark/storage/PutResult.scala
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.nio.ByteBuffer
-
-/**
- * Result of adding a block into a BlockStore. This case class contains a few things:
- * (1) The estimated size of the put,
- * (2) The values put if the caller asked for them to be returned (e.g. for chaining
- * replication), and
- * (3) A list of blocks dropped as a result of this put. This is always empty for DiskStore.
- */
-private[spark] case class PutResult(
- size: Long,
- data: Either[Iterator[_], ByteBuffer],
- droppedBlocks: Seq[(BlockId, BlockStatus)] = Seq.empty)
diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala
index 96062626b504..e5abbf745cc4 100644
--- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala
+++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala
@@ -18,16 +18,17 @@
package org.apache.spark.storage
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.rdd.{RDDOperationScope, RDD}
+import org.apache.spark.rdd.{RDD, RDDOperationScope}
import org.apache.spark.util.Utils
@DeveloperApi
class RDDInfo(
val id: Int,
- val name: String,
+ var name: String,
val numPartitions: Int,
var storageLevel: StorageLevel,
val parentIds: Seq[Int],
+ val callSite: String = "",
val scope: Option[RDDOperationScope] = None)
extends Ordered[RDDInfo] {
@@ -36,15 +37,14 @@ class RDDInfo(
var diskSize = 0L
var externalBlockStoreSize = 0L
- def isCached: Boolean =
- (memSize + diskSize + externalBlockStoreSize > 0) && numCachedPartitions > 0
+ def isCached: Boolean = (memSize + diskSize > 0) && numCachedPartitions > 0
override def toString: String = {
import Utils.bytesToString
("RDD \"%s\" (%d) StorageLevel: %s; CachedPartitions: %d; TotalPartitions: %d; " +
- "MemorySize: %s; ExternalBlockStoreSize: %s; DiskSize: %s").format(
+ "MemorySize: %s; DiskSize: %s").format(
name, id, storageLevel.toString, numCachedPartitions, numPartitions,
- bytesToString(memSize), bytesToString(externalBlockStoreSize), bytesToString(diskSize))
+ bytesToString(memSize), bytesToString(diskSize))
}
override def compare(that: RDDInfo): Int = {
@@ -56,6 +56,7 @@ private[spark] object RDDInfo {
def fromRdd(rdd: RDD[_]): RDDInfo = {
val rddName = Option(rdd.name).getOrElse(Utils.getFormattedClassName(rdd))
val parentIds = rdd.dependencies.map(_.rdd.id)
- new RDDInfo(rdd.id, rddName, rdd.partitions.length, rdd.getStorageLevel, parentIds, rdd.scope)
+ new RDDInfo(rdd.id, rddName, rdd.partitions.length,
+ rdd.getStorageLevel, parentIds, rdd.creationSite.shortForm, rdd.scope)
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index 0d0448feb5b0..a91bbf71b68d 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -17,17 +17,21 @@
package org.apache.spark.storage
-import java.io.InputStream
+import java.io.{File, InputStream, IOException}
+import java.nio.ByteBuffer
import java.util.concurrent.LinkedBlockingQueue
+import javax.annotation.concurrent.GuardedBy
-import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
-import scala.util.control.NonFatal
+import scala.collection.mutable
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
-import org.apache.spark.{Logging, SparkException, TaskContext}
-import org.apache.spark.network.buffer.ManagedBuffer
-import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
+import org.apache.spark.{SparkException, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util.Utils
+import org.apache.spark.util.io.ChunkedByteBufferOutputStream
/**
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
@@ -36,7 +40,7 @@ import org.apache.spark.util.Utils
* This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks
* in a pipelined fashion as they are received.
*
- * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid
+ * The implementation throttles the remote fetches so they don't exceed maxBytesInFlight to avoid
* using too much memory.
*
* @param context [[TaskContext]], used for metrics update
@@ -45,7 +49,13 @@ import org.apache.spark.util.Utils
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
* For each block we also require the size (in bytes as a long field) in
* order to throttle the memory usage.
+ * @param streamWrapper A function to wrap the returned input stream.
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
+ * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point.
+ * @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point
+ * for a given remote host:port.
+ * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory.
+ * @param detectCorrupt whether to detect any corruption in fetched blocks.
*/
private[spark]
final class ShuffleBlockFetcherIterator(
@@ -53,8 +63,13 @@ final class ShuffleBlockFetcherIterator(
shuffleClient: ShuffleClient,
blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
- maxBytesInFlight: Long)
- extends Iterator[(BlockId, InputStream)] with Logging {
+ streamWrapper: (BlockId, InputStream) => InputStream,
+ maxBytesInFlight: Long,
+ maxReqsInFlight: Int,
+ maxBlocksInFlightPerAddress: Int,
+ maxReqSizeShuffleToMem: Long,
+ detectCorrupt: Boolean)
+ extends Iterator[(BlockId, InputStream)] with TempShuffleFileManager with Logging {
import ShuffleBlockFetcherIterator._
@@ -67,7 +82,7 @@ final class ShuffleBlockFetcherIterator(
private[this] var numBlocksToFetch = 0
/**
- * The number of blocks proccessed by the caller. The iterator is exhausted when
+ * The number of blocks processed by the caller. The iterator is exhausted when
* [[numBlocksProcessed]] == [[numBlocksToFetch]].
*/
private[this] var numBlocksProcessed = 0
@@ -90,7 +105,7 @@ final class ShuffleBlockFetcherIterator(
* Current [[FetchResult]] being processed. We track this so we can release the current buffer
* in case of a runtime exception when processing the current buffer.
*/
- @volatile private[this] var currentResult: FetchResult = null
+ @volatile private[this] var currentResult: SuccessFetchResult = null
/**
* Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
@@ -98,16 +113,42 @@ final class ShuffleBlockFetcherIterator(
*/
private[this] val fetchRequests = new Queue[FetchRequest]
+ /**
+ * Queue of fetch requests which could not be issued the first time they were dequeued. These
+ * requests are tried again when the fetch constraints are satisfied.
+ */
+ private[this] val deferredFetchRequests = new HashMap[BlockManagerId, Queue[FetchRequest]]()
+
/** Current bytes in flight from our requests */
private[this] var bytesInFlight = 0L
- private[this] val shuffleMetrics = context.taskMetrics().createShuffleReadMetricsForDependency()
+ /** Current number of requests in flight */
+ private[this] var reqsInFlight = 0
+
+ /** Current number of blocks in flight per host:port */
+ private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]()
+
+ /**
+ * The blocks that can't be decompressed successfully, it is used to guarantee that we retry
+ * at most once for those corrupted blocks.
+ */
+ private[this] val corruptedBlocks = mutable.HashSet[BlockId]()
+
+ private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()
/**
* Whether the iterator is still active. If isZombie is true, the callback interface will no
* longer place fetched blocks into [[results]].
*/
- @volatile private[this] var isZombie = false
+ @GuardedBy("this")
+ private[this] var isZombie = false
+
+ /**
+ * A set to store the files used for shuffling remote huge blocks. Files in this set will be
+ * deleted when cleanup. This is a layer of defensiveness against disk file leaks.
+ */
+ @GuardedBy("this")
+ private[this] val shuffleFilesSet = mutable.HashSet[File]()
initialize()
@@ -115,62 +156,100 @@ final class ShuffleBlockFetcherIterator(
// The currentResult is set to null to prevent releasing the buffer again on cleanup()
private[storage] def releaseCurrentResultBuffer(): Unit = {
// Release the current buffer if necessary
- currentResult match {
- case SuccessFetchResult(_, _, _, buf) => buf.release()
- case _ =>
+ if (currentResult != null) {
+ currentResult.buf.release()
}
currentResult = null
}
+ override def createTempShuffleFile(): File = {
+ blockManager.diskBlockManager.createTempLocalBlock()._2
+ }
+
+ override def registerTempShuffleFileToClean(file: File): Boolean = synchronized {
+ if (isZombie) {
+ false
+ } else {
+ shuffleFilesSet += file
+ true
+ }
+ }
+
/**
* Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
*/
private[this] def cleanup() {
- isZombie = true
+ synchronized {
+ isZombie = true
+ }
releaseCurrentResultBuffer()
// Release buffers in the results queue
val iter = results.iterator()
while (iter.hasNext) {
val result = iter.next()
result match {
- case SuccessFetchResult(_, _, _, buf) => buf.release()
+ case SuccessFetchResult(_, address, _, buf, _) =>
+ if (address != blockManager.blockManagerId) {
+ shuffleMetrics.incRemoteBytesRead(buf.size)
+ shuffleMetrics.incRemoteBlocksFetched(1)
+ }
+ buf.release()
case _ =>
}
}
+ shuffleFilesSet.foreach { file =>
+ if (!file.delete()) {
+ logWarning("Failed to cleanup shuffle fetch temp file " + file.getAbsolutePath())
+ }
+ }
}
private[this] def sendRequest(req: FetchRequest) {
logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
bytesInFlight += req.size
+ reqsInFlight += 1
// so we can look up the size of each blockID
val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
+ val remainingBlocks = new HashSet[String]() ++= sizeMap.keys
val blockIds = req.blocks.map(_._1.toString)
-
val address = req.address
- shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
- new BlockFetchingListener {
- override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
- // Only add the buffer to results queue if the iterator is not zombie,
- // i.e. cleanup() has not been called yet.
+
+ val blockFetchingListener = new BlockFetchingListener {
+ override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
+ // Only add the buffer to results queue if the iterator is not zombie,
+ // i.e. cleanup() has not been called yet.
+ ShuffleBlockFetcherIterator.this.synchronized {
if (!isZombie) {
// Increment the ref count because we need to pass this to a different thread.
// This needs to be released after use.
buf.retain()
- results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf))
- shuffleMetrics.incRemoteBytesRead(buf.size)
- shuffleMetrics.incRemoteBlocksFetched(1)
+ remainingBlocks -= blockId
+ results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
+ remainingBlocks.isEmpty))
+ logDebug("remainingBlocks: " + remainingBlocks)
}
- logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
+ logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
+ }
- override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
- logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
- results.put(new FailureFetchResult(BlockId(blockId), address, e))
- }
+ override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
+ logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
+ results.put(new FailureFetchResult(BlockId(blockId), address, e))
}
- )
+ }
+
+ // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is
+ // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch
+ // the data and write it to file directly.
+ if (req.size > maxReqSizeShuffleToMem) {
+ shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
+ blockFetchingListener, this)
+ } else {
+ shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
+ blockFetchingListener, null)
+ }
}
private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
@@ -178,7 +257,8 @@ final class ShuffleBlockFetcherIterator(
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
// nodes, rather than blocking on reading output from one node.
val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
- logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
+ logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize
+ + ", maxBlocksInFlightPerAddress: " + maxBlocksInFlightPerAddress)
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
@@ -207,11 +287,13 @@ final class ShuffleBlockFetcherIterator(
} else if (size < 0) {
throw new BlockException(blockId, "Negative block size " + size)
}
- if (curRequestSize >= targetRequestSize) {
+ if (curRequestSize >= targetRequestSize ||
+ curBlocks.size >= maxBlocksInFlightPerAddress) {
// Add this FetchRequest
remoteRequests += new FetchRequest(address, curBlocks)
+ logDebug(s"Creating fetch request of $curRequestSize at $address "
+ + s"with ${curBlocks.size} blocks")
curBlocks = new ArrayBuffer[(BlockId, Long)]
- logDebug(s"Creating fetch request of $curRequestSize at $address")
curRequestSize = 0
}
}
@@ -227,7 +309,7 @@ final class ShuffleBlockFetcherIterator(
/**
* Fetch the local blocks while we are fetching remote blocks. This is ok because
- * [[ManagedBuffer]]'s memory is allocated lazily when we create the input stream, so all we
+ * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we
* track in-memory are the ManagedBuffer references themselves.
*/
private[this] def fetchLocalBlocks() {
@@ -239,7 +321,7 @@ final class ShuffleBlockFetcherIterator(
shuffleMetrics.incLocalBlocksFetched(1)
shuffleMetrics.incLocalBytesRead(buf.size)
buf.retain()
- results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf))
+ results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false))
} catch {
case e: Exception =>
// If we see an exception, stop immediately.
@@ -258,6 +340,9 @@ final class ShuffleBlockFetcherIterator(
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
fetchRequests ++= Utils.randomize(remoteRequests)
+ assert ((0 == reqsInFlight) == (0 == bytesInFlight),
+ "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +
+ ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)
// Send out initial requests for blocks, up to our maxBytesInFlight
fetchUpToMaxBytes()
@@ -281,39 +366,144 @@ final class ShuffleBlockFetcherIterator(
* Throws a FetchFailedException if the next block could not be fetched.
*/
override def next(): (BlockId, InputStream) = {
- numBlocksProcessed += 1
- val startFetchWait = System.currentTimeMillis()
- currentResult = results.take()
- val result = currentResult
- val stopFetchWait = System.currentTimeMillis()
- shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
-
- result match {
- case SuccessFetchResult(_, _, size, _) => bytesInFlight -= size
- case _ =>
+ if (!hasNext) {
+ throw new NoSuchElementException
}
- // Send fetch requests up to maxBytesInFlight
- fetchUpToMaxBytes()
- result match {
- case FailureFetchResult(blockId, address, e) =>
- throwFetchFailedException(blockId, address, e)
+ numBlocksProcessed += 1
+
+ var result: FetchResult = null
+ var input: InputStream = null
+ // Take the next fetched result and try to decompress it to detect data corruption,
+ // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch
+ // is also corrupt, so the previous stage could be retried.
+ // For local shuffle block, throw FailureFetchResult for the first IOException.
+ while (result == null) {
+ val startFetchWait = System.currentTimeMillis()
+ result = results.take()
+ val stopFetchWait = System.currentTimeMillis()
+ shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
- case SuccessFetchResult(blockId, address, _, buf) =>
- try {
- (result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this))
- } catch {
- case NonFatal(t) =>
- throwFetchFailedException(blockId, address, t)
- }
+ result match {
+ case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) =>
+ if (address != blockManager.blockManagerId) {
+ numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
+ shuffleMetrics.incRemoteBytesRead(buf.size)
+ shuffleMetrics.incRemoteBlocksFetched(1)
+ }
+ bytesInFlight -= size
+ if (isNetworkReqDone) {
+ reqsInFlight -= 1
+ logDebug("Number of requests in flight " + reqsInFlight)
+ }
+
+ val in = try {
+ buf.createInputStream()
+ } catch {
+ // The exception could only be throwed by local shuffle block
+ case e: IOException =>
+ assert(buf.isInstanceOf[FileSegmentManagedBuffer])
+ logError("Failed to create input stream from local block", e)
+ buf.release()
+ throwFetchFailedException(blockId, address, e)
+ }
+
+ input = streamWrapper(blockId, in)
+ // Only copy the stream if it's wrapped by compression or encryption, also the size of
+ // block is small (the decompressed block is smaller than maxBytesInFlight)
+ if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) {
+ val originalInput = input
+ val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate)
+ try {
+ // Decompress the whole block at once to detect any corruption, which could increase
+ // the memory usage tne potential increase the chance of OOM.
+ // TODO: manage the memory used here, and spill it into disk in case of OOM.
+ Utils.copyStream(input, out)
+ out.close()
+ input = out.toChunkedByteBuffer.toInputStream(dispose = true)
+ } catch {
+ case e: IOException =>
+ buf.release()
+ if (buf.isInstanceOf[FileSegmentManagedBuffer]
+ || corruptedBlocks.contains(blockId)) {
+ throwFetchFailedException(blockId, address, e)
+ } else {
+ logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
+ corruptedBlocks += blockId
+ fetchRequests += FetchRequest(address, Array((blockId, size)))
+ result = null
+ }
+ } finally {
+ // TODO: release the buf here to free memory earlier
+ originalInput.close()
+ in.close()
+ }
+ }
+
+ case FailureFetchResult(blockId, address, e) =>
+ throwFetchFailedException(blockId, address, e)
+ }
+
+ // Send fetch requests up to maxBytesInFlight
+ fetchUpToMaxBytes()
}
+
+ currentResult = result.asInstanceOf[SuccessFetchResult]
+ (currentResult.blockId, new BufferReleasingInputStream(input, this))
}
private def fetchUpToMaxBytes(): Unit = {
- // Send fetch requests up to maxBytesInFlight
- while (fetchRequests.nonEmpty &&
- (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
- sendRequest(fetchRequests.dequeue())
+ // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host
+ // immediately, defer the request until the next time it can be processed.
+
+ // Process any outstanding deferred fetch requests if possible.
+ if (deferredFetchRequests.nonEmpty) {
+ for ((remoteAddress, defReqQueue) <- deferredFetchRequests) {
+ while (isRemoteBlockFetchable(defReqQueue) &&
+ !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
+ val request = defReqQueue.dequeue()
+ logDebug(s"Processing deferred fetch request for $remoteAddress with "
+ + s"${request.blocks.length} blocks")
+ send(remoteAddress, request)
+ if (defReqQueue.isEmpty) {
+ deferredFetchRequests -= remoteAddress
+ }
+ }
+ }
+ }
+
+ // Process any regular fetch requests if possible.
+ while (isRemoteBlockFetchable(fetchRequests)) {
+ val request = fetchRequests.dequeue()
+ val remoteAddress = request.address
+ if (isRemoteAddressMaxedOut(remoteAddress, request)) {
+ logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks")
+ val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]())
+ defReqQueue.enqueue(request)
+ deferredFetchRequests(remoteAddress) = defReqQueue
+ } else {
+ send(remoteAddress, request)
+ }
+ }
+
+ def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = {
+ sendRequest(request)
+ numBlocksInFlightPerAddress(remoteAddress) =
+ numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size
+ }
+
+ def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = {
+ fetchReqQueue.nonEmpty &&
+ (bytesInFlight == 0 ||
+ (reqsInFlight + 1 <= maxReqsInFlight &&
+ bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight))
+ }
+
+ // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a
+ // given remote address.
+ def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = {
+ numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size >
+ maxBlocksInFlightPerAddress
}
}
@@ -329,7 +519,7 @@ final class ShuffleBlockFetcherIterator(
}
/**
- * Helper class that ensures a ManagedBuffer is release upon InputStream.close()
+ * Helper class that ensures a ManagedBuffer is released upon InputStream.close()
*/
private class BufferReleasingInputStream(
private val delegate: InputStream,
@@ -389,14 +579,15 @@ object ShuffleBlockFetcherIterator {
* @param address BlockManager that the block was fetched from.
* @param size estimated size of the block, used to calculate bytesInFlight.
* Note that this is NOT the exact bytes.
- * @param buf [[ManagedBuffer]] for the content.
+ * @param buf `ManagedBuffer` for the content.
+ * @param isNetworkReqDone Is this the last network request for this host in this fetch request.
*/
private[storage] case class SuccessFetchResult(
blockId: BlockId,
address: BlockManagerId,
size: Long,
- buf: ManagedBuffer)
- extends FetchResult {
+ buf: ManagedBuffer,
+ isNetworkReqDone: Boolean) extends FetchResult {
require(buf != null)
require(size >= 0)
}
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
index 703bce3e6b85..4c6998d7a8e2 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
@@ -21,6 +21,7 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
import java.util.concurrent.ConcurrentHashMap
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.memory.MemoryMode
import org.apache.spark.util.Utils
/**
@@ -30,7 +31,7 @@ import org.apache.spark.util.Utils
* ExternalBlockStore, whether to keep the data in memory in a serialized format, and whether
* to replicate the RDD partitions on multiple nodes.
*
- * The [[org.apache.spark.storage.StorageLevel$]] singleton object contains some static constants
+ * The [[org.apache.spark.storage.StorageLevel]] singleton object contains some static constants
* for commonly useful storage levels. To create your own storage level object, use the
* factory method of the singleton object (`StorageLevel(...)`).
*/
@@ -59,10 +60,12 @@ class StorageLevel private(
assert(replication < 40, "Replication restricted to be less than 40 for calculating hash codes")
if (useOffHeap) {
- require(!useDisk, "Off-heap storage level does not support using disk")
- require(!useMemory, "Off-heap storage level does not support using heap memory")
require(!deserialized, "Off-heap storage level does not support deserialized storage")
- require(replication == 1, "Off-heap storage level does not support multiple replication")
+ }
+
+ private[spark] def memoryMode: MemoryMode = {
+ if (useOffHeap) MemoryMode.OFF_HEAP
+ else MemoryMode.ON_HEAP
}
override def clone(): StorageLevel = {
@@ -80,7 +83,7 @@ class StorageLevel private(
false
}
- def isValid: Boolean = (useMemory || useDisk || useOffHeap) && (replication > 0)
+ def isValid: Boolean = (useMemory || useDisk) && (replication > 0)
def toInt: Int = {
var ret = 0
@@ -117,7 +120,14 @@ class StorageLevel private(
private def readResolve(): Object = StorageLevel.getCachedStorageLevel(this)
override def toString: String = {
- s"StorageLevel($useDisk, $useMemory, $useOffHeap, $deserialized, $replication)"
+ val disk = if (useDisk) "disk" else ""
+ val memory = if (useMemory) "memory" else ""
+ val heap = if (useOffHeap) "offheap" else ""
+ val deserialize = if (deserialized) "deserialized" else ""
+
+ val output =
+ Seq(disk, memory, heap, deserialize, s"$replication replicas").filter(_.nonEmpty)
+ s"StorageLevel(${output.mkString(", ")})"
}
override def hashCode(): Int = toInt * 41 + replication
@@ -125,8 +135,9 @@ class StorageLevel private(
def description: String = {
var result = ""
result += (if (useDisk) "Disk " else "")
- result += (if (useMemory) "Memory " else "")
- result += (if (useOffHeap) "ExternalBlockStore " else "")
+ if (useMemory) {
+ result += (if (useOffHeap) "Memory (off heap) " else "Memory ")
+ }
result += (if (deserialized) "Deserialized " else "Serialized ")
result += s"${replication}x Replicated"
result
@@ -150,7 +161,7 @@ object StorageLevel {
val MEMORY_AND_DISK_2 = new StorageLevel(true, true, false, true, 2)
val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, false)
val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, false, 2)
- val OFF_HEAP = new StorageLevel(false, false, true, false)
+ val OFF_HEAP = new StorageLevel(true, true, true, false, 1)
/**
* :: DeveloperApi ::
@@ -175,7 +186,7 @@ object StorageLevel {
/**
* :: DeveloperApi ::
- * Create a new StorageLevel object without setting useOffHeap.
+ * Create a new StorageLevel object.
*/
@DeveloperApi
def apply(
@@ -190,7 +201,7 @@ object StorageLevel {
/**
* :: DeveloperApi ::
- * Create a new StorageLevel object.
+ * Create a new StorageLevel object without setting useOffHeap.
*/
@DeveloperApi
def apply(
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
index ec711480ebf3..ac60f795915a 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
@@ -19,6 +19,7 @@ package org.apache.spark.storage
import scala.collection.mutable
+import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.scheduler._
@@ -29,14 +30,21 @@ import org.apache.spark.scheduler._
* This class is thread-safe (unlike JobProgressListener)
*/
@DeveloperApi
-class StorageStatusListener extends SparkListener {
+@deprecated("This class will be removed in a future release.", "2.2.0")
+class StorageStatusListener(conf: SparkConf) extends SparkListener {
// This maintains only blocks that are cached (i.e. storage level is not StorageLevel.NONE)
private[storage] val executorIdToStorageStatus = mutable.Map[String, StorageStatus]()
+ private[storage] val deadExecutorStorageStatus = new mutable.ListBuffer[StorageStatus]()
+ private[this] val retainedDeadExecutors = conf.getInt("spark.ui.retainedDeadExecutors", 100)
def storageStatusList: Seq[StorageStatus] = synchronized {
executorIdToStorageStatus.values.toSeq
}
+ def deadStorageStatusList: Seq[StorageStatus] = synchronized {
+ deadExecutorStorageStatus
+ }
+
/** Update storage status list to reflect updated block statuses */
private def updateStorageStatus(execId: String, updatedBlocks: Seq[(BlockId, BlockStatus)]) {
executorIdToStorageStatus.get(execId).foreach { storageStatus =>
@@ -59,17 +67,6 @@ class StorageStatusListener extends SparkListener {
}
}
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
- val info = taskEnd.taskInfo
- val metrics = taskEnd.taskMetrics
- if (info != null && metrics != null) {
- val updatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]())
- if (updatedBlocks.length > 0) {
- updateStorageStatus(info.executorId, updatedBlocks)
- }
- }
- }
-
override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = synchronized {
updateStorageStatus(unpersistRDD.rddId)
}
@@ -78,17 +75,37 @@ class StorageStatusListener extends SparkListener {
synchronized {
val blockManagerId = blockManagerAdded.blockManagerId
val executorId = blockManagerId.executorId
- val maxMem = blockManagerAdded.maxMem
- val storageStatus = new StorageStatus(blockManagerId, maxMem)
+ // The onHeap and offHeap memory are always defined for new applications,
+ // but they can be missing if we are replaying old event logs.
+ val storageStatus = new StorageStatus(blockManagerId, blockManagerAdded.maxMem,
+ blockManagerAdded.maxOnHeapMem, blockManagerAdded.maxOffHeapMem)
executorIdToStorageStatus(executorId) = storageStatus
+
+ // Try to remove the dead storage status if same executor register the block manager twice.
+ deadExecutorStorageStatus.zipWithIndex.find(_._1.blockManagerId.executorId == executorId)
+ .foreach(toRemoveExecutor => deadExecutorStorageStatus.remove(toRemoveExecutor._2))
}
}
override def onBlockManagerRemoved(blockManagerRemoved: SparkListenerBlockManagerRemoved) {
synchronized {
val executorId = blockManagerRemoved.blockManagerId.executorId
- executorIdToStorageStatus.remove(executorId)
+ executorIdToStorageStatus.remove(executorId).foreach { status =>
+ deadExecutorStorageStatus += status
+ }
+ if (deadExecutorStorageStatus.size > retainedDeadExecutors) {
+ deadExecutorStorageStatus.trimStart(1)
+ }
}
}
+ override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = {
+ val executorId = blockUpdated.blockUpdatedInfo.blockManagerId.executorId
+ val blockId = blockUpdated.blockUpdatedInfo.blockId
+ val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel
+ val memSize = blockUpdated.blockUpdatedInfo.memSize
+ val diskSize = blockUpdated.blockUpdatedInfo.diskSize
+ val blockStatus = BlockStatus(storageLevel, memSize, diskSize)
+ updateStorageStatus(executorId, Seq((blockId, blockStatus)))
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
index c4ac30092f80..e9694fdbca2d 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
@@ -17,10 +17,15 @@
package org.apache.spark.storage
+import java.nio.{ByteBuffer, MappedByteBuffer}
+
import scala.collection.Map
import scala.collection.mutable
+import sun.nio.ch.DirectBuffer
+
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.internal.Logging
/**
* :: DeveloperApi ::
@@ -30,7 +35,12 @@ import org.apache.spark.annotation.DeveloperApi
* class cannot mutate the source of the information. Accesses are not thread-safe.
*/
@DeveloperApi
-class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) {
+@deprecated("This class may be removed or made private in a future release.", "2.2.0")
+class StorageStatus(
+ val blockManagerId: BlockManagerId,
+ val maxMemory: Long,
+ val maxOnHeapMem: Option[Long],
+ val maxOffHeapMem: Option[Long]) {
/**
* Internal representation of the blocks stored in this block manager.
@@ -41,32 +51,28 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) {
private val _rddBlocks = new mutable.HashMap[Int, mutable.Map[BlockId, BlockStatus]]
private val _nonRddBlocks = new mutable.HashMap[BlockId, BlockStatus]
- /**
- * Storage information of the blocks that entails memory, disk, and off-heap memory usage.
- *
- * As with the block maps, we store the storage information separately for RDD blocks and
- * non-RDD blocks for the same reason. In particular, RDD storage information is stored
- * in a map indexed by the RDD ID to the following 4-tuple:
- *
- * (memory size, disk size, off-heap size, storage level)
- *
- * We assume that all the blocks that belong to the same RDD have the same storage level.
- * This field is not relevant to non-RDD blocks, however, so the storage information for
- * non-RDD blocks contains only the first 3 fields (in the same order).
- */
- private val _rddStorageInfo = new mutable.HashMap[Int, (Long, Long, Long, StorageLevel)]
- private var _nonRddStorageInfo: (Long, Long, Long) = (0L, 0L, 0L)
+ private case class RddStorageInfo(memoryUsage: Long, diskUsage: Long, level: StorageLevel)
+ private val _rddStorageInfo = new mutable.HashMap[Int, RddStorageInfo]
+
+ private case class NonRddStorageInfo(var onHeapUsage: Long, var offHeapUsage: Long,
+ var diskUsage: Long)
+ private val _nonRddStorageInfo = NonRddStorageInfo(0L, 0L, 0L)
/** Create a storage status with an initial set of blocks, leaving the source unmodified. */
- def this(bmid: BlockManagerId, maxMem: Long, initialBlocks: Map[BlockId, BlockStatus]) {
- this(bmid, maxMem)
+ def this(
+ bmid: BlockManagerId,
+ maxMemory: Long,
+ maxOnHeapMem: Option[Long],
+ maxOffHeapMem: Option[Long],
+ initialBlocks: Map[BlockId, BlockStatus]) {
+ this(bmid, maxMemory, maxOnHeapMem, maxOffHeapMem)
initialBlocks.foreach { case (bid, bstatus) => addBlock(bid, bstatus) }
}
/**
* Return the blocks stored in this block manager.
*
- * Note that this is somewhat expensive, as it involves cloning the underlying maps and then
+ * @note This is somewhat expensive, as it involves cloning the underlying maps and then
* concatenating them together. Much faster alternatives exist for common operations such as
* contains, get, and size.
*/
@@ -75,16 +81,14 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) {
/**
* Return the RDD blocks stored in this block manager.
*
- * Note that this is somewhat expensive, as it involves cloning the underlying maps and then
+ * @note This is somewhat expensive, as it involves cloning the underlying maps and then
* concatenating them together. Much faster alternatives exist for common operations such as
* getting the memory, disk, and off-heap memory sizes occupied by this RDD.
*/
def rddBlocks: Map[BlockId, BlockStatus] = _rddBlocks.flatMap { case (_, blocks) => blocks }
/** Return the blocks that belong to the given RDD stored in this block manager. */
- def rddBlocksById(rddId: Int): Map[BlockId, BlockStatus] = {
- _rddBlocks.get(rddId).getOrElse(Map.empty)
- }
+ def rddBlocksById(rddId: Int): Map[BlockId, BlockStatus] = _rddBlocks.getOrElse(rddId, Map.empty)
/** Add the given block to this storage status. If it already exists, overwrite it. */
private[spark] def addBlock(blockId: BlockId, blockStatus: BlockStatus): Unit = {
@@ -125,7 +129,8 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) {
/**
* Return whether the given block is stored in this block manager in O(1) time.
- * Note that this is much faster than `this.blocks.contains`, which is O(blocks) time.
+ *
+ * @note This is much faster than `this.blocks.contains`, which is O(blocks) time.
*/
def containsBlock(blockId: BlockId): Boolean = {
blockId match {
@@ -138,12 +143,13 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) {
/**
* Return the given block stored in this block manager in O(1) time.
- * Note that this is much faster than `this.blocks.get`, which is O(blocks) time.
+ *
+ * @note This is much faster than `this.blocks.get`, which is O(blocks) time.
*/
def getBlock(blockId: BlockId): Option[BlockStatus] = {
blockId match {
case RDDBlockId(rddId, _) =>
- _rddBlocks.get(rddId).map(_.get(blockId)).flatten
+ _rddBlocks.get(rddId).flatMap(_.get(blockId))
case _ =>
_nonRddBlocks.get(blockId)
}
@@ -151,46 +157,77 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) {
/**
* Return the number of blocks stored in this block manager in O(RDDs) time.
- * Note that this is much faster than `this.blocks.size`, which is O(blocks) time.
+ *
+ * @note This is much faster than `this.blocks.size`, which is O(blocks) time.
*/
def numBlocks: Int = _nonRddBlocks.size + numRddBlocks
/**
* Return the number of RDD blocks stored in this block manager in O(RDDs) time.
- * Note that this is much faster than `this.rddBlocks.size`, which is O(RDD blocks) time.
+ *
+ * @note This is much faster than `this.rddBlocks.size`, which is O(RDD blocks) time.
*/
def numRddBlocks: Int = _rddBlocks.values.map(_.size).sum
/**
* Return the number of blocks that belong to the given RDD in O(1) time.
- * Note that this is much faster than `this.rddBlocksById(rddId).size`, which is
+ *
+ * @note This is much faster than `this.rddBlocksById(rddId).size`, which is
* O(blocks in this RDD) time.
*/
def numRddBlocksById(rddId: Int): Int = _rddBlocks.get(rddId).map(_.size).getOrElse(0)
+ /** Return the max memory can be used by this block manager. */
+ def maxMem: Long = maxMemory
+
/** Return the memory remaining in this block manager. */
def memRemaining: Long = maxMem - memUsed
+ /** Return the memory used by caching RDDs */
+ def cacheSize: Long = onHeapCacheSize.getOrElse(0L) + offHeapCacheSize.getOrElse(0L)
+
/** Return the memory used by this block manager. */
- def memUsed: Long = _nonRddStorageInfo._1 + _rddBlocks.keys.toSeq.map(memUsedByRdd).sum
+ def memUsed: Long = onHeapMemUsed.getOrElse(0L) + offHeapMemUsed.getOrElse(0L)
- /** Return the disk space used by this block manager. */
- def diskUsed: Long = _nonRddStorageInfo._2 + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum
+ /** Return the on-heap memory remaining in this block manager. */
+ def onHeapMemRemaining: Option[Long] =
+ for (m <- maxOnHeapMem; o <- onHeapMemUsed) yield m - o
+
+ /** Return the off-heap memory remaining in this block manager. */
+ def offHeapMemRemaining: Option[Long] =
+ for (m <- maxOffHeapMem; o <- offHeapMemUsed) yield m - o
+
+ /** Return the on-heap memory used by this block manager. */
+ def onHeapMemUsed: Option[Long] = onHeapCacheSize.map(_ + _nonRddStorageInfo.onHeapUsage)
- /** Return the off-heap space used by this block manager. */
- def offHeapUsed: Long = _nonRddStorageInfo._3 + _rddBlocks.keys.toSeq.map(offHeapUsedByRdd).sum
+ /** Return the off-heap memory used by this block manager. */
+ def offHeapMemUsed: Option[Long] = offHeapCacheSize.map(_ + _nonRddStorageInfo.offHeapUsage)
+
+ /** Return the memory used by on-heap caching RDDs */
+ def onHeapCacheSize: Option[Long] = maxOnHeapMem.map { _ =>
+ _rddStorageInfo.collect {
+ case (_, storageInfo) if !storageInfo.level.useOffHeap => storageInfo.memoryUsage
+ }.sum
+ }
+
+ /** Return the memory used by off-heap caching RDDs */
+ def offHeapCacheSize: Option[Long] = maxOffHeapMem.map { _ =>
+ _rddStorageInfo.collect {
+ case (_, storageInfo) if storageInfo.level.useOffHeap => storageInfo.memoryUsage
+ }.sum
+ }
+
+ /** Return the disk space used by this block manager. */
+ def diskUsed: Long = _nonRddStorageInfo.diskUsage + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum
/** Return the memory used by the given RDD in this block manager in O(1) time. */
- def memUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._1).getOrElse(0L)
+ def memUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_.memoryUsage).getOrElse(0L)
/** Return the disk space used by the given RDD in this block manager in O(1) time. */
- def diskUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._2).getOrElse(0L)
-
- /** Return the off-heap space used by the given RDD in this block manager in O(1) time. */
- def offHeapUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._3).getOrElse(0L)
+ def diskUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_.diskUsage).getOrElse(0L)
/** Return the storage level, if any, used by the given RDD in this block manager. */
- def rddStorageLevel(rddId: Int): Option[StorageLevel] = _rddStorageInfo.get(rddId).map(_._4)
+ def rddStorageLevel(rddId: Int): Option[StorageLevel] = _rddStorageInfo.get(rddId).map(_.level)
/**
* Update the relevant storage info, taking into account any existing status for this block.
@@ -199,41 +236,65 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) {
val oldBlockStatus = getBlock(blockId).getOrElse(BlockStatus.empty)
val changeInMem = newBlockStatus.memSize - oldBlockStatus.memSize
val changeInDisk = newBlockStatus.diskSize - oldBlockStatus.diskSize
- val changeInExternalBlockStore =
- newBlockStatus.externalBlockStoreSize - oldBlockStatus.externalBlockStoreSize
val level = newBlockStatus.storageLevel
// Compute new info from old info
- val (oldMem, oldDisk, oldExternalBlockStore) = blockId match {
+ val (oldMem, oldDisk) = blockId match {
case RDDBlockId(rddId, _) =>
_rddStorageInfo.get(rddId)
- .map { case (mem, disk, externalBlockStore, _) => (mem, disk, externalBlockStore) }
- .getOrElse((0L, 0L, 0L))
- case _ =>
- _nonRddStorageInfo
+ .map { case RddStorageInfo(mem, disk, _) => (mem, disk) }
+ .getOrElse((0L, 0L))
+ case _ if !level.useOffHeap =>
+ (_nonRddStorageInfo.onHeapUsage, _nonRddStorageInfo.diskUsage)
+ case _ if level.useOffHeap =>
+ (_nonRddStorageInfo.offHeapUsage, _nonRddStorageInfo.diskUsage)
}
val newMem = math.max(oldMem + changeInMem, 0L)
val newDisk = math.max(oldDisk + changeInDisk, 0L)
- val newExternalBlockStore = math.max(oldExternalBlockStore + changeInExternalBlockStore, 0L)
// Set the correct info
blockId match {
case RDDBlockId(rddId, _) =>
// If this RDD is no longer persisted, remove it
- if (newMem + newDisk + newExternalBlockStore == 0) {
+ if (newMem + newDisk == 0) {
_rddStorageInfo.remove(rddId)
} else {
- _rddStorageInfo(rddId) = (newMem, newDisk, newExternalBlockStore, level)
+ _rddStorageInfo(rddId) = RddStorageInfo(newMem, newDisk, level)
}
case _ =>
- _nonRddStorageInfo = (newMem, newDisk, newExternalBlockStore)
+ if (!level.useOffHeap) {
+ _nonRddStorageInfo.onHeapUsage = newMem
+ } else {
+ _nonRddStorageInfo.offHeapUsage = newMem
+ }
+ _nonRddStorageInfo.diskUsage = newDisk
}
}
-
}
/** Helper methods for storage-related objects. */
-private[spark] object StorageUtils {
+private[spark] object StorageUtils extends Logging {
+ /**
+ * Attempt to clean up a ByteBuffer if it is direct or memory-mapped. This uses an *unsafe* Sun
+ * API that will cause errors if one attempts to read from the disposed buffer. However, neither
+ * the bytes allocated to direct buffers nor file descriptors opened for memory-mapped buffers put
+ * pressure on the garbage collector. Waiting for garbage collection may lead to the depletion of
+ * off-heap memory or huge numbers of open files. There's unfortunately no standard API to
+ * manually dispose of these kinds of buffers.
+ */
+ def dispose(buffer: ByteBuffer): Unit = {
+ if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) {
+ logTrace(s"Disposing of $buffer")
+ cleanDirectBuffer(buffer.asInstanceOf[DirectBuffer])
+ }
+ }
+
+ private def cleanDirectBuffer(buffer: DirectBuffer) = {
+ val cleaner = buffer.cleaner()
+ if (cleaner != null) {
+ cleaner.clean()
+ }
+ }
/**
* Update the given list of RDDInfo with the given list of storage statuses.
@@ -248,13 +309,11 @@ private[spark] object StorageUtils {
val numCachedPartitions = statuses.map(_.numRddBlocksById(rddId)).sum
val memSize = statuses.map(_.memUsedByRdd(rddId)).sum
val diskSize = statuses.map(_.diskUsedByRdd(rddId)).sum
- val externalBlockStoreSize = statuses.map(_.offHeapUsedByRdd(rddId)).sum
rddInfo.storageLevel = storageLevel
rddInfo.numCachedPartitions = numCachedPartitions
rddInfo.memSize = memSize
rddInfo.diskSize = diskSize
- rddInfo.externalBlockStoreSize = externalBlockStoreSize
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
deleted file mode 100644
index 22878783fca6..000000000000
--- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
+++ /dev/null
@@ -1,253 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.io.IOException
-import java.nio.ByteBuffer
-import java.text.SimpleDateFormat
-import java.util.{Date, Random}
-
-import scala.util.control.NonFatal
-
-import com.google.common.io.ByteStreams
-
-import tachyon.client.{ReadType, WriteType, TachyonFS, TachyonFile}
-import tachyon.conf.TachyonConf
-import tachyon.TachyonURI
-
-import org.apache.spark.Logging
-import org.apache.spark.executor.ExecutorExitCode
-import org.apache.spark.util.{ShutdownHookManager, Utils}
-
-
-/**
- * Creates and maintains the logical mapping between logical blocks and tachyon fs locations. By
- * default, one block is mapped to one file with a name given by its BlockId.
- *
- */
-private[spark] class TachyonBlockManager() extends ExternalBlockManager with Logging {
-
- var rootDirs: String = _
- var master: String = _
- var client: tachyon.client.TachyonFS = _
- private var subDirsPerTachyonDir: Int = _
-
- // Create one Tachyon directory for each path mentioned in spark.tachyonStore.folderName;
- // then, inside this directory, create multiple subdirectories that we will hash files into,
- // in order to avoid having really large inodes at the top level in Tachyon.
- private var tachyonDirs: Array[TachyonFile] = _
- private var subDirs: Array[Array[tachyon.client.TachyonFile]] = _
-
-
- override def init(blockManager: BlockManager, executorId: String): Unit = {
- super.init(blockManager, executorId)
- val storeDir = blockManager.conf.get(ExternalBlockStore.BASE_DIR, "/tmp_spark_tachyon")
- val appFolderName = blockManager.conf.get(ExternalBlockStore.FOLD_NAME)
-
- rootDirs = s"$storeDir/$appFolderName/$executorId"
- master = blockManager.conf.get(ExternalBlockStore.MASTER_URL, "tachyon://localhost:19998")
- client = if (master != null && master != "") {
- TachyonFS.get(new TachyonURI(master), new TachyonConf())
- } else {
- null
- }
- // original implementation call System.exit, we change it to run without extblkstore support
- if (client == null) {
- logError("Failed to connect to the Tachyon as the master address is not configured")
- throw new IOException("Failed to connect to the Tachyon as the master " +
- "address is not configured")
- }
- subDirsPerTachyonDir = blockManager.conf.get("spark.externalBlockStore.subDirectories",
- ExternalBlockStore.SUB_DIRS_PER_DIR).toInt
-
- // Create one Tachyon directory for each path mentioned in spark.tachyonStore.folderName;
- // then, inside this directory, create multiple subdirectories that we will hash files into,
- // in order to avoid having really large inodes at the top level in Tachyon.
- tachyonDirs = createTachyonDirs()
- subDirs = Array.fill(tachyonDirs.length)(new Array[TachyonFile](subDirsPerTachyonDir))
- tachyonDirs.foreach(tachyonDir => ShutdownHookManager.registerShutdownDeleteDir(tachyonDir))
- }
-
- override def toString: String = {"ExternalBlockStore-Tachyon"}
-
- override def removeBlock(blockId: BlockId): Boolean = {
- val file = getFile(blockId)
- if (fileExists(file)) {
- removeFile(file)
- } else {
- false
- }
- }
-
- override def blockExists(blockId: BlockId): Boolean = {
- val file = getFile(blockId)
- fileExists(file)
- }
-
- override def putBytes(blockId: BlockId, bytes: ByteBuffer): Unit = {
- val file = getFile(blockId)
- val os = file.getOutStream(WriteType.TRY_CACHE)
- try {
- os.write(bytes.array())
- } catch {
- case NonFatal(e) =>
- logWarning(s"Failed to put bytes of block $blockId into Tachyon", e)
- os.cancel()
- } finally {
- os.close()
- }
- }
-
- override def putValues(blockId: BlockId, values: Iterator[_]): Unit = {
- val file = getFile(blockId)
- val os = file.getOutStream(WriteType.TRY_CACHE)
- try {
- blockManager.dataSerializeStream(blockId, os, values)
- } catch {
- case NonFatal(e) =>
- logWarning(s"Failed to put values of block $blockId into Tachyon", e)
- os.cancel()
- } finally {
- os.close()
- }
- }
-
- override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
- val file = getFile(blockId)
- if (file == null || file.getLocationHosts.size == 0) {
- return None
- }
- val is = file.getInStream(ReadType.CACHE)
- try {
- val size = file.length
- val bs = new Array[Byte](size.asInstanceOf[Int])
- ByteStreams.readFully(is, bs)
- Some(ByteBuffer.wrap(bs))
- } catch {
- case NonFatal(e) =>
- logWarning(s"Failed to get bytes of block $blockId from Tachyon", e)
- None
- } finally {
- is.close()
- }
- }
-
- override def getValues(blockId: BlockId): Option[Iterator[_]] = {
- val file = getFile(blockId)
- if (file == null || file.getLocationHosts().size() == 0) {
- return None
- }
- val is = file.getInStream(ReadType.CACHE)
- Option(is).map { is =>
- blockManager.dataDeserializeStream(blockId, is)
- }
- }
-
- override def getSize(blockId: BlockId): Long = {
- getFile(blockId.name).length
- }
-
- def removeFile(file: TachyonFile): Boolean = {
- client.delete(new TachyonURI(file.getPath()), false)
- }
-
- def fileExists(file: TachyonFile): Boolean = {
- client.exist(new TachyonURI(file.getPath()))
- }
-
- def getFile(filename: String): TachyonFile = {
- // Figure out which tachyon directory it hashes to, and which subdirectory in that
- val hash = Utils.nonNegativeHash(filename)
- val dirId = hash % tachyonDirs.length
- val subDirId = (hash / tachyonDirs.length) % subDirsPerTachyonDir
-
- // Create the subdirectory if it doesn't already exist
- var subDir = subDirs(dirId)(subDirId)
- if (subDir == null) {
- subDir = subDirs(dirId).synchronized {
- val old = subDirs(dirId)(subDirId)
- if (old != null) {
- old
- } else {
- val path = new TachyonURI(s"${tachyonDirs(dirId)}/${"%02x".format(subDirId)}")
- client.mkdir(path)
- val newDir = client.getFile(path)
- subDirs(dirId)(subDirId) = newDir
- newDir
- }
- }
- }
- val filePath = new TachyonURI(s"$subDir/$filename")
- if(!client.exist(filePath)) {
- client.createFile(filePath)
- }
- val file = client.getFile(filePath)
- file
- }
-
- def getFile(blockId: BlockId): TachyonFile = getFile(blockId.name)
-
- // TODO: Some of the logic here could be consolidated/de-duplicated with that in the DiskStore.
- private def createTachyonDirs(): Array[TachyonFile] = {
- logDebug("Creating tachyon directories at root dirs '" + rootDirs + "'")
- val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
- rootDirs.split(",").map { rootDir =>
- var foundLocalDir = false
- var tachyonDir: TachyonFile = null
- var tachyonDirId: String = null
- var tries = 0
- val rand = new Random()
- while (!foundLocalDir && tries < ExternalBlockStore.MAX_DIR_CREATION_ATTEMPTS) {
- tries += 1
- try {
- tachyonDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
- val path = new TachyonURI(s"$rootDir/spark-tachyon-$tachyonDirId")
- if (!client.exist(path)) {
- foundLocalDir = client.mkdir(path)
- tachyonDir = client.getFile(path)
- }
- } catch {
- case NonFatal(e) =>
- logWarning("Attempt " + tries + " to create tachyon dir " + tachyonDir + " failed", e)
- }
- }
- if (!foundLocalDir) {
- logError("Failed " + ExternalBlockStore.MAX_DIR_CREATION_ATTEMPTS
- + " attempts to create tachyon dir in " + rootDir)
- System.exit(ExecutorExitCode.EXTERNAL_BLOCK_STORE_FAILED_TO_CREATE_DIR)
- }
- logInfo("Created tachyon directory at " + tachyonDir)
- tachyonDir
- }
- }
-
- override def shutdown() {
- logDebug("Shutdown hook called")
- tachyonDirs.foreach { tachyonDir =>
- try {
- if (!ShutdownHookManager.hasRootAsShutdownDeleteDir(tachyonDir)) {
- Utils.deleteRecursively(tachyonDir, client)
- }
- } catch {
- case NonFatal(e) =>
- logError("Exception while deleting tachyon spark dir: " + tachyonDir, e)
- }
- }
- client.close()
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/storage/TopologyMapper.scala b/core/src/main/scala/org/apache/spark/storage/TopologyMapper.scala
new file mode 100644
index 000000000000..a150a8e3636e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/TopologyMapper.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import org.apache.spark.SparkConf
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.Utils
+
+/**
+ * ::DeveloperApi::
+ * TopologyMapper provides topology information for a given host
+ * @param conf SparkConf to get required properties, if needed
+ */
+@DeveloperApi
+abstract class TopologyMapper(conf: SparkConf) {
+ /**
+ * Gets the topology information given the host name
+ *
+ * @param hostname Hostname
+ * @return topology information for the given hostname. One can use a 'topology delimiter'
+ * to make this topology information nested.
+ * For example : ‘/myrack/myhost’, where ‘/’ is the topology delimiter,
+ * ‘myrack’ is the topology identifier, and ‘myhost’ is the individual host.
+ * This function only returns the topology information without the hostname.
+ * This information can be used when choosing executors for block replication
+ * to discern executors from a different rack than a candidate executor, for example.
+ *
+ * An implementation can choose to use empty strings or None in case topology info
+ * is not available. This would imply that all such executors belong to the same rack.
+ */
+ def getTopologyForHost(hostname: String): Option[String]
+}
+
+/**
+ * A TopologyMapper that assumes all nodes are in the same rack
+ */
+@DeveloperApi
+class DefaultTopologyMapper(conf: SparkConf) extends TopologyMapper(conf) with Logging {
+ override def getTopologyForHost(hostname: String): Option[String] = {
+ logDebug(s"Got a request for $hostname")
+ None
+ }
+}
+
+/**
+ * A simple file based topology mapper. This expects topology information provided as a
+ * `java.util.Properties` file. The name of the file is obtained from SparkConf property
+ * `spark.storage.replication.topologyFile`. To use this topology mapper, set the
+ * `spark.storage.replication.topologyMapper` property to
+ * [[org.apache.spark.storage.FileBasedTopologyMapper]]
+ * @param conf SparkConf object
+ */
+@DeveloperApi
+class FileBasedTopologyMapper(conf: SparkConf) extends TopologyMapper(conf) with Logging {
+ val topologyFile = conf.getOption("spark.storage.replication.topologyFile")
+ require(topologyFile.isDefined, "Please specify topology file via " +
+ "spark.storage.replication.topologyFile for FileBasedTopologyMapper.")
+ val topologyMap = Utils.getPropertiesFromFile(topologyFile.get)
+
+ override def getTopologyForHost(hostname: String): Option[String] = {
+ val topology = topologyMap.get(hostname)
+ if (topology.isDefined) {
+ logDebug(s"$hostname -> ${topology.get}")
+ } else {
+ logWarning(s"$hostname does not have any topology information")
+ }
+ topology
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
new file mode 100644
index 000000000000..1b8b4db2e45d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
@@ -0,0 +1,879 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage.memory
+
+import java.io.OutputStream
+import java.nio.ByteBuffer
+import java.util.LinkedHashMap
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
+import com.google.common.io.ByteStreams
+
+import org.apache.spark.{SparkConf, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.memory.{MemoryManager, MemoryMode}
+import org.apache.spark.serializer.{SerializationStream, SerializerManager}
+import org.apache.spark.storage._
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.util.{SizeEstimator, Utils}
+import org.apache.spark.util.collection.SizeTrackingVector
+import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
+
+private sealed trait MemoryEntry[T] {
+ def size: Long
+ def memoryMode: MemoryMode
+ def classTag: ClassTag[T]
+}
+private case class DeserializedMemoryEntry[T](
+ value: Array[T],
+ size: Long,
+ classTag: ClassTag[T]) extends MemoryEntry[T] {
+ val memoryMode: MemoryMode = MemoryMode.ON_HEAP
+}
+private case class SerializedMemoryEntry[T](
+ buffer: ChunkedByteBuffer,
+ memoryMode: MemoryMode,
+ classTag: ClassTag[T]) extends MemoryEntry[T] {
+ def size: Long = buffer.size
+}
+
+private[storage] trait BlockEvictionHandler {
+ /**
+ * Drop a block from memory, possibly putting it on disk if applicable. Called when the memory
+ * store reaches its limit and needs to free up space.
+ *
+ * If `data` is not put on disk, it won't be created.
+ *
+ * The caller of this method must hold a write lock on the block before calling this method.
+ * This method does not release the write lock.
+ *
+ * @return the block's new effective StorageLevel.
+ */
+ private[storage] def dropFromMemory[T: ClassTag](
+ blockId: BlockId,
+ data: () => Either[Array[T], ChunkedByteBuffer]): StorageLevel
+}
+
+/**
+ * Stores blocks in memory, either as Arrays of deserialized Java objects or as
+ * serialized ByteBuffers.
+ */
+private[spark] class MemoryStore(
+ conf: SparkConf,
+ blockInfoManager: BlockInfoManager,
+ serializerManager: SerializerManager,
+ memoryManager: MemoryManager,
+ blockEvictionHandler: BlockEvictionHandler)
+ extends Logging {
+
+ // Note: all changes to memory allocations, notably putting blocks, evicting blocks, and
+ // acquiring or releasing unroll memory, must be synchronized on `memoryManager`!
+
+ private val entries = new LinkedHashMap[BlockId, MemoryEntry[_]](32, 0.75f, true)
+
+ // A mapping from taskAttemptId to amount of memory used for unrolling a block (in bytes)
+ // All accesses of this map are assumed to have manually synchronized on `memoryManager`
+ private val onHeapUnrollMemoryMap = mutable.HashMap[Long, Long]()
+ // Note: off-heap unroll memory is only used in putIteratorAsBytes() because off-heap caching
+ // always stores serialized values.
+ private val offHeapUnrollMemoryMap = mutable.HashMap[Long, Long]()
+
+ // Initial memory to request before unrolling any block
+ private val unrollMemoryThreshold: Long =
+ conf.getLong("spark.storage.unrollMemoryThreshold", 1024 * 1024)
+
+ /** Total amount of memory available for storage, in bytes. */
+ private def maxMemory: Long = {
+ memoryManager.maxOnHeapStorageMemory + memoryManager.maxOffHeapStorageMemory
+ }
+
+ if (maxMemory < unrollMemoryThreshold) {
+ logWarning(s"Max memory ${Utils.bytesToString(maxMemory)} is less than the initial memory " +
+ s"threshold ${Utils.bytesToString(unrollMemoryThreshold)} needed to store a block in " +
+ s"memory. Please configure Spark with more memory.")
+ }
+
+ logInfo("MemoryStore started with capacity %s".format(Utils.bytesToString(maxMemory)))
+
+ /** Total storage memory used including unroll memory, in bytes. */
+ private def memoryUsed: Long = memoryManager.storageMemoryUsed
+
+ /**
+ * Amount of storage memory, in bytes, used for caching blocks.
+ * This does not include memory used for unrolling.
+ */
+ private def blocksMemoryUsed: Long = memoryManager.synchronized {
+ memoryUsed - currentUnrollMemory
+ }
+
+ def getSize(blockId: BlockId): Long = {
+ entries.synchronized {
+ entries.get(blockId).size
+ }
+ }
+
+ /**
+ * Use `size` to test if there is enough space in MemoryStore. If so, create the ByteBuffer and
+ * put it into MemoryStore. Otherwise, the ByteBuffer won't be created.
+ *
+ * The caller should guarantee that `size` is correct.
+ *
+ * @return true if the put() succeeded, false otherwise.
+ */
+ def putBytes[T: ClassTag](
+ blockId: BlockId,
+ size: Long,
+ memoryMode: MemoryMode,
+ _bytes: () => ChunkedByteBuffer): Boolean = {
+ require(!contains(blockId), s"Block $blockId is already present in the MemoryStore")
+ if (memoryManager.acquireStorageMemory(blockId, size, memoryMode)) {
+ // We acquired enough memory for the block, so go ahead and put it
+ val bytes = _bytes()
+ assert(bytes.size == size)
+ val entry = new SerializedMemoryEntry[T](bytes, memoryMode, implicitly[ClassTag[T]])
+ entries.synchronized {
+ entries.put(blockId, entry)
+ }
+ logInfo("Block %s stored as bytes in memory (estimated size %s, free %s)".format(
+ blockId, Utils.bytesToString(size), Utils.bytesToString(maxMemory - blocksMemoryUsed)))
+ true
+ } else {
+ false
+ }
+ }
+
+ /**
+ * Attempt to put the given block in memory store as values.
+ *
+ * It's possible that the iterator is too large to materialize and store in memory. To avoid
+ * OOM exceptions, this method will gradually unroll the iterator while periodically checking
+ * whether there is enough free memory. If the block is successfully materialized, then the
+ * temporary unroll memory used during the materialization is "transferred" to storage memory,
+ * so we won't acquire more memory than is actually needed to store the block.
+ *
+ * @return in case of success, the estimated size of the stored data. In case of failure, return
+ * an iterator containing the values of the block. The returned iterator will be backed
+ * by the combination of the partially-unrolled block and the remaining elements of the
+ * original input iterator. The caller must either fully consume this iterator or call
+ * `close()` on it in order to free the storage memory consumed by the partially-unrolled
+ * block.
+ */
+ private[storage] def putIteratorAsValues[T](
+ blockId: BlockId,
+ values: Iterator[T],
+ classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = {
+
+ require(!contains(blockId), s"Block $blockId is already present in the MemoryStore")
+
+ // Number of elements unrolled so far
+ var elementsUnrolled = 0
+ // Whether there is still enough memory for us to continue unrolling this block
+ var keepUnrolling = true
+ // Initial per-task memory to request for unrolling blocks (bytes).
+ val initialMemoryThreshold = unrollMemoryThreshold
+ // How often to check whether we need to request more memory
+ val memoryCheckPeriod = 16
+ // Memory currently reserved by this task for this particular unrolling operation
+ var memoryThreshold = initialMemoryThreshold
+ // Memory to request as a multiple of current vector size
+ val memoryGrowthFactor = 1.5
+ // Keep track of unroll memory used by this particular block / putIterator() operation
+ var unrollMemoryUsedByThisBlock = 0L
+ // Underlying vector for unrolling the block
+ var vector = new SizeTrackingVector[T]()(classTag)
+
+ // Request enough memory to begin unrolling
+ keepUnrolling =
+ reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, MemoryMode.ON_HEAP)
+
+ if (!keepUnrolling) {
+ logWarning(s"Failed to reserve initial memory threshold of " +
+ s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.")
+ } else {
+ unrollMemoryUsedByThisBlock += initialMemoryThreshold
+ }
+
+ // Unroll this block safely, checking whether we have exceeded our threshold periodically
+ while (values.hasNext && keepUnrolling) {
+ vector += values.next()
+ if (elementsUnrolled % memoryCheckPeriod == 0) {
+ // If our vector's size has exceeded the threshold, request more memory
+ val currentSize = vector.estimateSize()
+ if (currentSize >= memoryThreshold) {
+ val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong
+ keepUnrolling =
+ reserveUnrollMemoryForThisTask(blockId, amountToRequest, MemoryMode.ON_HEAP)
+ if (keepUnrolling) {
+ unrollMemoryUsedByThisBlock += amountToRequest
+ }
+ // New threshold is currentSize * memoryGrowthFactor
+ memoryThreshold += amountToRequest
+ }
+ }
+ elementsUnrolled += 1
+ }
+
+ if (keepUnrolling) {
+ // We successfully unrolled the entirety of this block
+ val arrayValues = vector.toArray
+ vector = null
+ val entry =
+ new DeserializedMemoryEntry[T](arrayValues, SizeEstimator.estimate(arrayValues), classTag)
+ val size = entry.size
+ def transferUnrollToStorage(amount: Long): Unit = {
+ // Synchronize so that transfer is atomic
+ memoryManager.synchronized {
+ releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, amount)
+ val success = memoryManager.acquireStorageMemory(blockId, amount, MemoryMode.ON_HEAP)
+ assert(success, "transferring unroll memory to storage memory failed")
+ }
+ }
+ // Acquire storage memory if necessary to store this block in memory.
+ val enoughStorageMemory = {
+ if (unrollMemoryUsedByThisBlock <= size) {
+ val acquiredExtra =
+ memoryManager.acquireStorageMemory(
+ blockId, size - unrollMemoryUsedByThisBlock, MemoryMode.ON_HEAP)
+ if (acquiredExtra) {
+ transferUnrollToStorage(unrollMemoryUsedByThisBlock)
+ }
+ acquiredExtra
+ } else { // unrollMemoryUsedByThisBlock > size
+ // If this task attempt already owns more unroll memory than is necessary to store the
+ // block, then release the extra memory that will not be used.
+ val excessUnrollMemory = unrollMemoryUsedByThisBlock - size
+ releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, excessUnrollMemory)
+ transferUnrollToStorage(size)
+ true
+ }
+ }
+ if (enoughStorageMemory) {
+ entries.synchronized {
+ entries.put(blockId, entry)
+ }
+ logInfo("Block %s stored as values in memory (estimated size %s, free %s)".format(
+ blockId, Utils.bytesToString(size), Utils.bytesToString(maxMemory - blocksMemoryUsed)))
+ Right(size)
+ } else {
+ assert(currentUnrollMemoryForThisTask >= unrollMemoryUsedByThisBlock,
+ "released too much unroll memory")
+ Left(new PartiallyUnrolledIterator(
+ this,
+ MemoryMode.ON_HEAP,
+ unrollMemoryUsedByThisBlock,
+ unrolled = arrayValues.toIterator,
+ rest = Iterator.empty))
+ }
+ } else {
+ // We ran out of space while unrolling the values for this block
+ logUnrollFailureMessage(blockId, vector.estimateSize())
+ Left(new PartiallyUnrolledIterator(
+ this,
+ MemoryMode.ON_HEAP,
+ unrollMemoryUsedByThisBlock,
+ unrolled = vector.iterator,
+ rest = values))
+ }
+ }
+
+ /**
+ * Attempt to put the given block in memory store as bytes.
+ *
+ * It's possible that the iterator is too large to materialize and store in memory. To avoid
+ * OOM exceptions, this method will gradually unroll the iterator while periodically checking
+ * whether there is enough free memory. If the block is successfully materialized, then the
+ * temporary unroll memory used during the materialization is "transferred" to storage memory,
+ * so we won't acquire more memory than is actually needed to store the block.
+ *
+ * @return in case of success, the estimated size of the stored data. In case of failure,
+ * return a handle which allows the caller to either finish the serialization by
+ * spilling to disk or to deserialize the partially-serialized block and reconstruct
+ * the original input iterator. The caller must either fully consume this result
+ * iterator or call `discard()` on it in order to free the storage memory consumed by the
+ * partially-unrolled block.
+ */
+ private[storage] def putIteratorAsBytes[T](
+ blockId: BlockId,
+ values: Iterator[T],
+ classTag: ClassTag[T],
+ memoryMode: MemoryMode): Either[PartiallySerializedBlock[T], Long] = {
+
+ require(!contains(blockId), s"Block $blockId is already present in the MemoryStore")
+
+ val allocator = memoryMode match {
+ case MemoryMode.ON_HEAP => ByteBuffer.allocate _
+ case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _
+ }
+
+ // Whether there is still enough memory for us to continue unrolling this block
+ var keepUnrolling = true
+ // Initial per-task memory to request for unrolling blocks (bytes).
+ val initialMemoryThreshold = unrollMemoryThreshold
+ // Keep track of unroll memory used by this particular block / putIterator() operation
+ var unrollMemoryUsedByThisBlock = 0L
+ // Underlying buffer for unrolling the block
+ val redirectableStream = new RedirectableOutputStream
+ val chunkSize = if (initialMemoryThreshold > Int.MaxValue) {
+ logWarning(s"Initial memory threshold of ${Utils.bytesToString(initialMemoryThreshold)} " +
+ s"is too large to be set as chunk size. Chunk size has been capped to " +
+ s"${Utils.bytesToString(Int.MaxValue)}")
+ Int.MaxValue
+ } else {
+ initialMemoryThreshold.toInt
+ }
+ val bbos = new ChunkedByteBufferOutputStream(chunkSize, allocator)
+ redirectableStream.setOutputStream(bbos)
+ val serializationStream: SerializationStream = {
+ val autoPick = !blockId.isInstanceOf[StreamBlockId]
+ val ser = serializerManager.getSerializer(classTag, autoPick).newInstance()
+ ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream))
+ }
+
+ // Request enough memory to begin unrolling
+ keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, memoryMode)
+
+ if (!keepUnrolling) {
+ logWarning(s"Failed to reserve initial memory threshold of " +
+ s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.")
+ } else {
+ unrollMemoryUsedByThisBlock += initialMemoryThreshold
+ }
+
+ def reserveAdditionalMemoryIfNecessary(): Unit = {
+ if (bbos.size > unrollMemoryUsedByThisBlock) {
+ val amountToRequest = bbos.size - unrollMemoryUsedByThisBlock
+ keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode)
+ if (keepUnrolling) {
+ unrollMemoryUsedByThisBlock += amountToRequest
+ }
+ }
+ }
+
+ // Unroll this block safely, checking whether we have exceeded our threshold
+ while (values.hasNext && keepUnrolling) {
+ serializationStream.writeObject(values.next())(classTag)
+ reserveAdditionalMemoryIfNecessary()
+ }
+
+ // Make sure that we have enough memory to store the block. By this point, it is possible that
+ // the block's actual memory usage has exceeded the unroll memory by a small amount, so we
+ // perform one final call to attempt to allocate additional memory if necessary.
+ if (keepUnrolling) {
+ serializationStream.close()
+ reserveAdditionalMemoryIfNecessary()
+ }
+
+ if (keepUnrolling) {
+ val entry = SerializedMemoryEntry[T](bbos.toChunkedByteBuffer, memoryMode, classTag)
+ // Synchronize so that transfer is atomic
+ memoryManager.synchronized {
+ releaseUnrollMemoryForThisTask(memoryMode, unrollMemoryUsedByThisBlock)
+ val success = memoryManager.acquireStorageMemory(blockId, entry.size, memoryMode)
+ assert(success, "transferring unroll memory to storage memory failed")
+ }
+ entries.synchronized {
+ entries.put(blockId, entry)
+ }
+ logInfo("Block %s stored as bytes in memory (estimated size %s, free %s)".format(
+ blockId, Utils.bytesToString(entry.size),
+ Utils.bytesToString(maxMemory - blocksMemoryUsed)))
+ Right(entry.size)
+ } else {
+ // We ran out of space while unrolling the values for this block
+ logUnrollFailureMessage(blockId, bbos.size)
+ Left(
+ new PartiallySerializedBlock(
+ this,
+ serializerManager,
+ blockId,
+ serializationStream,
+ redirectableStream,
+ unrollMemoryUsedByThisBlock,
+ memoryMode,
+ bbos,
+ values,
+ classTag))
+ }
+ }
+
+ def getBytes(blockId: BlockId): Option[ChunkedByteBuffer] = {
+ val entry = entries.synchronized { entries.get(blockId) }
+ entry match {
+ case null => None
+ case e: DeserializedMemoryEntry[_] =>
+ throw new IllegalArgumentException("should only call getBytes on serialized blocks")
+ case SerializedMemoryEntry(bytes, _, _) => Some(bytes)
+ }
+ }
+
+ def getValues(blockId: BlockId): Option[Iterator[_]] = {
+ val entry = entries.synchronized { entries.get(blockId) }
+ entry match {
+ case null => None
+ case e: SerializedMemoryEntry[_] =>
+ throw new IllegalArgumentException("should only call getValues on deserialized blocks")
+ case DeserializedMemoryEntry(values, _, _) =>
+ val x = Some(values)
+ x.map(_.iterator)
+ }
+ }
+
+ def remove(blockId: BlockId): Boolean = memoryManager.synchronized {
+ val entry = entries.synchronized {
+ entries.remove(blockId)
+ }
+ if (entry != null) {
+ entry match {
+ case SerializedMemoryEntry(buffer, _, _) => buffer.dispose()
+ case _ =>
+ }
+ memoryManager.releaseStorageMemory(entry.size, entry.memoryMode)
+ logDebug(s"Block $blockId of size ${entry.size} dropped " +
+ s"from memory (free ${maxMemory - blocksMemoryUsed})")
+ true
+ } else {
+ false
+ }
+ }
+
+ def clear(): Unit = memoryManager.synchronized {
+ entries.synchronized {
+ entries.clear()
+ }
+ onHeapUnrollMemoryMap.clear()
+ offHeapUnrollMemoryMap.clear()
+ memoryManager.releaseAllStorageMemory()
+ logInfo("MemoryStore cleared")
+ }
+
+ /**
+ * Return the RDD ID that a given block ID is from, or None if it is not an RDD block.
+ */
+ private def getRddId(blockId: BlockId): Option[Int] = {
+ blockId.asRDDId.map(_.rddId)
+ }
+
+ /**
+ * Try to evict blocks to free up a given amount of space to store a particular block.
+ * Can fail if either the block is bigger than our memory or it would require replacing
+ * another block from the same RDD (which leads to a wasteful cyclic replacement pattern for
+ * RDDs that don't fit into memory that we want to avoid).
+ *
+ * @param blockId the ID of the block we are freeing space for, if any
+ * @param space the size of this block
+ * @param memoryMode the type of memory to free (on- or off-heap)
+ * @return the amount of memory (in bytes) freed by eviction
+ */
+ private[spark] def evictBlocksToFreeSpace(
+ blockId: Option[BlockId],
+ space: Long,
+ memoryMode: MemoryMode): Long = {
+ assert(space > 0)
+ memoryManager.synchronized {
+ var freedMemory = 0L
+ val rddToAdd = blockId.flatMap(getRddId)
+ val selectedBlocks = new ArrayBuffer[BlockId]
+ def blockIsEvictable(blockId: BlockId, entry: MemoryEntry[_]): Boolean = {
+ entry.memoryMode == memoryMode && (rddToAdd.isEmpty || rddToAdd != getRddId(blockId))
+ }
+ // This is synchronized to ensure that the set of entries is not changed
+ // (because of getValue or getBytes) while traversing the iterator, as that
+ // can lead to exceptions.
+ entries.synchronized {
+ val iterator = entries.entrySet().iterator()
+ while (freedMemory < space && iterator.hasNext) {
+ val pair = iterator.next()
+ val blockId = pair.getKey
+ val entry = pair.getValue
+ if (blockIsEvictable(blockId, entry)) {
+ // We don't want to evict blocks which are currently being read, so we need to obtain
+ // an exclusive write lock on blocks which are candidates for eviction. We perform a
+ // non-blocking "tryLock" here in order to ignore blocks which are locked for reading:
+ if (blockInfoManager.lockForWriting(blockId, blocking = false).isDefined) {
+ selectedBlocks += blockId
+ freedMemory += pair.getValue.size
+ }
+ }
+ }
+ }
+
+ def dropBlock[T](blockId: BlockId, entry: MemoryEntry[T]): Unit = {
+ val data = entry match {
+ case DeserializedMemoryEntry(values, _, _) => Left(values)
+ case SerializedMemoryEntry(buffer, _, _) => Right(buffer)
+ }
+ val newEffectiveStorageLevel =
+ blockEvictionHandler.dropFromMemory(blockId, () => data)(entry.classTag)
+ if (newEffectiveStorageLevel.isValid) {
+ // The block is still present in at least one store, so release the lock
+ // but don't delete the block info
+ blockInfoManager.unlock(blockId)
+ } else {
+ // The block isn't present in any store, so delete the block info so that the
+ // block can be stored again
+ blockInfoManager.removeBlock(blockId)
+ }
+ }
+
+ if (freedMemory >= space) {
+ var lastSuccessfulBlock = -1
+ try {
+ logInfo(s"${selectedBlocks.size} blocks selected for dropping " +
+ s"(${Utils.bytesToString(freedMemory)} bytes)")
+ (0 until selectedBlocks.size).foreach { idx =>
+ val blockId = selectedBlocks(idx)
+ val entry = entries.synchronized {
+ entries.get(blockId)
+ }
+ // This should never be null as only one task should be dropping
+ // blocks and removing entries. However the check is still here for
+ // future safety.
+ if (entry != null) {
+ dropBlock(blockId, entry)
+ afterDropAction(blockId)
+ }
+ lastSuccessfulBlock = idx
+ }
+ logInfo(s"After dropping ${selectedBlocks.size} blocks, " +
+ s"free memory is ${Utils.bytesToString(maxMemory - blocksMemoryUsed)}")
+ freedMemory
+ } finally {
+ // like BlockManager.doPut, we use a finally rather than a catch to avoid having to deal
+ // with InterruptedException
+ if (lastSuccessfulBlock != selectedBlocks.size - 1) {
+ // the blocks we didn't process successfully are still locked, so we have to unlock them
+ (lastSuccessfulBlock + 1 until selectedBlocks.size).foreach { idx =>
+ val blockId = selectedBlocks(idx)
+ blockInfoManager.unlock(blockId)
+ }
+ }
+ }
+ } else {
+ blockId.foreach { id =>
+ logInfo(s"Will not store $id")
+ }
+ selectedBlocks.foreach { id =>
+ blockInfoManager.unlock(id)
+ }
+ 0L
+ }
+ }
+ }
+
+ // hook for testing, so we can simulate a race
+ protected def afterDropAction(blockId: BlockId): Unit = {}
+
+ def contains(blockId: BlockId): Boolean = {
+ entries.synchronized { entries.containsKey(blockId) }
+ }
+
+ private def currentTaskAttemptId(): Long = {
+ // In case this is called on the driver, return an invalid task attempt id.
+ Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L)
+ }
+
+ /**
+ * Reserve memory for unrolling the given block for this task.
+ *
+ * @return whether the request is granted.
+ */
+ def reserveUnrollMemoryForThisTask(
+ blockId: BlockId,
+ memory: Long,
+ memoryMode: MemoryMode): Boolean = {
+ memoryManager.synchronized {
+ val success = memoryManager.acquireUnrollMemory(blockId, memory, memoryMode)
+ if (success) {
+ val taskAttemptId = currentTaskAttemptId()
+ val unrollMemoryMap = memoryMode match {
+ case MemoryMode.ON_HEAP => onHeapUnrollMemoryMap
+ case MemoryMode.OFF_HEAP => offHeapUnrollMemoryMap
+ }
+ unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory
+ }
+ success
+ }
+ }
+
+ /**
+ * Release memory used by this task for unrolling blocks.
+ * If the amount is not specified, remove the current task's allocation altogether.
+ */
+ def releaseUnrollMemoryForThisTask(memoryMode: MemoryMode, memory: Long = Long.MaxValue): Unit = {
+ val taskAttemptId = currentTaskAttemptId()
+ memoryManager.synchronized {
+ val unrollMemoryMap = memoryMode match {
+ case MemoryMode.ON_HEAP => onHeapUnrollMemoryMap
+ case MemoryMode.OFF_HEAP => offHeapUnrollMemoryMap
+ }
+ if (unrollMemoryMap.contains(taskAttemptId)) {
+ val memoryToRelease = math.min(memory, unrollMemoryMap(taskAttemptId))
+ if (memoryToRelease > 0) {
+ unrollMemoryMap(taskAttemptId) -= memoryToRelease
+ memoryManager.releaseUnrollMemory(memoryToRelease, memoryMode)
+ }
+ if (unrollMemoryMap(taskAttemptId) == 0) {
+ unrollMemoryMap.remove(taskAttemptId)
+ }
+ }
+ }
+ }
+
+ /**
+ * Return the amount of memory currently occupied for unrolling blocks across all tasks.
+ */
+ def currentUnrollMemory: Long = memoryManager.synchronized {
+ onHeapUnrollMemoryMap.values.sum + offHeapUnrollMemoryMap.values.sum
+ }
+
+ /**
+ * Return the amount of memory currently occupied for unrolling blocks by this task.
+ */
+ def currentUnrollMemoryForThisTask: Long = memoryManager.synchronized {
+ onHeapUnrollMemoryMap.getOrElse(currentTaskAttemptId(), 0L) +
+ offHeapUnrollMemoryMap.getOrElse(currentTaskAttemptId(), 0L)
+ }
+
+ /**
+ * Return the number of tasks currently unrolling blocks.
+ */
+ private def numTasksUnrolling: Int = memoryManager.synchronized {
+ (onHeapUnrollMemoryMap.keys ++ offHeapUnrollMemoryMap.keys).toSet.size
+ }
+
+ /**
+ * Log information about current memory usage.
+ */
+ private def logMemoryUsage(): Unit = {
+ logInfo(
+ s"Memory use = ${Utils.bytesToString(blocksMemoryUsed)} (blocks) + " +
+ s"${Utils.bytesToString(currentUnrollMemory)} (scratch space shared across " +
+ s"$numTasksUnrolling tasks(s)) = ${Utils.bytesToString(memoryUsed)}. " +
+ s"Storage limit = ${Utils.bytesToString(maxMemory)}."
+ )
+ }
+
+ /**
+ * Log a warning for failing to unroll a block.
+ *
+ * @param blockId ID of the block we are trying to unroll.
+ * @param finalVectorSize Final size of the vector before unrolling failed.
+ */
+ private def logUnrollFailureMessage(blockId: BlockId, finalVectorSize: Long): Unit = {
+ logWarning(
+ s"Not enough space to cache $blockId in memory! " +
+ s"(computed ${Utils.bytesToString(finalVectorSize)} so far)"
+ )
+ logMemoryUsage()
+ }
+}
+
+/**
+ * The result of a failed [[MemoryStore.putIteratorAsValues()]] call.
+ *
+ * @param memoryStore the memoryStore, used for freeing memory.
+ * @param memoryMode the memory mode (on- or off-heap).
+ * @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
+ * @param unrolled an iterator for the partially-unrolled values.
+ * @param rest the rest of the original iterator passed to
+ * [[MemoryStore.putIteratorAsValues()]].
+ */
+private[storage] class PartiallyUnrolledIterator[T](
+ memoryStore: MemoryStore,
+ memoryMode: MemoryMode,
+ unrollMemory: Long,
+ private[this] var unrolled: Iterator[T],
+ rest: Iterator[T])
+ extends Iterator[T] {
+
+ private def releaseUnrollMemory(): Unit = {
+ memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
+ // SPARK-17503: Garbage collects the unrolling memory before the life end of
+ // PartiallyUnrolledIterator.
+ unrolled = null
+ }
+
+ override def hasNext: Boolean = {
+ if (unrolled == null) {
+ rest.hasNext
+ } else if (!unrolled.hasNext) {
+ releaseUnrollMemory()
+ rest.hasNext
+ } else {
+ true
+ }
+ }
+
+ override def next(): T = {
+ if (unrolled == null || !unrolled.hasNext) {
+ rest.next()
+ } else {
+ unrolled.next()
+ }
+ }
+
+ /**
+ * Called to dispose of this iterator and free its memory.
+ */
+ def close(): Unit = {
+ if (unrolled != null) {
+ releaseUnrollMemory()
+ }
+ }
+}
+
+/**
+ * A wrapper which allows an open [[OutputStream]] to be redirected to a different sink.
+ */
+private[storage] class RedirectableOutputStream extends OutputStream {
+ private[this] var os: OutputStream = _
+ def setOutputStream(s: OutputStream): Unit = { os = s }
+ override def write(b: Int): Unit = os.write(b)
+ override def write(b: Array[Byte]): Unit = os.write(b)
+ override def write(b: Array[Byte], off: Int, len: Int): Unit = os.write(b, off, len)
+ override def flush(): Unit = os.flush()
+ override def close(): Unit = os.close()
+}
+
+/**
+ * The result of a failed [[MemoryStore.putIteratorAsBytes()]] call.
+ *
+ * @param memoryStore the MemoryStore, used for freeing memory.
+ * @param serializerManager the SerializerManager, used for deserializing values.
+ * @param blockId the block id.
+ * @param serializationStream a serialization stream which writes to [[redirectableOutputStream]].
+ * @param redirectableOutputStream an OutputStream which can be redirected to a different sink.
+ * @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
+ * @param memoryMode whether the unroll memory is on- or off-heap
+ * @param bbos byte buffer output stream containing the partially-serialized values.
+ * [[redirectableOutputStream]] initially points to this output stream.
+ * @param rest the rest of the original iterator passed to
+ * [[MemoryStore.putIteratorAsValues()]].
+ * @param classTag the [[ClassTag]] for the block.
+ */
+private[storage] class PartiallySerializedBlock[T](
+ memoryStore: MemoryStore,
+ serializerManager: SerializerManager,
+ blockId: BlockId,
+ private val serializationStream: SerializationStream,
+ private val redirectableOutputStream: RedirectableOutputStream,
+ val unrollMemory: Long,
+ memoryMode: MemoryMode,
+ bbos: ChunkedByteBufferOutputStream,
+ rest: Iterator[T],
+ classTag: ClassTag[T]) {
+
+ private lazy val unrolledBuffer: ChunkedByteBuffer = {
+ bbos.close()
+ bbos.toChunkedByteBuffer
+ }
+
+ // If the task does not fully consume `valuesIterator` or otherwise fails to consume or dispose of
+ // this PartiallySerializedBlock then we risk leaking of direct buffers, so we use a task
+ // completion listener here in order to ensure that `unrolled.dispose()` is called at least once.
+ // The dispose() method is idempotent, so it's safe to call it unconditionally.
+ Option(TaskContext.get()).foreach { taskContext =>
+ taskContext.addTaskCompletionListener { _ =>
+ // When a task completes, its unroll memory will automatically be freed. Thus we do not call
+ // releaseUnrollMemoryForThisTask() here because we want to avoid double-freeing.
+ unrolledBuffer.dispose()
+ }
+ }
+
+ // Exposed for testing
+ private[storage] def getUnrolledChunkedByteBuffer: ChunkedByteBuffer = unrolledBuffer
+
+ private[this] var discarded = false
+ private[this] var consumed = false
+
+ private def verifyNotConsumedAndNotDiscarded(): Unit = {
+ if (consumed) {
+ throw new IllegalStateException(
+ "Can only call one of finishWritingToStream() or valuesIterator() and can only call once.")
+ }
+ if (discarded) {
+ throw new IllegalStateException("Cannot call methods on a discarded PartiallySerializedBlock")
+ }
+ }
+
+ /**
+ * Called to dispose of this block and free its memory.
+ */
+ def discard(): Unit = {
+ if (!discarded) {
+ try {
+ // We want to close the output stream in order to free any resources associated with the
+ // serializer itself (such as Kryo's internal buffers). close() might cause data to be
+ // written, so redirect the output stream to discard that data.
+ redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream())
+ serializationStream.close()
+ } finally {
+ discarded = true
+ unrolledBuffer.dispose()
+ memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
+ }
+ }
+ }
+
+ /**
+ * Finish writing this block to the given output stream by first writing the serialized values
+ * and then serializing the values from the original input iterator.
+ */
+ def finishWritingToStream(os: OutputStream): Unit = {
+ verifyNotConsumedAndNotDiscarded()
+ consumed = true
+ // `unrolled`'s underlying buffers will be freed once this input stream is fully read:
+ ByteStreams.copy(unrolledBuffer.toInputStream(dispose = true), os)
+ memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
+ redirectableOutputStream.setOutputStream(os)
+ while (rest.hasNext) {
+ serializationStream.writeObject(rest.next())(classTag)
+ }
+ serializationStream.close()
+ }
+
+ /**
+ * Returns an iterator over the values in this block by first deserializing the serialized
+ * values and then consuming the rest of the original input iterator.
+ *
+ * If the caller does not plan to fully consume the resulting iterator then they must call
+ * `close()` on it to free its resources.
+ */
+ def valuesIterator: PartiallyUnrolledIterator[T] = {
+ verifyNotConsumedAndNotDiscarded()
+ consumed = true
+ // Close the serialization stream so that the serializer's internal buffers are freed and any
+ // "end-of-stream" markers can be written out so that `unrolled` is a valid serialized stream.
+ serializationStream.close()
+ // `unrolled`'s underlying buffers will be freed once this input stream is fully read:
+ val unrolledIter = serializerManager.dataDeserializeStream(
+ blockId, unrolledBuffer.toInputStream(dispose = true))(classTag)
+ // The unroll memory will be freed once `unrolledIter` is fully consumed in
+ // PartiallyUnrolledIterator. If the iterator is not consumed by the end of the task then any
+ // extra unroll memory will automatically be freed by a `finally` block in `Task`.
+ new PartiallyUnrolledIterator(
+ memoryStore,
+ memoryMode,
+ unrollMemory,
+ unrolled = unrolledIter,
+ rest = rest)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala
index 77c0bc8b5360..3ae80ecfd22e 100644
--- a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala
+++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ui
import java.util.{Timer, TimerTask}
import org.apache.spark._
+import org.apache.spark.internal.Logging
/**
* ConsoleProgressBar shows the progress of stages in the next line of the console. It poll the
@@ -28,23 +29,24 @@ import org.apache.spark._
* of them will be combined together, showed in one line.
*/
private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging {
- // Carrige return
- val CR = '\r'
+ // Carriage return
+ private val CR = '\r'
// Update period of progress bar, in milliseconds
- val UPDATE_PERIOD = 200L
+ private val updatePeriodMSec =
+ sc.getConf.getTimeAsMs("spark.ui.consoleProgress.update.interval", "200")
// Delay to show up a progress bar, in milliseconds
- val FIRST_DELAY = 500L
+ private val firstDelayMSec = 500L
// The width of terminal
- val TerminalWidth = if (!sys.env.getOrElse("COLUMNS", "").isEmpty) {
+ private val TerminalWidth = if (!sys.env.getOrElse("COLUMNS", "").isEmpty) {
sys.env.get("COLUMNS").get.toInt
} else {
80
}
- var lastFinishTime = 0L
- var lastUpdateTime = 0L
- var lastProgressBar = ""
+ private var lastFinishTime = 0L
+ private var lastUpdateTime = 0L
+ private var lastProgressBar = ""
// Schedule a refresh thread to run periodically
private val timer = new Timer("refresh progress", true)
@@ -52,19 +54,19 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging {
override def run() {
refresh()
}
- }, FIRST_DELAY, UPDATE_PERIOD)
+ }, firstDelayMSec, updatePeriodMSec)
/**
* Try to refresh the progress bar in every cycle
*/
private def refresh(): Unit = synchronized {
val now = System.currentTimeMillis()
- if (now - lastFinishTime < FIRST_DELAY) {
+ if (now - lastFinishTime < firstDelayMSec) {
return
}
val stageIds = sc.statusTracker.getActiveStageIds()
- val stages = stageIds.map(sc.statusTracker.getStageInfo).flatten.filter(_.numTasks() > 1)
- .filter(now - _.submissionTime() > FIRST_DELAY).sortBy(_.stageId())
+ val stages = stageIds.flatMap(sc.statusTracker.getStageInfo).filter(_.numTasks() > 1)
+ .filter(now - _.submissionTime() > firstDelayMSec).sortBy(_.stageId())
if (stages.length > 0) {
show(now, stages.take(3)) // display at most 3 stages in same time
}
@@ -93,7 +95,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging {
header + bar + tailer
}.mkString("")
- // only refresh if it's changed of after 1 minute (or the ssh connection will be closed
+ // only refresh if it's changed OR after 1 minute (or the ssh connection will be closed
// after idle some time)
if (bar != lastProgressBar || now - lastUpdateTime > 60 * 1000L) {
System.err.print(CR + bar)
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index b796a44fe01a..3e0b62dc8aba 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -17,21 +17,29 @@
package org.apache.spark.ui
-import java.net.{InetSocketAddress, URL}
+import java.net.{URI, URL}
import javax.servlet.DispatcherType
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}
+import scala.collection.mutable.ArrayBuffer
import scala.language.implicitConversions
import scala.xml.Node
-import org.eclipse.jetty.server.Server
+import org.eclipse.jetty.client.api.Response
+import org.eclipse.jetty.client.HttpClient
+import org.eclipse.jetty.client.http.HttpClientTransportOverHTTP
+import org.eclipse.jetty.proxy.ProxyServlet
+import org.eclipse.jetty.server._
import org.eclipse.jetty.server.handler._
+import org.eclipse.jetty.server.handler.gzip.GzipHandler
import org.eclipse.jetty.servlet._
-import org.eclipse.jetty.util.thread.QueuedThreadPool
+import org.eclipse.jetty.util.component.LifeCycle
+import org.eclipse.jetty.util.thread.{QueuedThreadPool, ScheduledExecutorScheduler}
import org.json4s.JValue
import org.json4s.jackson.JsonMethods.{pretty, render}
-import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.{SecurityManager, SparkConf, SSLOptions}
+import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils
/**
@@ -39,6 +47,9 @@ import org.apache.spark.util.Utils
*/
private[spark] object JettyUtils extends Logging {
+ val SPARK_CONNECTOR_NAME = "Spark"
+ val REDIRECT_CONNECTOR_NAME = "HttpsRedirect"
+
// Base type for a function that returns something based on an HTTP request. Allows for
// implicit conversion from many types of functions to jetty Handlers.
type Responder[T] = HttpServletRequest => T
@@ -79,13 +90,11 @@ private[spark] object JettyUtils extends Logging {
val result = servletParams.responder(request)
response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
response.setHeader("X-Frame-Options", xFrameOptionsValue)
- // scalastyle:off println
- response.getWriter.println(servletParams.extractFn(result))
- // scalastyle:on println
+ response.getWriter.print(servletParams.extractFn(result))
} else {
- response.setStatus(HttpServletResponse.SC_UNAUTHORIZED)
+ response.setStatus(HttpServletResponse.SC_FORBIDDEN)
response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
- response.sendError(HttpServletResponse.SC_UNAUTHORIZED,
+ response.sendError(HttpServletResponse.SC_FORBIDDEN,
"User is not authorized to access this page.")
}
} catch {
@@ -184,13 +193,64 @@ private[spark] object JettyUtils extends Logging {
contextHandler
}
+ /** Create a handler for proxying request to Workers and Application Drivers */
+ def createProxyHandler(
+ prefix: String,
+ target: String): ServletContextHandler = {
+ val servlet = new ProxyServlet {
+ override def rewriteTarget(request: HttpServletRequest): String = {
+ val rewrittenURI = createProxyURI(
+ prefix, target, request.getRequestURI(), request.getQueryString())
+ if (rewrittenURI == null) {
+ return null
+ }
+ if (!validateDestination(rewrittenURI.getHost(), rewrittenURI.getPort())) {
+ return null
+ }
+ rewrittenURI.toString()
+ }
+
+ override def newHttpClient(): HttpClient = {
+ // SPARK-21176: Use the Jetty logic to calculate the number of selector threads (#CPUs/2),
+ // but limit it to 8 max.
+ // Otherwise, it might happen that we exhaust the threadpool since in reverse proxy mode
+ // a proxy is instantiated for each executor. If the head node has many processors, this
+ // can quickly add up to an unreasonably high number of threads.
+ val numSelectors = math.max(1, math.min(8, Runtime.getRuntime().availableProcessors() / 2))
+ new HttpClient(new HttpClientTransportOverHTTP(numSelectors), null)
+ }
+
+ override def filterServerResponseHeader(
+ clientRequest: HttpServletRequest,
+ serverResponse: Response,
+ headerName: String,
+ headerValue: String): String = {
+ if (headerName.equalsIgnoreCase("location")) {
+ val newHeader = createProxyLocationHeader(
+ prefix, headerValue, clientRequest, serverResponse.getRequest().getURI())
+ if (newHeader != null) {
+ return newHeader
+ }
+ }
+ super.filterServerResponseHeader(
+ clientRequest, serverResponse, headerName, headerValue)
+ }
+ }
+
+ val contextHandler = new ServletContextHandler
+ val holder = new ServletHolder(servlet)
+ contextHandler.setContextPath(prefix)
+ contextHandler.addServlet(holder, "/")
+ contextHandler
+ }
+
/** Add filters, if any, to the given list of ServletContextHandlers */
def addFilters(handlers: Seq[ServletContextHandler], conf: SparkConf) {
val filters: Array[String] = conf.get("spark.ui.filters", "").split(',').map(_.trim())
filters.foreach {
case filter : String =>
if (!filter.isEmpty) {
- logInfo("Adding filter: " + filter)
+ logInfo(s"Adding filter $filter to ${handlers.map(_.getContextPath).mkString(", ")}.")
val holder : FilterHolder = new FilterHolder()
holder.setClassName(filter)
// Get any parameters for each filter
@@ -224,47 +284,240 @@ private[spark] object JettyUtils extends Logging {
def startJettyServer(
hostName: String,
port: Int,
+ sslOptions: SSLOptions,
handlers: Seq[ServletContextHandler],
conf: SparkConf,
serverName: String = ""): ServerInfo = {
addFilters(handlers, conf)
- val collection = new ContextHandlerCollection
- val gzipHandlers = handlers.map { h =>
- val gzipHandler = new GzipHandler
- gzipHandler.setHandler(h)
- gzipHandler
+ // Start the server first, with no connectors.
+ val pool = new QueuedThreadPool
+ if (serverName.nonEmpty) {
+ pool.setName(serverName)
}
- collection.setHandlers(gzipHandlers.toArray)
-
- // Bind to the given port, or throw a java.net.BindException if the port is occupied
- def connect(currentPort: Int): (Server, Int) = {
- val server = new Server(new InetSocketAddress(hostName, currentPort))
- val pool = new QueuedThreadPool
- pool.setDaemon(true)
- server.setThreadPool(pool)
- val errorHandler = new ErrorHandler()
- errorHandler.setShowStacks(true)
- server.addBean(errorHandler)
- server.setHandler(collection)
- try {
- server.start()
- (server, server.getConnectors.head.getLocalPort)
- } catch {
- case e: Exception =>
- server.stop()
+ pool.setDaemon(true)
+
+ val server = new Server(pool)
+
+ val errorHandler = new ErrorHandler()
+ errorHandler.setShowStacks(true)
+ errorHandler.setServer(server)
+ server.addBean(errorHandler)
+
+ val collection = new ContextHandlerCollection
+ server.setHandler(collection)
+
+ // Executor used to create daemon threads for the Jetty connectors.
+ val serverExecutor = new ScheduledExecutorScheduler(s"$serverName-JettyScheduler", true)
+
+ try {
+ server.start()
+
+ // As each acceptor and each selector will use one thread, the number of threads should at
+ // least be the number of acceptors and selectors plus 1. (See SPARK-13776)
+ var minThreads = 1
+
+ def newConnector(
+ connectionFactories: Array[ConnectionFactory],
+ port: Int): (ServerConnector, Int) = {
+ val connector = new ServerConnector(
+ server,
+ null,
+ serverExecutor,
+ null,
+ -1,
+ -1,
+ connectionFactories: _*)
+ connector.setPort(port)
+ connector.setHost(hostName)
+
+ // Currently we only use "SelectChannelConnector"
+ // Limit the max acceptor number to 8 so that we don't waste a lot of threads
+ connector.setAcceptQueueSize(math.min(connector.getAcceptors, 8))
+
+ connector.start()
+ // The number of selectors always equals to the number of acceptors
+ minThreads += connector.getAcceptors * 2
+
+ (connector, connector.getLocalPort())
+ }
+
+ // If SSL is configured, create the secure connector first.
+ val securePort = sslOptions.createJettySslContextFactory().map { factory =>
+ val securePort = sslOptions.port.getOrElse(if (port > 0) Utils.userPort(port, 400) else 0)
+ val secureServerName = if (serverName.nonEmpty) s"$serverName (HTTPS)" else serverName
+ val connectionFactories = AbstractConnectionFactory.getFactories(factory,
+ new HttpConnectionFactory())
+
+ def sslConnect(currentPort: Int): (ServerConnector, Int) = {
+ newConnector(connectionFactories, currentPort)
+ }
+
+ val (connector, boundPort) = Utils.startServiceOnPort[ServerConnector](securePort,
+ sslConnect, conf, secureServerName)
+ connector.setName(SPARK_CONNECTOR_NAME)
+ server.addConnector(connector)
+ boundPort
+ }
+
+ // Bind the HTTP port.
+ def httpConnect(currentPort: Int): (ServerConnector, Int) = {
+ newConnector(Array(new HttpConnectionFactory()), currentPort)
+ }
+
+ val (httpConnector, httpPort) = Utils.startServiceOnPort[ServerConnector](port, httpConnect,
+ conf, serverName)
+
+ // If SSL is configured, then configure redirection in the HTTP connector.
+ securePort match {
+ case Some(p) =>
+ httpConnector.setName(REDIRECT_CONNECTOR_NAME)
+ val redirector = createRedirectHttpsHandler(p, "https")
+ collection.addHandler(redirector)
+ redirector.start()
+
+ case None =>
+ httpConnector.setName(SPARK_CONNECTOR_NAME)
+ }
+
+ server.addConnector(httpConnector)
+
+ // Add all the known handlers now that connectors are configured.
+ handlers.foreach { h =>
+ h.setVirtualHosts(toVirtualHosts(SPARK_CONNECTOR_NAME))
+ val gzipHandler = new GzipHandler()
+ gzipHandler.setHandler(h)
+ collection.addHandler(gzipHandler)
+ gzipHandler.start()
+ }
+
+ pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads))
+ ServerInfo(server, httpPort, securePort, conf, collection)
+ } catch {
+ case e: Exception =>
+ server.stop()
+ if (serverExecutor.isStarted()) {
+ serverExecutor.stop()
+ }
+ if (pool.isStarted()) {
pool.stop()
- throw e
+ }
+ throw e
+ }
+ }
+
+ private def createRedirectHttpsHandler(securePort: Int, scheme: String): ContextHandler = {
+ val redirectHandler: ContextHandler = new ContextHandler
+ redirectHandler.setContextPath("/")
+ redirectHandler.setVirtualHosts(toVirtualHosts(REDIRECT_CONNECTOR_NAME))
+ redirectHandler.setHandler(new AbstractHandler {
+ override def handle(
+ target: String,
+ baseRequest: Request,
+ request: HttpServletRequest,
+ response: HttpServletResponse): Unit = {
+ if (baseRequest.isSecure) {
+ return
+ }
+ val httpsURI = createRedirectURI(scheme, baseRequest.getServerName, securePort,
+ baseRequest.getRequestURI, baseRequest.getQueryString)
+ response.setContentLength(0)
+ response.sendRedirect(response.encodeRedirectURL(httpsURI))
+ baseRequest.setHandled(true)
+ }
+ })
+ redirectHandler
+ }
+
+ def createProxyURI(prefix: String, target: String, path: String, query: String): URI = {
+ if (!path.startsWith(prefix)) {
+ return null
+ }
+
+ val uri = new StringBuilder(target)
+ val rest = path.substring(prefix.length())
+
+ if (!rest.isEmpty()) {
+ if (!rest.startsWith("/")) {
+ uri.append("/")
}
+ uri.append(rest)
}
- val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName)
- ServerInfo(server, boundPort, collection)
+ val rewrittenURI = URI.create(uri.toString())
+ if (query != null) {
+ return new URI(
+ rewrittenURI.getScheme(),
+ rewrittenURI.getAuthority(),
+ rewrittenURI.getPath(),
+ query,
+ rewrittenURI.getFragment()
+ ).normalize()
+ }
+ rewrittenURI.normalize()
}
+
+ def createProxyLocationHeader(
+ prefix: String,
+ headerValue: String,
+ clientRequest: HttpServletRequest,
+ targetUri: URI): String = {
+ val toReplace = targetUri.getScheme() + "://" + targetUri.getAuthority()
+ if (headerValue.startsWith(toReplace)) {
+ clientRequest.getScheme() + "://" + clientRequest.getHeader("host") +
+ prefix + headerValue.substring(toReplace.length())
+ } else {
+ null
+ }
+ }
+
+ // Create a new URI from the arguments, handling IPv6 host encoding and default ports.
+ private def createRedirectURI(
+ scheme: String, server: String, port: Int, path: String, query: String) = {
+ val redirectServer = if (server.contains(":") && !server.startsWith("[")) {
+ s"[${server}]"
+ } else {
+ server
+ }
+ val authority = s"$redirectServer:$port"
+ new URI(scheme, authority, path, query, null).toString
+ }
+
+ def toVirtualHosts(connectors: String*): Array[String] = connectors.map("@" + _).toArray
+
}
private[spark] case class ServerInfo(
server: Server,
boundPort: Int,
- rootHandler: ContextHandlerCollection)
+ securePort: Option[Int],
+ conf: SparkConf,
+ private val rootHandler: ContextHandlerCollection) {
+
+ def addHandler(handler: ServletContextHandler): Unit = {
+ handler.setVirtualHosts(JettyUtils.toVirtualHosts(JettyUtils.SPARK_CONNECTOR_NAME))
+ JettyUtils.addFilters(Seq(handler), conf)
+ rootHandler.addHandler(handler)
+ if (!handler.isStarted()) {
+ handler.start()
+ }
+ }
+
+ def removeHandler(handler: ContextHandler): Unit = {
+ rootHandler.removeHandler(handler)
+ if (handler.isStarted) {
+ handler.stop()
+ }
+ }
+
+ def stop(): Unit = {
+ server.stop()
+ // Stop the ThreadPool if it supports stop() method (through LifeCycle).
+ // It is needed because stopping the Server won't stop the ThreadPool it uses.
+ val threadPool = server.getThreadPool
+ if (threadPool != null && threadPool.isInstanceOf[LifeCycle]) {
+ threadPool.asInstanceOf[LifeCycle].stop
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala
index 6e2375477a68..79974df2603f 100644
--- a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala
@@ -17,8 +17,15 @@
package org.apache.spark.ui
+import java.net.URLDecoder
+
+import scala.collection.JavaConverters._
import scala.xml.{Node, Unparsed}
+import com.google.common.base.Splitter
+
+import org.apache.spark.util.Utils
+
/**
* A data source that provides data for a page.
*
@@ -71,6 +78,12 @@ private[ui] trait PagedTable[T] {
def tableCssClass: String
+ def pageSizeFormField: String
+
+ def prevPageSizeFormField: String
+
+ def pageNumberFormField: String
+
def dataSource: PagedDataSource[T]
def headers: Seq[Node]
@@ -95,7 +108,12 @@ private[ui] trait PagedTable[T] {
val PageData(totalPages, _) = _dataSource.pageData(1)
{pageNavigation(1, _dataSource.pageSize, totalPages)}
- {e.getMessage}
+
+ Error while rendering table:
+
+ {Utils.exceptionString(e)}
+
+
}
}
@@ -151,36 +169,58 @@ private[ui] trait PagedTable[T] {
// The current page should be disabled so that it cannot be clicked.
{p}
} else {
- {p}
+ {p}
+ }
+ }
+
+ val hiddenFormFields = {
+ if (goButtonFormPath.contains('?')) {
+ val queryString = goButtonFormPath.split("\\?", 2)(1)
+ val search = queryString.split("#")(0)
+ Splitter
+ .on('&')
+ .trimResults()
+ .omitEmptyStrings()
+ .withKeyValueSeparator("=")
+ .split(search)
+ .asScala
+ .filterKeys(_ != pageSizeFormField)
+ .filterKeys(_ != prevPageSizeFormField)
+ .filterKeys(_ != pageNumberFormField)
+ .mapValues(URLDecoder.decode(_, "UTF-8"))
+ .map { case (k, v) =>
+
+ }
+ } else {
+ Seq.empty
}
}
- val (goButtonJsFuncName, goButtonJsFunc) = goButtonJavascriptFunction
- // When clicking the "Go" button, it will call this javascript method and then call
- // "goButtonJsFuncName"
- val formJs =
- s"""$$(function(){
- | $$( "#form-$tableId-page" ).submit(function(event) {
- | var page = $$("#form-$tableId-page-no").val()
- | var pageSize = $$("#form-$tableId-page-size").val()
- | pageSize = pageSize ? pageSize: 100;
- | if (page != "") {
- | ${goButtonJsFuncName}(page, pageSize);
- | }
- | event.preventDefault();
- | });
- |});
- """.stripMargin
@@ -189,7 +229,7 @@ private[ui] trait PagedTable[T] {
{if (currentGroup > firstGroup) {
-
-
+
@@ -198,7 +238,7 @@ private[ui] trait PagedTable[T] {
}}
{if (page > 1) {
-
-
+
@@ -208,14 +248,14 @@ private[ui] trait PagedTable[T] {
{pageTags}
{if (page < totalPages) {
-
-
+
}}
{if (currentGroup < lastGroup) {
-
-
+
@@ -224,11 +264,6 @@ private[ui] trait PagedTable[T] {
}}
-
}
}
@@ -239,10 +274,7 @@ private[ui] trait PagedTable[T] {
def pageLink(page: Int): String
/**
- * Only the implementation knows how to create the url with a page number and the page size, so we
- * leave this one to the implementation. The implementation should create a JavaScript method that
- * accepts a page number along with the page size and jumps to the page. The return value is this
- * method name and its JavaScript codes.
+ * Returns the submission path for the "go to page #" form.
*/
- def goButtonJavascriptFunction: (String, String)
+ def goButtonFormPath: String
}
diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
index 99085ada9f0a..bf4cf79e9faa 100644
--- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
@@ -17,19 +17,23 @@
package org.apache.spark.ui
-import java.util.Date
+import java.util.{Date, ServiceLoader}
+import scala.collection.JavaConverters._
+
+import org.apache.spark.{SecurityManager, SparkConf, SparkContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.scheduler._
import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo,
UIRoot}
-import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext}
-import org.apache.spark.scheduler._
import org.apache.spark.storage.StorageStatusListener
import org.apache.spark.ui.JettyUtils._
import org.apache.spark.ui.env.{EnvironmentListener, EnvironmentTab}
import org.apache.spark.ui.exec.{ExecutorsListener, ExecutorsTab}
-import org.apache.spark.ui.jobs.{JobsTab, JobProgressListener, StagesTab}
-import org.apache.spark.ui.storage.{StorageListener, StorageTab}
+import org.apache.spark.ui.jobs.{JobProgressListener, JobsTab, StagesTab}
import org.apache.spark.ui.scope.RDDOperationGraphListener
+import org.apache.spark.ui.storage.{StorageListener, StorageTab}
+import org.apache.spark.util.Utils
/**
* Top level user interface for a Spark application.
@@ -47,20 +51,22 @@ private[spark] class SparkUI private (
var appName: String,
val basePath: String,
val startTime: Long)
- extends WebUI(securityManager, SparkUI.getUIPort(conf), conf, basePath, "SparkUI")
+ extends WebUI(securityManager, securityManager.getSSLOptions("ui"), SparkUI.getUIPort(conf),
+ conf, basePath, "SparkUI")
with Logging
with UIRoot {
val killEnabled = sc.map(_.conf.getBoolean("spark.ui.killEnabled", true)).getOrElse(false)
-
- val stagesTab = new StagesTab(this)
-
var appId: String = _
+ private var streamingJobProgressListener: Option[SparkListener] = None
+
/** Initialize all components of the server. */
def initialize() {
- attachTab(new JobsTab(this))
+ val jobsTab = new JobsTab(this)
+ attachTab(jobsTab)
+ val stagesTab = new StagesTab(this)
attachTab(stagesTab)
attachTab(new StorageTab(this))
attachTab(new EnvironmentTab(this))
@@ -68,13 +74,19 @@ private[spark] class SparkUI private (
attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static"))
attachHandler(createRedirectHandler("/", "/jobs/", basePath = basePath))
attachHandler(ApiRootResource.getServletHandler(this))
- // This should be POST only, but, the YARN AM proxy won't proxy POSTs
+ // These should be POST only, but, the YARN AM proxy won't proxy POSTs
+ attachHandler(createRedirectHandler(
+ "/jobs/job/kill", "/jobs/", jobsTab.handleKillRequest, httpMethods = Set("GET", "POST")))
attachHandler(createRedirectHandler(
"/stages/stage/kill", "/stages/", stagesTab.handleKillRequest,
httpMethods = Set("GET", "POST")))
}
initialize()
+ def getSparkUser: String = {
+ environmentListener.systemProperties.toMap.getOrElse("user.name", "")
+ }
+
def getAppName: String = appName
def setAppId(id: String): Unit = {
@@ -84,16 +96,9 @@ private[spark] class SparkUI private (
/** Stop the server behind this web interface. Only valid after bind(). */
override def stop() {
super.stop()
- logInfo("Stopped Spark web UI at %s".format(appUIAddress))
+ logInfo(s"Stopped Spark web UI at $webUrl")
}
- /**
- * Return the application UI host:port. This does not include the scheme (http://).
- */
- private[spark] def appUIHostPort = publicHostName + ":" + boundPort
-
- private[spark] def appUIAddress = s"http://$appUIHostPort"
-
def getSparkUI(appId: String): Option[SparkUI] = {
if (appId == this.appId) Some(this) else None
}
@@ -102,21 +107,37 @@ private[spark] class SparkUI private (
Iterator(new ApplicationInfo(
id = appId,
name = appName,
+ coresGranted = None,
+ maxCores = None,
+ coresPerExecutor = None,
+ memoryPerExecutorMB = None,
attempts = Seq(new ApplicationAttemptInfo(
attemptId = None,
startTime = new Date(startTime),
endTime = new Date(-1),
- sparkUser = "",
+ duration = 0,
+ lastUpdated = new Date(startTime),
+ sparkUser = getSparkUser,
completed = false
))
))
}
+
+ def getApplicationInfo(appId: String): Option[ApplicationInfo] = {
+ getApplicationInfoList.find(_.id == appId)
+ }
+
+ def getStreamingJobProgressListener: Option[SparkListener] = streamingJobProgressListener
+
+ def setStreamingJobProgressListener(sparkListener: SparkListener): Unit = {
+ streamingJobProgressListener = Option(sparkListener)
+ }
}
private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String)
extends WebUITab(parent, prefix) {
- def appName: String = parent.getAppName
+ def appName: String = parent.appName
}
@@ -150,7 +171,16 @@ private[spark] object SparkUI {
appName: String,
basePath: String,
startTime: Long): SparkUI = {
- create(None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime)
+ val sparkUI = create(
+ None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime)
+
+ val listenerFactories = ServiceLoader.load(classOf[SparkHistoryListenerFactory],
+ Utils.getContextOrSparkClassLoader).asScala
+ listenerFactories.foreach { listenerFactory =>
+ val listeners = listenerFactory.createListeners(conf, sparkUI)
+ listeners.foreach(listenerBus.addListener)
+ }
+ sparkUI
}
/**
@@ -177,8 +207,8 @@ private[spark] object SparkUI {
}
val environmentListener = new EnvironmentListener
- val storageStatusListener = new StorageStatusListener
- val executorsListener = new ExecutorsListener(storageStatusListener)
+ val storageStatusListener = new StorageStatusListener(conf)
+ val executorsListener = new ExecutorsListener(storageStatusListener, conf)
val storageListener = new StorageListener(storageStatusListener)
val operationGraphListener = new RDDOperationGraphListener(conf)
diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
index cb122eaed83d..766cc65084f0 100644
--- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
+++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
@@ -87,4 +87,16 @@ private[spark] object ToolTips {
multiple operations (e.g. two map() functions) if they can be pipelined. Some operations
also create multiple RDDs internally. Cached RDDs are shown in green.
"""
+
+ val TASK_TIME =
+ "Shaded red when garbage collection (GC) time is over 10% of task time"
+
+ val BLACKLISTED =
+ "Shows if this executor has been blacklisted by the scheduler due to task failures."
+
+ val APPLICATION_EXECUTOR_LIMIT =
+ """Maximum number of executors that this application will use. This limit is finite only when
+ dynamic allocation is enabled. The number of granted executors may exceed the limit
+ ephemerally when executors are being killed.
+ """
}
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index 25dcb604d9e5..4bc7fb6185e6 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -17,14 +17,17 @@
package org.apache.spark.ui
+import java.net.URLDecoder
import java.text.SimpleDateFormat
-import java.util.{Date, Locale}
+import java.util.{Date, Locale, TimeZone}
import scala.util.control.NonFatal
import scala.xml._
import scala.xml.transform.{RewriteRule, RuleTransformer}
-import org.apache.spark.Logging
+import org.apache.commons.lang3.StringEscapeUtils
+
+import org.apache.spark.internal.Logging
import org.apache.spark.ui.scope.RDDOperationGraph
/** Utility functions for generating XML pages with spark content. */
@@ -33,9 +36,12 @@ private[spark] object UIUtils extends Logging {
val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped"
val TABLE_CLASS_STRIPED_SORTABLE = TABLE_CLASS_STRIPED + " sortable"
+ private val NEWLINE_AND_SINGLE_QUOTE_REGEX = raw"(?i)(\r\n|\n|\r|%0D%0A|%0A|%0D|'|%27)".r
+
// SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use.
private val dateFormat = new ThreadLocal[SimpleDateFormat]() {
- override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
+ override def initialValue(): SimpleDateFormat =
+ new SimpleDateFormat("yyyy/MM/dd HH:mm:ss", Locale.US)
}
def formatDate(date: Date): String = dateFormat.get.format(date)
@@ -167,6 +173,9 @@ private[spark] object UIUtils extends Logging {
+
+
+
}
def vizHeaderNodes: Seq[Node] = {
@@ -177,6 +186,20 @@ private[spark] object UIUtils extends Logging {
}
+ def dataTablesHeaderNodes: Seq[Node] = {
+
+
+
+
+
+
+
+
+
+ }
+
/** Returns a spark page with correctly formatted headers */
def headerSparkPage(
title: String,
@@ -184,7 +207,8 @@ private[spark] object UIUtils extends Logging {
activeTab: SparkUITab,
refreshInterval: Option[Int] = None,
helpText: Option[String] = None,
- showVisualization: Boolean = false): Seq[Node] = {
+ showVisualization: Boolean = false,
+ useDataTables: Boolean = false): Seq[Node] = {
val appName = activeTab.appName
val shortAppName = if (appName.length < 36) appName else appName.take(32) + "..."
@@ -199,6 +223,7 @@ private[spark] object UIUtils extends Logging {
{commonHeaderNodes}
{if (showVisualization) vizHeaderNodes else Seq.empty}
+ {if (useDataTables) dataTablesHeaderNodes else Seq.empty}
{appName} - {title}
@@ -210,10 +235,10 @@ private[spark] object UIUtils extends Logging {
{org.apache.spark.SPARK_VERSION}
-
+
@@ -232,10 +257,14 @@ private[spark] object UIUtils extends Logging {
}
/** Returns a page with the spark css/js and a simple format. Used for scheduler UI. */
- def basicSparkPage(content: => Seq[Node], title: String): Seq[Node] = {
+ def basicSparkPage(
+ content: => Seq[Node],
+ title: String,
+ useDataTables: Boolean = false): Seq[Node] = {
{commonHeaderNodes}
+ {if (useDataTables) dataTablesHeaderNodes else Seq.empty}
{title}
@@ -317,15 +346,22 @@ private[spark] object UIUtils extends Logging {
completed: Int,
failed: Int,
skipped: Int,
+ reasonToNumKilled: Map[String, Int],
total: Int): Seq[Node] = {
val completeWidth = "width: %s%%".format((completed.toDouble/total)*100)
- val startWidth = "width: %s%%".format((started.toDouble/total)*100)
+ // started + completed can be > total when there are speculative tasks
+ val boundedStarted = math.min(started, total - completed)
+ val startWidth = "width: %s%%".format((boundedStarted.toDouble/total)*100)
{completed}/{total}
{ if (failed > 0) s"($failed failed)" }
{ if (skipped > 0) s"($skipped skipped)" }
+ { reasonToNumKilled.toSeq.sortBy(-_._2).map {
+ case (reason, count) => s"($count killed: $reason)"
+ }
+ }
@@ -387,23 +423,24 @@ private[spark] object UIUtils extends Logging {
}
- /** Return a script element that automatically expands the DAG visualization on page load. */
- def expandDagVizOnLoad(forJob: Boolean): Seq[Node] = {
-
- }
-
/**
* Returns HTML rendering of a job or stage description. It will try to parse the string as HTML
* and make sure that it only contains anchors with root-relative links. Otherwise,
* the whole string will rendered as a simple escaped text.
*
* Note: In terms of security, only anchor tags with root relative links are supported. So any
- * attempts to embed links outside Spark UI, or other tags like ++
+ ++
+
+ }
-
-
- {execTable}
-
- ;
- UIUtils.headerSparkPage("Executors (" + execInfo.size + ")", content, parent)
+ UIUtils.headerSparkPage("Executors", content, parent, useDataTables = true)
}
-
- /** Render an HTML row representing an executor */
- private def execRow(info: ExecutorSummary, logsExist: Boolean): Seq[Node] = {
- val maximumMemory = info.maxMemory
- val memoryUsed = info.memoryUsed
- val diskUsed = info.diskUsed
-
- {info.id}
- {info.hostPort}
- {info.rddBlocks}
-
- {Utils.bytesToString(memoryUsed)} /
- {Utils.bytesToString(maximumMemory)}
-
-
- {Utils.bytesToString(diskUsed)}
-
- {info.activeTasks}
- {info.failedTasks}
- {info.completedTasks}
- {info.totalTasks}
-
- {Utils.msDurationToString(info.totalDuration)}
-
-
- {Utils.bytesToString(info.totalInputBytes)}
-
-
- {Utils.bytesToString(info.totalShuffleRead)}
-
-
- {Utils.bytesToString(info.totalShuffleWrite)}
-
- {
- if (logsExist) {
-
- {
- info.executorLogs.map { case (logName, logUrl) =>
-
-
- {logName}
-
-
- }
- }
-
- }
- }
- {
- if (threadDumpEnabled) {
- val encodedId = URLEncoder.encode(info.id, "UTF-8")
-
- Thread Dump
-
- } else {
- Seq.empty
- }
- }
-
- }
-
}
private[spark] object ExecutorsPage {
+ private val ON_HEAP_MEMORY_TOOLTIP = "Memory used / total available memory for on heap " +
+ "storage of data like RDD partitions cached in memory."
+ private val OFF_HEAP_MEMORY_TOOLTIP = "Memory used / total available memory for off heap " +
+ "storage of data like RDD partitions cached in memory."
+
/** Represent an executor's info as a map given a storage status index */
- def getExecInfo(listener: ExecutorsListener, statusId: Int): ExecutorSummary = {
- val status = listener.storageStatusList(statusId)
+ def getExecInfo(
+ listener: ExecutorsListener,
+ statusId: Int,
+ isActive: Boolean): ExecutorSummary = {
+ val status = if (isActive) {
+ listener.activeStorageStatusList(statusId)
+ } else {
+ listener.deadStorageStatusList(statusId)
+ }
val execId = status.blockManagerId.executorId
val hostPort = status.blockManagerId.hostPort
val rddBlocks = status.numBlocks
val memUsed = status.memUsed
val maxMem = status.maxMem
+ val memoryMetrics = for {
+ onHeapUsed <- status.onHeapMemUsed
+ offHeapUsed <- status.offHeapMemUsed
+ maxOnHeap <- status.maxOnHeapMem
+ maxOffHeap <- status.maxOffHeapMem
+ } yield {
+ new MemoryMetrics(onHeapUsed, offHeapUsed, maxOnHeap, maxOffHeap)
+ }
+
+
val diskUsed = status.diskUsed
- val activeTasks = listener.executorToTasksActive.getOrElse(execId, 0)
- val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0)
- val completedTasks = listener.executorToTasksComplete.getOrElse(execId, 0)
- val totalTasks = activeTasks + failedTasks + completedTasks
- val totalDuration = listener.executorToDuration.getOrElse(execId, 0L)
- val totalInputBytes = listener.executorToInputBytes.getOrElse(execId, 0L)
- val totalShuffleRead = listener.executorToShuffleRead.getOrElse(execId, 0L)
- val totalShuffleWrite = listener.executorToShuffleWrite.getOrElse(execId, 0L)
- val executorLogs = listener.executorToLogUrls.getOrElse(execId, Map.empty)
+ val taskSummary = listener.executorToTaskSummary.getOrElse(execId, ExecutorTaskSummary(execId))
new ExecutorSummary(
execId,
hostPort,
+ isActive,
rddBlocks,
memUsed,
diskUsed,
- activeTasks,
- failedTasks,
- completedTasks,
- totalTasks,
- totalDuration,
- totalInputBytes,
- totalShuffleRead,
- totalShuffleWrite,
+ taskSummary.totalCores,
+ taskSummary.tasksMax,
+ taskSummary.tasksActive,
+ taskSummary.tasksFailed,
+ taskSummary.tasksComplete,
+ taskSummary.tasksActive + taskSummary.tasksFailed + taskSummary.tasksComplete,
+ taskSummary.duration,
+ taskSummary.jvmGCTime,
+ taskSummary.inputBytes,
+ taskSummary.shuffleRead,
+ taskSummary.shuffleWrite,
+ taskSummary.isBlacklisted,
maxMem,
- executorLogs
+ taskSummary.executorLogs,
+ memoryMetrics
)
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
index a88fc4c37d3c..770da2226fe0 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
@@ -17,14 +17,13 @@
package org.apache.spark.ui.exec
-import scala.collection.mutable.HashMap
+import scala.collection.mutable.{LinkedHashMap, ListBuffer}
-import org.apache.spark.{Resubmitted, ExceptionFailure, SparkContext}
+import org.apache.spark.{Resubmitted, SparkConf, SparkContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.scheduler._
import org.apache.spark.storage.{StorageStatus, StorageStatusListener}
import org.apache.spark.ui.{SparkUI, SparkUITab}
-import org.apache.spark.ui.jobs.UIData.ExecutorUIData
private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") {
val listener = parent.executorsListener
@@ -38,101 +37,172 @@ private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "exec
}
}
+private[ui] case class ExecutorTaskSummary(
+ var executorId: String,
+ var totalCores: Int = 0,
+ var tasksMax: Int = 0,
+ var tasksActive: Int = 0,
+ var tasksFailed: Int = 0,
+ var tasksComplete: Int = 0,
+ var duration: Long = 0L,
+ var jvmGCTime: Long = 0L,
+ var inputBytes: Long = 0L,
+ var inputRecords: Long = 0L,
+ var outputBytes: Long = 0L,
+ var outputRecords: Long = 0L,
+ var shuffleRead: Long = 0L,
+ var shuffleWrite: Long = 0L,
+ var executorLogs: Map[String, String] = Map.empty,
+ var isAlive: Boolean = true,
+ var isBlacklisted: Boolean = false
+)
+
/**
* :: DeveloperApi ::
* A SparkListener that prepares information to be displayed on the ExecutorsTab
*/
@DeveloperApi
-class ExecutorsListener(storageStatusListener: StorageStatusListener) extends SparkListener {
- val executorToTasksActive = HashMap[String, Int]()
- val executorToTasksComplete = HashMap[String, Int]()
- val executorToTasksFailed = HashMap[String, Int]()
- val executorToDuration = HashMap[String, Long]()
- val executorToInputBytes = HashMap[String, Long]()
- val executorToInputRecords = HashMap[String, Long]()
- val executorToOutputBytes = HashMap[String, Long]()
- val executorToOutputRecords = HashMap[String, Long]()
- val executorToShuffleRead = HashMap[String, Long]()
- val executorToShuffleWrite = HashMap[String, Long]()
- val executorToLogUrls = HashMap[String, Map[String, String]]()
- val executorIdToData = HashMap[String, ExecutorUIData]()
-
- def storageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList
-
- override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = synchronized {
+@deprecated("This class will be removed in a future release.", "2.2.0")
+class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: SparkConf)
+ extends SparkListener {
+ val executorToTaskSummary = LinkedHashMap[String, ExecutorTaskSummary]()
+ var executorEvents = new ListBuffer[SparkListenerEvent]()
+
+ private val maxTimelineExecutors = conf.getInt("spark.ui.timeline.executors.maximum", 1000)
+ private val retainedDeadExecutors = conf.getInt("spark.ui.retainedDeadExecutors", 100)
+
+ def activeStorageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList
+
+ def deadStorageStatusList: Seq[StorageStatus] = storageStatusListener.deadStorageStatusList
+
+ override def onExecutorAdded(
+ executorAdded: SparkListenerExecutorAdded): Unit = synchronized {
val eid = executorAdded.executorId
- executorToLogUrls(eid) = executorAdded.executorInfo.logUrlMap
- executorIdToData(eid) = ExecutorUIData(executorAdded.time)
+ val taskSummary = executorToTaskSummary.getOrElseUpdate(eid, ExecutorTaskSummary(eid))
+ taskSummary.executorLogs = executorAdded.executorInfo.logUrlMap
+ taskSummary.totalCores = executorAdded.executorInfo.totalCores
+ taskSummary.tasksMax = taskSummary.totalCores / conf.getInt("spark.task.cpus", 1)
+ executorEvents += executorAdded
+ if (executorEvents.size > maxTimelineExecutors) {
+ executorEvents.remove(0)
+ }
+
+ val deadExecutors = executorToTaskSummary.filter(e => !e._2.isAlive)
+ if (deadExecutors.size > retainedDeadExecutors) {
+ val head = deadExecutors.head
+ executorToTaskSummary.remove(head._1)
+ }
}
override def onExecutorRemoved(
executorRemoved: SparkListenerExecutorRemoved): Unit = synchronized {
- val eid = executorRemoved.executorId
- val uiData = executorIdToData(eid)
- uiData.finishTime = Some(executorRemoved.time)
- uiData.finishReason = Some(executorRemoved.reason)
+ executorEvents += executorRemoved
+ if (executorEvents.size > maxTimelineExecutors) {
+ executorEvents.remove(0)
+ }
+ executorToTaskSummary.get(executorRemoved.executorId).foreach(e => e.isAlive = false)
}
- override def onApplicationStart(applicationStart: SparkListenerApplicationStart): Unit = {
+ override def onApplicationStart(
+ applicationStart: SparkListenerApplicationStart): Unit = {
applicationStart.driverLogs.foreach { logs =>
- val storageStatus = storageStatusList.find { s =>
+ val storageStatus = activeStorageStatusList.find { s =>
s.blockManagerId.executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER ||
s.blockManagerId.executorId == SparkContext.DRIVER_IDENTIFIER
}
- storageStatus.foreach { s => executorToLogUrls(s.blockManagerId.executorId) = logs.toMap }
+ storageStatus.foreach { s =>
+ val eid = s.blockManagerId.executorId
+ val taskSummary = executorToTaskSummary.getOrElseUpdate(eid, ExecutorTaskSummary(eid))
+ taskSummary.executorLogs = logs.toMap
+ }
}
}
- override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized {
+ override def onTaskStart(
+ taskStart: SparkListenerTaskStart): Unit = synchronized {
val eid = taskStart.taskInfo.executorId
- executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 0) + 1
+ val taskSummary = executorToTaskSummary.getOrElseUpdate(eid, ExecutorTaskSummary(eid))
+ taskSummary.tasksActive += 1
}
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
+ override def onTaskEnd(
+ taskEnd: SparkListenerTaskEnd): Unit = synchronized {
val info = taskEnd.taskInfo
if (info != null) {
val eid = info.executorId
- taskEnd.reason match {
- case Resubmitted =>
- // Note: For resubmitted tasks, we continue to use the metrics that belong to the
- // first attempt of this task. This may not be 100% accurate because the first attempt
- // could have failed half-way through. The correct fix would be to keep track of the
- // metrics added by each attempt, but this is much more complicated.
- return
- case e: ExceptionFailure =>
- executorToTasksFailed(eid) = executorToTasksFailed.getOrElse(eid, 0) + 1
- case _ =>
- executorToTasksComplete(eid) = executorToTasksComplete.getOrElse(eid, 0) + 1
+ val taskSummary = executorToTaskSummary.getOrElseUpdate(eid, ExecutorTaskSummary(eid))
+ // Note: For resubmitted tasks, we continue to use the metrics that belong to the
+ // first attempt of this task. This may not be 100% accurate because the first attempt
+ // could have failed half-way through. The correct fix would be to keep track of the
+ // metrics added by each attempt, but this is much more complicated.
+ if (taskEnd.reason == Resubmitted) {
+ return
}
-
- executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 1) - 1
- executorToDuration(eid) = executorToDuration.getOrElse(eid, 0L) + info.duration
+ if (info.successful) {
+ taskSummary.tasksComplete += 1
+ } else {
+ taskSummary.tasksFailed += 1
+ }
+ if (taskSummary.tasksActive >= 1) {
+ taskSummary.tasksActive -= 1
+ }
+ taskSummary.duration += info.duration
// Update shuffle read/write
val metrics = taskEnd.taskMetrics
if (metrics != null) {
- metrics.inputMetrics.foreach { inputMetrics =>
- executorToInputBytes(eid) =
- executorToInputBytes.getOrElse(eid, 0L) + inputMetrics.bytesRead
- executorToInputRecords(eid) =
- executorToInputRecords.getOrElse(eid, 0L) + inputMetrics.recordsRead
- }
- metrics.outputMetrics.foreach { outputMetrics =>
- executorToOutputBytes(eid) =
- executorToOutputBytes.getOrElse(eid, 0L) + outputMetrics.bytesWritten
- executorToOutputRecords(eid) =
- executorToOutputRecords.getOrElse(eid, 0L) + outputMetrics.recordsWritten
- }
- metrics.shuffleReadMetrics.foreach { shuffleRead =>
- executorToShuffleRead(eid) =
- executorToShuffleRead.getOrElse(eid, 0L) + shuffleRead.remoteBytesRead
- }
- metrics.shuffleWriteMetrics.foreach { shuffleWrite =>
- executorToShuffleWrite(eid) =
- executorToShuffleWrite.getOrElse(eid, 0L) + shuffleWrite.shuffleBytesWritten
- }
+ taskSummary.inputBytes += metrics.inputMetrics.bytesRead
+ taskSummary.inputRecords += metrics.inputMetrics.recordsRead
+ taskSummary.outputBytes += metrics.outputMetrics.bytesWritten
+ taskSummary.outputRecords += metrics.outputMetrics.recordsWritten
+
+ taskSummary.shuffleRead += metrics.shuffleReadMetrics.remoteBytesRead
+ taskSummary.shuffleWrite += metrics.shuffleWriteMetrics.bytesWritten
+ taskSummary.jvmGCTime += metrics.jvmGCTime
}
}
}
+ private def updateExecutorBlacklist(
+ eid: String,
+ isBlacklisted: Boolean): Unit = {
+ val execTaskSummary = executorToTaskSummary.getOrElseUpdate(eid, ExecutorTaskSummary(eid))
+ execTaskSummary.isBlacklisted = isBlacklisted
+ }
+
+ override def onExecutorBlacklisted(
+ executorBlacklisted: SparkListenerExecutorBlacklisted)
+ : Unit = synchronized {
+ updateExecutorBlacklist(executorBlacklisted.executorId, true)
+ }
+
+ override def onExecutorUnblacklisted(
+ executorUnblacklisted: SparkListenerExecutorUnblacklisted)
+ : Unit = synchronized {
+ updateExecutorBlacklist(executorUnblacklisted.executorId, false)
+ }
+
+ override def onNodeBlacklisted(
+ nodeBlacklisted: SparkListenerNodeBlacklisted)
+ : Unit = synchronized {
+ // Implicitly blacklist every executor associated with this node, and show this in the UI.
+ activeStorageStatusList.foreach { status =>
+ if (status.blockManagerId.host == nodeBlacklisted.hostId) {
+ updateExecutorBlacklist(status.blockManagerId.executorId, true)
+ }
+ }
+ }
+
+ override def onNodeUnblacklisted(
+ nodeUnblacklisted: SparkListenerNodeUnblacklisted)
+ : Unit = synchronized {
+ // Implicitly unblacklist every executor associated with this node, regardless of how
+ // they may have been blacklisted initially (either explicitly through executor blacklisting
+ // or implicitly through node blacklisting). Show this in the UI.
+ activeStorageStatusList.foreach { status =>
+ if (status.blockManagerId.host == nodeUnblacklisted.hostId) {
+ updateExecutorBlacklist(status.blockManagerId.executorId, false)
+ }
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
index d467dd9e1f29..f2491cb07d68 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
@@ -17,15 +17,21 @@
package org.apache.spark.ui.jobs
+import java.net.URLEncoder
import java.util.Date
import javax.servlet.http.HttpServletRequest
+import scala.collection.JavaConverters._
import scala.collection.mutable.{HashMap, ListBuffer}
import scala.xml._
+import org.apache.commons.lang3.StringEscapeUtils
+
import org.apache.spark.JobExecutionStatus
-import org.apache.spark.ui.jobs.UIData.{ExecutorUIData, JobUIData}
-import org.apache.spark.ui.{ToolTips, UIUtils, WebUIPage}
+import org.apache.spark.scheduler._
+import org.apache.spark.ui._
+import org.apache.spark.ui.jobs.UIData.{JobUIData, StageUIData}
+import org.apache.spark.util.Utils
/** Page showing list of all ongoing and recently finished jobs */
private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
@@ -71,7 +77,12 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
val jobId = jobUIData.jobId
val status = jobUIData.status
val (jobName, jobDescription) = getLastStageNameAndDescription(jobUIData)
- val displayJobDescription = if (jobDescription.isEmpty) jobName else jobDescription
+ val displayJobDescription =
+ if (jobDescription.isEmpty) {
+ jobName
+ } else {
+ UIUtils.makeDescription(jobDescription, "", plainText = true).text
+ }
val submissionTime = jobUIData.submissionTime.get
val completionTimeOpt = jobUIData.completionTime
val completionTime = completionTimeOpt.getOrElse(System.currentTimeMillis())
@@ -82,9 +93,10 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
case JobExecutionStatus.UNKNOWN => "unknown"
}
- // The timeline library treats contents as HTML, so we have to escape them; for the
- // data-title attribute string we have to escape them twice since that's in a string.
+ // The timeline library treats contents as HTML, so we have to escape them. We need to add
+ // extra layers of escaping in order to embed this in a Javascript string literal.
val escapedDesc = Utility.escape(displayJobDescription)
+ val jsEscapedDesc = StringEscapeUtils.escapeEcmaScript(escapedDesc)
val jobEventJsonAsStr =
s"""
|{
@@ -94,7 +106,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
| 'end': new Date(${completionTime}),
| 'content': '' +
+ | 'data-title="${jsEscapedDesc} (Job ${jobId})
' +
| 'Status: ${status}
' +
| 'Submitted: ${UIUtils.formatDate(new Date(submissionTime))}' +
| '${
@@ -104,62 +116,62 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
""
}
}">' +
- | '${escapedDesc} (Job ${jobId})'
+ | '${jsEscapedDesc} (Job ${jobId})'
|}
""".stripMargin
jobEventJsonAsStr
}
}
- private def makeExecutorEvent(executorUIDatas: HashMap[String, ExecutorUIData]): Seq[String] = {
+ private def makeExecutorEvent(executorUIDatas: Seq[SparkListenerEvent]):
+ Seq[String] = {
val events = ListBuffer[String]()
executorUIDatas.foreach {
- case (executorId, event) =>
+ case a: SparkListenerExecutorAdded =>
val addedEvent =
s"""
|{
| 'className': 'executor added',
| 'group': 'executors',
- | 'start': new Date(${event.startTime}),
+ | 'start': new Date(${a.time}),
| 'content': '' +
- | 'Added at ${UIUtils.formatDate(new Date(event.startTime))}"' +
- | 'data-html="true">Executor ${executorId} added'
+ | 'data-title="Executor ${a.executorId}
' +
+ | 'Added at ${UIUtils.formatDate(new Date(a.time))}"' +
+ | 'data-html="true">Executor ${a.executorId} added'
|}
""".stripMargin
events += addedEvent
+ case e: SparkListenerExecutorRemoved =>
+ val removedEvent =
+ s"""
+ |{
+ | 'className': 'executor removed',
+ | 'group': 'executors',
+ | 'start': new Date(${e.time}),
+ | 'content': '' +
+ | 'Removed at ${UIUtils.formatDate(new Date(e.time))}' +
+ | '${
+ if (e.reason != null) {
+ s"""
Reason: ${e.reason.replace("\n", " ")}"""
+ } else {
+ ""
+ }
+ }"' +
+ | 'data-html="true">Executor ${e.executorId} removed'
+ |}
+ """.stripMargin
+ events += removedEvent
- if (event.finishTime.isDefined) {
- val removedEvent =
- s"""
- |{
- | 'className': 'executor removed',
- | 'group': 'executors',
- | 'start': new Date(${event.finishTime.get}),
- | 'content': '' +
- | 'Removed at ${UIUtils.formatDate(new Date(event.finishTime.get))}' +
- | '${
- if (event.finishReason.isDefined) {
- s"""
Reason: ${event.finishReason.get}"""
- } else {
- ""
- }
- }"' +
- | 'data-html="true">Executor ${executorId} removed'
- |}
- """.stripMargin
- events += removedEvent
- }
}
events.toSeq
}
private def makeTimeline(
jobs: Seq[JobUIData],
- executors: HashMap[String, ExecutorUIData],
+ executors: Seq[SparkListenerEvent],
startTime: Long): Seq[Node] = {
val jobEventJsonAsStrSeq = makeJobEvent(jobs)
@@ -198,67 +210,81 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
++
}
- private def jobsTable(jobs: Seq[JobUIData]): Seq[Node] = {
+ private def jobsTable(
+ request: HttpServletRequest,
+ tableHeaderId: String,
+ jobTag: String,
+ jobs: Seq[JobUIData],
+ killEnabled: Boolean): Seq[Node] = {
+ // stripXSS is called to remove suspicious characters used in XSS attacks
+ val allParameters = request.getParameterMap.asScala.toMap.map { case (k, v) =>
+ UIUtils.stripXSS(k) -> v.map(UIUtils.stripXSS).toSeq
+ }
+ val parameterOtherTable = allParameters.filterNot(_._1.startsWith(jobTag))
+ .map(para => para._1 + "=" + para._2(0))
+
val someJobHasJobGroup = jobs.exists(_.jobGroup.isDefined)
+ val jobIdTitle = if (someJobHasJobGroup) "Job Id (Job Group)" else "Job Id"
- val columns: Seq[Node] = {
- {if (someJobHasJobGroup) "Job Id (Job Group)" else "Job Id"}
- Description
- Submitted
- Duration
- Stages: Succeeded/Total
- Tasks (for all stages): Succeeded/Total
- }
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val parameterJobPage = UIUtils.stripXSS(request.getParameter(jobTag + ".page"))
+ val parameterJobSortColumn = UIUtils.stripXSS(request.getParameter(jobTag + ".sort"))
+ val parameterJobSortDesc = UIUtils.stripXSS(request.getParameter(jobTag + ".desc"))
+ val parameterJobPageSize = UIUtils.stripXSS(request.getParameter(jobTag + ".pageSize"))
+ val parameterJobPrevPageSize = UIUtils.stripXSS(request.getParameter(jobTag + ".prevPageSize"))
- def makeRow(job: JobUIData): Seq[Node] = {
- val (lastStageName, lastStageDescription) = getLastStageNameAndDescription(job)
- val duration: Option[Long] = {
- job.submissionTime.map { start =>
- val end = job.completionTime.getOrElse(System.currentTimeMillis())
- end - start
- }
+ val jobPage = Option(parameterJobPage).map(_.toInt).getOrElse(1)
+ val jobSortColumn = Option(parameterJobSortColumn).map { sortColumn =>
+ UIUtils.decodeURLParameter(sortColumn)
+ }.getOrElse(jobIdTitle)
+ val jobSortDesc = Option(parameterJobSortDesc).map(_.toBoolean).getOrElse(
+ // New jobs should be shown above old jobs by default.
+ if (jobSortColumn == jobIdTitle) true else false
+ )
+ val jobPageSize = Option(parameterJobPageSize).map(_.toInt).getOrElse(100)
+ val jobPrevPageSize = Option(parameterJobPrevPageSize).map(_.toInt).getOrElse(jobPageSize)
+
+ val page: Int = {
+ // If the user has changed to a larger page size, then go to page 1 in order to avoid
+ // IndexOutOfBoundsException.
+ if (jobPageSize <= jobPrevPageSize) {
+ jobPage
+ } else {
+ 1
}
- val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown")
- val formattedSubmissionTime = job.submissionTime.map(UIUtils.formatDate).getOrElse("Unknown")
- val jobDescription = UIUtils.makeDescription(lastStageDescription, parent.basePath)
-
- val detailUrl =
- "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), job.jobId)
-
-
- {job.jobId} {job.jobGroup.map(id => s"($id)").getOrElse("")}
-
-
- {jobDescription}
- {lastStageName}
-
-
- {formattedSubmissionTime}
-
- {formattedDuration}
-
- {job.completedStageIndices.size}/{job.stageIds.size - job.numSkippedStages}
- {if (job.numFailedStages > 0) s"(${job.numFailedStages} failed)"}
- {if (job.numSkippedStages > 0) s"(${job.numSkippedStages} skipped)"}
-
-
- {UIUtils.makeProgressBar(started = job.numActiveTasks, completed = job.numCompletedTasks,
- failed = job.numFailedTasks, skipped = job.numSkippedTasks,
- total = job.numTasks - job.numSkippedTasks)}
-
-
}
+ val currentTime = System.currentTimeMillis()
-
- {columns}
-
- {jobs.map(makeRow)}
-
-
+ try {
+ new JobPagedTable(
+ jobs,
+ tableHeaderId,
+ jobTag,
+ UIUtils.prependBaseUri(parent.basePath),
+ "jobs", // subPath
+ parameterOtherTable,
+ parent.jobProgresslistener.stageIdToInfo,
+ parent.jobProgresslistener.stageIdToData,
+ killEnabled,
+ currentTime,
+ jobIdTitle,
+ pageSize = jobPageSize,
+ sortColumn = jobSortColumn,
+ desc = jobSortDesc
+ ).table(page)
+ } catch {
+ case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) =>
+
+ Error while rendering job table:
+
+ {Utils.exceptionString(e)}
+
+
+ }
}
def render(request: HttpServletRequest): Seq[Node] = {
@@ -267,15 +293,15 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
val startTime = listener.startTime
val endTime = listener.endTime
val activeJobs = listener.activeJobs.values.toSeq
- val completedJobs = listener.completedJobs.reverse.toSeq
- val failedJobs = listener.failedJobs.reverse.toSeq
+ val completedJobs = listener.completedJobs.reverse
+ val failedJobs = listener.failedJobs.reverse
val activeJobsTable =
- jobsTable(activeJobs.sortBy(_.submissionTime.getOrElse(-1L)).reverse)
+ jobsTable(request, "active", "activeJob", activeJobs, killEnabled = parent.killEnabled)
val completedJobsTable =
- jobsTable(completedJobs.sortBy(_.completionTime.getOrElse(-1L)).reverse)
+ jobsTable(request, "completed", "completedJob", completedJobs, killEnabled = false)
val failedJobsTable =
- jobsTable(failedJobs.sortBy(_.completionTime.getOrElse(-1L)).reverse)
+ jobsTable(request, "failed", "failedJob", failedJobs, killEnabled = false)
val shouldShowActiveJobs = activeJobs.nonEmpty
val shouldShowCompletedJobs = completedJobs.nonEmpty
@@ -290,6 +316,10 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
val summary: NodeSeq =
+ -
+ User:
+ {parent.getSparkUser}
+
-
Total Uptime:
{
@@ -334,7 +364,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
var content = summary
val executorListener = parent.executorListener
content ++= makeTimeline(activeJobs ++ completedJobs ++ failedJobs,
- executorListener.executorIdToData, startTime)
+ executorListener.executorEvents, startTime)
if (shouldShowActiveJobs) {
content ++=
Active Jobs ({activeJobs.size})
++
@@ -356,3 +386,257 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
}
}
}
+
+private[ui] class JobTableRowData(
+ val jobData: JobUIData,
+ val lastStageName: String,
+ val lastStageDescription: String,
+ val duration: Long,
+ val formattedDuration: String,
+ val submissionTime: Long,
+ val formattedSubmissionTime: String,
+ val jobDescription: NodeSeq,
+ val detailUrl: String)
+
+private[ui] class JobDataSource(
+ jobs: Seq[JobUIData],
+ stageIdToInfo: HashMap[Int, StageInfo],
+ stageIdToData: HashMap[(Int, Int), StageUIData],
+ basePath: String,
+ currentTime: Long,
+ pageSize: Int,
+ sortColumn: String,
+ desc: Boolean) extends PagedDataSource[JobTableRowData](pageSize) {
+
+ // Convert JobUIData to JobTableRowData which contains the final contents to show in the table
+ // so that we can avoid creating duplicate contents during sorting the data
+ private val data = jobs.map(jobRow).sorted(ordering(sortColumn, desc))
+
+ private var _slicedJobIds: Set[Int] = null
+
+ override def dataSize: Int = data.size
+
+ override def sliceData(from: Int, to: Int): Seq[JobTableRowData] = {
+ val r = data.slice(from, to)
+ _slicedJobIds = r.map(_.jobData.jobId).toSet
+ r
+ }
+
+ private def getLastStageNameAndDescription(job: JobUIData): (String, String) = {
+ val lastStageInfo = Option(job.stageIds)
+ .filter(_.nonEmpty)
+ .flatMap { ids => stageIdToInfo.get(ids.max)}
+ val lastStageData = lastStageInfo.flatMap { s =>
+ stageIdToData.get((s.stageId, s.attemptId))
+ }
+ val name = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)")
+ val description = lastStageData.flatMap(_.description).getOrElse("")
+ (name, description)
+ }
+
+ private def jobRow(jobData: JobUIData): JobTableRowData = {
+ val (lastStageName, lastStageDescription) = getLastStageNameAndDescription(jobData)
+ val duration: Option[Long] = {
+ jobData.submissionTime.map { start =>
+ val end = jobData.completionTime.getOrElse(System.currentTimeMillis())
+ end - start
+ }
+ }
+ val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown")
+ val submissionTime = jobData.submissionTime
+ val formattedSubmissionTime = submissionTime.map(UIUtils.formatDate).getOrElse("Unknown")
+ val jobDescription = UIUtils.makeDescription(lastStageDescription, basePath, plainText = false)
+
+ val detailUrl = "%s/jobs/job?id=%s".format(basePath, jobData.jobId)
+
+ new JobTableRowData (
+ jobData,
+ lastStageName,
+ lastStageDescription,
+ duration.getOrElse(-1),
+ formattedDuration,
+ submissionTime.getOrElse(-1),
+ formattedSubmissionTime,
+ jobDescription,
+ detailUrl
+ )
+ }
+
+ /**
+ * Return Ordering according to sortColumn and desc
+ */
+ private def ordering(sortColumn: String, desc: Boolean): Ordering[JobTableRowData] = {
+ val ordering: Ordering[JobTableRowData] = sortColumn match {
+ case "Job Id" | "Job Id (Job Group)" => Ordering.by(_.jobData.jobId)
+ case "Description" => Ordering.by(x => (x.lastStageDescription, x.lastStageName))
+ case "Submitted" => Ordering.by(_.submissionTime)
+ case "Duration" => Ordering.by(_.duration)
+ case "Stages: Succeeded/Total" | "Tasks (for all stages): Succeeded/Total" =>
+ throw new IllegalArgumentException(s"Unsortable column: $sortColumn")
+ case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn")
+ }
+ if (desc) {
+ ordering.reverse
+ } else {
+ ordering
+ }
+ }
+
+}
+private[ui] class JobPagedTable(
+ data: Seq[JobUIData],
+ tableHeaderId: String,
+ jobTag: String,
+ basePath: String,
+ subPath: String,
+ parameterOtherTable: Iterable[String],
+ stageIdToInfo: HashMap[Int, StageInfo],
+ stageIdToData: HashMap[(Int, Int), StageUIData],
+ killEnabled: Boolean,
+ currentTime: Long,
+ jobIdTitle: String,
+ pageSize: Int,
+ sortColumn: String,
+ desc: Boolean
+ ) extends PagedTable[JobTableRowData] {
+ val parameterPath = basePath + s"/$subPath/?" + parameterOtherTable.mkString("&")
+
+ override def tableId: String = jobTag + "-table"
+
+ override def tableCssClass: String =
+ "table table-bordered table-condensed table-striped " +
+ "table-head-clickable table-cell-width-limited"
+
+ override def pageSizeFormField: String = jobTag + ".pageSize"
+
+ override def prevPageSizeFormField: String = jobTag + ".prevPageSize"
+
+ override def pageNumberFormField: String = jobTag + ".page"
+
+ override val dataSource = new JobDataSource(
+ data,
+ stageIdToInfo,
+ stageIdToData,
+ basePath,
+ currentTime,
+ pageSize,
+ sortColumn,
+ desc)
+
+ override def pageLink(page: Int): String = {
+ val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8")
+ parameterPath +
+ s"&$pageNumberFormField=$page" +
+ s"&$jobTag.sort=$encodedSortColumn" +
+ s"&$jobTag.desc=$desc" +
+ s"&$pageSizeFormField=$pageSize" +
+ s"#$tableHeaderId"
+ }
+
+ override def goButtonFormPath: String = {
+ val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8")
+ s"$parameterPath&$jobTag.sort=$encodedSortColumn&$jobTag.desc=$desc#$tableHeaderId"
+ }
+
+ override def headers: Seq[Node] = {
+ // Information for each header: title, cssClass, and sortable
+ val jobHeadersAndCssClasses: Seq[(String, String, Boolean)] =
+ Seq(
+ (jobIdTitle, "", true),
+ ("Description", "", true), ("Submitted", "", true), ("Duration", "", true),
+ ("Stages: Succeeded/Total", "", false),
+ ("Tasks (for all stages): Succeeded/Total", "", false)
+ )
+
+ if (!jobHeadersAndCssClasses.filter(_._3).map(_._1).contains(sortColumn)) {
+ throw new IllegalArgumentException(s"Unknown column: $sortColumn")
+ }
+
+ val headerRow: Seq[Node] = {
+ jobHeadersAndCssClasses.map { case (header, cssClass, sortable) =>
+ if (header == sortColumn) {
+ val headerLink = Unparsed(
+ parameterPath +
+ s"&$jobTag.sort=${URLEncoder.encode(header, "UTF-8")}" +
+ s"&$jobTag.desc=${!desc}" +
+ s"&$jobTag.pageSize=$pageSize" +
+ s"#$tableHeaderId")
+ val arrow = if (desc) "▾" else "▴" // UP or DOWN
+
+
+
+ {header}
+ {Unparsed(arrow)}
+
+
+
+ } else {
+ if (sortable) {
+ val headerLink = Unparsed(
+ parameterPath +
+ s"&$jobTag.sort=${URLEncoder.encode(header, "UTF-8")}" +
+ s"&$jobTag.pageSize=$pageSize" +
+ s"#$tableHeaderId")
+
+
+
+ {header}
+
+
+ } else {
+
+ {header}
+
+ }
+ }
+ }
+ }
+ {headerRow}
+ }
+
+ override def row(jobTableRow: JobTableRowData): Seq[Node] = {
+ val job = jobTableRow.jobData
+
+ val killLink = if (killEnabled) {
+ val confirm =
+ s"if (window.confirm('Are you sure you want to kill job ${job.jobId} ?')) " +
+ "{ this.parentNode.submit(); return true; } else { return false; }"
+ // SPARK-6846 this should be POST-only but YARN AM won't proxy POST
+ /*
+ val killLinkUri = s"$basePathUri/jobs/job/kill/"
+
+ */
+ val killLinkUri = s"$basePath/jobs/job/kill/?id=${job.jobId}"
+ (kill)
+ } else {
+ Seq.empty
+ }
+
+
+
+ {job.jobId} {job.jobGroup.map(id => s"($id)").getOrElse("")}
+
+
+ {jobTableRow.jobDescription} {killLink}
+ {jobTableRow.lastStageName}
+
+
+ {jobTableRow.formattedSubmissionTime}
+
+ {jobTableRow.formattedDuration}
+
+ {job.completedStageIndices.size}/{job.stageIds.size - job.numSkippedStages}
+ {if (job.numFailedStages > 0) s"(${job.numFailedStages} failed)"}
+ {if (job.numSkippedStages > 0) s"(${job.numSkippedStages} skipped)"}
+
+
+ {UIUtils.makeProgressBar(started = job.numActiveTasks, completed = job.numCompletedTasks,
+ failed = job.numFailedTasks, skipped = job.numSkippedTasks,
+ reasonToNumKilled = job.reasonToNumKilled, total = job.numTasks - job.numSkippedTasks)}
+
+
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala
index 5e52942b64f3..2b0816e35747 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala
@@ -22,7 +22,7 @@ import javax.servlet.http.HttpServletRequest
import scala.xml.{Node, NodeSeq}
import org.apache.spark.scheduler.Schedulable
-import org.apache.spark.ui.{WebUIPage, UIUtils}
+import org.apache.spark.ui.{UIUtils, WebUIPage}
/** Page showing list of all ongoing and recently finished stages and pools */
private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") {
@@ -34,26 +34,28 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") {
listener.synchronized {
val activeStages = listener.activeStages.values.toSeq
val pendingStages = listener.pendingStages.values.toSeq
- val completedStages = listener.completedStages.reverse.toSeq
+ val completedStages = listener.completedStages.reverse
val numCompletedStages = listener.numCompletedStages
- val failedStages = listener.failedStages.reverse.toSeq
+ val failedStages = listener.failedStages.reverse
val numFailedStages = listener.numFailedStages
- val now = System.currentTimeMillis
+ val subPath = "stages"
val activeStagesTable =
- new StageTableBase(activeStages.sortBy(_.submissionTime).reverse,
- parent.basePath, parent.progressListener, isFairScheduler = parent.isFairScheduler,
- killEnabled = parent.killEnabled)
+ new StageTableBase(request, activeStages, "active", "activeStage", parent.basePath, subPath,
+ parent.progressListener, parent.isFairScheduler,
+ killEnabled = parent.killEnabled, isFailedStage = false)
val pendingStagesTable =
- new StageTableBase(pendingStages.sortBy(_.submissionTime).reverse,
- parent.basePath, parent.progressListener, isFairScheduler = parent.isFairScheduler,
- killEnabled = false)
+ new StageTableBase(request, pendingStages, "pending", "pendingStage", parent.basePath,
+ subPath, parent.progressListener, parent.isFairScheduler,
+ killEnabled = false, isFailedStage = false)
val completedStagesTable =
- new StageTableBase(completedStages.sortBy(_.submissionTime).reverse, parent.basePath,
- parent.progressListener, isFairScheduler = parent.isFairScheduler, killEnabled = false)
+ new StageTableBase(request, completedStages, "completed", "completedStage", parent.basePath,
+ subPath, parent.progressListener, parent.isFairScheduler,
+ killEnabled = false, isFailedStage = false)
val failedStagesTable =
- new FailedStageTable(failedStages.sortBy(_.submissionTime).reverse, parent.basePath,
- parent.progressListener, isFairScheduler = parent.isFairScheduler)
+ new StageTableBase(request, failedStages, "failed", "failedStage", parent.basePath, subPath,
+ parent.progressListener, parent.isFairScheduler,
+ killEnabled = false, isFailedStage = true)
// For now, pool information is only accessible in live UIs
val pools = sc.map(_.getAllPools).getOrElse(Seq.empty[Schedulable])
@@ -136,3 +138,4 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") {
}
}
}
+
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
index be144f6065ba..382a6f979f2e 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
@@ -18,7 +18,7 @@
package org.apache.spark.ui.jobs
import scala.collection.mutable
-import scala.xml.Node
+import scala.xml.{Node, Unparsed}
import org.apache.spark.ui.{ToolTips, UIUtils}
import org.apache.spark.ui.jobs.UIData.StageUIData
@@ -42,21 +42,22 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage
var hasShuffleWrite = false
var hasShuffleRead = false
var hasBytesSpilled = false
- stageData.foreach(data => {
+ stageData.foreach { data =>
hasInput = data.hasInput
hasOutput = data.hasOutput
hasShuffleRead = data.hasShuffleRead
hasShuffleWrite = data.hasShuffleWrite
hasBytesSpilled = data.hasBytesSpilled
- })
+ }
- Executor ID
+ Executor ID
Address
Task Time
Total Tasks
Failed Tasks
+ Killed Tasks
Succeeded Tasks
{if (hasInput) {
@@ -84,11 +85,25 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage
Shuffle Spill (Memory)
Shuffle Spill (Disk)
}}
+
+
+ Blacklisted
+
+
{createExecutorTable()}
+
}
private def createExecutorTable() : Seq[Node] = {
@@ -104,11 +119,23 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage
case Some(stageData: StageUIData) =>
stageData.executorSummary.toSeq.sortBy(_._1).map { case (k, v) =>
- {k}
+
+ {k}
+
+ {
+ val logs = parent.executorsListener.executorToTaskSummary.get(k)
+ .map(_.executorLogs).getOrElse(Map.empty)
+ logs.map {
+ case (logName, logUrl) =>
+ }
+ }
+
+
{executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")}
{UIUtils.formatDuration(v.taskTime)}
- {v.failedTasks + v.succeededTasks}
+ {v.failedTasks + v.succeededTasks + v.reasonToNumKilled.values.sum}
{v.failedTasks}
+ {v.reasonToNumKilled.values.sum}
{v.succeededTasks}
{if (stageData.hasInput) {
@@ -138,6 +165,7 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage
{Utils.bytesToString(v.diskBytesSpilled)}
}}
+ {v.isBlacklisted}
}
case None =>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
index 2cad0a796913..9fb011a049b7 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
@@ -17,17 +17,17 @@
package org.apache.spark.ui.jobs
-import java.util.Date
+import java.util.{Date, Locale}
+import javax.servlet.http.HttpServletRequest
-import scala.collection.mutable.{Buffer, HashMap, ListBuffer}
-import scala.xml.{NodeSeq, Node, Unparsed, Utility}
+import scala.collection.mutable.{Buffer, ListBuffer}
+import scala.xml.{Node, NodeSeq, Unparsed, Utility}
-import javax.servlet.http.HttpServletRequest
+import org.apache.commons.lang3.StringEscapeUtils
import org.apache.spark.JobExecutionStatus
-import org.apache.spark.scheduler.StageInfo
+import org.apache.spark.scheduler._
import org.apache.spark.ui.{ToolTips, UIUtils, WebUIPage}
-import org.apache.spark.ui.jobs.UIData.ExecutorUIData
/** Page showing statistics and stage list for a given job */
private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") {
@@ -64,9 +64,10 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") {
val submissionTime = stage.submissionTime.get
val completionTime = stage.completionTime.getOrElse(System.currentTimeMillis())
- // The timeline library treats contents as HTML, so we have to escape them; for the
- // data-title attribute string we have to escape them twice since that's in a string.
+ // The timeline library treats contents as HTML, so we have to escape them. We need to add
+ // extra layers of escaping in order to embed this in a Javascript string literal.
val escapedName = Utility.escape(name)
+ val jsEscapedName = StringEscapeUtils.escapeEcmaScript(escapedName)
s"""
|{
| 'className': 'stage job-timeline-object ${status}',
@@ -75,8 +76,8 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") {
| 'end': new Date(${completionTime}),
| 'content': '' +
- | 'Status: ${status.toUpperCase}
' +
+ | 'data-title="${jsEscapedName} (Stage ${stageId}.${attemptId})
' +
+ | 'Status: ${status.toUpperCase(Locale.ROOT)}
' +
| 'Submitted: ${UIUtils.formatDate(new Date(submissionTime))}' +
| '${
if (status != "running") {
@@ -85,61 +86,61 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") {
""
}
}">' +
- | '${escapedName} (Stage ${stageId}.${attemptId})',
+ | '${jsEscapedName} (Stage ${stageId}.${attemptId})
',
|}
""".stripMargin
}
}
- def makeExecutorEvent(executorUIDatas: HashMap[String, ExecutorUIData]): Seq[String] = {
+ def makeExecutorEvent(executorUIDatas: Seq[SparkListenerEvent]): Seq[String] = {
val events = ListBuffer[String]()
executorUIDatas.foreach {
- case (executorId, event) =>
+ case a: SparkListenerExecutorAdded =>
val addedEvent =
s"""
|{
| 'className': 'executor added',
| 'group': 'executors',
- | 'start': new Date(${event.startTime}),
+ | 'start': new Date(${a.time}),
| 'content': '' +
- | 'Added at ${UIUtils.formatDate(new Date(event.startTime))}"' +
- | 'data-html="true">Executor ${executorId} added'
+ | 'data-title="Executor ${a.executorId}
' +
+ | 'Added at ${UIUtils.formatDate(new Date(a.time))}"' +
+ | 'data-html="true">Executor ${a.executorId} added'
|}
""".stripMargin
events += addedEvent
- if (event.finishTime.isDefined) {
- val removedEvent =
- s"""
- |{
- | 'className': 'executor removed',
- | 'group': 'executors',
- | 'start': new Date(${event.finishTime.get}),
- | 'content': '' +
- | 'Removed at ${UIUtils.formatDate(new Date(event.finishTime.get))}' +
- | '${
- if (event.finishReason.isDefined) {
- s"""
Reason: ${event.finishReason.get}"""
- } else {
- ""
- }
- }"' +
- | 'data-html="true">Executor ${executorId} removed'
- |}
- """.stripMargin
- events += removedEvent
- }
+ case e: SparkListenerExecutorRemoved =>
+ val removedEvent =
+ s"""
+ |{
+ | 'className': 'executor removed',
+ | 'group': 'executors',
+ | 'start': new Date(${e.time}),
+ | 'content': '' +
+ | 'Removed at ${UIUtils.formatDate(new Date(e.time))}' +
+ | '${
+ if (e.reason != null) {
+ s"""
Reason: ${e.reason.replace("\n", " ")}"""
+ } else {
+ ""
+ }
+ }"' +
+ | 'data-html="true">Executor ${e.executorId} removed'
+ |}
+ """.stripMargin
+ events += removedEvent
+
}
events.toSeq
}
private def makeTimeline(
stages: Seq[StageInfo],
- executors: HashMap[String, ExecutorUIData],
+ executors: Seq[SparkListenerEvent],
appStartTime: Long): Seq[Node] = {
val stageEventJsonAsStrSeq = makeStageEvent(stages)
@@ -177,7 +178,8 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") {
++
}
@@ -185,7 +187,8 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") {
val listener = parent.jobProgresslistener
listener.synchronized {
- val parameterId = request.getParameter("id")
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")
val jobId = parameterId.toInt
@@ -226,20 +229,31 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") {
}
}
+ val basePath = "jobs/job"
+
+ val pendingOrSkippedTableId =
+ if (isComplete) {
+ "pending"
+ } else {
+ "skipped"
+ }
+
val activeStagesTable =
- new StageTableBase(activeStages.sortBy(_.submissionTime).reverse,
- parent.basePath, parent.jobProgresslistener, isFairScheduler = parent.isFairScheduler,
- killEnabled = parent.killEnabled)
+ new StageTableBase(request, activeStages, "active", "activeStage", parent.basePath,
+ basePath, parent.jobProgresslistener, parent.isFairScheduler,
+ killEnabled = parent.killEnabled, isFailedStage = false)
val pendingOrSkippedStagesTable =
- new StageTableBase(pendingOrSkippedStages.sortBy(_.stageId).reverse,
- parent.basePath, parent.jobProgresslistener, isFairScheduler = parent.isFairScheduler,
- killEnabled = false)
+ new StageTableBase(request, pendingOrSkippedStages, pendingOrSkippedTableId, "pendingStage",
+ parent.basePath, basePath, parent.jobProgresslistener, parent.isFairScheduler,
+ killEnabled = false, isFailedStage = false)
val completedStagesTable =
- new StageTableBase(completedStages.sortBy(_.submissionTime).reverse, parent.basePath,
- parent.jobProgresslistener, isFairScheduler = parent.isFairScheduler, killEnabled = false)
+ new StageTableBase(request, completedStages, "completed", "completedStage", parent.basePath,
+ basePath, parent.jobProgresslistener, parent.isFairScheduler,
+ killEnabled = false, isFailedStage = false)
val failedStagesTable =
- new FailedStageTable(failedStages.sortBy(_.submissionTime).reverse, parent.basePath,
- parent.jobProgresslistener, isFairScheduler = parent.isFairScheduler)
+ new StageTableBase(request, failedStages, "failed", "failedStage", parent.basePath,
+ basePath, parent.jobProgresslistener, parent.isFairScheduler,
+ killEnabled = false, isFailedStage = true)
val shouldShowActiveStages = activeStages.nonEmpty
val shouldShowPendingStages = !isComplete && pendingOrSkippedStages.nonEmpty
@@ -312,7 +326,7 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") {
val operationGraphListener = parent.operationGraphListener
content ++= makeTimeline(activeStages ++ completedStages ++ failedStages,
- executorListener.executorIdToData, appStartTime)
+ executorListener.executorEvents, appStartTime)
content ++= UIUtils.showDagVizForJob(
jobId, operationGraphListener.getOperationGraphForJob(jobId))
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index 77d034fa5ba2..7370f9feb68c 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -19,13 +19,13 @@ package org.apache.spark.ui.jobs
import java.util.concurrent.TimeoutException
-import scala.collection.mutable.{HashMap, HashSet, ListBuffer}
-
-import com.google.common.annotations.VisibleForTesting
+import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap, ListBuffer}
import org.apache.spark._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.storage.BlockManagerId
@@ -41,6 +41,7 @@ import org.apache.spark.ui.jobs.UIData._
* updating the internal data structures concurrently.
*/
@DeveloperApi
+@deprecated("This class will be removed in a future release.", "2.2.0")
class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
// Define a handful of type aliases so that data structures' types can serve as documentation.
@@ -94,6 +95,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
val retainedStages = conf.getInt("spark.ui.retainedStages", SparkUI.DEFAULT_RETAINED_STAGES)
val retainedJobs = conf.getInt("spark.ui.retainedJobs", SparkUI.DEFAULT_RETAINED_JOBS)
+ val retainedTasks = conf.get(UI_RETAINED_TASKS)
// We can test for memory leaks by ensuring that collections that track non-active jobs and
// stages do not grow without bound and that collections for active jobs/stages eventually become
@@ -141,7 +143,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
/** If stages is too large, remove and garbage collect old stages */
private def trimStagesIfNecessary(stages: ListBuffer[StageInfo]) = synchronized {
if (stages.size > retainedStages) {
- val toRemove = math.max(retainedStages / 10, 1)
+ val toRemove = calculateNumberToRemove(stages.size, retainedStages)
stages.take(toRemove).foreach { s =>
stageIdToData.remove((s.stageId, s.attemptId))
stageIdToInfo.remove(s.stageId)
@@ -153,7 +155,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
/** If jobs is too large, remove and garbage collect old jobs */
private def trimJobsIfNecessary(jobs: ListBuffer[JobUIData]) = synchronized {
if (jobs.size > retainedJobs) {
- val toRemove = math.max(retainedJobs / 10, 1)
+ val toRemove = calculateNumberToRemove(jobs.size, retainedJobs)
jobs.take(toRemove).foreach { job =>
// Remove the job's UI data, if it exists
jobIdToData.remove(job.jobId).foreach { removedJob =>
@@ -225,7 +227,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
trimJobsIfNecessary(completedJobs)
jobData.status = JobExecutionStatus.SUCCEEDED
numCompletedJobs += 1
- case JobFailed(exception) =>
+ case JobFailed(_) =>
failedJobs += jobData
trimJobsIfNecessary(failedJobs)
jobData.status = JobExecutionStatus.FAILED
@@ -283,7 +285,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
) {
jobData.numActiveStages -= 1
if (stage.failureReason.isEmpty) {
- if (!stage.submissionTime.isEmpty) {
+ if (stage.submissionTime.isDefined) {
jobData.completedStageIndices.add(stage.stageId)
}
} else {
@@ -332,7 +334,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
new StageUIData
})
stageData.numActiveTasks += 1
- stageData.taskData.put(taskInfo.taskId, new TaskUIData(taskInfo))
+ stageData.taskData.put(taskInfo.taskId, TaskUIData(taskInfo))
}
for (
activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskStart.stageId);
@@ -369,36 +371,50 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
taskEnd.reason match {
case Success =>
execSummary.succeededTasks += 1
+ case kill: TaskKilled =>
+ execSummary.reasonToNumKilled = execSummary.reasonToNumKilled.updated(
+ kill.reason, execSummary.reasonToNumKilled.getOrElse(kill.reason, 0) + 1)
case _ =>
execSummary.failedTasks += 1
}
execSummary.taskTime += info.duration
stageData.numActiveTasks -= 1
- val (errorMessage, metrics): (Option[String], Option[TaskMetrics]) =
+ val errorMessage: Option[String] =
taskEnd.reason match {
case org.apache.spark.Success =>
stageData.completedIndices.add(info.index)
stageData.numCompleteTasks += 1
- (None, Option(taskEnd.taskMetrics))
- case e: ExceptionFailure => // Handle ExceptionFailure because we might have metrics
+ None
+ case kill: TaskKilled =>
+ stageData.reasonToNumKilled = stageData.reasonToNumKilled.updated(
+ kill.reason, stageData.reasonToNumKilled.getOrElse(kill.reason, 0) + 1)
+ Some(kill.toErrorString)
+ case e: ExceptionFailure => // Handle ExceptionFailure because we might have accumUpdates
stageData.numFailedTasks += 1
- (Some(e.toErrorString), e.metrics)
- case e: TaskFailedReason => // All other failure cases
+ Some(e.toErrorString)
+ case e: TaskFailedReason => // All other failure cases
stageData.numFailedTasks += 1
- (Some(e.toErrorString), None)
+ Some(e.toErrorString)
}
- if (!metrics.isEmpty) {
- val oldMetrics = stageData.taskData.get(info.taskId).flatMap(_.taskMetrics)
- updateAggregateMetrics(stageData, info.executorId, metrics.get, oldMetrics)
+ val taskMetrics = Option(taskEnd.taskMetrics)
+ taskMetrics.foreach { m =>
+ val oldMetrics = stageData.taskData.get(info.taskId).flatMap(_.metrics)
+ updateAggregateMetrics(stageData, info.executorId, m, oldMetrics)
}
- val taskData = stageData.taskData.getOrElseUpdate(info.taskId, new TaskUIData(info))
- taskData.taskInfo = info
- taskData.taskMetrics = metrics
+ val taskData = stageData.taskData.getOrElseUpdate(info.taskId, TaskUIData(info))
+ taskData.updateTaskInfo(info)
+ taskData.updateTaskMetrics(taskMetrics)
taskData.errorMessage = errorMessage
+ // If Tasks is too large, remove and garbage collect old tasks
+ if (stageData.taskData.size > retainedTasks) {
+ stageData.taskData = stageData.taskData.drop(
+ calculateNumberToRemove(stageData.taskData.size, retainedTasks))
+ }
+
for (
activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskEnd.stageId);
jobId <- activeJobsDependentOnStage;
@@ -408,6 +424,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
taskEnd.reason match {
case Success =>
jobData.numCompletedTasks += 1
+ case kill: TaskKilled =>
+ jobData.reasonToNumKilled = jobData.reasonToNumKilled.updated(
+ kill.reason, jobData.reasonToNumKilled.getOrElse(kill.reason, 0) + 1)
case _ =>
jobData.numFailedTasks += 1
}
@@ -415,6 +434,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
}
}
+ /**
+ * Remove at least (maxRetained / 10) items to reduce friction.
+ */
+ private def calculateNumberToRemove(dataSize: Int, retainedSize: Int): Int = {
+ math.max(retainedSize / 10, dataSize - retainedSize)
+ }
+
/**
* Upon receiving new metrics for a task, updates the per-stage and per-executor-per-stage
* aggregate metrics by calculating deltas between the currently recorded metrics and the new
@@ -424,54 +450,54 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
stageData: StageUIData,
execId: String,
taskMetrics: TaskMetrics,
- oldMetrics: Option[TaskMetrics]) {
+ oldMetrics: Option[TaskMetricsUIData]) {
val execSummary = stageData.executorSummary.getOrElseUpdate(execId, new ExecutorSummary)
val shuffleWriteDelta =
- (taskMetrics.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L)
- - oldMetrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleBytesWritten).getOrElse(0L))
+ taskMetrics.shuffleWriteMetrics.bytesWritten -
+ oldMetrics.map(_.shuffleWriteMetrics.bytesWritten).getOrElse(0L)
stageData.shuffleWriteBytes += shuffleWriteDelta
execSummary.shuffleWrite += shuffleWriteDelta
val shuffleWriteRecordsDelta =
- (taskMetrics.shuffleWriteMetrics.map(_.shuffleRecordsWritten).getOrElse(0L)
- - oldMetrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleRecordsWritten).getOrElse(0L))
+ taskMetrics.shuffleWriteMetrics.recordsWritten -
+ oldMetrics.map(_.shuffleWriteMetrics.recordsWritten).getOrElse(0L)
stageData.shuffleWriteRecords += shuffleWriteRecordsDelta
execSummary.shuffleWriteRecords += shuffleWriteRecordsDelta
val shuffleReadDelta =
- (taskMetrics.shuffleReadMetrics.map(_.totalBytesRead).getOrElse(0L)
- - oldMetrics.flatMap(_.shuffleReadMetrics).map(_.totalBytesRead).getOrElse(0L))
+ taskMetrics.shuffleReadMetrics.totalBytesRead -
+ oldMetrics.map(_.shuffleReadMetrics.totalBytesRead).getOrElse(0L)
stageData.shuffleReadTotalBytes += shuffleReadDelta
execSummary.shuffleRead += shuffleReadDelta
val shuffleReadRecordsDelta =
- (taskMetrics.shuffleReadMetrics.map(_.recordsRead).getOrElse(0L)
- - oldMetrics.flatMap(_.shuffleReadMetrics).map(_.recordsRead).getOrElse(0L))
+ taskMetrics.shuffleReadMetrics.recordsRead -
+ oldMetrics.map(_.shuffleReadMetrics.recordsRead).getOrElse(0L)
stageData.shuffleReadRecords += shuffleReadRecordsDelta
execSummary.shuffleReadRecords += shuffleReadRecordsDelta
val inputBytesDelta =
- (taskMetrics.inputMetrics.map(_.bytesRead).getOrElse(0L)
- - oldMetrics.flatMap(_.inputMetrics).map(_.bytesRead).getOrElse(0L))
+ taskMetrics.inputMetrics.bytesRead -
+ oldMetrics.map(_.inputMetrics.bytesRead).getOrElse(0L)
stageData.inputBytes += inputBytesDelta
execSummary.inputBytes += inputBytesDelta
val inputRecordsDelta =
- (taskMetrics.inputMetrics.map(_.recordsRead).getOrElse(0L)
- - oldMetrics.flatMap(_.inputMetrics).map(_.recordsRead).getOrElse(0L))
+ taskMetrics.inputMetrics.recordsRead -
+ oldMetrics.map(_.inputMetrics.recordsRead).getOrElse(0L)
stageData.inputRecords += inputRecordsDelta
execSummary.inputRecords += inputRecordsDelta
val outputBytesDelta =
- (taskMetrics.outputMetrics.map(_.bytesWritten).getOrElse(0L)
- - oldMetrics.flatMap(_.outputMetrics).map(_.bytesWritten).getOrElse(0L))
+ taskMetrics.outputMetrics.bytesWritten -
+ oldMetrics.map(_.outputMetrics.bytesWritten).getOrElse(0L)
stageData.outputBytes += outputBytesDelta
execSummary.outputBytes += outputBytesDelta
val outputRecordsDelta =
- (taskMetrics.outputMetrics.map(_.recordsWritten).getOrElse(0L)
- - oldMetrics.flatMap(_.outputMetrics).map(_.recordsWritten).getOrElse(0L))
+ taskMetrics.outputMetrics.recordsWritten -
+ oldMetrics.map(_.outputMetrics.recordsWritten).getOrElse(0L)
stageData.outputRecords += outputRecordsDelta
execSummary.outputRecords += outputRecordsDelta
@@ -488,22 +514,25 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
val timeDelta =
taskMetrics.executorRunTime - oldMetrics.map(_.executorRunTime).getOrElse(0L)
stageData.executorRunTime += timeDelta
+
+ val cpuTimeDelta =
+ taskMetrics.executorCpuTime - oldMetrics.map(_.executorCpuTime).getOrElse(0L)
+ stageData.executorCpuTime += cpuTimeDelta
}
override def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) {
- for ((taskId, sid, sAttempt, taskMetrics) <- executorMetricsUpdate.taskMetrics) {
+ for ((taskId, sid, sAttempt, accumUpdates) <- executorMetricsUpdate.accumUpdates) {
val stageData = stageIdToData.getOrElseUpdate((sid, sAttempt), {
logWarning("Metrics update for task in unknown stage " + sid)
new StageUIData
})
val taskData = stageData.taskData.get(taskId)
- taskData.map { t =>
+ val metrics = TaskMetrics.fromAccumulatorInfos(accumUpdates)
+ taskData.foreach { t =>
if (!t.taskInfo.finished) {
- updateAggregateMetrics(stageData, executorMetricsUpdate.execId, taskMetrics,
- t.taskMetrics)
-
+ updateAggregateMetrics(stageData, executorMetricsUpdate.execId, metrics, t.metrics)
// Overwrite task metrics
- t.taskMetrics = Some(taskMetrics)
+ t.updateTaskMetrics(Some(metrics))
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
index 77ca60b000a9..cc173381879a 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
@@ -17,8 +17,10 @@
package org.apache.spark.ui.jobs
+import javax.servlet.http.HttpServletRequest
+
import org.apache.spark.scheduler.SchedulingMode
-import org.apache.spark.ui.{SparkUI, SparkUITab}
+import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils}
/** Web UI showing progress status of all jobs in the given SparkContext. */
private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") {
@@ -29,8 +31,26 @@ private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") {
val operationGraphListener = parent.operationGraphListener
def isFairScheduler: Boolean =
- jobProgresslistener.schedulingMode.exists(_ == SchedulingMode.FAIR)
+ jobProgresslistener.schedulingMode == Some(SchedulingMode.FAIR)
+
+ def getSparkUser: String = parent.getSparkUser
attachPage(new AllJobsPage(this))
attachPage(new JobPage(this))
+
+ def handleKillRequest(request: HttpServletRequest): Unit = {
+ if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) {
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val jobId = Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt)
+ jobId.foreach { id =>
+ if (jobProgresslistener.activeJobs.contains(id)) {
+ sc.foreach(_.cancelJob(id))
+ // Do a quick pause here to give Spark time to kill the job so it shows up as
+ // killed after the refresh. Note that this will block the serving thread so the
+ // time should be limited in duration.
+ Thread.sleep(100)
+ }
+ }
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
index f3e0b38523f3..b164f32b62e9 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
@@ -22,7 +22,7 @@ import javax.servlet.http.HttpServletRequest
import scala.xml.Node
import org.apache.spark.scheduler.StageInfo
-import org.apache.spark.ui.{WebUIPage, UIUtils}
+import org.apache.spark.ui.{UIUtils, WebUIPage}
/** Page showing specific pool details */
private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") {
@@ -31,25 +31,34 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") {
def render(request: HttpServletRequest): Seq[Node] = {
listener.synchronized {
- val poolName = request.getParameter("poolname")
- require(poolName != null && poolName.nonEmpty, "Missing poolname parameter")
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val poolName = Option(UIUtils.stripXSS(request.getParameter("poolname"))).map { poolname =>
+ UIUtils.decodeURLParameter(poolname)
+ }.getOrElse {
+ throw new IllegalArgumentException(s"Missing poolname parameter")
+ }
val poolToActiveStages = listener.poolToActiveStages
val activeStages = poolToActiveStages.get(poolName) match {
case Some(s) => s.values.toSeq
case None => Seq[StageInfo]()
}
- val activeStagesTable = new StageTableBase(activeStages.sortBy(_.submissionTime).reverse,
- parent.basePath, parent.progressListener, isFairScheduler = parent.isFairScheduler,
- killEnabled = parent.killEnabled)
+ val shouldShowActiveStages = activeStages.nonEmpty
+ val activeStagesTable =
+ new StageTableBase(request, activeStages, "", "activeStage", parent.basePath, "stages/pool",
+ parent.progressListener, parent.isFairScheduler, parent.killEnabled,
+ isFailedStage = false)
// For now, pool information is only accessible in live UIs
- val pools = sc.map(_.getPoolForName(poolName).get).toSeq
+ val pools = sc.map(_.getPoolForName(poolName).getOrElse {
+ throw new IllegalArgumentException(s"Unknown poolname: $poolName")
+ }).toSeq
val poolTable = new PoolTable(pools, parent)
- val content =
- Summary
++ poolTable.toNodeSeq ++
- {activeStages.size} Active Stages
++ activeStagesTable.toNodeSeq
+ var content = Summary
++ poolTable.toNodeSeq
+ if (shouldShowActiveStages) {
+ content ++= {activeStages.size} Active Stages
++ activeStagesTable.toNodeSeq
+ }
UIUtils.headerSparkPage("Fair Scheduler Pool: " + poolName, content, parent)
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
index 9ba2af54dacf..ea02968733ca 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ui.jobs
+import java.net.URLEncoder
+
import scala.collection.mutable.HashMap
import scala.xml.Node
@@ -59,7 +61,7 @@ private[ui] class PoolTable(pools: Seq[Schedulable], parent: StagesTab) {
case None => 0
}
val href = "%s/stages/pool?poolname=%s"
- .format(UIUtils.prependBaseUri(parent.basePath), p.name)
+ .format(UIUtils.prependBaseUri(parent.basePath), URLEncoder.encode(p.name, "UTF-8"))
{p.name}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 51425e599e74..6b3dadc33331 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -26,12 +26,13 @@ import scala.xml.{Elem, Node, Unparsed}
import org.apache.commons.lang3.StringEscapeUtils
-import org.apache.spark.{InternalAccumulator, SparkConf}
+import org.apache.spark.SparkConf
import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo}
+import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo, TaskLocality}
import org.apache.spark.ui._
+import org.apache.spark.ui.exec.ExecutorsListener
import org.apache.spark.ui.jobs.UIData._
-import org.apache.spark.util.{Utils, Distribution}
+import org.apache.spark.util.{Distribution, Utils}
/** Page showing statistics and task list for a given stage */
private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
@@ -39,6 +40,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
private val progressListener = parent.progressListener
private val operationGraphListener = parent.operationGraphListener
+ private val executorsListener = parent.executorsListener
private val TIMELINE_LEGEND = {
@@ -68,29 +70,43 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
// if we find that it's okay.
private val MAX_TIMELINE_TASKS = parent.conf.getInt("spark.ui.timeline.tasks.maximum", 1000)
- private val displayPeakExecutionMemory = parent.conf.getBoolean("spark.sql.unsafe.enabled", true)
+ private def getLocalitySummaryString(stageData: StageUIData): String = {
+ val localities = stageData.taskData.values.map(_.taskInfo.taskLocality)
+ val localityCounts = localities.groupBy(identity).mapValues(_.size)
+ val localityNamesAndCounts = localityCounts.toSeq.map { case (locality, count) =>
+ val localityName = locality match {
+ case TaskLocality.PROCESS_LOCAL => "Process local"
+ case TaskLocality.NODE_LOCAL => "Node local"
+ case TaskLocality.RACK_LOCAL => "Rack local"
+ case TaskLocality.ANY => "Any"
+ }
+ s"$localityName: $count"
+ }
+ localityNamesAndCounts.sorted.mkString("; ")
+ }
def render(request: HttpServletRequest): Seq[Node] = {
progressListener.synchronized {
- val parameterId = request.getParameter("id")
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")
- val parameterAttempt = request.getParameter("attempt")
+ val parameterAttempt = UIUtils.stripXSS(request.getParameter("attempt"))
require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter")
- val parameterTaskPage = request.getParameter("task.page")
- val parameterTaskSortColumn = request.getParameter("task.sort")
- val parameterTaskSortDesc = request.getParameter("task.desc")
- val parameterTaskPageSize = request.getParameter("task.pageSize")
+ val parameterTaskPage = UIUtils.stripXSS(request.getParameter("task.page"))
+ val parameterTaskSortColumn = UIUtils.stripXSS(request.getParameter("task.sort"))
+ val parameterTaskSortDesc = UIUtils.stripXSS(request.getParameter("task.desc"))
+ val parameterTaskPageSize = UIUtils.stripXSS(request.getParameter("task.pageSize"))
+ val parameterTaskPrevPageSize = UIUtils.stripXSS(request.getParameter("task.prevPageSize"))
val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1)
- val taskSortColumn = Option(parameterTaskSortColumn).getOrElse("Index")
+ val taskSortColumn = Option(parameterTaskSortColumn).map { sortColumn =>
+ UIUtils.decodeURLParameter(sortColumn)
+ }.getOrElse("Index")
val taskSortDesc = Option(parameterTaskSortDesc).map(_.toBoolean).getOrElse(false)
val taskPageSize = Option(parameterTaskPageSize).map(_.toInt).getOrElse(100)
-
- // If this is set, expand the dag visualization by default
- val expandDagVizParam = request.getParameter("expandDagViz")
- val expandDagViz = expandDagVizParam != null && expandDagVizParam.toBoolean
+ val taskPrevPageSize = Option(parameterTaskPrevPageSize).map(_.toInt).getOrElse(taskPageSize)
val stageId = parameterId.toInt
val stageAttemptId = parameterAttempt.toInt
@@ -116,11 +132,18 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
val stageData = stageDataOption.get
val tasks = stageData.taskData.values.toSeq.sortBy(_.taskInfo.launchTime)
- val numCompleted = tasks.count(_.taskInfo.finished)
+ val numCompleted = stageData.numCompleteTasks
+ val totalTasks = stageData.numActiveTasks +
+ stageData.numCompleteTasks + stageData.numFailedTasks
+ val totalTasksNumStr = if (totalTasks == tasks.size) {
+ s"$totalTasks"
+ } else {
+ s"$totalTasks, showing ${tasks.size}"
+ }
val allAccumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables
val externalAccumulables = allAccumulables.values.filter { acc => !acc.internal }
- val hasAccumulators = externalAccumulables.size > 0
+ val hasAccumulators = externalAccumulables.nonEmpty
val summary =
@@ -129,6 +152,10 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
Total Time Across All Tasks:
{UIUtils.formatDuration(stageData.executorRunTime)}
+
+ Locality Level Summary:
+ {getLocalitySummaryString(stageData)}
+
{if (stageData.hasInput) {
Input Size / Records:
@@ -224,15 +251,13 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
Getting Result Time
- {if (displayPeakExecutionMemory) {
-
-
-
- Peak Execution Memory
-
-
- }}
+
+
+
+ Peak Execution Memory
+
+
@@ -240,21 +265,27 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
val dagViz = UIUtils.showDagVizForStage(
stageId, operationGraphListener.getOperationGraphForStage(stageId))
- val maybeExpandDagViz: Seq[Node] =
- if (expandDagViz) {
- UIUtils.expandDagVizOnLoad(forJob = false)
- } else {
- Seq.empty
- }
-
val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value")
- def accumulableRow(acc: AccumulableInfo): Elem =
- {acc.name} {acc.value}
+ def accumulableRow(acc: AccumulableInfo): Seq[Node] = {
+ (acc.name, acc.value) match {
+ case (Some(name), Some(value)) => {name} {value}
+ case _ => Seq.empty[Node]
+ }
+ }
val accumulableTable = UIUtils.listingTable(
accumulableHeaders,
accumulableRow,
externalAccumulables.toSeq)
+ val page: Int = {
+ // If the user has changed to a larger page size, then go to page 1 in order to avoid
+ // IndexOutOfBoundsException.
+ if (taskPageSize <= taskPrevPageSize) {
+ taskPage
+ } else {
+ 1
+ }
+ }
val currentTime = System.currentTimeMillis()
val (taskTable, taskTableHTML) = try {
val _taskTable = new TaskPagedTable(
@@ -271,12 +302,20 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
currentTime,
pageSize = taskPageSize,
sortColumn = taskSortColumn,
- desc = taskSortDesc
+ desc = taskSortDesc,
+ executorsListener = executorsListener
)
- (_taskTable, _taskTable.table(taskPage))
+ (_taskTable, _taskTable.table(page))
} catch {
case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) =>
- (null, {e.getMessage})
+ val errorMessage =
+
+ Error while rendering stage table:
+
+ {Utils.exceptionString(e)}
+
+
+ (null, errorMessage)
}
val jsForScrollingDownToTaskTable =
@@ -298,10 +337,10 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
else taskTable.dataSource.slicedTaskIds
// Excludes tasks which failed and have incomplete metrics
- val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.taskMetrics.isDefined)
+ val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.metrics.isDefined)
val summaryTable: Option[Seq[Node]] =
- if (validTasks.size == 0) {
+ if (validTasks.isEmpty) {
None
}
else {
@@ -316,8 +355,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
getDistributionQuantiles(data).map(d => {Utils.bytesToString(d.toLong)} )
}
- val deserializationTimes = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.executorDeserializeTime.toDouble
+ val deserializationTimes = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.executorDeserializeTime.toDouble
}
val deserializationQuantiles =
@@ -327,13 +366,13 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
+: getFormattedTimeQuantiles(deserializationTimes)
- val serviceTimes = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.executorRunTime.toDouble
+ val serviceTimes = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.executorRunTime.toDouble
}
val serviceQuantiles = Duration +: getFormattedTimeQuantiles(serviceTimes)
- val gcTimes = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.jvmGCTime.toDouble
+ val gcTimes = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.jvmGCTime.toDouble
}
val gcQuantiles =
@@ -342,8 +381,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
+: getFormattedTimeQuantiles(gcTimes)
- val serializationTimes = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.resultSerializationTime.toDouble
+ val serializationTimes = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.resultSerializationTime.toDouble
}
val serializationQuantiles =
@@ -353,8 +392,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
+: getFormattedTimeQuantiles(serializationTimes)
- val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) =>
- getGettingResultTime(info, currentTime).toDouble
+ val gettingResultTimes = validTasks.map { taskUIData: TaskUIData =>
+ getGettingResultTime(taskUIData.taskInfo, currentTime).toDouble
}
val gettingResultQuantiles =
@@ -365,12 +404,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
+:
getFormattedTimeQuantiles(gettingResultTimes)
- val peakExecutionMemory = validTasks.map { case TaskUIData(info, _, _) =>
- info.accumulables
- .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY }
- .map { acc => acc.update.getOrElse("0").toLong }
- .getOrElse(0L)
- .toDouble
+ val peakExecutionMemory = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.peakExecutionMemory.toDouble
}
val peakExecutionMemoryQuantiles = {
@@ -384,8 +419,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
// The scheduler delay includes the network delay to send the task to the worker
// machine and to send back the result (but not the time to fetch the task result,
// if it needed to be fetched from the block manager on the worker).
- val schedulerDelays = validTasks.map { case TaskUIData(info, metrics, _) =>
- getSchedulerDelay(info, metrics.get, currentTime).toDouble
+ val schedulerDelays = validTasks.map { taskUIData: TaskUIData =>
+ getSchedulerDelay(taskUIData.taskInfo, taskUIData.metrics.get, currentTime).toDouble
}
val schedulerDelayTitle = Scheduler Delay
@@ -399,30 +434,30 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
)
}
- val inputSizes = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.inputMetrics.map(_.bytesRead).getOrElse(0L).toDouble
+ val inputSizes = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.inputMetrics.bytesRead.toDouble
}
- val inputRecords = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.inputMetrics.map(_.recordsRead).getOrElse(0L).toDouble
+ val inputRecords = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.inputMetrics.recordsRead.toDouble
}
val inputQuantiles = Input Size / Records +:
getFormattedSizeQuantilesWithRecords(inputSizes, inputRecords)
- val outputSizes = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.outputMetrics.map(_.bytesWritten).getOrElse(0L).toDouble
+ val outputSizes = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.outputMetrics.bytesWritten.toDouble
}
- val outputRecords = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.outputMetrics.map(_.recordsWritten).getOrElse(0L).toDouble
+ val outputRecords = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.outputMetrics.recordsWritten.toDouble
}
val outputQuantiles = Output Size / Records +:
getFormattedSizeQuantilesWithRecords(outputSizes, outputRecords)
- val shuffleReadBlockedTimes = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.shuffleReadMetrics.map(_.fetchWaitTime).getOrElse(0L).toDouble
+ val shuffleReadBlockedTimes = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.shuffleReadMetrics.fetchWaitTime.toDouble
}
val shuffleReadBlockedQuantiles =
@@ -433,11 +468,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
+:
getFormattedTimeQuantiles(shuffleReadBlockedTimes)
- val shuffleReadTotalSizes = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.shuffleReadMetrics.map(_.totalBytesRead).getOrElse(0L).toDouble
+ val shuffleReadTotalSizes = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.shuffleReadMetrics.totalBytesRead.toDouble
}
- val shuffleReadTotalRecords = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.shuffleReadMetrics.map(_.recordsRead).getOrElse(0L).toDouble
+ val shuffleReadTotalRecords = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.shuffleReadMetrics.recordsRead.toDouble
}
val shuffleReadTotalQuantiles =
@@ -448,8 +483,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
+:
getFormattedSizeQuantilesWithRecords(shuffleReadTotalSizes, shuffleReadTotalRecords)
- val shuffleReadRemoteSizes = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble
+ val shuffleReadRemoteSizes = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.shuffleReadMetrics.remoteBytesRead.toDouble
}
val shuffleReadRemoteQuantiles =
@@ -460,25 +495,25 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
+:
getFormattedSizeQuantiles(shuffleReadRemoteSizes)
- val shuffleWriteSizes = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L).toDouble
+ val shuffleWriteSizes = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.shuffleWriteMetrics.bytesWritten.toDouble
}
- val shuffleWriteRecords = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.shuffleWriteMetrics.map(_.shuffleRecordsWritten).getOrElse(0L).toDouble
+ val shuffleWriteRecords = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.shuffleWriteMetrics.recordsWritten.toDouble
}
val shuffleWriteQuantiles = Shuffle Write Size / Records +:
getFormattedSizeQuantilesWithRecords(shuffleWriteSizes, shuffleWriteRecords)
- val memoryBytesSpilledSizes = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.memoryBytesSpilled.toDouble
+ val memoryBytesSpilledSizes = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.memoryBytesSpilled.toDouble
}
val memoryBytesSpilledQuantiles = Shuffle spill (memory) +:
getFormattedSizeQuantiles(memoryBytesSpilledSizes)
- val diskBytesSpilledSizes = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.diskBytesSpilled.toDouble
+ val diskBytesSpilledSizes = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.diskBytesSpilled.toDouble
}
val diskBytesSpilledQuantiles = Shuffle spill (disk) +:
getFormattedSizeQuantiles(diskBytesSpilledSizes)
@@ -494,13 +529,9 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
{serializationQuantiles}
,
{gettingResultQuantiles} ,
- if (displayPeakExecutionMemory) {
-
- {peakExecutionMemoryQuantiles}
-
- } else {
- Nil
- },
+
+ {peakExecutionMemoryQuantiles}
+ ,
if (stageData.hasInput) {inputQuantiles} else Nil,
if (stageData.hasOutput) {outputQuantiles} else Nil,
if (stageData.hasShuffleRead) {
@@ -536,20 +567,32 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
val maybeAccumulableTable: Seq[Node] =
if (hasAccumulators) { Accumulators
++ accumulableTable } else Seq()
+ val aggMetrics =
+
+
+
+ Aggregated Metrics by Executor
+
+
+
+ {executorTable.toNodeSeq}
+
+
val content =
summary ++
dagViz ++
- maybeExpandDagViz ++
showAdditionalMetrics ++
makeTimeline(
// Only show the tasks in the table
stageData.taskData.values.toSeq.filter(t => taskIdsInPage.contains(t.taskInfo.taskId)),
currentTime) ++
- Summary Metrics for {numCompleted} Completed Tasks
++
+ Summary Metrics for {numCompleted} Completed Tasks
++
{summaryTable.getOrElse("No tasks have reported metrics yet.")} ++
- Aggregated Metrics by Executor
++ executorTable.toNodeSeq ++
+ aggMetrics ++
maybeAccumulableTable ++
- Tasks
++ taskTableHTML ++ jsForScrollingDownToTaskTable
+ Tasks ({totalTasksNumStr})
++
+ taskTableHTML ++ jsForScrollingDownToTaskTable
UIUtils.headerSparkPage(stageHeader, content, parent, showVisualization = true)
}
}
@@ -574,13 +617,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
def toProportion(time: Long) = time.toDouble / totalExecutionTime * 100
- val metricsOpt = taskUIData.taskMetrics
+ val metricsOpt = taskUIData.metrics
val shuffleReadTime =
- metricsOpt.flatMap(_.shuffleReadMetrics.map(_.fetchWaitTime)).getOrElse(0L)
+ metricsOpt.map(_.shuffleReadMetrics.fetchWaitTime).getOrElse(0L)
val shuffleReadTimeProportion = toProportion(shuffleReadTime)
val shuffleWriteTime =
- (metricsOpt.flatMap(_.shuffleWriteMetrics
- .map(_.shuffleWriteTime)).getOrElse(0L) / 1e6).toLong
+ (metricsOpt.map(_.shuffleWriteMetrics.writeTime).getOrElse(0L) / 1e6).toLong
val shuffleWriteTimeProportion = toProportion(shuffleWriteTime)
val serializationTime = metricsOpt.map(_.resultSerializationTime).getOrElse(0L)
@@ -602,9 +644,9 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
}
val executorComputingTime = executorRunTime - shuffleReadTime - shuffleWriteTime
val executorComputingTimeProportion =
- (100 - schedulerDelayProportion - shuffleReadTimeProportion -
+ math.max(100 - schedulerDelayProportion - shuffleReadTimeProportion -
shuffleWriteTimeProportion - serializationTimeProportion -
- deserializationTimeProportion - gettingResultTimeProportion)
+ deserializationTimeProportion - gettingResultTimeProportion, 0)
val schedulerDelayProportionPos = 0
val deserializationTimeProportionPos =
@@ -720,7 +762,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
++
}
@@ -741,11 +784,11 @@ private[ui] object StagePage {
}
private[ui] def getSchedulerDelay(
- info: TaskInfo, metrics: TaskMetrics, currentTime: Long): Long = {
+ info: TaskInfo, metrics: TaskMetricsUIData, currentTime: Long): Long = {
if (info.finished) {
val totalExecutionTime = info.finishTime - info.launchTime
- val executorOverhead = (metrics.executorDeserializeTime +
- metrics.resultSerializationTime)
+ val executorOverhead = metrics.executorDeserializeTime +
+ metrics.resultSerializationTime
math.max(
0,
totalExecutionTime - metrics.executorRunTime - executorOverhead -
@@ -808,7 +851,8 @@ private[ui] class TaskTableRowData(
val shuffleRead: Option[TaskTableRowShuffleReadData],
val shuffleWrite: Option[TaskTableRowShuffleWriteData],
val bytesSpilled: Option[TaskTableRowBytesSpilledData],
- val error: String)
+ val error: String,
+ val logs: Map[String, String])
private[ui] class TaskDataSource(
tasks: Seq[TaskUIData],
@@ -821,14 +865,15 @@ private[ui] class TaskDataSource(
currentTime: Long,
pageSize: Int,
sortColumn: String,
- desc: Boolean) extends PagedDataSource[TaskTableRowData](pageSize) {
+ desc: Boolean,
+ executorsListener: ExecutorsListener) extends PagedDataSource[TaskTableRowData](pageSize) {
import StagePage._
// Convert TaskUIData to TaskTableRowData which contains the final contents to show in the table
// so that we can avoid creating duplicate contents during sorting the data
private val data = tasks.map(taskRow).sorted(ordering(sortColumn, desc))
- private var _slicedTaskIds: Set[Long] = null
+ private var _slicedTaskIds: Set[Long] = _
override def dataSize: Int = data.size
@@ -841,42 +886,41 @@ private[ui] class TaskDataSource(
def slicedTaskIds: Set[Long] = _slicedTaskIds
private def taskRow(taskData: TaskUIData): TaskTableRowData = {
- val TaskUIData(info, metrics, errorMessage) = taskData
- val duration = if (info.status == "RUNNING") info.timeRunning(currentTime)
- else metrics.map(_.executorRunTime).getOrElse(1L)
- val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration)
- else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("")
+ val info = taskData.taskInfo
+ val metrics = taskData.metrics
+ val duration = taskData.taskDuration.getOrElse(1L)
+ val formatDuration = taskData.taskDuration.map(d => UIUtils.formatDuration(d)).getOrElse("")
val schedulerDelay = metrics.map(getSchedulerDelay(info, _, currentTime)).getOrElse(0L)
val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L)
val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L)
val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L)
val gettingResultTime = getGettingResultTime(info, currentTime)
- val (taskInternalAccumulables, taskExternalAccumulables) =
- info.accumulables.partition(_.internal)
- val externalAccumulableReadable = taskExternalAccumulables.map { acc =>
- StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}")
- }
- val peakExecutionMemoryUsed = taskInternalAccumulables
- .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY }
- .map { acc => acc.update.getOrElse("0").toLong }
- .getOrElse(0L)
+ val externalAccumulableReadable = info.accumulables
+ .filterNot(_.internal)
+ .flatMap { a =>
+ (a.name, a.update) match {
+ case (Some(name), Some(update)) => Some(StringEscapeUtils.escapeHtml4(s"$name: $update"))
+ case _ => None
+ }
+ }
+ val peakExecutionMemoryUsed = metrics.map(_.peakExecutionMemory).getOrElse(0L)
- val maybeInput = metrics.flatMap(_.inputMetrics)
+ val maybeInput = metrics.map(_.inputMetrics)
val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L)
val inputReadable = maybeInput
- .map(m => s"${Utils.bytesToString(m.bytesRead)} (${m.readMethod.toString.toLowerCase()})")
+ .map(m => s"${Utils.bytesToString(m.bytesRead)}")
.getOrElse("")
val inputRecords = maybeInput.map(_.recordsRead.toString).getOrElse("")
- val maybeOutput = metrics.flatMap(_.outputMetrics)
+ val maybeOutput = metrics.map(_.outputMetrics)
val outputSortable = maybeOutput.map(_.bytesWritten).getOrElse(0L)
val outputReadable = maybeOutput
.map(m => s"${Utils.bytesToString(m.bytesWritten)}")
.getOrElse("")
val outputRecords = maybeOutput.map(_.recordsWritten.toString).getOrElse("")
- val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics)
+ val maybeShuffleRead = metrics.map(_.shuffleReadMetrics)
val shuffleReadBlockedTimeSortable = maybeShuffleRead.map(_.fetchWaitTime).getOrElse(0L)
val shuffleReadBlockedTimeReadable =
maybeShuffleRead.map(ms => UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("")
@@ -890,14 +934,14 @@ private[ui] class TaskDataSource(
val shuffleReadRemoteSortable = remoteShuffleBytes.getOrElse(0L)
val shuffleReadRemoteReadable = remoteShuffleBytes.map(Utils.bytesToString).getOrElse("")
- val maybeShuffleWrite = metrics.flatMap(_.shuffleWriteMetrics)
- val shuffleWriteSortable = maybeShuffleWrite.map(_.shuffleBytesWritten).getOrElse(0L)
+ val maybeShuffleWrite = metrics.map(_.shuffleWriteMetrics)
+ val shuffleWriteSortable = maybeShuffleWrite.map(_.bytesWritten).getOrElse(0L)
val shuffleWriteReadable = maybeShuffleWrite
- .map(m => s"${Utils.bytesToString(m.shuffleBytesWritten)}").getOrElse("")
+ .map(m => s"${Utils.bytesToString(m.bytesWritten)}").getOrElse("")
val shuffleWriteRecords = maybeShuffleWrite
- .map(_.shuffleRecordsWritten.toString).getOrElse("")
+ .map(_.recordsWritten.toString).getOrElse("")
- val maybeWriteTime = metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleWriteTime)
+ val maybeWriteTime = metrics.map(_.shuffleWriteMetrics.writeTime)
val writeTimeSortable = maybeWriteTime.getOrElse(0L)
val writeTimeReadable = maybeWriteTime.map(t => t / (1000 * 1000)).map { ms =>
if (ms == 0) "" else UIUtils.formatDuration(ms)
@@ -964,6 +1008,8 @@ private[ui] class TaskDataSource(
None
}
+ val logs = executorsListener.executorToTaskSummary.get(info.executorId)
+ .map(_.executorLogs).getOrElse(Map.empty)
new TaskTableRowData(
info.index,
info.taskId,
@@ -987,96 +1033,46 @@ private[ui] class TaskDataSource(
shuffleRead,
shuffleWrite,
bytesSpilled,
- errorMessage.getOrElse(""))
+ taskData.errorMessage.getOrElse(""),
+ logs)
}
/**
* Return Ordering according to sortColumn and desc
*/
private def ordering(sortColumn: String, desc: Boolean): Ordering[TaskTableRowData] = {
- val ordering = sortColumn match {
- case "Index" => new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.Int.compare(x.index, y.index)
- }
- case "ID" => new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.Long.compare(x.taskId, y.taskId)
- }
- case "Attempt" => new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.Int.compare(x.attempt, y.attempt)
- }
- case "Status" => new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.String.compare(x.status, y.status)
- }
- case "Locality Level" => new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.String.compare(x.taskLocality, y.taskLocality)
- }
- case "Executor ID / Host" => new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.String.compare(x.executorIdAndHost, y.executorIdAndHost)
- }
- case "Launch Time" => new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.Long.compare(x.launchTime, y.launchTime)
- }
- case "Duration" => new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.Long.compare(x.duration, y.duration)
- }
- case "Scheduler Delay" => new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.Long.compare(x.schedulerDelay, y.schedulerDelay)
- }
- case "Task Deserialization Time" => new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.Long.compare(x.taskDeserializationTime, y.taskDeserializationTime)
- }
- case "GC Time" => new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.Long.compare(x.gcTime, y.gcTime)
- }
- case "Result Serialization Time" => new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.Long.compare(x.serializationTime, y.serializationTime)
- }
- case "Getting Result Time" => new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.Long.compare(x.gettingResultTime, y.gettingResultTime)
- }
- case "Peak Execution Memory" => new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.Long.compare(x.peakExecutionMemoryUsed, y.peakExecutionMemoryUsed)
- }
+ val ordering: Ordering[TaskTableRowData] = sortColumn match {
+ case "Index" => Ordering.by(_.index)
+ case "ID" => Ordering.by(_.taskId)
+ case "Attempt" => Ordering.by(_.attempt)
+ case "Status" => Ordering.by(_.status)
+ case "Locality Level" => Ordering.by(_.taskLocality)
+ case "Executor ID / Host" => Ordering.by(_.executorIdAndHost)
+ case "Launch Time" => Ordering.by(_.launchTime)
+ case "Duration" => Ordering.by(_.duration)
+ case "Scheduler Delay" => Ordering.by(_.schedulerDelay)
+ case "Task Deserialization Time" => Ordering.by(_.taskDeserializationTime)
+ case "GC Time" => Ordering.by(_.gcTime)
+ case "Result Serialization Time" => Ordering.by(_.serializationTime)
+ case "Getting Result Time" => Ordering.by(_.gettingResultTime)
+ case "Peak Execution Memory" => Ordering.by(_.peakExecutionMemoryUsed)
case "Accumulators" =>
if (hasAccumulators) {
- new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.String.compare(x.accumulators.get, y.accumulators.get)
- }
+ Ordering.by(_.accumulators.get)
} else {
throw new IllegalArgumentException(
"Cannot sort by Accumulators because of no accumulators")
}
case "Input Size / Records" =>
if (hasInput) {
- new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.Long.compare(x.input.get.inputSortable, y.input.get.inputSortable)
- }
+ Ordering.by(_.input.get.inputSortable)
} else {
throw new IllegalArgumentException(
"Cannot sort by Input Size / Records because of no inputs")
}
case "Output Size / Records" =>
if (hasOutput) {
- new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.Long.compare(x.output.get.outputSortable, y.output.get.outputSortable)
- }
+ Ordering.by(_.output.get.outputSortable)
} else {
throw new IllegalArgumentException(
"Cannot sort by Output Size / Records because of no outputs")
@@ -1084,33 +1080,21 @@ private[ui] class TaskDataSource(
// ShuffleRead
case "Shuffle Read Blocked Time" =>
if (hasShuffleRead) {
- new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.Long.compare(x.shuffleRead.get.shuffleReadBlockedTimeSortable,
- y.shuffleRead.get.shuffleReadBlockedTimeSortable)
- }
+ Ordering.by(_.shuffleRead.get.shuffleReadBlockedTimeSortable)
} else {
throw new IllegalArgumentException(
"Cannot sort by Shuffle Read Blocked Time because of no shuffle reads")
}
case "Shuffle Read Size / Records" =>
if (hasShuffleRead) {
- new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.Long.compare(x.shuffleRead.get.shuffleReadSortable,
- y.shuffleRead.get.shuffleReadSortable)
- }
+ Ordering.by(_.shuffleRead.get.shuffleReadSortable)
} else {
throw new IllegalArgumentException(
"Cannot sort by Shuffle Read Size / Records because of no shuffle reads")
}
case "Shuffle Remote Reads" =>
if (hasShuffleRead) {
- new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.Long.compare(x.shuffleRead.get.shuffleReadRemoteSortable,
- y.shuffleRead.get.shuffleReadRemoteSortable)
- }
+ Ordering.by(_.shuffleRead.get.shuffleReadRemoteSortable)
} else {
throw new IllegalArgumentException(
"Cannot sort by Shuffle Remote Reads because of no shuffle reads")
@@ -1118,22 +1102,14 @@ private[ui] class TaskDataSource(
// ShuffleWrite
case "Write Time" =>
if (hasShuffleWrite) {
- new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.Long.compare(x.shuffleWrite.get.writeTimeSortable,
- y.shuffleWrite.get.writeTimeSortable)
- }
+ Ordering.by(_.shuffleWrite.get.writeTimeSortable)
} else {
throw new IllegalArgumentException(
"Cannot sort by Write Time because of no shuffle writes")
}
case "Shuffle Write Size / Records" =>
if (hasShuffleWrite) {
- new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.Long.compare(x.shuffleWrite.get.shuffleWriteSortable,
- y.shuffleWrite.get.shuffleWriteSortable)
- }
+ Ordering.by(_.shuffleWrite.get.shuffleWriteSortable)
} else {
throw new IllegalArgumentException(
"Cannot sort by Shuffle Write Size / Records because of no shuffle writes")
@@ -1141,30 +1117,19 @@ private[ui] class TaskDataSource(
// BytesSpilled
case "Shuffle Spill (Memory)" =>
if (hasBytesSpilled) {
- new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.Long.compare(x.bytesSpilled.get.memoryBytesSpilledSortable,
- y.bytesSpilled.get.memoryBytesSpilledSortable)
- }
+ Ordering.by(_.bytesSpilled.get.memoryBytesSpilledSortable)
} else {
throw new IllegalArgumentException(
"Cannot sort by Shuffle Spill (Memory) because of no spills")
}
case "Shuffle Spill (Disk)" =>
if (hasBytesSpilled) {
- new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.Long.compare(x.bytesSpilled.get.diskBytesSpilledSortable,
- y.bytesSpilled.get.diskBytesSpilledSortable)
- }
+ Ordering.by(_.bytesSpilled.get.diskBytesSpilledSortable)
} else {
throw new IllegalArgumentException(
"Cannot sort by Shuffle Spill (Disk) because of no spills")
}
- case "Errors" => new Ordering[TaskTableRowData] {
- override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
- Ordering.String.compare(x.error, y.error)
- }
+ case "Errors" => Ordering.by(_.error)
case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn")
}
if (desc) {
@@ -1189,14 +1154,19 @@ private[ui] class TaskPagedTable(
currentTime: Long,
pageSize: Int,
sortColumn: String,
- desc: Boolean) extends PagedTable[TaskTableRowData] {
-
- // We only track peak memory used for unsafe operators
- private val displayPeakExecutionMemory = conf.getBoolean("spark.sql.unsafe.enabled", true)
+ desc: Boolean,
+ executorsListener: ExecutorsListener) extends PagedTable[TaskTableRowData] {
override def tableId: String = "task-table"
- override def tableCssClass: String = "table table-bordered table-condensed table-striped"
+ override def tableCssClass: String =
+ "table table-bordered table-condensed table-striped table-head-clickable"
+
+ override def pageSizeFormField: String = "task.pageSize"
+
+ override def prevPageSizeFormField: String = "task.prevPageSize"
+
+ override def pageNumberFormField: String = "task.page"
override val dataSource: TaskDataSource = new TaskDataSource(
data,
@@ -1209,28 +1179,21 @@ private[ui] class TaskPagedTable(
currentTime,
pageSize,
sortColumn,
- desc)
+ desc,
+ executorsListener)
override def pageLink(page: Int): String = {
val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8")
- s"${basePath}&task.page=$page&task.sort=${encodedSortColumn}&task.desc=${desc}" +
- s"&task.pageSize=${pageSize}"
+ basePath +
+ s"&$pageNumberFormField=$page" +
+ s"&task.sort=$encodedSortColumn" +
+ s"&task.desc=$desc" +
+ s"&$pageSizeFormField=$pageSize"
}
- override def goButtonJavascriptFunction: (String, String) = {
- val jsFuncName = "goToTaskPage"
+ override def goButtonFormPath: String = {
val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8")
- val jsFunc = s"""
- |currentTaskPageSize = ${pageSize}
- |function goToTaskPage(page, pageSize) {
- | // Set page to 1 if the page size changes
- | page = pageSize == currentTaskPageSize ? page : 1;
- | var url = "${basePath}&task.sort=${encodedSortColumn}&task.desc=${desc}" +
- | "&task.page=" + page + "&task.pageSize=" + pageSize;
- | window.location.href = url;
- |}
- """.stripMargin
- (jsFuncName, jsFunc)
+ s"$basePath&task.sort=$encodedSortColumn&task.desc=$desc"
}
def headers: Seq[Node] = {
@@ -1242,14 +1205,8 @@ private[ui] class TaskPagedTable(
("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME),
("GC Time", ""),
("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME),
- ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++
- {
- if (displayPeakExecutionMemory) {
- Seq(("Peak Execution Memory", TaskDetailsClassNames.PEAK_EXECUTION_MEMORY))
- } else {
- Nil
- }
- } ++
+ ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME),
+ ("Peak Execution Memory", TaskDetailsClassNames.PEAK_EXECUTION_MEMORY)) ++
{if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++
{if (hasInput) Seq(("Input Size / Records", "")) else Nil} ++
{if (hasOutput) Seq(("Output Size / Records", "")) else Nil} ++
@@ -1279,21 +1236,27 @@ private[ui] class TaskPagedTable(
val headerRow: Seq[Node] = {
taskHeadersAndCssClasses.map { case (header, cssClass) =>
if (header == sortColumn) {
- val headerLink =
- s"$basePath&task.sort=${URLEncoder.encode(header, "UTF-8")}&task.desc=${!desc}" +
- s"&task.pageSize=${pageSize}"
- val js = Unparsed(s"window.location.href='${headerLink}'")
+ val headerLink = Unparsed(
+ basePath +
+ s"&task.sort=${URLEncoder.encode(header, "UTF-8")}" +
+ s"&task.desc=${!desc}" +
+ s"&task.pageSize=$pageSize")
val arrow = if (desc) "▾" else "▴" // UP or DOWN
-
- {header}
- {Unparsed(arrow)}
+
+
+ {header}
+ {Unparsed(arrow)}
+
} else {
- val headerLink =
- s"$basePath&task.sort=${URLEncoder.encode(header, "UTF-8")}&task.pageSize=${pageSize}"
- val js = Unparsed(s"window.location.href='${headerLink}'")
-
- {header}
+ val headerLink = Unparsed(
+ basePath +
+ s"&task.sort=${URLEncoder.encode(header, "UTF-8")}" +
+ s"&task.pageSize=$pageSize")
+
+
+ {header}
+
}
}
@@ -1308,7 +1271,16 @@ private[ui] class TaskPagedTable(
{if (task.speculative) s"${task.attempt} (speculative)" else task.attempt.toString}
{task.status}
{task.taskLocality}
- {task.executorIdAndHost}
+
+ {task.executorIdAndHost}
+
+ {
+ task.logs.map {
+ case (logName, logUrl) =>
+ }
+ }
+
+
{UIUtils.formatDate(new Date(task.launchTime))}
{task.formatDuration}
@@ -1326,11 +1298,9 @@ private[ui] class TaskPagedTable(
{UIUtils.formatDuration(task.gettingResultTime)}
- {if (displayPeakExecutionMemory) {
-
- {Utils.bytesToString(task.peakExecutionMemoryUsed)}
-
- }}
+
+ {Utils.bytesToString(task.peakExecutionMemoryUsed)}
+
{if (task.accumulators.nonEmpty) {
{Unparsed(task.accumulators.get)}
}}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index ea806d09b600..741f95ae2642 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -17,61 +17,336 @@
package org.apache.spark.ui.jobs
+import java.net.URLEncoder
import java.util.Date
+import javax.servlet.http.HttpServletRequest
-import scala.xml.{Node, Text}
+import scala.collection.JavaConverters._
+import scala.xml._
import org.apache.commons.lang3.StringEscapeUtils
import org.apache.spark.scheduler.StageInfo
-import org.apache.spark.ui.{ToolTips, UIUtils}
+import org.apache.spark.ui._
+import org.apache.spark.ui.jobs.UIData.StageUIData
import org.apache.spark.util.Utils
-/** Page showing list of all ongoing and recently finished stages */
private[ui] class StageTableBase(
+ request: HttpServletRequest,
stages: Seq[StageInfo],
+ tableHeaderID: String,
+ stageTag: String,
basePath: String,
+ subPath: String,
+ progressListener: JobProgressListener,
+ isFairScheduler: Boolean,
+ killEnabled: Boolean,
+ isFailedStage: Boolean) {
+ // stripXSS is called to remove suspicious characters used in XSS attacks
+ val allParameters = request.getParameterMap.asScala.toMap.map { case (k, v) =>
+ UIUtils.stripXSS(k) -> v.map(UIUtils.stripXSS).toSeq
+ }
+ val parameterOtherTable = allParameters.filterNot(_._1.startsWith(stageTag))
+ .map(para => para._1 + "=" + para._2(0))
+
+ val parameterStagePage = UIUtils.stripXSS(request.getParameter(stageTag + ".page"))
+ val parameterStageSortColumn = UIUtils.stripXSS(request.getParameter(stageTag + ".sort"))
+ val parameterStageSortDesc = UIUtils.stripXSS(request.getParameter(stageTag + ".desc"))
+ val parameterStagePageSize = UIUtils.stripXSS(request.getParameter(stageTag + ".pageSize"))
+ val parameterStagePrevPageSize =
+ UIUtils.stripXSS(request.getParameter(stageTag + ".prevPageSize"))
+
+ val stagePage = Option(parameterStagePage).map(_.toInt).getOrElse(1)
+ val stageSortColumn = Option(parameterStageSortColumn).map { sortColumn =>
+ UIUtils.decodeURLParameter(sortColumn)
+ }.getOrElse("Stage Id")
+ val stageSortDesc = Option(parameterStageSortDesc).map(_.toBoolean).getOrElse(
+ // New stages should be shown above old jobs by default.
+ if (stageSortColumn == "Stage Id") true else false
+ )
+ val stagePageSize = Option(parameterStagePageSize).map(_.toInt).getOrElse(100)
+ val stagePrevPageSize = Option(parameterStagePrevPageSize).map(_.toInt)
+ .getOrElse(stagePageSize)
+
+ val page: Int = {
+ // If the user has changed to a larger page size, then go to page 1 in order to avoid
+ // IndexOutOfBoundsException.
+ if (stagePageSize <= stagePrevPageSize) {
+ stagePage
+ } else {
+ 1
+ }
+ }
+ val currentTime = System.currentTimeMillis()
+
+ val toNodeSeq = try {
+ new StagePagedTable(
+ stages,
+ tableHeaderID,
+ stageTag,
+ basePath,
+ subPath,
+ progressListener,
+ isFairScheduler,
+ killEnabled,
+ currentTime,
+ stagePageSize,
+ stageSortColumn,
+ stageSortDesc,
+ isFailedStage,
+ parameterOtherTable
+ ).table(page)
+ } catch {
+ case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) =>
+
+ Error while rendering stage table:
+
+ {Utils.exceptionString(e)}
+
+
+ }
+}
+
+private[ui] class StageTableRowData(
+ val stageInfo: StageInfo,
+ val stageData: Option[StageUIData],
+ val stageId: Int,
+ val attemptId: Int,
+ val schedulingPool: String,
+ val descriptionOption: Option[String],
+ val submissionTime: Long,
+ val formattedSubmissionTime: String,
+ val duration: Long,
+ val formattedDuration: String,
+ val inputRead: Long,
+ val inputReadWithUnit: String,
+ val outputWrite: Long,
+ val outputWriteWithUnit: String,
+ val shuffleRead: Long,
+ val shuffleReadWithUnit: String,
+ val shuffleWrite: Long,
+ val shuffleWriteWithUnit: String)
+
+private[ui] class MissingStageTableRowData(
+ stageInfo: StageInfo,
+ stageId: Int,
+ attemptId: Int) extends StageTableRowData(
+ stageInfo, None, stageId, attemptId, "", None, 0, "", -1, "", 0, "", 0, "", 0, "", 0, "")
+
+/** Page showing list of all ongoing and recently finished stages */
+private[ui] class StagePagedTable(
+ stages: Seq[StageInfo],
+ tableHeaderId: String,
+ stageTag: String,
+ basePath: String,
+ subPath: String,
listener: JobProgressListener,
isFairScheduler: Boolean,
- killEnabled: Boolean) {
-
- protected def columns: Seq[Node] = {
- Stage Id ++
- {if (isFairScheduler) {Pool Name } else Seq.empty} ++
- Description
- Submitted
- Duration
- Tasks: Succeeded/Total
- Input
- Output
- Shuffle Read
-
-
-
- Shuffle Write
-
-
+ killEnabled: Boolean,
+ currentTime: Long,
+ pageSize: Int,
+ sortColumn: String,
+ desc: Boolean,
+ isFailedStage: Boolean,
+ parameterOtherTable: Iterable[String]) extends PagedTable[StageTableRowData] {
+
+ override def tableId: String = stageTag + "-table"
+
+ override def tableCssClass: String =
+ "table table-bordered table-condensed table-striped " +
+ "table-head-clickable table-cell-width-limited"
+
+ override def pageSizeFormField: String = stageTag + ".pageSize"
+
+ override def prevPageSizeFormField: String = stageTag + ".prevPageSize"
+
+ override def pageNumberFormField: String = stageTag + ".page"
+
+ val parameterPath = UIUtils.prependBaseUri(basePath) + s"/$subPath/?" +
+ parameterOtherTable.mkString("&")
+
+ override val dataSource = new StageDataSource(
+ stages,
+ listener,
+ currentTime,
+ pageSize,
+ sortColumn,
+ desc
+ )
+
+ override def pageLink(page: Int): String = {
+ val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8")
+ parameterPath +
+ s"&$pageNumberFormField=$page" +
+ s"&$stageTag.sort=$encodedSortColumn" +
+ s"&$stageTag.desc=$desc" +
+ s"&$pageSizeFormField=$pageSize" +
+ s"#$tableHeaderId"
+ }
+
+ override def goButtonFormPath: String = {
+ val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8")
+ s"$parameterPath&$stageTag.sort=$encodedSortColumn&$stageTag.desc=$desc#$tableHeaderId"
}
- def toNodeSeq: Seq[Node] = {
- listener.synchronized {
- stageTable(renderStageRow, stages)
+ override def headers: Seq[Node] = {
+ // stageHeadersAndCssClasses has three parts: header title, tooltip information, and sortable.
+ // The tooltip information could be None, which indicates it does not have a tooltip.
+ // Otherwise, it has two parts: tooltip text, and position (true for left, false for default).
+ val stageHeadersAndCssClasses: Seq[(String, Option[(String, Boolean)], Boolean)] =
+ Seq(("Stage Id", None, true)) ++
+ {if (isFairScheduler) {Seq(("Pool Name", None, true))} else Seq.empty} ++
+ Seq(
+ ("Description", None, true), ("Submitted", None, true), ("Duration", None, true),
+ ("Tasks: Succeeded/Total", None, false),
+ ("Input", Some((ToolTips.INPUT, false)), true),
+ ("Output", Some((ToolTips.OUTPUT, false)), true),
+ ("Shuffle Read", Some((ToolTips.SHUFFLE_READ, false)), true),
+ ("Shuffle Write", Some((ToolTips.SHUFFLE_WRITE, true)), true)
+ ) ++
+ {if (isFailedStage) {Seq(("Failure Reason", None, false))} else Seq.empty}
+
+ if (!stageHeadersAndCssClasses.filter(_._3).map(_._1).contains(sortColumn)) {
+ throw new IllegalArgumentException(s"Unknown column: $sortColumn")
+ }
+
+ val headerRow: Seq[Node] = {
+ stageHeadersAndCssClasses.map { case (header, tooltip, sortable) =>
+ val headerSpan = tooltip.map { case (title, left) =>
+ if (left) {
+ /* Place the shuffle write tooltip on the left (rather than the default position
+ of on top) because the shuffle write column is the last column on the right side and
+ the tooltip is wider than the column, so it doesn't fit on top. */
+
+ {header}
+
+ } else {
+
+ {header}
+
+ }
+ }.getOrElse(
+ {header}
+ )
+
+ if (header == sortColumn) {
+ val headerLink = Unparsed(
+ parameterPath +
+ s"&$stageTag.sort=${URLEncoder.encode(header, "UTF-8")}" +
+ s"&$stageTag.desc=${!desc}" +
+ s"&$stageTag.pageSize=$pageSize") +
+ s"#$tableHeaderId"
+ val arrow = if (desc) "▾" else "▴" // UP or DOWN
+
+
+
+ {headerSpan}
+ {Unparsed(arrow)}
+
+
+
+ } else {
+ if (sortable) {
+ val headerLink = Unparsed(
+ parameterPath +
+ s"&$stageTag.sort=${URLEncoder.encode(header, "UTF-8")}" +
+ s"&$stageTag.pageSize=$pageSize") +
+ s"#$tableHeaderId"
+
+
+
+ {headerSpan}
+
+
+ } else {
+
+ {headerSpan}
+
+ }
+ }
+ }
+ }
+ {headerRow}
+ }
+
+ override def row(data: StageTableRowData): Seq[Node] = {
+
+ {rowContent(data)}
+
+ }
+
+ private def rowContent(data: StageTableRowData): Seq[Node] = {
+ data.stageData match {
+ case None => missingStageRow(data.stageId)
+ case Some(stageData) =>
+ val info = data.stageInfo
+
+ {if (data.attemptId > 0) {
+ {data.stageId} (retry {data.attemptId})
+ } else {
+ {data.stageId}
+ }} ++
+ {if (isFairScheduler) {
+
+
+ {data.schedulingPool}
+
+
+ } else {
+ Seq.empty
+ }} ++
+ {makeDescription(info, data.descriptionOption)}
+
+ {data.formattedSubmissionTime}
+
+ {data.formattedDuration}
+
+ {UIUtils.makeProgressBar(started = stageData.numActiveTasks,
+ completed = stageData.completedIndices.size, failed = stageData.numFailedTasks,
+ skipped = 0, reasonToNumKilled = stageData.reasonToNumKilled, total = info.numTasks)}
+
+ {data.inputReadWithUnit}
+ {data.outputWriteWithUnit}
+ {data.shuffleReadWithUnit}
+ {data.shuffleWriteWithUnit} ++
+ {
+ if (isFailedStage) {
+ failureReasonHtml(info)
+ } else {
+ Seq.empty
+ }
+ }
}
}
- /** Special table that merges two header cells. */
- protected def stageTable[T](makeRow: T => Seq[Node], rows: Seq[T]): Seq[Node] = {
-
- {columns}
-
- {rows.map(r => makeRow(r))}
-
-
+ private def failureReasonHtml(s: StageInfo): Seq[Node] = {
+ val failureReason = s.failureReason.getOrElse("")
+ val isMultiline = failureReason.indexOf('\n') >= 0
+ // Display the first line by default
+ val failureReasonSummary = StringEscapeUtils.escapeHtml4(
+ if (isMultiline) {
+ failureReason.substring(0, failureReason.indexOf('\n'))
+ } else {
+ failureReason
+ })
+ val details = if (isMultiline) {
+ // scalastyle:off
+ ++
+
+ {failureReason}
+
+ // scalastyle:on
+ } else {
+ ""
+ }
+ {failureReasonSummary}{details}
}
- private def makeDescription(s: StageInfo): Seq[Node] = {
+ private def makeDescription(s: StageInfo, descriptionOption: Option[String]): Seq[Node] = {
val basePathUri = UIUtils.prependBaseUri(basePath)
val killLink = if (killEnabled) {
@@ -83,12 +358,13 @@ private[ui] class StageTableBase(
val killLinkUri = s"$basePathUri/stages/stage/kill/"
*/
- val killLinkUri = s"$basePathUri/stages/stage/kill/?id=${s.stageId}&terminate=true"
+ val killLinkUri = s"$basePathUri/stages/stage/kill/?id=${s.stageId}"
(kill)
+ } else {
+ Seq.empty
}
val nameLinkUri = s"$basePathUri/stages/stage?id=${s.stageId}&attempt=${s.attemptId}"
@@ -111,12 +387,7 @@ private[ui] class StageTableBase(
}
- val stageDesc = for {
- stageData <- listener.stageIdToData.get((s.stageId, s.attemptId))
- desc <- stageData.description
- } yield {
- UIUtils.makeDescription(desc, basePathUri)
- }
+ val stageDesc = descriptionOption.map(UIUtils.makeDescription(_, basePathUri))
{stageDesc.getOrElse("")} {killLink} {nameLink} {details}
}
@@ -132,22 +403,60 @@ private[ui] class StageTableBase(
++ // Shuffle Read
// Shuffle Write
}
+}
+
+private[ui] class StageDataSource(
+ stages: Seq[StageInfo],
+ listener: JobProgressListener,
+ currentTime: Long,
+ pageSize: Int,
+ sortColumn: String,
+ desc: Boolean) extends PagedDataSource[StageTableRowData](pageSize) {
+ // Convert StageInfo to StageTableRowData which contains the final contents to show in the table
+ // so that we can avoid creating duplicate contents during sorting the data
+ private val data = stages.map(stageRow).sorted(ordering(sortColumn, desc))
- protected def stageRow(s: StageInfo): Seq[Node] = {
+ private var _slicedStageIds: Set[Int] = _
+
+ override def dataSize: Int = data.size
+
+ override def sliceData(from: Int, to: Int): Seq[StageTableRowData] = {
+ val r = data.slice(from, to)
+ _slicedStageIds = r.map(_.stageId).toSet
+ r
+ }
+
+ private def stageRow(s: StageInfo): StageTableRowData = {
val stageDataOption = listener.stageIdToData.get((s.stageId, s.attemptId))
+
if (stageDataOption.isEmpty) {
- return missingStageRow(s.stageId)
+ return new MissingStageTableRowData(s, s.stageId, s.attemptId)
}
-
val stageData = stageDataOption.get
- val submissionTime = s.submissionTime match {
+
+ val description = stageData.description
+
+ val formattedSubmissionTime = s.submissionTime match {
case Some(t) => UIUtils.formatDate(new Date(t))
case None => "Unknown"
}
- val finishTime = s.completionTime.getOrElse(System.currentTimeMillis)
- val duration = s.submissionTime.map { t =>
- if (finishTime > t) finishTime - t else System.currentTimeMillis - t
- }
+ val finishTime = s.completionTime.getOrElse(currentTime)
+
+ // The submission time for a stage is misleading because it counts the time
+ // the stage waits to be launched. (SPARK-10930)
+ val taskLaunchTimes =
+ stageData.taskData.values.map(_.taskInfo.launchTime).filter(_ > 0)
+ val duration: Option[Long] =
+ if (taskLaunchTimes.nonEmpty) {
+ val startTime = taskLaunchTimes.min
+ if (finishTime > startTime) {
+ Some(finishTime - startTime)
+ } else {
+ Some(currentTime - startTime)
+ }
+ } else {
+ None
+ }
val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown")
val inputRead = stageData.inputBytes
@@ -159,76 +468,51 @@ private[ui] class StageTableBase(
val shuffleWrite = stageData.shuffleWriteBytes
val shuffleWriteWithUnit = if (shuffleWrite > 0) Utils.bytesToString(shuffleWrite) else ""
- {if (s.attemptId > 0) {
- {s.stageId} (retry {s.attemptId})
- } else {
- {s.stageId}
- }} ++
- {if (isFairScheduler) {
-
-
- {stageData.schedulingPool}
-
-
- } else {
- Seq.empty
- }} ++
- {makeDescription(s)}
-
- {submissionTime}
-
- {formattedDuration}
-
- {UIUtils.makeProgressBar(started = stageData.numActiveTasks,
- completed = stageData.completedIndices.size, failed = stageData.numFailedTasks,
- skipped = 0, total = s.numTasks)}
-
- {inputReadWithUnit}
- {outputWriteWithUnit}
- {shuffleReadWithUnit}
- {shuffleWriteWithUnit}
- }
-
- /** Render an HTML row that represents a stage */
- private def renderStageRow(s: StageInfo): Seq[Node] =
- {stageRow(s)}
-}
-
-private[ui] class FailedStageTable(
- stages: Seq[StageInfo],
- basePath: String,
- listener: JobProgressListener,
- isFairScheduler: Boolean)
- extends StageTableBase(stages, basePath, listener, isFairScheduler, killEnabled = false) {
- override protected def columns: Seq[Node] = super.columns ++ Failure Reason
+ new StageTableRowData(
+ s,
+ stageDataOption,
+ s.stageId,
+ s.attemptId,
+ stageData.schedulingPool,
+ description,
+ s.submissionTime.getOrElse(0),
+ formattedSubmissionTime,
+ duration.getOrElse(-1),
+ formattedDuration,
+ inputRead,
+ inputReadWithUnit,
+ outputWrite,
+ outputWriteWithUnit,
+ shuffleRead,
+ shuffleReadWithUnit,
+ shuffleWrite,
+ shuffleWriteWithUnit
+ )
+ }
- override protected def stageRow(s: StageInfo): Seq[Node] = {
- val basicColumns = super.stageRow(s)
- val failureReason = s.failureReason.getOrElse("")
- val isMultiline = failureReason.indexOf('\n') >= 0
- // Display the first line by default
- val failureReasonSummary = StringEscapeUtils.escapeHtml4(
- if (isMultiline) {
- failureReason.substring(0, failureReason.indexOf('\n'))
- } else {
- failureReason
- })
- val details = if (isMultiline) {
- // scalastyle:off
- ++
-
- {failureReason}
-
- // scalastyle:on
+ /**
+ * Return Ordering according to sortColumn and desc
+ */
+ private def ordering(sortColumn: String, desc: Boolean): Ordering[StageTableRowData] = {
+ val ordering: Ordering[StageTableRowData] = sortColumn match {
+ case "Stage Id" => Ordering.by(_.stageId)
+ case "Pool Name" => Ordering.by(_.schedulingPool)
+ case "Description" => Ordering.by(x => (x.descriptionOption, x.stageInfo.name))
+ case "Submitted" => Ordering.by(_.submissionTime)
+ case "Duration" => Ordering.by(_.duration)
+ case "Input" => Ordering.by(_.inputRead)
+ case "Output" => Ordering.by(_.outputWrite)
+ case "Shuffle Read" => Ordering.by(_.shuffleRead)
+ case "Shuffle Write" => Ordering.by(_.shuffleWrite)
+ case "Tasks: Succeeded/Total" =>
+ throw new IllegalArgumentException(s"Unsortable column: $sortColumn")
+ case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn")
+ }
+ if (desc) {
+ ordering.reverse
} else {
- ""
+ ordering
}
- val failureReasonHtml = {failureReasonSummary}{details}
- basicColumns ++ failureReasonHtml
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
index 5989f0035b27..799d76962639 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs
import javax.servlet.http.HttpServletRequest
import org.apache.spark.scheduler.SchedulingMode
-import org.apache.spark.ui.{SparkUI, SparkUITab}
+import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils}
/** Web UI showing progress status of all stages in the given SparkContext. */
private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages") {
@@ -29,24 +29,27 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages"
val killEnabled = parent.killEnabled
val progressListener = parent.jobProgressListener
val operationGraphListener = parent.operationGraphListener
+ val executorsListener = parent.executorsListener
attachPage(new AllStagesPage(this))
attachPage(new StagePage(this))
attachPage(new PoolPage(this))
- def isFairScheduler: Boolean = progressListener.schedulingMode.exists(_ == SchedulingMode.FAIR)
+ def isFairScheduler: Boolean = progressListener.schedulingMode == Some(SchedulingMode.FAIR)
def handleKillRequest(request: HttpServletRequest): Unit = {
if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) {
- val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean
- val stageId = Option(request.getParameter("id")).getOrElse("-1").toInt
- if (stageId >= 0 && killFlag && progressListener.activeStages.contains(stageId)) {
- sc.get.cancelStage(stageId)
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val stageId = Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt)
+ stageId.foreach { id =>
+ if (progressListener.activeStages.contains(id)) {
+ sc.foreach(_.cancelStage(id, "killed via the Web UI"))
+ // Do a quick pause here to give Spark time to kill the stage so it shows up as
+ // killed after the refresh. Note that this will block the serving thread so the
+ // time should be limited in duration.
+ Thread.sleep(100)
+ }
}
- // Do a quick pause here to give Spark time to kill the stage so it shows up as
- // killed after the refresh. Note that this will block the serving thread so the
- // time should be limited in duration.
- Thread.sleep(100)
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
index f008d4018061..25aa5042e0e0 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
@@ -17,20 +17,24 @@
package org.apache.spark.ui.jobs
+import scala.collection.mutable
+import scala.collection.mutable.{HashMap, LinkedHashMap}
+
+import com.google.common.collect.Interners
+
import org.apache.spark.JobExecutionStatus
-import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.executor._
import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo}
+import org.apache.spark.util.AccumulatorContext
import org.apache.spark.util.collection.OpenHashSet
-import scala.collection.mutable
-import scala.collection.mutable.HashMap
-
private[spark] object UIData {
class ExecutorSummary {
var taskTime : Long = 0
var failedTasks : Int = 0
var succeededTasks : Int = 0
+ var reasonToNumKilled : Map[String, Int] = Map.empty
var inputBytes : Long = 0
var inputRecords : Long = 0
var outputBytes : Long = 0
@@ -41,6 +45,7 @@ private[spark] object UIData {
var shuffleWriteRecords : Long = 0
var memoryBytesSpilled : Long = 0
var diskBytesSpilled : Long = 0
+ var isBlacklisted : Int = 0
}
class JobUIData(
@@ -61,6 +66,7 @@ private[spark] object UIData {
var numCompletedTasks: Int = 0,
var numSkippedTasks: Int = 0,
var numFailedTasks: Int = 0,
+ var reasonToNumKilled: Map[String, Int] = Map.empty,
/* Stages */
var numActiveStages: Int = 0,
// This needs to be a set instead of a simple count to prevent double-counting of rerun stages:
@@ -74,8 +80,10 @@ private[spark] object UIData {
var numCompleteTasks: Int = _
var completedIndices = new OpenHashSet[Int]()
var numFailedTasks: Int = _
+ var reasonToNumKilled: Map[String, Int] = Map.empty
var executorRunTime: Long = _
+ var executorCpuTime: Long = _
var inputBytes: Long = _
var inputRecords: Long = _
@@ -87,31 +95,213 @@ private[spark] object UIData {
var shuffleWriteRecords: Long = _
var memoryBytesSpilled: Long = _
var diskBytesSpilled: Long = _
+ var isBlacklisted: Int = _
var schedulingPool: String = ""
var description: Option[String] = None
var accumulables = new HashMap[Long, AccumulableInfo]
- var taskData = new HashMap[Long, TaskUIData]
+ var taskData = new LinkedHashMap[Long, TaskUIData]
var executorSummary = new HashMap[String, ExecutorSummary]
def hasInput: Boolean = inputBytes > 0
def hasOutput: Boolean = outputBytes > 0
def hasShuffleRead: Boolean = shuffleReadTotalBytes > 0
def hasShuffleWrite: Boolean = shuffleWriteBytes > 0
- def hasBytesSpilled: Boolean = memoryBytesSpilled > 0 && diskBytesSpilled > 0
+ def hasBytesSpilled: Boolean = memoryBytesSpilled > 0 || diskBytesSpilled > 0
}
/**
* These are kept mutable and reused throughout a task's lifetime to avoid excessive reallocation.
*/
- case class TaskUIData(
- var taskInfo: TaskInfo,
- var taskMetrics: Option[TaskMetrics] = None,
- var errorMessage: Option[String] = None)
-
- case class ExecutorUIData(
- val startTime: Long,
- var finishTime: Option[Long] = None,
- var finishReason: Option[String] = None)
+ class TaskUIData private(private var _taskInfo: TaskInfo) {
+
+ private[this] var _metrics: Option[TaskMetricsUIData] = Some(TaskMetricsUIData.EMPTY)
+
+ var errorMessage: Option[String] = None
+
+ def taskInfo: TaskInfo = _taskInfo
+
+ def metrics: Option[TaskMetricsUIData] = _metrics
+
+ def updateTaskInfo(taskInfo: TaskInfo): Unit = {
+ _taskInfo = TaskUIData.dropInternalAndSQLAccumulables(taskInfo)
+ }
+
+ def updateTaskMetrics(metrics: Option[TaskMetrics]): Unit = {
+ _metrics = metrics.map(TaskMetricsUIData.fromTaskMetrics)
+ }
+
+ def taskDuration: Option[Long] = {
+ if (taskInfo.status == "RUNNING") {
+ Some(_taskInfo.timeRunning(System.currentTimeMillis))
+ } else {
+ _metrics.map(_.executorRunTime)
+ }
+ }
+ }
+
+ object TaskUIData {
+
+ private val stringInterner = Interners.newWeakInterner[String]()
+
+ /** String interning to reduce the memory usage. */
+ private def weakIntern(s: String): String = {
+ stringInterner.intern(s)
+ }
+
+ def apply(taskInfo: TaskInfo): TaskUIData = {
+ new TaskUIData(dropInternalAndSQLAccumulables(taskInfo))
+ }
+
+ /**
+ * We don't need to store internal or SQL accumulables as their values will be shown in other
+ * places, so drop them to reduce the memory usage.
+ */
+ private[spark] def dropInternalAndSQLAccumulables(taskInfo: TaskInfo): TaskInfo = {
+ val newTaskInfo = new TaskInfo(
+ taskId = taskInfo.taskId,
+ index = taskInfo.index,
+ attemptNumber = taskInfo.attemptNumber,
+ launchTime = taskInfo.launchTime,
+ executorId = weakIntern(taskInfo.executorId),
+ host = weakIntern(taskInfo.host),
+ taskLocality = taskInfo.taskLocality,
+ speculative = taskInfo.speculative
+ )
+ newTaskInfo.gettingResultTime = taskInfo.gettingResultTime
+ newTaskInfo.setAccumulables(taskInfo.accumulables.filter {
+ accum => !accum.internal && accum.metadata != Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER)
+ })
+ newTaskInfo.finishTime = taskInfo.finishTime
+ newTaskInfo.failed = taskInfo.failed
+ newTaskInfo.killed = taskInfo.killed
+ newTaskInfo
+ }
+ }
+
+ case class TaskMetricsUIData(
+ executorDeserializeTime: Long,
+ executorDeserializeCpuTime: Long,
+ executorRunTime: Long,
+ executorCpuTime: Long,
+ resultSize: Long,
+ jvmGCTime: Long,
+ resultSerializationTime: Long,
+ memoryBytesSpilled: Long,
+ diskBytesSpilled: Long,
+ peakExecutionMemory: Long,
+ inputMetrics: InputMetricsUIData,
+ outputMetrics: OutputMetricsUIData,
+ shuffleReadMetrics: ShuffleReadMetricsUIData,
+ shuffleWriteMetrics: ShuffleWriteMetricsUIData)
+
+ object TaskMetricsUIData {
+ def fromTaskMetrics(m: TaskMetrics): TaskMetricsUIData = {
+ TaskMetricsUIData(
+ executorDeserializeTime = m.executorDeserializeTime,
+ executorDeserializeCpuTime = m.executorDeserializeCpuTime,
+ executorRunTime = m.executorRunTime,
+ executorCpuTime = m.executorCpuTime,
+ resultSize = m.resultSize,
+ jvmGCTime = m.jvmGCTime,
+ resultSerializationTime = m.resultSerializationTime,
+ memoryBytesSpilled = m.memoryBytesSpilled,
+ diskBytesSpilled = m.diskBytesSpilled,
+ peakExecutionMemory = m.peakExecutionMemory,
+ inputMetrics = InputMetricsUIData(m.inputMetrics),
+ outputMetrics = OutputMetricsUIData(m.outputMetrics),
+ shuffleReadMetrics = ShuffleReadMetricsUIData(m.shuffleReadMetrics),
+ shuffleWriteMetrics = ShuffleWriteMetricsUIData(m.shuffleWriteMetrics))
+ }
+
+ val EMPTY: TaskMetricsUIData = fromTaskMetrics(TaskMetrics.empty)
+ }
+
+ case class InputMetricsUIData(bytesRead: Long, recordsRead: Long)
+ object InputMetricsUIData {
+ def apply(metrics: InputMetrics): InputMetricsUIData = {
+ if (metrics.bytesRead == 0 && metrics.recordsRead == 0) {
+ EMPTY
+ } else {
+ new InputMetricsUIData(
+ bytesRead = metrics.bytesRead,
+ recordsRead = metrics.recordsRead)
+ }
+ }
+ private val EMPTY = InputMetricsUIData(0, 0)
+ }
+
+ case class OutputMetricsUIData(bytesWritten: Long, recordsWritten: Long)
+ object OutputMetricsUIData {
+ def apply(metrics: OutputMetrics): OutputMetricsUIData = {
+ if (metrics.bytesWritten == 0 && metrics.recordsWritten == 0) {
+ EMPTY
+ } else {
+ new OutputMetricsUIData(
+ bytesWritten = metrics.bytesWritten,
+ recordsWritten = metrics.recordsWritten)
+ }
+ }
+ private val EMPTY = OutputMetricsUIData(0, 0)
+ }
+
+ case class ShuffleReadMetricsUIData(
+ remoteBlocksFetched: Long,
+ localBlocksFetched: Long,
+ remoteBytesRead: Long,
+ localBytesRead: Long,
+ fetchWaitTime: Long,
+ recordsRead: Long,
+ totalBytesRead: Long,
+ totalBlocksFetched: Long)
+
+ object ShuffleReadMetricsUIData {
+ def apply(metrics: ShuffleReadMetrics): ShuffleReadMetricsUIData = {
+ if (
+ metrics.remoteBlocksFetched == 0 &&
+ metrics.localBlocksFetched == 0 &&
+ metrics.remoteBytesRead == 0 &&
+ metrics.localBytesRead == 0 &&
+ metrics.fetchWaitTime == 0 &&
+ metrics.recordsRead == 0 &&
+ metrics.totalBytesRead == 0 &&
+ metrics.totalBlocksFetched == 0) {
+ EMPTY
+ } else {
+ new ShuffleReadMetricsUIData(
+ remoteBlocksFetched = metrics.remoteBlocksFetched,
+ localBlocksFetched = metrics.localBlocksFetched,
+ remoteBytesRead = metrics.remoteBytesRead,
+ localBytesRead = metrics.localBytesRead,
+ fetchWaitTime = metrics.fetchWaitTime,
+ recordsRead = metrics.recordsRead,
+ totalBytesRead = metrics.totalBytesRead,
+ totalBlocksFetched = metrics.totalBlocksFetched
+ )
+ }
+ }
+ private val EMPTY = ShuffleReadMetricsUIData(0, 0, 0, 0, 0, 0, 0, 0)
+ }
+
+ case class ShuffleWriteMetricsUIData(
+ bytesWritten: Long,
+ recordsWritten: Long,
+ writeTime: Long)
+
+ object ShuffleWriteMetricsUIData {
+ def apply(metrics: ShuffleWriteMetrics): ShuffleWriteMetricsUIData = {
+ if (metrics.bytesWritten == 0 && metrics.recordsWritten == 0 && metrics.writeTime == 0) {
+ EMPTY
+ } else {
+ new ShuffleWriteMetricsUIData(
+ bytesWritten = metrics.bytesWritten,
+ recordsWritten = metrics.recordsWritten,
+ writeTime = metrics.writeTime
+ )
+ }
+ }
+ private val EMPTY = ShuffleWriteMetricsUIData(0, 0, 0)
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala
index 81f168a447ea..43bfe0aacf35 100644
--- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala
+++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala
@@ -17,12 +17,16 @@
package org.apache.spark.ui.scope
+import java.util.Objects
+
import scala.collection.mutable
-import scala.collection.mutable.{StringBuilder, ListBuffer}
+import scala.collection.mutable.{ListBuffer, StringBuilder}
+
+import org.apache.commons.lang3.StringEscapeUtils
-import org.apache.spark.Logging
+import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.StageInfo
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.{RDDInfo, StorageLevel}
/**
* A representation of a generic cluster graph used for storing information on RDD operations.
@@ -38,7 +42,7 @@ private[ui] case class RDDOperationGraph(
rootCluster: RDDOperationCluster)
/** A node in an RDDOperationGraph. This represents an RDD. */
-private[ui] case class RDDOperationNode(id: Int, name: String, cached: Boolean)
+private[ui] case class RDDOperationNode(id: Int, name: String, cached: Boolean, callsite: String)
/**
* A directed edge connecting two nodes in an RDDOperationGraph.
@@ -70,6 +74,22 @@ private[ui] class RDDOperationCluster(val id: String, private var _name: String)
def getCachedNodes: Seq[RDDOperationNode] = {
_childNodes.filter(_.cached) ++ _childClusters.flatMap(_.getCachedNodes)
}
+
+ def canEqual(other: Any): Boolean = other.isInstanceOf[RDDOperationCluster]
+
+ override def equals(other: Any): Boolean = other match {
+ case that: RDDOperationCluster =>
+ (that canEqual this) &&
+ _childClusters == that._childClusters &&
+ id == that.id &&
+ _name == that._name
+ case _ => false
+ }
+
+ override def hashCode(): Int = {
+ val state = Seq(_childClusters, id, _name)
+ state.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b)
+ }
}
private[ui] object RDDOperationGraph extends Logging {
@@ -87,7 +107,7 @@ private[ui] object RDDOperationGraph extends Logging {
* supporting in the future if we decide to group certain stages within the same job under
* a common scope (e.g. part of a SQL query).
*/
- def makeOperationGraph(stage: StageInfo): RDDOperationGraph = {
+ def makeOperationGraph(stage: StageInfo, retainedNodes: Int): RDDOperationGraph = {
val edges = new ListBuffer[RDDOperationEdge]
val nodes = new mutable.HashMap[Int, RDDOperationNode]
val clusters = new mutable.HashMap[String, RDDOperationCluster] // indexed by cluster ID
@@ -99,18 +119,37 @@ private[ui] object RDDOperationGraph extends Logging {
{ if (stage.attemptId == 0) "" else s" (attempt ${stage.attemptId})" }
val rootCluster = new RDDOperationCluster(stageClusterId, stageClusterName)
+ var rootNodeCount = 0
+ val addRDDIds = new mutable.HashSet[Int]()
+ val dropRDDIds = new mutable.HashSet[Int]()
+
// Find nodes, edges, and operation scopes that belong to this stage
- stage.rddInfos.foreach { rdd =>
- edges ++= rdd.parentIds.map { parentId => RDDOperationEdge(parentId, rdd.id) }
+ stage.rddInfos.sortBy(_.id).foreach { rdd =>
+ val parentIds = rdd.parentIds
+ val isAllowed =
+ if (parentIds.isEmpty) {
+ rootNodeCount += 1
+ rootNodeCount <= retainedNodes
+ } else {
+ parentIds.exists(id => addRDDIds.contains(id) || !dropRDDIds.contains(id))
+ }
- // TODO: differentiate between the intention to cache an RDD and whether it's actually cached
- val node = nodes.getOrElseUpdate(
- rdd.id, RDDOperationNode(rdd.id, rdd.name, rdd.storageLevel != StorageLevel.NONE))
+ if (isAllowed) {
+ addRDDIds += rdd.id
+ edges ++= parentIds.filter(id => !dropRDDIds.contains(id)).map(RDDOperationEdge(_, rdd.id))
+ } else {
+ dropRDDIds += rdd.id
+ }
+ // TODO: differentiate between the intention to cache an RDD and whether it's actually cached
+ val node = nodes.getOrElseUpdate(rdd.id, RDDOperationNode(
+ rdd.id, rdd.name, rdd.storageLevel != StorageLevel.NONE, rdd.callSite))
if (rdd.scope.isEmpty) {
// This RDD has no encompassing scope, so we put it directly in the root cluster
// This should happen only if an RDD is instantiated outside of a public RDD API
- rootCluster.attachChildNode(node)
+ if (isAllowed) {
+ rootCluster.attachChildNode(node)
+ }
} else {
// Otherwise, this RDD belongs to an inner cluster,
// which may be nested inside of other clusters
@@ -129,8 +168,14 @@ private[ui] object RDDOperationGraph extends Logging {
}
}
// Attach the outermost cluster to the root cluster, and the RDD to the innermost cluster
- rddClusters.headOption.foreach { cluster => rootCluster.attachChildCluster(cluster) }
- rddClusters.lastOption.foreach { cluster => cluster.attachChildNode(node) }
+ rddClusters.headOption.foreach { cluster =>
+ if (!rootCluster.childClusters.contains(cluster)) {
+ rootCluster.attachChildCluster(cluster)
+ }
+ }
+ if (isAllowed) {
+ rddClusters.lastOption.foreach { cluster => cluster.attachChildNode(node) }
+ }
}
}
@@ -177,7 +222,13 @@ private[ui] object RDDOperationGraph extends Logging {
/** Return the dot representation of a node in an RDDOperationGraph. */
private def makeDotNode(node: RDDOperationNode): String = {
- s"""${node.id} [label="${node.name} [${node.id}]"]"""
+ val isCached = if (node.cached) {
+ " [Cached]"
+ } else {
+ ""
+ }
+ val label = s"${node.name} [${node.id}]$isCached\n${node.callsite}"
+ s"""${node.id} [label="${StringEscapeUtils.escapeJava(label)}"]"""
}
/** Update the dot representation of the RDDOperationGraph in cluster to subgraph. */
@@ -186,7 +237,7 @@ private[ui] object RDDOperationGraph extends Logging {
cluster: RDDOperationCluster,
indent: String): Unit = {
subgraph.append(indent).append(s"subgraph cluster${cluster.id} {\n")
- subgraph.append(indent).append(s""" label="${cluster.name}";\n""")
+ .append(indent).append(s""" label="${StringEscapeUtils.escapeJava(cluster.name)}";\n""")
cluster.childNodes.foreach { node =>
subgraph.append(indent).append(s" ${makeDotNode(node)};\n")
}
diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala
index 89119cd3579e..37a12a864693 100644
--- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala
@@ -41,6 +41,10 @@ private[ui] class RDDOperationGraphListener(conf: SparkConf) extends SparkListen
private[ui] val jobIds = new mutable.ArrayBuffer[Int]
private[ui] val stageIds = new mutable.ArrayBuffer[Int]
+ // How many root nodes to retain in DAG Graph
+ private[ui] val retainedNodes =
+ conf.getInt("spark.ui.dagGraph.retainedRootRDDs", Int.MaxValue)
+
// How many jobs or stages to retain graph metadata for
private val retainedJobs =
conf.getInt("spark.ui.retainedJobs", SparkUI.DEFAULT_RETAINED_JOBS)
@@ -52,9 +56,8 @@ private[ui] class RDDOperationGraphListener(conf: SparkConf) extends SparkListen
* An empty list is returned if one or more of its stages has been cleaned up.
*/
def getOperationGraphForJob(jobId: Int): Seq[RDDOperationGraph] = synchronized {
- val skippedStageIds = jobIdToSkippedStageIds.get(jobId).getOrElse(Seq.empty)
- val graphs = jobIdToStageIds.get(jobId)
- .getOrElse(Seq.empty)
+ val skippedStageIds = jobIdToSkippedStageIds.getOrElse(jobId, Seq.empty)
+ val graphs = jobIdToStageIds.getOrElse(jobId, Seq.empty)
.flatMap { sid => stageIdToGraph.get(sid) }
// Mark any skipped stages as such
graphs.foreach { g =>
@@ -83,7 +86,7 @@ private[ui] class RDDOperationGraphListener(conf: SparkConf) extends SparkListen
val stageId = stageInfo.stageId
stageIds += stageId
stageIdToJobId(stageId) = jobId
- stageIdToGraph(stageId) = RDDOperationGraph.makeOperationGraph(stageInfo)
+ stageIdToGraph(stageId) = RDDOperationGraph.makeOperationGraph(stageInfo, retainedNodes)
trimStagesIfNecessary()
}
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
index fd6cc3ed759b..317e0aa5ea25 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
@@ -31,18 +31,21 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") {
private val listener = parent.listener
def render(request: HttpServletRequest): Seq[Node] = {
- val parameterId = request.getParameter("id")
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")
- val parameterBlockPage = request.getParameter("block.page")
- val parameterBlockSortColumn = request.getParameter("block.sort")
- val parameterBlockSortDesc = request.getParameter("block.desc")
- val parameterBlockPageSize = request.getParameter("block.pageSize")
+ val parameterBlockPage = UIUtils.stripXSS(request.getParameter("block.page"))
+ val parameterBlockSortColumn = UIUtils.stripXSS(request.getParameter("block.sort"))
+ val parameterBlockSortDesc = UIUtils.stripXSS(request.getParameter("block.desc"))
+ val parameterBlockPageSize = UIUtils.stripXSS(request.getParameter("block.pageSize"))
+ val parameterBlockPrevPageSize = UIUtils.stripXSS(request.getParameter("block.prevPageSize"))
val blockPage = Option(parameterBlockPage).map(_.toInt).getOrElse(1)
val blockSortColumn = Option(parameterBlockSortColumn).getOrElse("Block Name")
val blockSortDesc = Option(parameterBlockSortDesc).map(_.toBoolean).getOrElse(false)
val blockPageSize = Option(parameterBlockPageSize).map(_.toInt).getOrElse(100)
+ val blockPrevPageSize = Option(parameterBlockPrevPageSize).map(_.toInt).getOrElse(blockPageSize)
val rddId = parameterId.toInt
val rddStorageInfo = AllRDDResource.getRDDStorageInfo(rddId, listener, includeDetails = true)
@@ -56,17 +59,26 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") {
rddStorageInfo.dataDistribution.get, id = Some("rdd-storage-by-worker-table"))
// Block table
- val (blockTable, blockTableHTML) = try {
+ val page: Int = {
+ // If the user has changed to a larger page size, then go to page 1 in order to avoid
+ // IndexOutOfBoundsException.
+ if (blockPageSize <= blockPrevPageSize) {
+ blockPage
+ } else {
+ 1
+ }
+ }
+ val blockTableHTML = try {
val _blockTable = new BlockPagedTable(
UIUtils.prependBaseUri(parent.basePath) + s"/storage/rdd/?id=${rddId}",
rddStorageInfo.partitions.get,
blockPageSize,
blockSortColumn,
blockSortDesc)
- (_blockTable, _blockTable.table(blockPage))
+ _blockTable.table(page)
} catch {
case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) =>
- (null, {e.getMessage})
+ {e.getMessage}
}
val jsForScrollingDownToBlockTable =
@@ -136,7 +148,8 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") {
/** Header fields for the worker table */
private def workerHeader = Seq(
"Host",
- "Memory Usage",
+ "On Heap Memory Usage",
+ "Off Heap Memory Usage",
"Disk Usage")
/** Render an HTML row representing a worker */
@@ -144,8 +157,12 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") {
{worker.address}
- {Utils.bytesToString(worker.memoryUsed)}
- ({Utils.bytesToString(worker.memoryRemaining)} Remaining)
+ {Utils.bytesToString(worker.onHeapMemoryUsed.getOrElse(0L))}
+ ({Utils.bytesToString(worker.onHeapMemoryRemaining.getOrElse(0L))} Remaining)
+
+
+ {Utils.bytesToString(worker.offHeapMemoryUsed.getOrElse(0L))}
+ ({Utils.bytesToString(worker.offHeapMemoryRemaining.getOrElse(0L))} Remaining)
{Utils.bytesToString(worker.diskUsed)}
@@ -186,27 +203,12 @@ private[ui] class BlockDataSource(
* Return Ordering according to sortColumn and desc
*/
private def ordering(sortColumn: String, desc: Boolean): Ordering[BlockTableRowData] = {
- val ordering = sortColumn match {
- case "Block Name" => new Ordering[BlockTableRowData] {
- override def compare(x: BlockTableRowData, y: BlockTableRowData): Int =
- Ordering.String.compare(x.blockName, y.blockName)
- }
- case "Storage Level" => new Ordering[BlockTableRowData] {
- override def compare(x: BlockTableRowData, y: BlockTableRowData): Int =
- Ordering.String.compare(x.storageLevel, y.storageLevel)
- }
- case "Size in Memory" => new Ordering[BlockTableRowData] {
- override def compare(x: BlockTableRowData, y: BlockTableRowData): Int =
- Ordering.Long.compare(x.memoryUsed, y.memoryUsed)
- }
- case "Size on Disk" => new Ordering[BlockTableRowData] {
- override def compare(x: BlockTableRowData, y: BlockTableRowData): Int =
- Ordering.Long.compare(x.diskUsed, y.diskUsed)
- }
- case "Executors" => new Ordering[BlockTableRowData] {
- override def compare(x: BlockTableRowData, y: BlockTableRowData): Int =
- Ordering.String.compare(x.executors, y.executors)
- }
+ val ordering: Ordering[BlockTableRowData] = sortColumn match {
+ case "Block Name" => Ordering.by(_.blockName)
+ case "Storage Level" => Ordering.by(_.storageLevel)
+ case "Size in Memory" => Ordering.by(_.memoryUsed)
+ case "Size on Disk" => Ordering.by(_.diskUsed)
+ case "Executors" => Ordering.by(_.executors)
case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn")
}
if (desc) {
@@ -226,7 +228,14 @@ private[ui] class BlockPagedTable(
override def tableId: String = "rdd-storage-by-block-table"
- override def tableCssClass: String = "table table-bordered table-condensed table-striped"
+ override def tableCssClass: String =
+ "table table-bordered table-condensed table-striped table-head-clickable"
+
+ override def pageSizeFormField: String = "block.pageSize"
+
+ override def prevPageSizeFormField: String = "block.prevPageSize"
+
+ override def pageNumberFormField: String = "block.page"
override val dataSource: BlockDataSource = new BlockDataSource(
rddPartitions,
@@ -236,24 +245,16 @@ private[ui] class BlockPagedTable(
override def pageLink(page: Int): String = {
val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8")
- s"${basePath}&block.page=$page&block.sort=${encodedSortColumn}&block.desc=${desc}" +
- s"&block.pageSize=${pageSize}"
+ basePath +
+ s"&$pageNumberFormField=$page" +
+ s"&block.sort=$encodedSortColumn" +
+ s"&block.desc=$desc" +
+ s"&$pageSizeFormField=$pageSize"
}
- override def goButtonJavascriptFunction: (String, String) = {
- val jsFuncName = "goToBlockPage"
+ override def goButtonFormPath: String = {
val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8")
- val jsFunc = s"""
- |currentBlockPageSize = ${pageSize}
- |function goToBlockPage(page, pageSize) {
- | // Set page to 1 if the page size changes
- | page = pageSize == currentBlockPageSize ? page : 1;
- | var url = "${basePath}&block.sort=${encodedSortColumn}&block.desc=${desc}" +
- | "&block.page=" + page + "&block.pageSize=" + pageSize;
- | window.location.href = url;
- |}
- """.stripMargin
- (jsFuncName, jsFunc)
+ s"$basePath&block.sort=$encodedSortColumn&block.desc=$desc"
}
override def headers: Seq[Node] = {
@@ -271,22 +272,27 @@ private[ui] class BlockPagedTable(
val headerRow: Seq[Node] = {
blockHeaders.map { header =>
if (header == sortColumn) {
- val headerLink =
- s"$basePath&block.sort=${URLEncoder.encode(header, "UTF-8")}&block.desc=${!desc}" +
- s"&block.pageSize=${pageSize}"
- val js = Unparsed(s"window.location.href='${headerLink}'")
+ val headerLink = Unparsed(
+ basePath +
+ s"&block.sort=${URLEncoder.encode(header, "UTF-8")}" +
+ s"&block.desc=${!desc}" +
+ s"&block.pageSize=$pageSize")
val arrow = if (desc) "▾" else "▴" // UP or DOWN
-
- {header}
- {Unparsed(arrow)}
+
+
+ {header}
+ {Unparsed(arrow)}
+
} else {
- val headerLink =
- s"$basePath&block.sort=${URLEncoder.encode(header, "UTF-8")}" +
- s"&block.pageSize=${pageSize}"
- val js = Unparsed(s"window.location.href='${headerLink}'")
-
- {header}
+ val headerLink = Unparsed(
+ basePath +
+ s"&block.sort=${URLEncoder.encode(header, "UTF-8")}" +
+ s"&block.pageSize=$pageSize")
+
+
+ {header}
+
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala
index 04f584621e71..aa84788f1df8 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala
@@ -54,7 +54,6 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") {
"Cached Partitions",
"Fraction Cached",
"Size in Memory",
- "Size in ExternalBlockStore",
"Size on Disk")
/** Render an HTML row representing an RDD */
@@ -71,7 +70,6 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") {
{rdd.numCachedPartitions.toString}
{"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)}
{Utils.bytesToString(rdd.memSize)}
- {Utils.bytesToString(rdd.externalBlockStoreSize)}
{Utils.bytesToString(rdd.diskSize)}
// scalastyle:on
@@ -104,7 +102,6 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") {
"Executor ID",
"Address",
"Total Size in Memory",
- "Total Size in ExternalBlockStore",
"Total Size on Disk",
"Stream Blocks")
@@ -119,9 +116,6 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") {
{Utils.bytesToString(status.totalMemSize)}
-
- {Utils.bytesToString(status.totalExternalBlockStoreSize)}
-
{Utils.bytesToString(status.totalDiskSize)}
@@ -157,12 +151,12 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") {
/** Render a stream block */
private def streamBlockTableRow(block: (BlockId, Seq[BlockUIData])): Seq[Node] = {
val replications = block._2
- assert(replications.size > 0) // This must be true because it's the result of "groupBy"
+ assert(replications.nonEmpty) // This must be true because it's the result of "groupBy"
if (replications.size == 1) {
streamBlockTableSubrow(block._1, replications.head, replications.size, true)
} else {
streamBlockTableSubrow(block._1, replications.head, replications.size, true) ++
- replications.tail.map(streamBlockTableSubrow(block._1, _, replications.size, false)).flatten
+ replications.tail.flatMap(streamBlockTableSubrow(block._1, _, replications.size, false))
}
}
@@ -195,8 +189,6 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") {
("Memory", block.memSize)
} else if (block.storageLevel.useMemory && !block.storageLevel.deserialized) {
("Memory Serialized", block.memSize)
- } else if (block.storageLevel.useOffHeap) {
- ("External", block.externalBlockStoreSize)
} else {
throw new IllegalStateException(s"Invalid Storage Level: ${block.storageLevel}")
}
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
index 22e2993b3b5b..148efb134e14 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
@@ -20,9 +20,9 @@ package org.apache.spark.ui.storage
import scala.collection.mutable
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.ui._
import org.apache.spark.scheduler._
import org.apache.spark.storage._
+import org.apache.spark.ui._
/** Web UI showing storage status of all RDD's in the given SparkContext. */
private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storage") {
@@ -39,11 +39,12 @@ private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storag
* This class is thread-safe (unlike JobProgressListener)
*/
@DeveloperApi
+@deprecated("This class will be removed in a future release.", "2.2.0")
class StorageListener(storageStatusListener: StorageStatusListener) extends BlockStatusListener {
private[ui] val _rddInfoMap = mutable.Map[Int, RDDInfo]() // exposed for testing
- def storageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList
+ def activeStorageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList
/** Filter RDD info to include only those with cached partitions */
def rddInfoList: Seq[RDDInfo] = synchronized {
@@ -54,23 +55,12 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Bloc
private def updateRDDInfo(updatedBlocks: Seq[(BlockId, BlockStatus)]): Unit = {
val rddIdsToUpdate = updatedBlocks.flatMap { case (bid, _) => bid.asRDDId.map(_.rddId) }.toSet
val rddInfosToUpdate = _rddInfoMap.values.toSeq.filter { s => rddIdsToUpdate.contains(s.id) }
- StorageUtils.updateRddInfo(rddInfosToUpdate, storageStatusList)
- }
-
- /**
- * Assumes the storage status list is fully up-to-date. This implies the corresponding
- * StorageStatusSparkListener must process the SparkListenerTaskEnd event before this listener.
- */
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
- val metrics = taskEnd.taskMetrics
- if (metrics != null && metrics.updatedBlocks.isDefined) {
- updateRDDInfo(metrics.updatedBlocks.get)
- }
+ StorageUtils.updateRddInfo(rddInfosToUpdate, activeStorageStatusList)
}
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized {
val rddInfos = stageSubmitted.stageInfo.rddInfos
- rddInfos.foreach { info => _rddInfoMap.getOrElseUpdate(info.id, info) }
+ rddInfos.foreach { info => _rddInfoMap.getOrElseUpdate(info.id, info).name = info.name }
}
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = synchronized {
@@ -84,4 +74,14 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Bloc
override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = synchronized {
_rddInfoMap.remove(unpersistRDD.rddId)
}
+
+ override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = {
+ super.onBlockUpdated(blockUpdated)
+ val blockId = blockUpdated.blockUpdatedInfo.blockId
+ val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel
+ val memSize = blockUpdated.blockUpdatedInfo.memSize
+ val diskSize = blockUpdated.blockUpdatedInfo.diskSize
+ val blockStatus = BlockStatus(storageLevel, memSize, diskSize)
+ updateRDDInfo(Seq((blockId, blockStatus)))
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
new file mode 100644
index 000000000000..5df17ccb627a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
@@ -0,0 +1,508 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.{lang => jl}
+import java.io.ObjectInputStream
+import java.util.{ArrayList, Collections}
+import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.{InternalAccumulator, SparkContext, TaskContext}
+import org.apache.spark.scheduler.AccumulableInfo
+
+
+private[spark] case class AccumulatorMetadata(
+ id: Long,
+ name: Option[String],
+ countFailedValues: Boolean) extends Serializable
+
+
+/**
+ * The base class for accumulators, that can accumulate inputs of type `IN`, and produce output of
+ * type `OUT`.
+ *
+ * `OUT` should be a type that can be read atomically (e.g., Int, Long), or thread-safely
+ * (e.g., synchronized collections) because it will be read from other threads.
+ */
+abstract class AccumulatorV2[IN, OUT] extends Serializable {
+ private[spark] var metadata: AccumulatorMetadata = _
+ private[this] var atDriverSide = true
+
+ private[spark] def register(
+ sc: SparkContext,
+ name: Option[String] = None,
+ countFailedValues: Boolean = false): Unit = {
+ if (this.metadata != null) {
+ throw new IllegalStateException("Cannot register an Accumulator twice.")
+ }
+ this.metadata = AccumulatorMetadata(AccumulatorContext.newId(), name, countFailedValues)
+ AccumulatorContext.register(this)
+ sc.cleaner.foreach(_.registerAccumulatorForCleanup(this))
+ }
+
+ /**
+ * Returns true if this accumulator has been registered.
+ *
+ * @note All accumulators must be registered before use, or it will throw exception.
+ */
+ final def isRegistered: Boolean =
+ metadata != null && AccumulatorContext.get(metadata.id).isDefined
+
+ private def assertMetadataNotNull(): Unit = {
+ if (metadata == null) {
+ throw new IllegalStateException("The metadata of this accumulator has not been assigned yet.")
+ }
+ }
+
+ /**
+ * Returns the id of this accumulator, can only be called after registration.
+ */
+ final def id: Long = {
+ assertMetadataNotNull()
+ metadata.id
+ }
+
+ /**
+ * Returns the name of this accumulator, can only be called after registration.
+ */
+ final def name: Option[String] = {
+ assertMetadataNotNull()
+
+ if (atDriverSide) {
+ metadata.name.orElse(AccumulatorContext.get(id).flatMap(_.metadata.name))
+ } else {
+ metadata.name
+ }
+ }
+
+ /**
+ * Whether to accumulate values from failed tasks. This is set to true for system and time
+ * metrics like serialization time or bytes spilled, and false for things with absolute values
+ * like number of input rows. This should be used for internal metrics only.
+ */
+ private[spark] final def countFailedValues: Boolean = {
+ assertMetadataNotNull()
+ metadata.countFailedValues
+ }
+
+ /**
+ * Creates an [[AccumulableInfo]] representation of this [[AccumulatorV2]] with the provided
+ * values.
+ */
+ private[spark] def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = {
+ val isInternal = name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX))
+ new AccumulableInfo(id, name, update, value, isInternal, countFailedValues)
+ }
+
+ final private[spark] def isAtDriverSide: Boolean = atDriverSide
+
+ /**
+ * Returns if this accumulator is zero value or not. e.g. for a counter accumulator, 0 is zero
+ * value; for a list accumulator, Nil is zero value.
+ */
+ def isZero: Boolean
+
+ /**
+ * Creates a new copy of this accumulator, which is zero value. i.e. call `isZero` on the copy
+ * must return true.
+ */
+ def copyAndReset(): AccumulatorV2[IN, OUT] = {
+ val copyAcc = copy()
+ copyAcc.reset()
+ copyAcc
+ }
+
+ /**
+ * Creates a new copy of this accumulator.
+ */
+ def copy(): AccumulatorV2[IN, OUT]
+
+ /**
+ * Resets this accumulator, which is zero value. i.e. call `isZero` must
+ * return true.
+ */
+ def reset(): Unit
+
+ /**
+ * Takes the inputs and accumulates.
+ */
+ def add(v: IN): Unit
+
+ /**
+ * Merges another same-type accumulator into this one and update its state, i.e. this should be
+ * merge-in-place.
+ */
+ def merge(other: AccumulatorV2[IN, OUT]): Unit
+
+ /**
+ * Defines the current value of this accumulator
+ */
+ def value: OUT
+
+ // Called by Java when serializing an object
+ final protected def writeReplace(): Any = {
+ if (atDriverSide) {
+ if (!isRegistered) {
+ throw new UnsupportedOperationException(
+ "Accumulator must be registered before send to executor")
+ }
+ val copyAcc = copyAndReset()
+ assert(copyAcc.isZero, "copyAndReset must return a zero value copy")
+ val isInternalAcc = name.isDefined && name.get.startsWith(InternalAccumulator.METRICS_PREFIX)
+ if (isInternalAcc) {
+ // Do not serialize the name of internal accumulator and send it to executor.
+ copyAcc.metadata = metadata.copy(name = None)
+ } else {
+ // For non-internal accumulators, we still need to send the name because users may need to
+ // access the accumulator name at executor side, or they may keep the accumulators sent from
+ // executors and access the name when the registered accumulator is already garbage
+ // collected(e.g. SQLMetrics).
+ copyAcc.metadata = metadata
+ }
+ copyAcc
+ } else {
+ this
+ }
+ }
+
+ // Called by Java when deserializing an object
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
+ in.defaultReadObject()
+ if (atDriverSide) {
+ atDriverSide = false
+
+ // Automatically register the accumulator when it is deserialized with the task closure.
+ // This is for external accumulators and internal ones that do not represent task level
+ // metrics, e.g. internal SQL metrics, which are per-operator.
+ val taskContext = TaskContext.get()
+ if (taskContext != null) {
+ taskContext.registerAccumulator(this)
+ }
+ } else {
+ atDriverSide = true
+ }
+ }
+
+ override def toString: String = {
+ if (metadata == null) {
+ "Un-registered Accumulator: " + getClass.getSimpleName
+ } else {
+ getClass.getSimpleName + s"(id: $id, name: $name, value: $value)"
+ }
+ }
+}
+
+
+/**
+ * An internal class used to track accumulators by Spark itself.
+ */
+private[spark] object AccumulatorContext {
+
+ /**
+ * This global map holds the original accumulator objects that are created on the driver.
+ * It keeps weak references to these objects so that accumulators can be garbage-collected
+ * once the RDDs and user-code that reference them are cleaned up.
+ * TODO: Don't use a global map; these should be tied to a SparkContext (SPARK-13051).
+ */
+ private val originals = new ConcurrentHashMap[Long, jl.ref.WeakReference[AccumulatorV2[_, _]]]
+
+ private[this] val nextId = new AtomicLong(0L)
+
+ /**
+ * Returns a globally unique ID for a new [[AccumulatorV2]].
+ * Note: Once you copy the [[AccumulatorV2]] the ID is no longer unique.
+ */
+ def newId(): Long = nextId.getAndIncrement
+
+ /** Returns the number of accumulators registered. Used in testing. */
+ def numAccums: Int = originals.size
+
+ /**
+ * Registers an [[AccumulatorV2]] created on the driver such that it can be used on the executors.
+ *
+ * All accumulators registered here can later be used as a container for accumulating partial
+ * values across multiple tasks. This is what `org.apache.spark.scheduler.DAGScheduler` does.
+ * Note: if an accumulator is registered here, it should also be registered with the active
+ * context cleaner for cleanup so as to avoid memory leaks.
+ *
+ * If an [[AccumulatorV2]] with the same ID was already registered, this does nothing instead
+ * of overwriting it. We will never register same accumulator twice, this is just a sanity check.
+ */
+ def register(a: AccumulatorV2[_, _]): Unit = {
+ originals.putIfAbsent(a.id, new jl.ref.WeakReference[AccumulatorV2[_, _]](a))
+ }
+
+ /**
+ * Unregisters the [[AccumulatorV2]] with the given ID, if any.
+ */
+ def remove(id: Long): Unit = {
+ originals.remove(id)
+ }
+
+ /**
+ * Returns the [[AccumulatorV2]] registered with the given ID, if any.
+ */
+ def get(id: Long): Option[AccumulatorV2[_, _]] = {
+ Option(originals.get(id)).map { ref =>
+ // Since we are storing weak references, we must check whether the underlying data is valid.
+ val acc = ref.get
+ if (acc eq null) {
+ throw new IllegalStateException(s"Attempted to access garbage collected accumulator $id")
+ }
+ acc
+ }
+ }
+
+ /**
+ * Clears all registered [[AccumulatorV2]]s. For testing only.
+ */
+ def clear(): Unit = {
+ originals.clear()
+ }
+
+ // Identifier for distinguishing SQL metrics from other accumulators
+ private[spark] val SQL_ACCUM_IDENTIFIER = "sql"
+}
+
+
+/**
+ * An [[AccumulatorV2 accumulator]] for computing sum, count, and average of 64-bit integers.
+ *
+ * @since 2.0.0
+ */
+class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] {
+ private var _sum = 0L
+ private var _count = 0L
+
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
+ override def isZero: Boolean = _sum == 0L && _count == 0
+
+ override def copy(): LongAccumulator = {
+ val newAcc = new LongAccumulator
+ newAcc._count = this._count
+ newAcc._sum = this._sum
+ newAcc
+ }
+
+ override def reset(): Unit = {
+ _sum = 0L
+ _count = 0L
+ }
+
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
+ override def add(v: jl.Long): Unit = {
+ _sum += v
+ _count += 1
+ }
+
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
+ def add(v: Long): Unit = {
+ _sum += v
+ _count += 1
+ }
+
+ /**
+ * Returns the number of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def count: Long = _count
+
+ /**
+ * Returns the sum of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def sum: Long = _sum
+
+ /**
+ * Returns the average of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def avg: Double = _sum.toDouble / _count
+
+ override def merge(other: AccumulatorV2[jl.Long, jl.Long]): Unit = other match {
+ case o: LongAccumulator =>
+ _sum += o.sum
+ _count += o.count
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+ }
+
+ private[spark] def setValue(newValue: Long): Unit = _sum = newValue
+
+ override def value: jl.Long = _sum
+}
+
+
+/**
+ * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for double precision
+ * floating numbers.
+ *
+ * @since 2.0.0
+ */
+class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] {
+ private var _sum = 0.0
+ private var _count = 0L
+
+ override def isZero: Boolean = _sum == 0.0 && _count == 0
+
+ override def copy(): DoubleAccumulator = {
+ val newAcc = new DoubleAccumulator
+ newAcc._count = this._count
+ newAcc._sum = this._sum
+ newAcc
+ }
+
+ override def reset(): Unit = {
+ _sum = 0.0
+ _count = 0L
+ }
+
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
+ override def add(v: jl.Double): Unit = {
+ _sum += v
+ _count += 1
+ }
+
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
+ def add(v: Double): Unit = {
+ _sum += v
+ _count += 1
+ }
+
+ /**
+ * Returns the number of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def count: Long = _count
+
+ /**
+ * Returns the sum of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def sum: Double = _sum
+
+ /**
+ * Returns the average of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def avg: Double = _sum / _count
+
+ override def merge(other: AccumulatorV2[jl.Double, jl.Double]): Unit = other match {
+ case o: DoubleAccumulator =>
+ _sum += o.sum
+ _count += o.count
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+ }
+
+ private[spark] def setValue(newValue: Double): Unit = _sum = newValue
+
+ override def value: jl.Double = _sum
+}
+
+
+/**
+ * An [[AccumulatorV2 accumulator]] for collecting a list of elements.
+ *
+ * @since 2.0.0
+ */
+class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] {
+ private val _list: java.util.List[T] = Collections.synchronizedList(new ArrayList[T]())
+
+ override def isZero: Boolean = _list.isEmpty
+
+ override def copyAndReset(): CollectionAccumulator[T] = new CollectionAccumulator
+
+ override def copy(): CollectionAccumulator[T] = {
+ val newAcc = new CollectionAccumulator[T]
+ _list.synchronized {
+ newAcc._list.addAll(_list)
+ }
+ newAcc
+ }
+
+ override def reset(): Unit = _list.clear()
+
+ override def add(v: T): Unit = _list.add(v)
+
+ override def merge(other: AccumulatorV2[T, java.util.List[T]]): Unit = other match {
+ case o: CollectionAccumulator[T] => _list.addAll(o.value)
+ case _ => throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+ }
+
+ override def value: java.util.List[T] = _list.synchronized {
+ java.util.Collections.unmodifiableList(new ArrayList[T](_list))
+ }
+
+ private[spark] def setValue(newValue: java.util.List[T]): Unit = {
+ _list.clear()
+ _list.addAll(newValue)
+ }
+}
+
+
+class LegacyAccumulatorWrapper[R, T](
+ initialValue: R,
+ param: org.apache.spark.AccumulableParam[R, T]) extends AccumulatorV2[T, R] {
+ private[spark] var _value = initialValue // Current value on driver
+
+ @transient private lazy val _zero = param.zero(initialValue)
+
+ override def isZero: Boolean = _value.asInstanceOf[AnyRef].eq(_zero.asInstanceOf[AnyRef])
+
+ override def copy(): LegacyAccumulatorWrapper[R, T] = {
+ val acc = new LegacyAccumulatorWrapper(initialValue, param)
+ acc._value = _value
+ acc
+ }
+
+ override def reset(): Unit = {
+ _value = _zero
+ }
+
+ override def add(v: T): Unit = _value = param.addAccumulator(_value, v)
+
+ override def merge(other: AccumulatorV2[T, R]): Unit = other match {
+ case o: LegacyAccumulatorWrapper[R, T] => _value = param.addInPlace(_value, o.value)
+ case _ => throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+ }
+
+ override def value: R = _value
+}
diff --git a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala
deleted file mode 100644
index 81a7cbde01ce..000000000000
--- a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util
-
-import akka.actor.Actor
-import org.slf4j.Logger
-
-/**
- * A trait to enable logging all Akka actor messages. Here's an example of using this:
- *
- * {{{
- * class BlockManagerMasterActor extends Actor with ActorLogReceive with Logging {
- * ...
- * override def receiveWithLogging = {
- * case GetLocations(blockId) =>
- * sender ! getLocations(blockId)
- * ...
- * }
- * ...
- * }
- * }}}
- *
- */
-private[spark] trait ActorLogReceive {
- self: Actor =>
-
- override def receive: Actor.Receive = new Actor.Receive {
-
- private val _receiveWithLogging = receiveWithLogging
-
- override def isDefinedAt(o: Any): Boolean = {
- val handled = _receiveWithLogging.isDefinedAt(o)
- if (!handled) {
- log.debug(s"Received unexpected actor system event: $o")
- }
- handled
- }
-
- override def apply(o: Any): Unit = {
- if (log.isDebugEnabled) {
- log.debug(s"[actor] received message $o from ${self.sender}")
- }
- val start = System.nanoTime
- _receiveWithLogging.apply(o)
- val timeTaken = (System.nanoTime - start).toDouble / 1000000
- if (log.isDebugEnabled) {
- log.debug(s"[actor] handled message ($timeTaken ms) $o from ${self.sender}")
- }
- }
- }
-
- def receiveWithLogging: Actor.Receive
-
- protected def log: Logger
-}
diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
deleted file mode 100644
index 1738258a0c79..000000000000
--- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
+++ /dev/null
@@ -1,242 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util
-
-import scala.collection.JavaConverters._
-
-import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem}
-import akka.pattern.ask
-
-import com.typesafe.config.ConfigFactory
-import org.apache.log4j.{Level, Logger}
-
-import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException}
-import org.apache.spark.rpc.RpcTimeout
-
-/**
- * Various utility classes for working with Akka.
- */
-private[spark] object AkkaUtils extends Logging {
-
- /**
- * Creates an ActorSystem ready for remoting, with various Spark features. Returns both the
- * ActorSystem itself and its port (which is hard to get from Akka).
- *
- * Note: the `name` parameter is important, as even if a client sends a message to right
- * host + port, if the system name is incorrect, Akka will drop the message.
- *
- * If indestructible is set to true, the Actor System will continue running in the event
- * of a fatal exception. This is used by [[org.apache.spark.executor.Executor]].
- */
- def createActorSystem(
- name: String,
- host: String,
- port: Int,
- conf: SparkConf,
- securityManager: SecurityManager): (ActorSystem, Int) = {
- val startService: Int => (ActorSystem, Int) = { actualPort =>
- doCreateActorSystem(name, host, actualPort, conf, securityManager)
- }
- Utils.startServiceOnPort(port, startService, conf, name)
- }
-
- private def doCreateActorSystem(
- name: String,
- host: String,
- port: Int,
- conf: SparkConf,
- securityManager: SecurityManager): (ActorSystem, Int) = {
-
- val akkaThreads = conf.getInt("spark.akka.threads", 4)
- val akkaBatchSize = conf.getInt("spark.akka.batchSize", 15)
- val akkaTimeoutS = conf.getTimeAsSeconds("spark.akka.timeout",
- conf.get("spark.network.timeout", "120s"))
- val akkaFrameSize = maxFrameSizeBytes(conf)
- val akkaLogLifecycleEvents = conf.getBoolean("spark.akka.logLifecycleEvents", false)
- val lifecycleEvents = if (akkaLogLifecycleEvents) "on" else "off"
- if (!akkaLogLifecycleEvents) {
- // As a workaround for Akka issue #3787, we coerce the "EndpointWriter" log to be silent.
- // See: https://www.assembla.com/spaces/akka/tickets/3787#/
- Option(Logger.getLogger("akka.remote.EndpointWriter")).map(l => l.setLevel(Level.FATAL))
- }
-
- val logAkkaConfig = if (conf.getBoolean("spark.akka.logAkkaConfig", false)) "on" else "off"
-
- val akkaHeartBeatPausesS = conf.getTimeAsSeconds("spark.akka.heartbeat.pauses", "6000s")
- val akkaHeartBeatIntervalS = conf.getTimeAsSeconds("spark.akka.heartbeat.interval", "1000s")
-
- val secretKey = securityManager.getSecretKey()
- val isAuthOn = securityManager.isAuthenticationEnabled()
- if (isAuthOn && secretKey == null) {
- throw new Exception("Secret key is null with authentication on")
- }
- val requireCookie = if (isAuthOn) "on" else "off"
- val secureCookie = if (isAuthOn) secretKey else ""
- logDebug(s"In createActorSystem, requireCookie is: $requireCookie")
-
- val akkaSslConfig = securityManager.akkaSSLOptions.createAkkaConfig
- .getOrElse(ConfigFactory.empty())
-
- val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap.asJava)
- .withFallback(akkaSslConfig).withFallback(ConfigFactory.parseString(
- s"""
- |akka.daemonic = on
- |akka.loggers = [""akka.event.slf4j.Slf4jLogger""]
- |akka.stdout-loglevel = "ERROR"
- |akka.jvm-exit-on-fatal-error = off
- |akka.remote.require-cookie = "$requireCookie"
- |akka.remote.secure-cookie = "$secureCookie"
- |akka.remote.transport-failure-detector.heartbeat-interval = $akkaHeartBeatIntervalS s
- |akka.remote.transport-failure-detector.acceptable-heartbeat-pause = $akkaHeartBeatPausesS s
- |akka.actor.provider = "akka.remote.RemoteActorRefProvider"
- |akka.remote.netty.tcp.transport-class = "akka.remote.transport.netty.NettyTransport"
- |akka.remote.netty.tcp.hostname = "$host"
- |akka.remote.netty.tcp.port = $port
- |akka.remote.netty.tcp.tcp-nodelay = on
- |akka.remote.netty.tcp.connection-timeout = $akkaTimeoutS s
- |akka.remote.netty.tcp.maximum-frame-size = ${akkaFrameSize}B
- |akka.remote.netty.tcp.execution-pool-size = $akkaThreads
- |akka.actor.default-dispatcher.throughput = $akkaBatchSize
- |akka.log-config-on-start = $logAkkaConfig
- |akka.remote.log-remote-lifecycle-events = $lifecycleEvents
- |akka.log-dead-letters = $lifecycleEvents
- |akka.log-dead-letters-during-shutdown = $lifecycleEvents
- """.stripMargin))
-
- val actorSystem = ActorSystem(name, akkaConf)
- val provider = actorSystem.asInstanceOf[ExtendedActorSystem].provider
- val boundPort = provider.getDefaultAddress.port.get
- (actorSystem, boundPort)
- }
-
- private val AKKA_MAX_FRAME_SIZE_IN_MB = Int.MaxValue / 1024 / 1024
-
- /** Returns the configured max frame size for Akka messages in bytes. */
- def maxFrameSizeBytes(conf: SparkConf): Int = {
- val frameSizeInMB = conf.getInt("spark.akka.frameSize", 128)
- if (frameSizeInMB > AKKA_MAX_FRAME_SIZE_IN_MB) {
- throw new IllegalArgumentException(
- s"spark.akka.frameSize should not be greater than $AKKA_MAX_FRAME_SIZE_IN_MB MB")
- }
- frameSizeInMB * 1024 * 1024
- }
-
- /** Space reserved for extra data in an Akka message besides serialized task or task result. */
- val reservedSizeBytes = 200 * 1024
-
- /**
- * Send a message to the given actor and get its result within a default timeout, or
- * throw a SparkException if this fails.
- */
- def askWithReply[T](
- message: Any,
- actor: ActorRef,
- timeout: RpcTimeout): T = {
- askWithReply[T](message, actor, maxAttempts = 1, retryInterval = Int.MaxValue, timeout)
- }
-
- /**
- * Send a message to the given actor and get its result within a default timeout, or
- * throw a SparkException if this fails even after the specified number of retries.
- */
- def askWithReply[T](
- message: Any,
- actor: ActorRef,
- maxAttempts: Int,
- retryInterval: Long,
- timeout: RpcTimeout): T = {
- // TODO: Consider removing multiple attempts
- if (actor == null) {
- throw new SparkException(s"Error sending message [message = $message]" +
- " as actor is null ")
- }
- var attempts = 0
- var lastException: Exception = null
- while (attempts < maxAttempts) {
- attempts += 1
- try {
- val future = actor.ask(message)(timeout.duration)
- val result = timeout.awaitResult(future)
- if (result == null) {
- throw new SparkException("Actor returned null")
- }
- return result.asInstanceOf[T]
- } catch {
- case ie: InterruptedException => throw ie
- case e: Exception =>
- lastException = e
- logWarning(s"Error sending message [message = $message] in $attempts attempts", e)
- }
- if (attempts < maxAttempts) {
- Thread.sleep(retryInterval)
- }
- }
-
- throw new SparkException(
- s"Error sending message [message = $message]", lastException)
- }
-
- def makeDriverRef(name: String, conf: SparkConf, actorSystem: ActorSystem): ActorRef = {
- val driverActorSystemName = SparkEnv.driverActorSystemName
- val driverHost: String = conf.get("spark.driver.host", "localhost")
- val driverPort: Int = conf.getInt("spark.driver.port", 7077)
- Utils.checkHost(driverHost, "Expected hostname")
- val url = address(protocol(actorSystem), driverActorSystemName, driverHost, driverPort, name)
- val timeout = RpcUtils.lookupRpcTimeout(conf)
- logInfo(s"Connecting to $name: $url")
- timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration))
- }
-
- def makeExecutorRef(
- name: String,
- conf: SparkConf,
- host: String,
- port: Int,
- actorSystem: ActorSystem): ActorRef = {
- val executorActorSystemName = SparkEnv.executorActorSystemName
- Utils.checkHost(host, "Expected hostname")
- val url = address(protocol(actorSystem), executorActorSystemName, host, port, name)
- val timeout = RpcUtils.lookupRpcTimeout(conf)
- logInfo(s"Connecting to $name: $url")
- timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration))
- }
-
- def protocol(actorSystem: ActorSystem): String = {
- val akkaConf = actorSystem.settings.config
- val sslProp = "akka.remote.netty.tcp.enable-ssl"
- protocol(akkaConf.hasPath(sslProp) && akkaConf.getBoolean(sslProp))
- }
-
- def protocol(ssl: Boolean = false): String = {
- if (ssl) {
- "akka.ssl.tcp"
- } else {
- "akka.tcp"
- }
- }
-
- def address(
- protocol: String,
- systemName: String,
- host: String,
- port: Int,
- actorName: String): String = {
- s"$protocol://$systemName@$host:$port/user/$actorName"
- }
-
-}
diff --git a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala
deleted file mode 100644
index 61b5a4cecddc..000000000000
--- a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala
+++ /dev/null
@@ -1,180 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util
-
-import java.util.concurrent._
-import java.util.concurrent.atomic.AtomicBoolean
-
-import com.google.common.annotations.VisibleForTesting
-import org.apache.spark.SparkContext
-
-/**
- * Asynchronously passes events to registered listeners.
- *
- * Until `start()` is called, all posted events are only buffered. Only after this listener bus
- * has started will events be actually propagated to all attached listeners. This listener bus
- * is stopped when `stop()` is called, and it will drop further events after stopping.
- *
- * @param name name of the listener bus, will be the name of the listener thread.
- * @tparam L type of listener
- * @tparam E type of event
- */
-private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: String)
- extends ListenerBus[L, E] {
-
- self =>
-
- private var sparkContext: SparkContext = null
-
- /* Cap the capacity of the event queue so we get an explicit error (rather than
- * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */
- private val EVENT_QUEUE_CAPACITY = 10000
- private val eventQueue = new LinkedBlockingQueue[E](EVENT_QUEUE_CAPACITY)
-
- // Indicate if `start()` is called
- private val started = new AtomicBoolean(false)
- // Indicate if `stop()` is called
- private val stopped = new AtomicBoolean(false)
-
- // Indicate if we are processing some event
- // Guarded by `self`
- private var processingEvent = false
-
- // A counter that represents the number of events produced and consumed in the queue
- private val eventLock = new Semaphore(0)
-
- private val listenerThread = new Thread(name) {
- setDaemon(true)
- override def run(): Unit = Utils.tryOrStopSparkContext(sparkContext) {
- while (true) {
- eventLock.acquire()
- self.synchronized {
- processingEvent = true
- }
- try {
- val event = eventQueue.poll
- if (event == null) {
- // Get out of the while loop and shutdown the daemon thread
- if (!stopped.get) {
- throw new IllegalStateException("Polling `null` from eventQueue means" +
- " the listener bus has been stopped. So `stopped` must be true")
- }
- return
- }
- postToAll(event)
- } finally {
- self.synchronized {
- processingEvent = false
- }
- }
- }
- }
- }
-
- /**
- * Start sending events to attached listeners.
- *
- * This first sends out all buffered events posted before this listener bus has started, then
- * listens for any additional events asynchronously while the listener bus is still running.
- * This should only be called once.
- *
- * @param sc Used to stop the SparkContext in case the listener thread dies.
- */
- def start(sc: SparkContext) {
- if (started.compareAndSet(false, true)) {
- sparkContext = sc
- listenerThread.start()
- } else {
- throw new IllegalStateException(s"$name already started!")
- }
- }
-
- def post(event: E) {
- if (stopped.get) {
- // Drop further events to make `listenerThread` exit ASAP
- logError(s"$name has already stopped! Dropping event $event")
- return
- }
- val eventAdded = eventQueue.offer(event)
- if (eventAdded) {
- eventLock.release()
- } else {
- onDropEvent(event)
- }
- }
-
- /**
- * For testing only. Wait until there are no more events in the queue, or until the specified
- * time has elapsed. Throw `TimeoutException` if the specified time elapsed before the queue
- * emptied.
- */
- @VisibleForTesting
- @throws(classOf[TimeoutException])
- def waitUntilEmpty(timeoutMillis: Long): Unit = {
- val finishTime = System.currentTimeMillis + timeoutMillis
- while (!queueIsEmpty) {
- if (System.currentTimeMillis > finishTime) {
- throw new TimeoutException(
- s"The event queue is not empty after $timeoutMillis milliseconds")
- }
- /* Sleep rather than using wait/notify, because this is used only for testing and
- * wait/notify add overhead in the general case. */
- Thread.sleep(10)
- }
- }
-
- /**
- * For testing only. Return whether the listener daemon thread is still alive.
- */
- @VisibleForTesting
- def listenerThreadIsAlive: Boolean = listenerThread.isAlive
-
- /**
- * Return whether the event queue is empty.
- *
- * The use of synchronized here guarantees that all events that once belonged to this queue
- * have already been processed by all attached listeners, if this returns true.
- */
- private def queueIsEmpty: Boolean = synchronized { eventQueue.isEmpty && !processingEvent }
-
- /**
- * Stop the listener bus. It will wait until the queued events have been processed, but drop the
- * new events after stopping.
- */
- def stop() {
- if (!started.get()) {
- throw new IllegalStateException(s"Attempted to stop $name that has not yet started!")
- }
- if (stopped.compareAndSet(false, true)) {
- // Call eventLock.release() so that listenerThread will poll `null` from `eventQueue` and know
- // `stop` is called.
- eventLock.release()
- listenerThread.join()
- } else {
- // Keep quiet
- }
- }
-
- /**
- * If the event queue exceeds its capacity, the new events will be dropped. The subclasses will be
- * notified with the dropped events.
- *
- * Note: `onDropEvent` can be called in any thread.
- */
- def onDropEvent(event: E): Unit
-}
diff --git a/core/src/main/scala/org/apache/spark/util/Benchmark.scala b/core/src/main/scala/org/apache/spark/util/Benchmark.scala
new file mode 100644
index 000000000000..7def44bd2a2b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/Benchmark.scala
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.io.{OutputStream, PrintStream}
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.duration._
+import scala.util.Try
+
+import org.apache.commons.io.output.TeeOutputStream
+import org.apache.commons.lang3.SystemUtils
+
+/**
+ * Utility class to benchmark components. An example of how to use this is:
+ * val benchmark = new Benchmark("My Benchmark", valuesPerIteration)
+ * benchmark.addCase("V1")()
+ * benchmark.addCase("V2")()
+ * benchmark.run
+ * This will output the average time to run each function and the rate of each function.
+ *
+ * The benchmark function takes one argument that is the iteration that's being run.
+ *
+ * @param name name of this benchmark.
+ * @param valuesPerIteration number of values used in the test case, used to compute rows/s.
+ * @param minNumIters the min number of iterations that will be run per case, not counting warm-up.
+ * @param warmupTime amount of time to spend running dummy case iterations for JIT warm-up.
+ * @param minTime further iterations will be run for each case until this time is used up.
+ * @param outputPerIteration if true, the timing for each run will be printed to stdout.
+ * @param output optional output stream to write benchmark results to
+ */
+private[spark] class Benchmark(
+ name: String,
+ valuesPerIteration: Long,
+ minNumIters: Int = 2,
+ warmupTime: FiniteDuration = 2.seconds,
+ minTime: FiniteDuration = 2.seconds,
+ outputPerIteration: Boolean = false,
+ output: Option[OutputStream] = None) {
+ import Benchmark._
+ val benchmarks = mutable.ArrayBuffer.empty[Benchmark.Case]
+
+ val out = if (output.isDefined) {
+ new PrintStream(new TeeOutputStream(System.out, output.get))
+ } else {
+ System.out
+ }
+
+ /**
+ * Adds a case to run when run() is called. The given function will be run for several
+ * iterations to collect timing statistics.
+ *
+ * @param name of the benchmark case
+ * @param numIters if non-zero, forces exactly this many iterations to be run
+ */
+ def addCase(name: String, numIters: Int = 0)(f: Int => Unit): Unit = {
+ addTimerCase(name, numIters) { timer =>
+ timer.startTiming()
+ f(timer.iteration)
+ timer.stopTiming()
+ }
+ }
+
+ /**
+ * Adds a case with manual timing control. When the function is run, timing does not start
+ * until timer.startTiming() is called within the given function. The corresponding
+ * timer.stopTiming() method must be called before the function returns.
+ *
+ * @param name of the benchmark case
+ * @param numIters if non-zero, forces exactly this many iterations to be run
+ */
+ def addTimerCase(name: String, numIters: Int = 0)(f: Benchmark.Timer => Unit): Unit = {
+ benchmarks += Benchmark.Case(name, f, numIters)
+ }
+
+ /**
+ * Runs the benchmark and outputs the results to stdout. This should be copied and added as
+ * a comment with the benchmark. Although the results vary from machine to machine, it should
+ * provide some baseline.
+ */
+ def run(): Unit = {
+ require(benchmarks.nonEmpty)
+ // scalastyle:off
+ println("Running benchmark: " + name)
+
+ val results = benchmarks.map { c =>
+ println(" Running case: " + c.name)
+ measure(valuesPerIteration, c.numIters)(c.fn)
+ }
+ println
+
+ val firstBest = results.head.bestMs
+ // The results are going to be processor specific so it is useful to include that.
+ out.println(Benchmark.getJVMOSInfo())
+ out.println(Benchmark.getProcessorName())
+ out.printf("%-40s %16s %12s %13s %10s\n", name + ":", "Best/Avg Time(ms)", "Rate(M/s)",
+ "Per Row(ns)", "Relative")
+ out.println("-" * 96)
+ results.zip(benchmarks).foreach { case (result, benchmark) =>
+ out.printf("%-40s %16s %12s %13s %10s\n",
+ benchmark.name,
+ "%5.0f / %4.0f" format (result.bestMs, result.avgMs),
+ "%10.1f" format result.bestRate,
+ "%6.1f" format (1000 / result.bestRate),
+ "%3.1fX" format (firstBest / result.bestMs))
+ }
+ out.println
+ // scalastyle:on
+ }
+
+ /**
+ * Runs a single function `f` for iters, returning the average time the function took and
+ * the rate of the function.
+ */
+ def measure(num: Long, overrideNumIters: Int)(f: Timer => Unit): Result = {
+ System.gc() // ensures garbage from previous cases don't impact this one
+ val warmupDeadline = warmupTime.fromNow
+ while (!warmupDeadline.isOverdue) {
+ f(new Benchmark.Timer(-1))
+ }
+ val minIters = if (overrideNumIters != 0) overrideNumIters else minNumIters
+ val minDuration = if (overrideNumIters != 0) 0 else minTime.toNanos
+ val runTimes = ArrayBuffer[Long]()
+ var i = 0
+ while (i < minIters || runTimes.sum < minDuration) {
+ val timer = new Benchmark.Timer(i)
+ f(timer)
+ val runTime = timer.totalTime()
+ runTimes += runTime
+
+ if (outputPerIteration) {
+ // scalastyle:off
+ println(s"Iteration $i took ${runTime / 1000} microseconds")
+ // scalastyle:on
+ }
+ i += 1
+ }
+ // scalastyle:off
+ println(s" Stopped after $i iterations, ${runTimes.sum / 1000000} ms")
+ // scalastyle:on
+ val best = runTimes.min
+ val avg = runTimes.sum / runTimes.size
+ Result(avg / 1000000.0, num / (best / 1000.0), best / 1000000.0)
+ }
+}
+
+private[spark] object Benchmark {
+
+ /**
+ * Object available to benchmark code to control timing e.g. to exclude set-up time.
+ *
+ * @param iteration specifies this is the nth iteration of running the benchmark case
+ */
+ class Timer(val iteration: Int) {
+ private var accumulatedTime: Long = 0L
+ private var timeStart: Long = 0L
+
+ def startTiming(): Unit = {
+ assert(timeStart == 0L, "Already started timing.")
+ timeStart = System.nanoTime
+ }
+
+ def stopTiming(): Unit = {
+ assert(timeStart != 0L, "Have not started timing.")
+ accumulatedTime += System.nanoTime - timeStart
+ timeStart = 0L
+ }
+
+ def totalTime(): Long = {
+ assert(timeStart == 0L, "Have not stopped timing.")
+ accumulatedTime
+ }
+ }
+
+ case class Case(name: String, fn: Timer => Unit, numIters: Int)
+ case class Result(avgMs: Double, bestRate: Double, bestMs: Double)
+
+ /**
+ * This should return a user helpful processor information. Getting at this depends on the OS.
+ * This should return something like "Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz"
+ */
+ def getProcessorName(): String = {
+ val cpu = if (SystemUtils.IS_OS_MAC_OSX) {
+ Utils.executeAndGetOutput(Seq("/usr/sbin/sysctl", "-n", "machdep.cpu.brand_string"))
+ } else if (SystemUtils.IS_OS_LINUX) {
+ Try {
+ val grepPath = Utils.executeAndGetOutput(Seq("which", "grep")).stripLineEnd
+ Utils.executeAndGetOutput(Seq(grepPath, "-m", "1", "model name", "/proc/cpuinfo"))
+ .stripLineEnd.replaceFirst("model name[\\s*]:[\\s*]", "")
+ }.getOrElse("Unknown processor")
+ } else {
+ System.getenv("PROCESSOR_IDENTIFIER")
+ }
+ cpu
+ }
+
+ /**
+ * This should return a user helpful JVM & OS information.
+ * This should return something like
+ * "OpenJDK 64-Bit Server VM 1.8.0_65-b17 on Linux 4.1.13-100.fc21.x86_64"
+ */
+ def getJVMOSInfo(): String = {
+ val vmName = System.getProperty("java.vm.name")
+ val runtimeVersion = System.getProperty("java.runtime.version")
+ val osName = System.getProperty("os.name")
+ val osVersion = System.getProperty("os.version")
+ s"${vmName} ${runtimeVersion} on ${osName} ${osVersion}"
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala
index 54de4d4ee8ca..50dc948e6c41 100644
--- a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala
+++ b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala
@@ -20,14 +20,13 @@ package org.apache.spark.util
import java.io.InputStream
import java.nio.ByteBuffer
-import org.apache.spark.storage.BlockManager
+import org.apache.spark.storage.StorageUtils
/**
- * Reads data from a ByteBuffer, and optionally cleans it up using BlockManager.dispose()
- * at the end of the stream (e.g. to close a memory-mapped file).
+ * Reads data from a ByteBuffer.
*/
private[spark]
-class ByteBufferInputStream(private var buffer: ByteBuffer, dispose: Boolean = false)
+class ByteBufferInputStream(private var buffer: ByteBuffer)
extends InputStream {
override def read(): Int = {
@@ -68,13 +67,10 @@ class ByteBufferInputStream(private var buffer: ByteBuffer, dispose: Boolean = f
}
/**
- * Clean up the buffer, and potentially dispose of it using BlockManager.dispose().
+ * Clean up the buffer, and potentially dispose of it using StorageUtils.dispose().
*/
private def cleanUp() {
if (buffer != null) {
- if (dispose) {
- BlockManager.dispose(buffer)
- }
buffer = null
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala
new file mode 100644
index 000000000000..9077b86f9ba1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.io.ByteArrayOutputStream
+import java.nio.ByteBuffer
+
+/**
+ * Provide a zero-copy way to convert data in ByteArrayOutputStream to ByteBuffer
+ */
+private[spark] class ByteBufferOutputStream(capacity: Int) extends ByteArrayOutputStream(capacity) {
+
+ def this() = this(32)
+
+ def getCount(): Int = count
+
+ private[this] var closed: Boolean = false
+
+ override def write(b: Int): Unit = {
+ require(!closed, "cannot write to a closed ByteBufferOutputStream")
+ super.write(b)
+ }
+
+ override def write(b: Array[Byte], off: Int, len: Int): Unit = {
+ require(!closed, "cannot write to a closed ByteBufferOutputStream")
+ super.write(b, off, len)
+ }
+
+ override def reset(): Unit = {
+ require(!closed, "cannot reset a closed ByteBufferOutputStream")
+ super.reset()
+ }
+
+ override def close(): Unit = {
+ if (!closed) {
+ super.close()
+ closed = true
+ }
+ }
+
+ def toByteBuffer: ByteBuffer = {
+ require(closed, "can only call toByteBuffer() after ByteBufferOutputStream has been closed")
+ ByteBuffer.wrap(buf, 0, count)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/CausedBy.scala b/core/src/main/scala/org/apache/spark/util/CausedBy.scala
new file mode 100644
index 000000000000..73df446d981c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/CausedBy.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+/**
+ * Extractor Object for pulling out the root cause of an error.
+ * If the error contains no cause, it will return the error itself.
+ *
+ * Usage:
+ * try {
+ * ...
+ * } catch {
+ * case CausedBy(ex: CommitDeniedException) => ...
+ * }
+ */
+private[spark] object CausedBy {
+
+ def unapply(e: Throwable): Option[Throwable] = {
+ Option(e.getCause).flatMap(cause => unapply(cause)).orElse(Some(e))
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
index 1b49dca9dc78..2d5d3f863daa 100644
--- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
@@ -19,12 +19,14 @@ package org.apache.spark.util
import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
-import scala.collection.mutable.{Map, Set}
+import scala.collection.mutable.{Map, Set, Stack}
+import scala.language.existentials
-import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
-import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._
+import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor, Type}
+import org.apache.xbean.asm5.Opcodes._
-import org.apache.spark.{Logging, SparkEnv, SparkException}
+import org.apache.spark.{SparkEnv, SparkException}
+import org.apache.spark.internal.Logging
/**
* A cleaner that renders closures serializable if they can be done so safely.
@@ -76,35 +78,67 @@ private[spark] object ClosureCleaner extends Logging {
*/
private def getInnerClosureClasses(obj: AnyRef): List[Class[_]] = {
val seen = Set[Class[_]](obj.getClass)
- var stack = List[Class[_]](obj.getClass)
+ val stack = Stack[Class[_]](obj.getClass)
while (!stack.isEmpty) {
- val cr = getClassReader(stack.head)
- stack = stack.tail
+ val cr = getClassReader(stack.pop())
val set = Set[Class[_]]()
cr.accept(new InnerClosureFinder(set), 0)
for (cls <- set -- seen) {
seen += cls
- stack = cls :: stack
+ stack.push(cls)
}
}
(seen - obj.getClass).toList
}
- private def createNullValue(cls: Class[_]): AnyRef = {
- if (cls.isPrimitive) {
- cls match {
- case java.lang.Boolean.TYPE => new java.lang.Boolean(false)
- case java.lang.Character.TYPE => new java.lang.Character('\u0000')
- case java.lang.Void.TYPE =>
- // This should not happen because `Foo(void x) {}` does not compile.
- throw new IllegalStateException("Unexpected void parameter in constructor")
- case _ => new java.lang.Byte(0: Byte)
+ /** Initializes the accessed fields for outer classes and their super classes. */
+ private def initAccessedFields(
+ accessedFields: Map[Class[_], Set[String]],
+ outerClasses: Seq[Class[_]]): Unit = {
+ for (cls <- outerClasses) {
+ var currentClass = cls
+ assert(currentClass != null, "The outer class can't be null.")
+
+ while (currentClass != null) {
+ accessedFields(currentClass) = Set.empty[String]
+ currentClass = currentClass.getSuperclass()
}
- } else {
- null
}
}
+ /** Sets accessed fields for given class in clone object based on given object. */
+ private def setAccessedFields(
+ outerClass: Class[_],
+ clone: AnyRef,
+ obj: AnyRef,
+ accessedFields: Map[Class[_], Set[String]]): Unit = {
+ for (fieldName <- accessedFields(outerClass)) {
+ val field = outerClass.getDeclaredField(fieldName)
+ field.setAccessible(true)
+ val value = field.get(obj)
+ field.set(clone, value)
+ }
+ }
+
+ /** Clones a given object and sets accessed fields in cloned object. */
+ private def cloneAndSetFields(
+ parent: AnyRef,
+ obj: AnyRef,
+ outerClass: Class[_],
+ accessedFields: Map[Class[_], Set[String]]): AnyRef = {
+ val clone = instantiateClass(outerClass, parent)
+
+ var currentClass = outerClass
+ assert(currentClass != null, "The outer class can't be null.")
+
+ while (currentClass != null) {
+ setAccessedFields(currentClass, clone, obj, accessedFields)
+ currentClass = currentClass.getSuperclass()
+ }
+
+ clone
+ }
+
/**
* Clean the given closure in place.
*
@@ -214,9 +248,8 @@ private[spark] object ClosureCleaner extends Logging {
logDebug(s" + populating accessed fields because this is the starting closure")
// Initialize accessed fields with the outer classes first
// This step is needed to associate the fields to the correct classes later
- for (cls <- outerClasses) {
- accessedFields(cls) = Set[String]()
- }
+ initAccessedFields(accessedFields, outerClasses)
+
// Populate accessed fields by visiting all fields and methods accessed by this and
// all of its inner closures. If transitive cleaning is enabled, this may recursively
// visits methods that belong to other classes in search of transitively referenced fields.
@@ -232,16 +265,24 @@ private[spark] object ClosureCleaner extends Logging {
// Note that all outer objects but the outermost one (first one in this list) must be closures
var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse
var parent: AnyRef = null
- if (outerPairs.size > 0 && !isClosure(outerPairs.head._1)) {
- // The closure is ultimately nested inside a class; keep the object of that
- // class without cloning it since we don't want to clone the user's objects.
- // Note that we still need to keep around the outermost object itself because
- // we need it to clone its child closure later (see below).
- logDebug(s" + outermost object is not a closure, so do not clone it: ${outerPairs.head}")
- parent = outerPairs.head._2 // e.g. SparkContext
- outerPairs = outerPairs.tail
- } else if (outerPairs.size > 0) {
- logDebug(s" + outermost object is a closure, so we just keep it: ${outerPairs.head}")
+ if (outerPairs.size > 0) {
+ val (outermostClass, outermostObject) = outerPairs.head
+ if (isClosure(outermostClass)) {
+ logDebug(s" + outermost object is a closure, so we clone it: ${outerPairs.head}")
+ } else if (outermostClass.getName.startsWith("$line")) {
+ // SPARK-14558: if the outermost object is a REPL line object, we should clone and clean it
+ // as it may carray a lot of unnecessary information, e.g. hadoop conf, spark conf, etc.
+ logDebug(s" + outermost object is a REPL line object, so we clone it: ${outerPairs.head}")
+ } else {
+ // The closure is ultimately nested inside a class; keep the object of that
+ // class without cloning it since we don't want to clone the user's objects.
+ // Note that we still need to keep around the outermost object itself because
+ // we need it to clone its child closure later (see below).
+ logDebug(" + outermost object is not a closure or REPL line object, so do not clone it: " +
+ outerPairs.head)
+ parent = outermostObject // e.g. SparkContext
+ outerPairs = outerPairs.tail
+ }
} else {
logDebug(" + there are no enclosing objects!")
}
@@ -254,13 +295,8 @@ private[spark] object ClosureCleaner extends Logging {
// required fields from the original object. We need the parent here because the Java
// language specification requires the first constructor parameter of any closure to be
// its enclosing object.
- val clone = instantiateClass(cls, parent)
- for (fieldName <- accessedFields(cls)) {
- val field = cls.getDeclaredField(fieldName)
- field.setAccessible(true)
- val value = field.get(obj)
- field.set(clone, value)
- }
+ val clone = cloneAndSetFields(parent, obj, cls, accessedFields)
+
// If transitive cleaning is enabled, we recursively clean any enclosing closure using
// the already populated accessed fields map of the starting closure
if (cleanTransitively && isClosure(clone.getClass)) {
@@ -325,11 +361,11 @@ private[spark] object ClosureCleaner extends Logging {
private[spark] class ReturnStatementInClosureException
extends SparkException("Return statements aren't allowed in Spark closures")
-private class ReturnStatementFinder extends ClassVisitor(ASM4) {
+private class ReturnStatementFinder extends ClassVisitor(ASM5) {
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
if (name.contains("apply")) {
- new MethodVisitor(ASM4) {
+ new MethodVisitor(ASM5) {
override def visitTypeInsn(op: Int, tp: String) {
if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) {
throw new ReturnStatementInClosureException
@@ -337,7 +373,7 @@ private class ReturnStatementFinder extends ClassVisitor(ASM4) {
}
}
} else {
- new MethodVisitor(ASM4) {}
+ new MethodVisitor(ASM5) {}
}
}
}
@@ -361,7 +397,7 @@ private[util] class FieldAccessFinder(
findTransitively: Boolean,
specificMethod: Option[MethodIdentifier[_]] = None,
visitedMethods: Set[MethodIdentifier[_]] = Set.empty)
- extends ClassVisitor(ASM4) {
+ extends ClassVisitor(ASM5) {
override def visitMethod(
access: Int,
@@ -376,7 +412,7 @@ private[util] class FieldAccessFinder(
return null
}
- new MethodVisitor(ASM4) {
+ new MethodVisitor(ASM5) {
override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) {
if (op == GETFIELD) {
for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) {
@@ -385,7 +421,8 @@ private[util] class FieldAccessFinder(
}
}
- override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) {
+ override def visitMethodInsn(
+ op: Int, owner: String, name: String, desc: String, itf: Boolean) {
for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) {
// Check for calls a getter method for a variable in an interpreter wrapper object.
// This means that the corresponding field will be accessed, so we should save it.
@@ -398,8 +435,15 @@ private[util] class FieldAccessFinder(
if (!visitedMethods.contains(m)) {
// Keep track of visited methods to avoid potential infinite cycles
visitedMethods += m
- ClosureCleaner.getClassReader(cl).accept(
- new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0)
+
+ var currentClass = cl
+ assert(currentClass != null, "The outer class can't be null.")
+
+ while (currentClass != null) {
+ ClosureCleaner.getClassReader(currentClass).accept(
+ new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0)
+ currentClass = currentClass.getSuperclass()
+ }
}
}
}
@@ -408,7 +452,7 @@ private[util] class FieldAccessFinder(
}
}
-private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) {
+private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM5) {
var myName: String = null
// TODO: Recursively find inner closures that we indirectly reference, e.g.
@@ -423,9 +467,9 @@ private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
- new MethodVisitor(ASM4) {
- override def visitMethodInsn(op: Int, owner: String, name: String,
- desc: String) {
+ new MethodVisitor(ASM5) {
+ override def visitMethodInsn(
+ op: Int, owner: String, name: String, desc: String, itf: Boolean) {
val argTypes = Type.getArgumentTypes(desc)
if (op == INVOKESPECIAL && name == "" && argTypes.length > 0
&& argTypes(0).toString.startsWith("L") // is it an object?
diff --git a/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala b/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala
new file mode 100644
index 000000000000..d73901686b70
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.io.PrintStream
+
+import org.apache.spark.SparkException
+
+/**
+ * Contains basic command line parsing functionality and methods to parse some common Spark CLI
+ * options.
+ */
+private[spark] trait CommandLineUtils {
+
+ // Exposed for testing
+ private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode)
+
+ private[spark] var printStream: PrintStream = System.err
+
+ // scalastyle:off println
+
+ private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str)
+
+ private[spark] def printErrorAndExit(str: String): Unit = {
+ printStream.println("Error: " + str)
+ printStream.println("Run with --help for usage help or --verbose for debug output")
+ exitFn(1)
+ }
+
+ // scalastyle:on println
+
+ private[spark] def parseSparkConfProperty(pair: String): (String, String) = {
+ pair.split("=", 2).toSeq match {
+ case Seq(k, v) => (k, v)
+ case _ => printErrorAndExit(s"Spark config without '=': $pair")
+ throw new SparkException(s"Spark config without '=': $pair")
+ }
+ }
+
+ def main(args: Array[String]): Unit
+}
diff --git a/core/src/main/scala/org/apache/spark/util/EventLoop.scala b/core/src/main/scala/org/apache/spark/util/EventLoop.scala
index e9b2b8d24b47..3ea9139e1102 100644
--- a/core/src/main/scala/org/apache/spark/util/EventLoop.scala
+++ b/core/src/main/scala/org/apache/spark/util/EventLoop.scala
@@ -17,12 +17,12 @@
package org.apache.spark.util
-import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque}
+import java.util.concurrent.atomic.AtomicBoolean
import scala.util.control.NonFatal
-import org.apache.spark.Logging
+import org.apache.spark.internal.Logging
/**
* An event loop to receive events from the caller and process all events in the event thread. It
@@ -47,13 +47,12 @@ private[spark] abstract class EventLoop[E](name: String) extends Logging {
try {
onReceive(event)
} catch {
- case NonFatal(e) => {
+ case NonFatal(e) =>
try {
onError(e)
} catch {
case NonFatal(e) => logError("Unexpected error in " + name, e)
}
- }
}
}
} catch {
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index ee2eb58cf5e2..8296c4294242 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -19,19 +19,21 @@ package org.apache.spark.util
import java.util.{Properties, UUID}
-import org.apache.spark.scheduler.cluster.ExecutorInfo
-
import scala.collection.JavaConverters._
import scala.collection.Map
+import com.fasterxml.jackson.databind.ObjectMapper
+import com.fasterxml.jackson.module.scala.DefaultScalaModule
import org.json4s.DefaultFormats
-import org.json4s.JsonDSL._
import org.json4s.JsonAST._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
import org.apache.spark._
import org.apache.spark.executor._
import org.apache.spark.rdd.RDDOperationScope
import org.apache.spark.scheduler._
+import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.storage._
/**
@@ -54,6 +56,8 @@ private[spark] object JsonProtocol {
private implicit val format = DefaultFormats
+ private val mapper = new ObjectMapper().registerModule(DefaultScalaModule)
+
/** ------------------------------------------------- *
* JSON serialization methods for SparkListenerEvents |
* -------------------------------------------------- */
@@ -96,26 +100,27 @@ private[spark] object JsonProtocol {
executorMetricsUpdateToJson(metricsUpdate)
case blockUpdated: SparkListenerBlockUpdated =>
throw new MatchError(blockUpdated) // TODO(ekl) implement this
+ case _ => parse(mapper.writeValueAsString(event))
}
}
def stageSubmittedToJson(stageSubmitted: SparkListenerStageSubmitted): JValue = {
val stageInfo = stageInfoToJson(stageSubmitted.stageInfo)
val properties = propertiesToJson(stageSubmitted.properties)
- ("Event" -> Utils.getFormattedClassName(stageSubmitted)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.stageSubmitted) ~
("Stage Info" -> stageInfo) ~
("Properties" -> properties)
}
def stageCompletedToJson(stageCompleted: SparkListenerStageCompleted): JValue = {
val stageInfo = stageInfoToJson(stageCompleted.stageInfo)
- ("Event" -> Utils.getFormattedClassName(stageCompleted)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.stageCompleted) ~
("Stage Info" -> stageInfo)
}
def taskStartToJson(taskStart: SparkListenerTaskStart): JValue = {
val taskInfo = taskStart.taskInfo
- ("Event" -> Utils.getFormattedClassName(taskStart)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.taskStart) ~
("Stage ID" -> taskStart.stageId) ~
("Stage Attempt ID" -> taskStart.stageAttemptId) ~
("Task Info" -> taskInfoToJson(taskInfo))
@@ -123,7 +128,7 @@ private[spark] object JsonProtocol {
def taskGettingResultToJson(taskGettingResult: SparkListenerTaskGettingResult): JValue = {
val taskInfo = taskGettingResult.taskInfo
- ("Event" -> Utils.getFormattedClassName(taskGettingResult)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.taskGettingResult) ~
("Task Info" -> taskInfoToJson(taskInfo))
}
@@ -132,7 +137,7 @@ private[spark] object JsonProtocol {
val taskInfo = taskEnd.taskInfo
val taskMetrics = taskEnd.taskMetrics
val taskMetricsJson = if (taskMetrics != null) taskMetricsToJson(taskMetrics) else JNothing
- ("Event" -> Utils.getFormattedClassName(taskEnd)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.taskEnd) ~
("Stage ID" -> taskEnd.stageId) ~
("Stage Attempt ID" -> taskEnd.stageAttemptId) ~
("Task Type" -> taskEnd.taskType) ~
@@ -143,7 +148,7 @@ private[spark] object JsonProtocol {
def jobStartToJson(jobStart: SparkListenerJobStart): JValue = {
val properties = propertiesToJson(jobStart.properties)
- ("Event" -> Utils.getFormattedClassName(jobStart)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.jobStart) ~
("Job ID" -> jobStart.jobId) ~
("Submission Time" -> jobStart.time) ~
("Stage Infos" -> jobStart.stageInfos.map(stageInfoToJson)) ~ // Added in Spark 1.2.0
@@ -153,7 +158,7 @@ private[spark] object JsonProtocol {
def jobEndToJson(jobEnd: SparkListenerJobEnd): JValue = {
val jobResult = jobResultToJson(jobEnd.jobResult)
- ("Event" -> Utils.getFormattedClassName(jobEnd)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.jobEnd) ~
("Job ID" -> jobEnd.jobId) ~
("Completion Time" -> jobEnd.time) ~
("Job Result" -> jobResult)
@@ -165,7 +170,7 @@ private[spark] object JsonProtocol {
val sparkProperties = mapToJson(environmentDetails("Spark Properties").toMap)
val systemProperties = mapToJson(environmentDetails("System Properties").toMap)
val classpathEntries = mapToJson(environmentDetails("Classpath Entries").toMap)
- ("Event" -> Utils.getFormattedClassName(environmentUpdate)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.environmentUpdate) ~
("JVM Information" -> jvmInformation) ~
("Spark Properties" -> sparkProperties) ~
("System Properties" -> systemProperties) ~
@@ -174,26 +179,28 @@ private[spark] object JsonProtocol {
def blockManagerAddedToJson(blockManagerAdded: SparkListenerBlockManagerAdded): JValue = {
val blockManagerId = blockManagerIdToJson(blockManagerAdded.blockManagerId)
- ("Event" -> Utils.getFormattedClassName(blockManagerAdded)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.blockManagerAdded) ~
("Block Manager ID" -> blockManagerId) ~
("Maximum Memory" -> blockManagerAdded.maxMem) ~
- ("Timestamp" -> blockManagerAdded.time)
+ ("Timestamp" -> blockManagerAdded.time) ~
+ ("Maximum Onheap Memory" -> blockManagerAdded.maxOnHeapMem) ~
+ ("Maximum Offheap Memory" -> blockManagerAdded.maxOffHeapMem)
}
def blockManagerRemovedToJson(blockManagerRemoved: SparkListenerBlockManagerRemoved): JValue = {
val blockManagerId = blockManagerIdToJson(blockManagerRemoved.blockManagerId)
- ("Event" -> Utils.getFormattedClassName(blockManagerRemoved)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.blockManagerRemoved) ~
("Block Manager ID" -> blockManagerId) ~
("Timestamp" -> blockManagerRemoved.time)
}
def unpersistRDDToJson(unpersistRDD: SparkListenerUnpersistRDD): JValue = {
- ("Event" -> Utils.getFormattedClassName(unpersistRDD)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.unpersistRDD) ~
("RDD ID" -> unpersistRDD.rddId)
}
def applicationStartToJson(applicationStart: SparkListenerApplicationStart): JValue = {
- ("Event" -> Utils.getFormattedClassName(applicationStart)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.applicationStart) ~
("App Name" -> applicationStart.appName) ~
("App ID" -> applicationStart.appId.map(JString(_)).getOrElse(JNothing)) ~
("Timestamp" -> applicationStart.time) ~
@@ -203,39 +210,39 @@ private[spark] object JsonProtocol {
}
def applicationEndToJson(applicationEnd: SparkListenerApplicationEnd): JValue = {
- ("Event" -> Utils.getFormattedClassName(applicationEnd)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.applicationEnd) ~
("Timestamp" -> applicationEnd.time)
}
def executorAddedToJson(executorAdded: SparkListenerExecutorAdded): JValue = {
- ("Event" -> Utils.getFormattedClassName(executorAdded)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.executorAdded) ~
("Timestamp" -> executorAdded.time) ~
("Executor ID" -> executorAdded.executorId) ~
("Executor Info" -> executorInfoToJson(executorAdded.executorInfo))
}
def executorRemovedToJson(executorRemoved: SparkListenerExecutorRemoved): JValue = {
- ("Event" -> Utils.getFormattedClassName(executorRemoved)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.executorRemoved) ~
("Timestamp" -> executorRemoved.time) ~
("Executor ID" -> executorRemoved.executorId) ~
("Removed Reason" -> executorRemoved.reason)
}
def logStartToJson(logStart: SparkListenerLogStart): JValue = {
- ("Event" -> Utils.getFormattedClassName(logStart)) ~
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.logStart) ~
("Spark Version" -> SPARK_VERSION)
}
def executorMetricsUpdateToJson(metricsUpdate: SparkListenerExecutorMetricsUpdate): JValue = {
val execId = metricsUpdate.execId
- val taskMetrics = metricsUpdate.taskMetrics
- ("Event" -> Utils.getFormattedClassName(metricsUpdate)) ~
+ val accumUpdates = metricsUpdate.accumUpdates
+ ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.metricsUpdate) ~
("Executor ID" -> execId) ~
- ("Metrics Updated" -> taskMetrics.map { case (taskId, stageId, stageAttemptId, metrics) =>
+ ("Metrics Updated" -> accumUpdates.map { case (taskId, stageId, stageAttemptId, updates) =>
("Task ID" -> taskId) ~
("Stage ID" -> stageId) ~
("Stage Attempt ID" -> stageAttemptId) ~
- ("Task Metrics" -> taskMetricsToJson(metrics))
+ ("Accumulator Updates" -> JArray(updates.map(accumulableInfoToJson).toList))
})
}
@@ -259,8 +266,7 @@ private[spark] object JsonProtocol {
("Submission Time" -> submissionTime) ~
("Completion Time" -> completionTime) ~
("Failure Reason" -> failureReason) ~
- ("Accumulables" -> JArray(
- stageInfo.accumulables.values.map(accumulableInfoToJson).toList))
+ ("Accumulables" -> accumulablesToJson(stageInfo.accumulables.values))
}
def taskInfoToJson(taskInfo: TaskInfo): JValue = {
@@ -275,36 +281,85 @@ private[spark] object JsonProtocol {
("Getting Result Time" -> taskInfo.gettingResultTime) ~
("Finish Time" -> taskInfo.finishTime) ~
("Failed" -> taskInfo.failed) ~
- ("Accumulables" -> JArray(taskInfo.accumulables.map(accumulableInfoToJson).toList))
+ ("Killed" -> taskInfo.killed) ~
+ ("Accumulables" -> accumulablesToJson(taskInfo.accumulables))
+ }
+
+ private lazy val accumulableBlacklist = Set("internal.metrics.updatedBlockStatuses")
+
+ def accumulablesToJson(accumulables: Traversable[AccumulableInfo]): JArray = {
+ JArray(accumulables
+ .filterNot(_.name.exists(accumulableBlacklist.contains))
+ .toList.map(accumulableInfoToJson))
}
def accumulableInfoToJson(accumulableInfo: AccumulableInfo): JValue = {
+ val name = accumulableInfo.name
("ID" -> accumulableInfo.id) ~
- ("Name" -> accumulableInfo.name) ~
- ("Update" -> accumulableInfo.update.map(new JString(_)).getOrElse(JNothing)) ~
- ("Value" -> accumulableInfo.value) ~
- ("Internal" -> accumulableInfo.internal)
+ ("Name" -> name) ~
+ ("Update" -> accumulableInfo.update.map { v => accumValueToJson(name, v) }) ~
+ ("Value" -> accumulableInfo.value.map { v => accumValueToJson(name, v) }) ~
+ ("Internal" -> accumulableInfo.internal) ~
+ ("Count Failed Values" -> accumulableInfo.countFailedValues) ~
+ ("Metadata" -> accumulableInfo.metadata)
+ }
+
+ /**
+ * Serialize the value of an accumulator to JSON.
+ *
+ * For accumulators representing internal task metrics, this looks up the relevant
+ * [[AccumulatorParam]] to serialize the value accordingly. For all other accumulators,
+ * this will simply serialize the value as a string.
+ *
+ * The behavior here must match that of [[accumValueFromJson]]. Exposed for testing.
+ */
+ private[util] def accumValueToJson(name: Option[String], value: Any): JValue = {
+ if (name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX))) {
+ value match {
+ case v: Int => JInt(v)
+ case v: Long => JInt(v)
+ // We only have 3 kind of internal accumulator types, so if it's not int or long, it must be
+ // the blocks accumulator, whose type is `java.util.List[(BlockId, BlockStatus)]`
+ case v =>
+ JArray(v.asInstanceOf[java.util.List[(BlockId, BlockStatus)]].asScala.toList.map {
+ case (id, status) =>
+ ("Block ID" -> id.toString) ~
+ ("Status" -> blockStatusToJson(status))
+ })
+ }
+ } else {
+ // For all external accumulators, just use strings
+ JString(value.toString)
+ }
}
def taskMetricsToJson(taskMetrics: TaskMetrics): JValue = {
- val shuffleReadMetrics =
- taskMetrics.shuffleReadMetrics.map(shuffleReadMetricsToJson).getOrElse(JNothing)
- val shuffleWriteMetrics =
- taskMetrics.shuffleWriteMetrics.map(shuffleWriteMetricsToJson).getOrElse(JNothing)
- val inputMetrics =
- taskMetrics.inputMetrics.map(inputMetricsToJson).getOrElse(JNothing)
- val outputMetrics =
- taskMetrics.outputMetrics.map(outputMetricsToJson).getOrElse(JNothing)
+ val shuffleReadMetrics: JValue =
+ ("Remote Blocks Fetched" -> taskMetrics.shuffleReadMetrics.remoteBlocksFetched) ~
+ ("Local Blocks Fetched" -> taskMetrics.shuffleReadMetrics.localBlocksFetched) ~
+ ("Fetch Wait Time" -> taskMetrics.shuffleReadMetrics.fetchWaitTime) ~
+ ("Remote Bytes Read" -> taskMetrics.shuffleReadMetrics.remoteBytesRead) ~
+ ("Local Bytes Read" -> taskMetrics.shuffleReadMetrics.localBytesRead) ~
+ ("Total Records Read" -> taskMetrics.shuffleReadMetrics.recordsRead)
+ val shuffleWriteMetrics: JValue =
+ ("Shuffle Bytes Written" -> taskMetrics.shuffleWriteMetrics.bytesWritten) ~
+ ("Shuffle Write Time" -> taskMetrics.shuffleWriteMetrics.writeTime) ~
+ ("Shuffle Records Written" -> taskMetrics.shuffleWriteMetrics.recordsWritten)
+ val inputMetrics: JValue =
+ ("Bytes Read" -> taskMetrics.inputMetrics.bytesRead) ~
+ ("Records Read" -> taskMetrics.inputMetrics.recordsRead)
+ val outputMetrics: JValue =
+ ("Bytes Written" -> taskMetrics.outputMetrics.bytesWritten) ~
+ ("Records Written" -> taskMetrics.outputMetrics.recordsWritten)
val updatedBlocks =
- taskMetrics.updatedBlocks.map { blocks =>
- JArray(blocks.toList.map { case (id, status) =>
- ("Block ID" -> id.toString) ~
+ JArray(taskMetrics.updatedBlockStatuses.toList.map { case (id, status) =>
+ ("Block ID" -> id.toString) ~
("Status" -> blockStatusToJson(status))
- })
- }.getOrElse(JNothing)
- ("Host Name" -> taskMetrics.hostname) ~
+ })
("Executor Deserialize Time" -> taskMetrics.executorDeserializeTime) ~
+ ("Executor Deserialize CPU Time" -> taskMetrics.executorDeserializeCpuTime) ~
("Executor Run Time" -> taskMetrics.executorRunTime) ~
+ ("Executor CPU Time" -> taskMetrics.executorCpuTime) ~
("Result Size" -> taskMetrics.resultSize) ~
("JVM GC Time" -> taskMetrics.jvmGCTime) ~
("Result Serialization Time" -> taskMetrics.resultSerializationTime) ~
@@ -317,33 +372,6 @@ private[spark] object JsonProtocol {
("Updated Blocks" -> updatedBlocks)
}
- def shuffleReadMetricsToJson(shuffleReadMetrics: ShuffleReadMetrics): JValue = {
- ("Remote Blocks Fetched" -> shuffleReadMetrics.remoteBlocksFetched) ~
- ("Local Blocks Fetched" -> shuffleReadMetrics.localBlocksFetched) ~
- ("Fetch Wait Time" -> shuffleReadMetrics.fetchWaitTime) ~
- ("Remote Bytes Read" -> shuffleReadMetrics.remoteBytesRead) ~
- ("Local Bytes Read" -> shuffleReadMetrics.localBytesRead) ~
- ("Total Records Read" -> shuffleReadMetrics.recordsRead)
- }
-
- def shuffleWriteMetricsToJson(shuffleWriteMetrics: ShuffleWriteMetrics): JValue = {
- ("Shuffle Bytes Written" -> shuffleWriteMetrics.shuffleBytesWritten) ~
- ("Shuffle Write Time" -> shuffleWriteMetrics.shuffleWriteTime) ~
- ("Shuffle Records Written" -> shuffleWriteMetrics.shuffleRecordsWritten)
- }
-
- def inputMetricsToJson(inputMetrics: InputMetrics): JValue = {
- ("Data Read Method" -> inputMetrics.readMethod.toString) ~
- ("Bytes Read" -> inputMetrics.bytesRead) ~
- ("Records Read" -> inputMetrics.recordsRead)
- }
-
- def outputMetricsToJson(outputMetrics: OutputMetrics): JValue = {
- ("Data Write Method" -> outputMetrics.writeMethod.toString) ~
- ("Bytes Written" -> outputMetrics.bytesWritten) ~
- ("Records Written" -> outputMetrics.recordsWritten)
- }
-
def taskEndReasonToJson(taskEndReason: TaskEndReason): JValue = {
val reason = Utils.getFormattedClassName(taskEndReason)
val json: JObject = taskEndReason match {
@@ -357,12 +385,12 @@ private[spark] object JsonProtocol {
("Message" -> fetchFailed.message)
case exceptionFailure: ExceptionFailure =>
val stackTrace = stackTraceToJson(exceptionFailure.stackTrace)
- val metrics = exceptionFailure.metrics.map(taskMetricsToJson).getOrElse(JNothing)
+ val accumUpdates = accumulablesToJson(exceptionFailure.accumUpdates)
("Class Name" -> exceptionFailure.className) ~
("Description" -> exceptionFailure.description) ~
("Stack Trace" -> stackTrace) ~
("Full Stack Trace" -> exceptionFailure.fullStackTrace) ~
- ("Metrics" -> metrics)
+ ("Accumulator Updates" -> accumUpdates)
case taskCommitDenied: TaskCommitDenied =>
("Job ID" -> taskCommitDenied.jobID) ~
("Partition ID" -> taskCommitDenied.partitionID) ~
@@ -371,6 +399,8 @@ private[spark] object JsonProtocol {
("Executor ID" -> executorId) ~
("Exit Caused By App" -> exitCausedByApp) ~
("Loss Reason" -> reason.map(_.toString))
+ case taskKilled: TaskKilled =>
+ ("Kill Reason" -> taskKilled.reason)
case _ => Utils.emptyJson
}
("Reason" -> reason) ~ json
@@ -398,19 +428,18 @@ private[spark] object JsonProtocol {
("RDD ID" -> rddInfo.id) ~
("Name" -> rddInfo.name) ~
("Scope" -> rddInfo.scope.map(_.toJson)) ~
+ ("Callsite" -> rddInfo.callSite) ~
("Parent IDs" -> parentIds) ~
("Storage Level" -> storageLevel) ~
("Number of Partitions" -> rddInfo.numPartitions) ~
("Number of Cached Partitions" -> rddInfo.numCachedPartitions) ~
("Memory Size" -> rddInfo.memSize) ~
- ("ExternalBlockStore Size" -> rddInfo.externalBlockStoreSize) ~
("Disk Size" -> rddInfo.diskSize)
}
def storageLevelToJson(storageLevel: StorageLevel): JValue = {
("Use Disk" -> storageLevel.useDisk) ~
("Use Memory" -> storageLevel.useMemory) ~
- ("Use ExternalBlockStore" -> storageLevel.useOffHeap) ~
("Deserialized" -> storageLevel.deserialized) ~
("Replication" -> storageLevel.replication)
}
@@ -419,7 +448,6 @@ private[spark] object JsonProtocol {
val storageLevel = storageLevelToJson(blockStatus.storageLevel)
("Storage Level" -> storageLevel) ~
("Memory Size" -> blockStatus.memSize) ~
- ("ExternalBlockStore Size" -> blockStatus.externalBlockStoreSize) ~
("Disk Size" -> blockStatus.diskSize)
}
@@ -468,7 +496,7 @@ private[spark] object JsonProtocol {
* JSON deserialization methods for SparkListenerEvents |
* ---------------------------------------------------- */
- def sparkEventFromJson(json: JValue): SparkListenerEvent = {
+ private object SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES {
val stageSubmitted = Utils.getFormattedClassName(SparkListenerStageSubmitted)
val stageCompleted = Utils.getFormattedClassName(SparkListenerStageCompleted)
val taskStart = Utils.getFormattedClassName(SparkListenerTaskStart)
@@ -486,6 +514,10 @@ private[spark] object JsonProtocol {
val executorRemoved = Utils.getFormattedClassName(SparkListenerExecutorRemoved)
val logStart = Utils.getFormattedClassName(SparkListenerLogStart)
val metricsUpdate = Utils.getFormattedClassName(SparkListenerExecutorMetricsUpdate)
+ }
+
+ def sparkEventFromJson(json: JValue): SparkListenerEvent = {
+ import SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES._
(json \ "Event").extract[String] match {
case `stageSubmitted` => stageSubmittedFromJson(json)
@@ -505,6 +537,8 @@ private[spark] object JsonProtocol {
case `executorRemoved` => executorRemovedFromJson(json)
case `logStart` => logStartFromJson(json)
case `metricsUpdate` => executorMetricsUpdateFromJson(json)
+ case other => mapper.readValue(compact(render(json)), Utils.classForName(other))
+ .asInstanceOf[SparkListenerEvent]
}
}
@@ -521,7 +555,8 @@ private[spark] object JsonProtocol {
def taskStartFromJson(json: JValue): SparkListenerTaskStart = {
val stageId = (json \ "Stage ID").extract[Int]
- val stageAttemptId = (json \ "Stage Attempt ID").extractOpt[Int].getOrElse(0)
+ val stageAttemptId =
+ Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0)
val taskInfo = taskInfoFromJson(json \ "Task Info")
SparkListenerTaskStart(stageId, stageAttemptId, taskInfo)
}
@@ -533,7 +568,8 @@ private[spark] object JsonProtocol {
def taskEndFromJson(json: JValue): SparkListenerTaskEnd = {
val stageId = (json \ "Stage ID").extract[Int]
- val stageAttemptId = (json \ "Stage Attempt ID").extractOpt[Int].getOrElse(0)
+ val stageAttemptId =
+ Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0)
val taskType = (json \ "Task Type").extract[String]
val taskEndReason = taskEndReasonFromJson(json \ "Task End Reason")
val taskInfo = taskInfoFromJson(json \ "Task Info")
@@ -550,7 +586,9 @@ private[spark] object JsonProtocol {
// The "Stage Infos" field was added in Spark 1.2.0
val stageInfos = Utils.jsonOption(json \ "Stage Infos")
.map(_.extract[Seq[JValue]].map(stageInfoFromJson)).getOrElse {
- stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown"))
+ stageIds.map { id =>
+ new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown")
+ }
}
SparkListenerJobStart(jobId, submissionTime, stageInfos, properties)
}
@@ -576,7 +614,9 @@ private[spark] object JsonProtocol {
val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID")
val maxMem = (json \ "Maximum Memory").extract[Long]
val time = Utils.jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L)
- SparkListenerBlockManagerAdded(time, blockManagerId, maxMem)
+ val maxOnHeapMem = Utils.jsonOption(json \ "Maximum Onheap Memory").map(_.extract[Long])
+ val maxOffHeapMem = Utils.jsonOption(json \ "Maximum Offheap Memory").map(_.extract[Long])
+ SparkListenerBlockManagerAdded(time, blockManagerId, maxMem, maxOnHeapMem, maxOffHeapMem)
}
def blockManagerRemovedFromJson(json: JValue): SparkListenerBlockManagerRemoved = {
@@ -624,14 +664,15 @@ private[spark] object JsonProtocol {
def executorMetricsUpdateFromJson(json: JValue): SparkListenerExecutorMetricsUpdate = {
val execInfo = (json \ "Executor ID").extract[String]
- val taskMetrics = (json \ "Metrics Updated").extract[List[JValue]].map { json =>
+ val accumUpdates = (json \ "Metrics Updated").extract[List[JValue]].map { json =>
val taskId = (json \ "Task ID").extract[Long]
val stageId = (json \ "Stage ID").extract[Int]
val stageAttemptId = (json \ "Stage Attempt ID").extract[Int]
- val metrics = taskMetricsFromJson(json \ "Task Metrics")
- (taskId, stageId, stageAttemptId, metrics)
+ val updates =
+ (json \ "Accumulator Updates").extract[List[JValue]].map(accumulableInfoFromJson)
+ (taskId, stageId, stageAttemptId, updates)
}
- SparkListenerExecutorMetricsUpdate(execInfo, taskMetrics)
+ SparkListenerExecutorMetricsUpdate(execInfo, accumUpdates)
}
/** --------------------------------------------------------------------- *
@@ -640,20 +681,22 @@ private[spark] object JsonProtocol {
def stageInfoFromJson(json: JValue): StageInfo = {
val stageId = (json \ "Stage ID").extract[Int]
- val attemptId = (json \ "Stage Attempt ID").extractOpt[Int].getOrElse(0)
+ val attemptId = Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0)
val stageName = (json \ "Stage Name").extract[String]
val numTasks = (json \ "Number of Tasks").extract[Int]
val rddInfos = (json \ "RDD Info").extract[List[JValue]].map(rddInfoFromJson)
val parentIds = Utils.jsonOption(json \ "Parent IDs")
.map { l => l.extract[List[JValue]].map(_.extract[Int]) }
.getOrElse(Seq.empty)
- val details = (json \ "Details").extractOpt[String].getOrElse("")
+ val details = Utils.jsonOption(json \ "Details").map(_.extract[String]).getOrElse("")
val submissionTime = Utils.jsonOption(json \ "Submission Time").map(_.extract[Long])
val completionTime = Utils.jsonOption(json \ "Completion Time").map(_.extract[Long])
val failureReason = Utils.jsonOption(json \ "Failure Reason").map(_.extract[String])
- val accumulatedValues = (json \ "Accumulables").extractOpt[List[JValue]] match {
- case Some(values) => values.map(accumulableInfoFromJson(_))
- case None => Seq[AccumulableInfo]()
+ val accumulatedValues = {
+ Utils.jsonOption(json \ "Accumulables").map(_.extract[List[JValue]]) match {
+ case Some(values) => values.map(accumulableInfoFromJson)
+ case None => Seq[AccumulableInfo]()
+ }
}
val stageInfo = new StageInfo(
@@ -670,17 +713,18 @@ private[spark] object JsonProtocol {
def taskInfoFromJson(json: JValue): TaskInfo = {
val taskId = (json \ "Task ID").extract[Long]
val index = (json \ "Index").extract[Int]
- val attempt = (json \ "Attempt").extractOpt[Int].getOrElse(1)
+ val attempt = Utils.jsonOption(json \ "Attempt").map(_.extract[Int]).getOrElse(1)
val launchTime = (json \ "Launch Time").extract[Long]
- val executorId = (json \ "Executor ID").extract[String]
- val host = (json \ "Host").extract[String]
+ val executorId = (json \ "Executor ID").extract[String].intern()
+ val host = (json \ "Host").extract[String].intern()
val taskLocality = TaskLocality.withName((json \ "Locality").extract[String])
- val speculative = (json \ "Speculative").extractOpt[Boolean].getOrElse(false)
+ val speculative = Utils.jsonOption(json \ "Speculative").exists(_.extract[Boolean])
val gettingResultTime = (json \ "Getting Result Time").extract[Long]
val finishTime = (json \ "Finish Time").extract[Long]
val failed = (json \ "Failed").extract[Boolean]
- val accumulables = (json \ "Accumulables").extractOpt[Seq[JValue]] match {
- case Some(values) => values.map(accumulableInfoFromJson(_))
+ val killed = Utils.jsonOption(json \ "Killed").exists(_.extract[Boolean])
+ val accumulables = Utils.jsonOption(json \ "Accumulables").map(_.extract[Seq[JValue]]) match {
+ case Some(values) => values.map(accumulableInfoFromJson)
case None => Seq[AccumulableInfo]()
}
@@ -689,88 +733,124 @@ private[spark] object JsonProtocol {
taskInfo.gettingResultTime = gettingResultTime
taskInfo.finishTime = finishTime
taskInfo.failed = failed
- accumulables.foreach { taskInfo.accumulables += _ }
+ taskInfo.killed = killed
+ taskInfo.setAccumulables(accumulables)
taskInfo
}
def accumulableInfoFromJson(json: JValue): AccumulableInfo = {
val id = (json \ "ID").extract[Long]
- val name = (json \ "Name").extract[String]
- val update = Utils.jsonOption(json \ "Update").map(_.extract[String])
- val value = (json \ "Value").extract[String]
- val internal = (json \ "Internal").extractOpt[Boolean].getOrElse(false)
- AccumulableInfo(id, name, update, value, internal)
+ val name = Utils.jsonOption(json \ "Name").map(_.extract[String])
+ val update = Utils.jsonOption(json \ "Update").map { v => accumValueFromJson(name, v) }
+ val value = Utils.jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) }
+ val internal = Utils.jsonOption(json \ "Internal").exists(_.extract[Boolean])
+ val countFailedValues =
+ Utils.jsonOption(json \ "Count Failed Values").exists(_.extract[Boolean])
+ val metadata = Utils.jsonOption(json \ "Metadata").map(_.extract[String])
+ new AccumulableInfo(id, name, update, value, internal, countFailedValues, metadata)
+ }
+
+ /**
+ * Deserialize the value of an accumulator from JSON.
+ *
+ * For accumulators representing internal task metrics, this looks up the relevant
+ * [[AccumulatorParam]] to deserialize the value accordingly. For all other
+ * accumulators, this will simply deserialize the value as a string.
+ *
+ * The behavior here must match that of [[accumValueToJson]]. Exposed for testing.
+ */
+ private[util] def accumValueFromJson(name: Option[String], value: JValue): Any = {
+ if (name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX))) {
+ value match {
+ case JInt(v) => v.toLong
+ case JArray(v) =>
+ v.map { blockJson =>
+ val id = BlockId((blockJson \ "Block ID").extract[String])
+ val status = blockStatusFromJson(blockJson \ "Status")
+ (id, status)
+ }.asJava
+ case _ => throw new IllegalArgumentException(s"unexpected json value $value for " +
+ "accumulator " + name.get)
+ }
+ } else {
+ value.extract[String]
+ }
}
def taskMetricsFromJson(json: JValue): TaskMetrics = {
+ val metrics = TaskMetrics.empty
if (json == JNothing) {
- return TaskMetrics.empty
+ return metrics
}
- val metrics = new TaskMetrics
- metrics.setHostname((json \ "Host Name").extract[String])
metrics.setExecutorDeserializeTime((json \ "Executor Deserialize Time").extract[Long])
+ metrics.setExecutorDeserializeCpuTime((json \ "Executor Deserialize CPU Time") match {
+ case JNothing => 0
+ case x => x.extract[Long]
+ })
metrics.setExecutorRunTime((json \ "Executor Run Time").extract[Long])
+ metrics.setExecutorCpuTime((json \ "Executor CPU Time") match {
+ case JNothing => 0
+ case x => x.extract[Long]
+ })
metrics.setResultSize((json \ "Result Size").extract[Long])
metrics.setJvmGCTime((json \ "JVM GC Time").extract[Long])
metrics.setResultSerializationTime((json \ "Result Serialization Time").extract[Long])
metrics.incMemoryBytesSpilled((json \ "Memory Bytes Spilled").extract[Long])
metrics.incDiskBytesSpilled((json \ "Disk Bytes Spilled").extract[Long])
- metrics.setShuffleReadMetrics(
- Utils.jsonOption(json \ "Shuffle Read Metrics").map(shuffleReadMetricsFromJson))
- metrics.shuffleWriteMetrics =
- Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson)
- metrics.setInputMetrics(
- Utils.jsonOption(json \ "Input Metrics").map(inputMetricsFromJson))
- metrics.outputMetrics =
- Utils.jsonOption(json \ "Output Metrics").map(outputMetricsFromJson)
- metrics.updatedBlocks =
- Utils.jsonOption(json \ "Updated Blocks").map { value =>
- value.extract[List[JValue]].map { block =>
- val id = BlockId((block \ "Block ID").extract[String])
- val status = blockStatusFromJson(block \ "Status")
- (id, status)
- }
- }
- metrics
- }
- def shuffleReadMetricsFromJson(json: JValue): ShuffleReadMetrics = {
- val metrics = new ShuffleReadMetrics
- metrics.incRemoteBlocksFetched((json \ "Remote Blocks Fetched").extract[Int])
- metrics.incLocalBlocksFetched((json \ "Local Blocks Fetched").extract[Int])
- metrics.incFetchWaitTime((json \ "Fetch Wait Time").extract[Long])
- metrics.incRemoteBytesRead((json \ "Remote Bytes Read").extract[Long])
- metrics.incLocalBytesRead((json \ "Local Bytes Read").extractOpt[Long].getOrElse(0))
- metrics.incRecordsRead((json \ "Total Records Read").extractOpt[Long].getOrElse(0))
- metrics
- }
+ // Shuffle read metrics
+ Utils.jsonOption(json \ "Shuffle Read Metrics").foreach { readJson =>
+ val readMetrics = metrics.createTempShuffleReadMetrics()
+ readMetrics.incRemoteBlocksFetched((readJson \ "Remote Blocks Fetched").extract[Int])
+ readMetrics.incLocalBlocksFetched((readJson \ "Local Blocks Fetched").extract[Int])
+ readMetrics.incRemoteBytesRead((readJson \ "Remote Bytes Read").extract[Long])
+ readMetrics.incLocalBytesRead(
+ Utils.jsonOption(readJson \ "Local Bytes Read").map(_.extract[Long]).getOrElse(0L))
+ readMetrics.incFetchWaitTime((readJson \ "Fetch Wait Time").extract[Long])
+ readMetrics.incRecordsRead(
+ Utils.jsonOption(readJson \ "Total Records Read").map(_.extract[Long]).getOrElse(0L))
+ metrics.mergeShuffleReadMetrics()
+ }
- def shuffleWriteMetricsFromJson(json: JValue): ShuffleWriteMetrics = {
- val metrics = new ShuffleWriteMetrics
- metrics.incShuffleBytesWritten((json \ "Shuffle Bytes Written").extract[Long])
- metrics.incShuffleWriteTime((json \ "Shuffle Write Time").extract[Long])
- metrics.setShuffleRecordsWritten((json \ "Shuffle Records Written")
- .extractOpt[Long].getOrElse(0))
- metrics
- }
+ // Shuffle write metrics
+ // TODO: Drop the redundant "Shuffle" since it's inconsistent with related classes.
+ Utils.jsonOption(json \ "Shuffle Write Metrics").foreach { writeJson =>
+ val writeMetrics = metrics.shuffleWriteMetrics
+ writeMetrics.incBytesWritten((writeJson \ "Shuffle Bytes Written").extract[Long])
+ writeMetrics.incRecordsWritten(
+ Utils.jsonOption(writeJson \ "Shuffle Records Written").map(_.extract[Long]).getOrElse(0L))
+ writeMetrics.incWriteTime((writeJson \ "Shuffle Write Time").extract[Long])
+ }
- def inputMetricsFromJson(json: JValue): InputMetrics = {
- val metrics = new InputMetrics(
- DataReadMethod.withName((json \ "Data Read Method").extract[String]))
- metrics.incBytesRead((json \ "Bytes Read").extract[Long])
- metrics.incRecordsRead((json \ "Records Read").extractOpt[Long].getOrElse(0))
- metrics
- }
+ // Output metrics
+ Utils.jsonOption(json \ "Output Metrics").foreach { outJson =>
+ val outputMetrics = metrics.outputMetrics
+ outputMetrics.setBytesWritten((outJson \ "Bytes Written").extract[Long])
+ outputMetrics.setRecordsWritten(
+ Utils.jsonOption(outJson \ "Records Written").map(_.extract[Long]).getOrElse(0L))
+ }
+
+ // Input metrics
+ Utils.jsonOption(json \ "Input Metrics").foreach { inJson =>
+ val inputMetrics = metrics.inputMetrics
+ inputMetrics.incBytesRead((inJson \ "Bytes Read").extract[Long])
+ inputMetrics.incRecordsRead(
+ Utils.jsonOption(inJson \ "Records Read").map(_.extract[Long]).getOrElse(0L))
+ }
+
+ // Updated blocks
+ Utils.jsonOption(json \ "Updated Blocks").foreach { blocksJson =>
+ metrics.setUpdatedBlockStatuses(blocksJson.extract[List[JValue]].map { blockJson =>
+ val id = BlockId((blockJson \ "Block ID").extract[String])
+ val status = blockStatusFromJson(blockJson \ "Status")
+ (id, status)
+ })
+ }
- def outputMetricsFromJson(json: JValue): OutputMetrics = {
- val metrics = new OutputMetrics(
- DataWriteMethod.withName((json \ "Data Write Method").extract[String]))
- metrics.setBytesWritten((json \ "Bytes Written").extract[Long])
- metrics.setRecordsWritten((json \ "Records Written").extractOpt[Long].getOrElse(0))
metrics
}
- def taskEndReasonFromJson(json: JValue): TaskEndReason = {
+ private object TASK_END_REASON_FORMATTED_CLASS_NAMES {
val success = Utils.getFormattedClassName(Success)
val resubmitted = Utils.getFormattedClassName(Resubmitted)
val fetchFailed = Utils.getFormattedClassName(FetchFailed)
@@ -780,6 +860,10 @@ private[spark] object JsonProtocol {
val taskCommitDenied = Utils.getFormattedClassName(TaskCommitDenied)
val executorLostFailure = Utils.getFormattedClassName(ExecutorLostFailure)
val unknownReason = Utils.getFormattedClassName(UnknownReason)
+ }
+
+ def taskEndReasonFromJson(json: JValue): TaskEndReason = {
+ import TASK_END_REASON_FORMATTED_CLASS_NAMES._
(json \ "Reason").extract[String] match {
case `success` => Success
@@ -796,12 +880,20 @@ private[spark] object JsonProtocol {
val className = (json \ "Class Name").extract[String]
val description = (json \ "Description").extract[String]
val stackTrace = stackTraceFromJson(json \ "Stack Trace")
- val fullStackTrace = Utils.jsonOption(json \ "Full Stack Trace").
- map(_.extract[String]).orNull
- val metrics = Utils.jsonOption(json \ "Metrics").map(taskMetricsFromJson)
- ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics, None)
+ val fullStackTrace =
+ Utils.jsonOption(json \ "Full Stack Trace").map(_.extract[String]).orNull
+ // Fallback on getting accumulator updates from TaskMetrics, which was logged in Spark 1.x
+ val accumUpdates = Utils.jsonOption(json \ "Accumulator Updates")
+ .map(_.extract[List[JValue]].map(accumulableInfoFromJson))
+ .getOrElse(taskMetricsFromJson(json \ "Metrics").accumulators().map(acc => {
+ acc.toInfo(Some(acc.value), None)
+ }))
+ ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates)
case `taskResultLost` => TaskResultLost
- case `taskKilled` => TaskKilled
+ case `taskKilled` =>
+ val killReason = Utils.jsonOption(json \ "Kill Reason")
+ .map(_.extract[String]).getOrElse("unknown reason")
+ TaskKilled(killReason)
case `taskCommitDenied` =>
// Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON
// de/serialization logic was not added until 1.5.1. To provide backward compatibility
@@ -827,15 +919,19 @@ private[spark] object JsonProtocol {
if (json == JNothing) {
return null
}
- val executorId = (json \ "Executor ID").extract[String]
- val host = (json \ "Host").extract[String]
+ val executorId = (json \ "Executor ID").extract[String].intern()
+ val host = (json \ "Host").extract[String].intern()
val port = (json \ "Port").extract[Int]
BlockManagerId(executorId, host, port)
}
- def jobResultFromJson(json: JValue): JobResult = {
+ private object JOB_RESULT_FORMATTED_CLASS_NAMES {
val jobSucceeded = Utils.getFormattedClassName(JobSucceeded)
val jobFailed = Utils.getFormattedClassName(JobFailed)
+ }
+
+ def jobResultFromJson(json: JValue): JobResult = {
+ import JOB_RESULT_FORMATTED_CLASS_NAMES._
(json \ "Result").extract[String] match {
case `jobSucceeded` => JobSucceeded
@@ -851,6 +947,7 @@ private[spark] object JsonProtocol {
val scope = Utils.jsonOption(json \ "Scope")
.map(_.extract[String])
.map(RDDOperationScope.fromJson)
+ val callsite = Utils.jsonOption(json \ "Callsite").map(_.extract[String]).getOrElse("")
val parentIds = Utils.jsonOption(json \ "Parent IDs")
.map { l => l.extract[List[JValue]].map(_.extract[Int]) }
.getOrElse(Seq.empty)
@@ -858,15 +955,11 @@ private[spark] object JsonProtocol {
val numPartitions = (json \ "Number of Partitions").extract[Int]
val numCachedPartitions = (json \ "Number of Cached Partitions").extract[Int]
val memSize = (json \ "Memory Size").extract[Long]
- // fallback to tachyon for backward compatibility
- val externalBlockStoreSize = (json \ "ExternalBlockStore Size").toSome
- .getOrElse(json \ "Tachyon Size").extract[Long]
val diskSize = (json \ "Disk Size").extract[Long]
- val rddInfo = new RDDInfo(rddId, name, numPartitions, storageLevel, parentIds, scope)
+ val rddInfo = new RDDInfo(rddId, name, numPartitions, storageLevel, parentIds, callsite, scope)
rddInfo.numCachedPartitions = numCachedPartitions
rddInfo.memSize = memSize
- rddInfo.externalBlockStoreSize = externalBlockStoreSize
rddInfo.diskSize = diskSize
rddInfo
}
@@ -874,22 +967,16 @@ private[spark] object JsonProtocol {
def storageLevelFromJson(json: JValue): StorageLevel = {
val useDisk = (json \ "Use Disk").extract[Boolean]
val useMemory = (json \ "Use Memory").extract[Boolean]
- // fallback to tachyon for backward compatability
- val useExternalBlockStore = (json \ "Use ExternalBlockStore").toSome
- .getOrElse(json \ "Use Tachyon").extract[Boolean]
val deserialized = (json \ "Deserialized").extract[Boolean]
val replication = (json \ "Replication").extract[Int]
- StorageLevel(useDisk, useMemory, useExternalBlockStore, deserialized, replication)
+ StorageLevel(useDisk, useMemory, deserialized, replication)
}
def blockStatusFromJson(json: JValue): BlockStatus = {
val storageLevel = storageLevelFromJson(json \ "Storage Level")
val memorySize = (json \ "Memory Size").extract[Long]
val diskSize = (json \ "Disk Size").extract[Long]
- // fallback to tachyon for backward compatability
- val externalBlockStoreSize = (json \ "ExternalBlockStore Size").toSome
- .getOrElse(json \ "Tachyon Size").extract[Long]
- BlockStatus(storageLevel, memorySize, diskSize, externalBlockStoreSize)
+ BlockStatus(storageLevel, memorySize, diskSize)
}
def executorInfoFromJson(json: JValue): ExecutorInfo = {
diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala
index 13cb516b583e..fa5ad4e8d81e 100644
--- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
import scala.reflect.ClassTag
import scala.util.control.NonFatal
-import org.apache.spark.Logging
+import org.apache.spark.internal.Logging
/**
* An event bus which posts events to its listeners.
@@ -36,23 +36,31 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging {
/**
* Add a listener to listen events. This method is thread-safe and can be called in any thread.
*/
- final def addListener(listener: L) {
+ final def addListener(listener: L): Unit = {
listeners.add(listener)
}
+ /**
+ * Remove a listener and it won't receive any events. This method is thread-safe and can be called
+ * in any thread.
+ */
+ final def removeListener(listener: L): Unit = {
+ listeners.remove(listener)
+ }
+
/**
* Post the event to all registered listeners. The `postToAll` caller should guarantee calling
* `postToAll` in the same thread for all events.
*/
- final def postToAll(event: E): Unit = {
+ def postToAll(event: E): Unit = {
// JavaConverters can create a JIterableWrapper if we use asScala.
- // However, this method will be called frequently. To avoid the wrapper cost, here ewe use
+ // However, this method will be called frequently. To avoid the wrapper cost, here we use
// Java Iterator directly.
val iter = listeners.iterator
while (iter.hasNext) {
val listener = iter.next()
try {
- onPostEvent(listener, event)
+ doPostEvent(listener, event)
} catch {
case NonFatal(e) =>
logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e)
@@ -62,9 +70,9 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging {
/**
* Post an event to the specified listener. `onPostEvent` is guaranteed to be called in the same
- * thread.
+ * thread for all listeners.
*/
- def onPostEvent(listener: L, event: E): Unit
+ protected def doPostEvent(listener: L, event: E): Unit
private[spark] def findListenersByClass[T <: L : ClassTag](): Seq[T] = {
val c = implicitly[ClassTag[T]].runtimeClass
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
deleted file mode 100644
index a8bbad086849..000000000000
--- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
+++ /dev/null
@@ -1,110 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util
-
-import java.util.{Timer, TimerTask}
-
-import org.apache.spark.{Logging, SparkConf}
-
-/**
- * Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries)
- */
-private[spark] class MetadataCleaner(
- cleanerType: MetadataCleanerType.MetadataCleanerType,
- cleanupFunc: (Long) => Unit,
- conf: SparkConf)
- extends Logging
-{
- val name = cleanerType.toString
-
- private val delaySeconds = MetadataCleaner.getDelaySeconds(conf, cleanerType)
- private val periodSeconds = math.max(10, delaySeconds / 10)
- private val timer = new Timer(name + " cleanup timer", true)
-
-
- private val task = new TimerTask {
- override def run() {
- try {
- cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000))
- logInfo("Ran metadata cleaner for " + name)
- } catch {
- case e: Exception => logError("Error running cleanup task for " + name, e)
- }
- }
- }
-
- if (delaySeconds > 0) {
- logDebug(
- "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds " +
- "and period of " + periodSeconds + " secs")
- timer.schedule(task, delaySeconds * 1000, periodSeconds * 1000)
- }
-
- def cancel() {
- timer.cancel()
- }
-}
-
-private[spark] object MetadataCleanerType extends Enumeration {
-
- val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, BLOCK_MANAGER,
- SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value
-
- type MetadataCleanerType = Value
-
- def systemProperty(which: MetadataCleanerType.MetadataCleanerType): String = {
- "spark.cleaner.ttl." + which.toString
- }
-}
-
-// TODO: This mutates a Conf to set properties right now, which is kind of ugly when used in the
-// initialization of StreamingContext. It's okay for users trying to configure stuff themselves.
-private[spark] object MetadataCleaner {
- def getDelaySeconds(conf: SparkConf): Int = {
- conf.getTimeAsSeconds("spark.cleaner.ttl", "-1").toInt
- }
-
- def getDelaySeconds(
- conf: SparkConf,
- cleanerType: MetadataCleanerType.MetadataCleanerType): Int = {
- conf.get(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString).toInt
- }
-
- def setDelaySeconds(
- conf: SparkConf,
- cleanerType: MetadataCleanerType.MetadataCleanerType,
- delay: Int) {
- conf.set(MetadataCleanerType.systemProperty(cleanerType), delay.toString)
- }
-
- /**
- * Set the default delay time (in seconds).
- * @param conf SparkConf instance
- * @param delay default delay time to set
- * @param resetAll whether to reset all to default
- */
- def setDelaySeconds(conf: SparkConf, delay: Int, resetAll: Boolean = true) {
- conf.set("spark.cleaner.ttl", delay.toString)
- if (resetAll) {
- for (cleanerType <- MetadataCleanerType.values) {
- System.clearProperty(MetadataCleanerType.systemProperty(cleanerType))
- }
- }
- }
-}
-
diff --git a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala
index 945217203be7..034826c57ef1 100644
--- a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala
+++ b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala
@@ -17,9 +17,8 @@
package org.apache.spark.util
-import java.net.{URLClassLoader, URL}
+import java.net.{URL, URLClassLoader}
import java.util.Enumeration
-import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConverters._
@@ -48,32 +47,12 @@ private[spark] class ChildFirstURLClassLoader(urls: Array[URL], parent: ClassLoa
private val parentClassLoader = new ParentClassLoader(parent)
- /**
- * Used to implement fine-grained class loading locks similar to what is done by Java 7. This
- * prevents deadlock issues when using non-hierarchical class loaders.
- *
- * Note that due to some issues with implementing class loaders in
- * Scala, Java 7's `ClassLoader.registerAsParallelCapable` method is not called.
- */
- private val locks = new ConcurrentHashMap[String, Object]()
-
override def loadClass(name: String, resolve: Boolean): Class[_] = {
- var lock = locks.get(name)
- if (lock == null) {
- val newLock = new Object()
- lock = locks.putIfAbsent(name, newLock)
- if (lock == null) {
- lock = newLock
- }
- }
-
- lock.synchronized {
- try {
- super.loadClass(name, resolve)
- } catch {
- case e: ClassNotFoundException =>
- parentClassLoader.loadClass(name, resolve)
- }
+ try {
+ super.loadClass(name, resolve)
+ } catch {
+ case e: ClassNotFoundException =>
+ parentClassLoader.loadClass(name, resolve)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala
index 73d126ff6254..c9b7493fcdc1 100644
--- a/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala
+++ b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala
@@ -18,7 +18,7 @@
package org.apache.spark.util
/**
- * A class loader which makes some protected methods in ClassLoader accesible.
+ * A class loader which makes some protected methods in ClassLoader accessible.
*/
private[spark] class ParentClassLoader(parent: ClassLoader) extends ClassLoader(parent) {
diff --git a/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala
new file mode 100644
index 000000000000..ce06e18879a4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import scala.collection.mutable
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.SparkContext
+import org.apache.spark.internal.Logging
+import org.apache.spark.storage.StorageLevel
+
+
+/**
+ * This abstraction helps with persisting and checkpointing RDDs and types derived from RDDs
+ * (such as Graphs and DataFrames). In documentation, we use the phrase "Dataset" to refer to
+ * the distributed data type (RDD, Graph, etc.).
+ *
+ * Specifically, this abstraction automatically handles persisting and (optionally) checkpointing,
+ * as well as unpersisting and removing checkpoint files.
+ *
+ * Users should call update() when a new Dataset has been created,
+ * before the Dataset has been materialized. After updating [[PeriodicCheckpointer]], users are
+ * responsible for materializing the Dataset to ensure that persisting and checkpointing actually
+ * occur.
+ *
+ * When update() is called, this does the following:
+ * - Persist new Dataset (if not yet persisted), and put in queue of persisted Datasets.
+ * - Unpersist Datasets from queue until there are at most 3 persisted Datasets.
+ * - If using checkpointing and the checkpoint interval has been reached,
+ * - Checkpoint the new Dataset, and put in a queue of checkpointed Datasets.
+ * - Remove older checkpoints.
+ *
+ * WARNINGS:
+ * - This class should NOT be copied (since copies may conflict on which Datasets should be
+ * checkpointed).
+ * - This class removes checkpoint files once later Datasets have been checkpointed.
+ * However, references to the older Datasets will still return isCheckpointed = true.
+ *
+ * @param checkpointInterval Datasets will be checkpointed at this interval.
+ * If this interval was set as -1, then checkpointing will be disabled.
+ * @param sc SparkContext for the Datasets given to this checkpointer
+ * @tparam T Dataset type, such as RDD[Double]
+ */
+private[spark] abstract class PeriodicCheckpointer[T](
+ val checkpointInterval: Int,
+ val sc: SparkContext) extends Logging {
+
+ /** FIFO queue of past checkpointed Datasets */
+ private val checkpointQueue = mutable.Queue[T]()
+
+ /** FIFO queue of past persisted Datasets */
+ private val persistedQueue = mutable.Queue[T]()
+
+ /** Number of times [[update()]] has been called */
+ private var updateCount = 0
+
+ /**
+ * Update with a new Dataset. Handle persistence and checkpointing as needed.
+ * Since this handles persistence and checkpointing, this should be called before the Dataset
+ * has been materialized.
+ *
+ * @param newData New Dataset created from previous Datasets in the lineage.
+ */
+ def update(newData: T): Unit = {
+ persist(newData)
+ persistedQueue.enqueue(newData)
+ // We try to maintain 2 Datasets in persistedQueue to support the semantics of this class:
+ // Users should call [[update()]] when a new Dataset has been created,
+ // before the Dataset has been materialized.
+ while (persistedQueue.size > 3) {
+ val dataToUnpersist = persistedQueue.dequeue()
+ unpersist(dataToUnpersist)
+ }
+ updateCount += 1
+
+ // Handle checkpointing (after persisting)
+ if (checkpointInterval != -1 && (updateCount % checkpointInterval) == 0
+ && sc.getCheckpointDir.nonEmpty) {
+ // Add new checkpoint before removing old checkpoints.
+ checkpoint(newData)
+ checkpointQueue.enqueue(newData)
+ // Remove checkpoints before the latest one.
+ var canDelete = true
+ while (checkpointQueue.size > 1 && canDelete) {
+ // Delete the oldest checkpoint only if the next checkpoint exists.
+ if (isCheckpointed(checkpointQueue.head)) {
+ removeCheckpointFile()
+ } else {
+ canDelete = false
+ }
+ }
+ }
+ }
+
+ /** Checkpoint the Dataset */
+ protected def checkpoint(data: T): Unit
+
+ /** Return true iff the Dataset is checkpointed */
+ protected def isCheckpointed(data: T): Boolean
+
+ /**
+ * Persist the Dataset.
+ * Note: This should handle checking the current [[StorageLevel]] of the Dataset.
+ */
+ protected def persist(data: T): Unit
+
+ /** Unpersist the Dataset */
+ protected def unpersist(data: T): Unit
+
+ /** Get list of checkpoint files for this given Dataset */
+ protected def getCheckpointFiles(data: T): Iterable[String]
+
+ /**
+ * Call this to unpersist the Dataset.
+ */
+ def unpersistDataSet(): Unit = {
+ while (persistedQueue.nonEmpty) {
+ val dataToUnpersist = persistedQueue.dequeue()
+ unpersist(dataToUnpersist)
+ }
+ }
+
+ /**
+ * Call this at the end to delete any remaining checkpoint files.
+ */
+ def deleteAllCheckpoints(): Unit = {
+ while (checkpointQueue.nonEmpty) {
+ removeCheckpointFile()
+ }
+ }
+
+ /**
+ * Call this at the end to delete any remaining checkpoint files, except for the last checkpoint.
+ * Note that there may not be any checkpoints at all.
+ */
+ def deleteAllCheckpointsButLast(): Unit = {
+ while (checkpointQueue.size > 1) {
+ removeCheckpointFile()
+ }
+ }
+
+ /**
+ * Get all current checkpoint files.
+ * This is useful in combination with [[deleteAllCheckpointsButLast()]].
+ */
+ def getAllCheckpointFiles: Array[String] = {
+ checkpointQueue.flatMap(getCheckpointFiles).toArray
+ }
+
+ /**
+ * Dequeue the oldest checkpointed Dataset, and remove its checkpoint files.
+ * This prints a warning but does not fail if the files cannot be removed.
+ */
+ private def removeCheckpointFile(): Unit = {
+ val old = checkpointQueue.dequeue()
+ // Since the old checkpoint is not deleted by Spark, we manually delete it.
+ getCheckpointFiles(old).foreach(
+ PeriodicCheckpointer.removeCheckpointFile(_, sc.hadoopConfiguration))
+ }
+}
+
+private[spark] object PeriodicCheckpointer extends Logging {
+
+ /** Delete a checkpoint file, and log a warning if deletion fails. */
+ def removeCheckpointFile(checkpointFile: String, conf: Configuration): Unit = {
+ try {
+ val path = new Path(checkpointFile)
+ val fs = path.getFileSystem(conf)
+ fs.delete(path, true)
+ } catch {
+ case e: Exception =>
+ logWarning("PeriodicCheckpointer could not remove old checkpoint file: " +
+ checkpointFile)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala
index 7578a3b1d85f..46a5cb2cff5a 100644
--- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala
@@ -17,23 +17,19 @@
package org.apache.spark.util
-import scala.concurrent.duration.FiniteDuration
-import scala.language.postfixOps
-
-import org.apache.spark.{SparkEnv, SparkConf}
+import org.apache.spark.SparkConf
import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, RpcTimeout}
-object RpcUtils {
+private[spark] object RpcUtils {
/**
- * Retrieve a [[RpcEndpointRef]] which is located in the driver via its name.
+ * Retrieve a `RpcEndpointRef` which is located in the driver via its name.
*/
def makeDriverRef(name: String, conf: SparkConf, rpcEnv: RpcEnv): RpcEndpointRef = {
- val driverActorSystemName = SparkEnv.driverActorSystemName
val driverHost: String = conf.get("spark.driver.host", "localhost")
val driverPort: Int = conf.getInt("spark.driver.port", 7077)
Utils.checkHost(driverHost, "Expected hostname")
- rpcEnv.setupEndpointRef(driverActorSystemName, RpcAddress(driverHost, driverPort), name)
+ rpcEnv.setupEndpointRef(RpcAddress(driverHost, driverPort), name)
}
/** Returns the configured number of times to retry connecting */
@@ -47,22 +43,24 @@ object RpcUtils {
}
/** Returns the default Spark timeout to use for RPC ask operations. */
- private[spark] def askRpcTimeout(conf: SparkConf): RpcTimeout = {
+ def askRpcTimeout(conf: SparkConf): RpcTimeout = {
RpcTimeout(conf, Seq("spark.rpc.askTimeout", "spark.network.timeout"), "120s")
}
- @deprecated("use askRpcTimeout instead, this method was not intended to be public", "1.5.0")
- def askTimeout(conf: SparkConf): FiniteDuration = {
- askRpcTimeout(conf).duration
- }
-
/** Returns the default Spark timeout to use for RPC remote endpoint lookup. */
- private[spark] def lookupRpcTimeout(conf: SparkConf): RpcTimeout = {
+ def lookupRpcTimeout(conf: SparkConf): RpcTimeout = {
RpcTimeout(conf, Seq("spark.rpc.lookupTimeout", "spark.network.timeout"), "120s")
}
- @deprecated("use lookupRpcTimeout instead, this method was not intended to be public", "1.5.0")
- def lookupTimeout(conf: SparkConf): FiniteDuration = {
- lookupRpcTimeout(conf).duration
+ private val MAX_MESSAGE_SIZE_IN_MB = Int.MaxValue / 1024 / 1024
+
+ /** Returns the configured max message size for messages in bytes. */
+ def maxMessageSizeBytes(conf: SparkConf): Int = {
+ val maxSizeInMB = conf.getInt("spark.rpc.message.maxSize", 128)
+ if (maxSizeInMB > MAX_MESSAGE_SIZE_IN_MB) {
+ throw new IllegalArgumentException(
+ s"spark.rpc.message.maxSize should not be greater than $MAX_MESSAGE_SIZE_IN_MB MB")
+ }
+ maxSizeInMB * 1024 * 1024
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala
index db4a8b304ec3..4001fac3c3d5 100644
--- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala
+++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala
@@ -20,11 +20,11 @@ package org.apache.spark.util
import java.io.File
import java.util.PriorityQueue
-import scala.util.{Failure, Success, Try}
-import tachyon.client.TachyonFile
+import scala.util.Try
import org.apache.hadoop.fs.FileSystem
-import org.apache.spark.Logging
+
+import org.apache.spark.internal.Logging
/**
* Various utility methods used by Spark.
@@ -52,12 +52,14 @@ private[spark] object ShutdownHookManager extends Logging {
}
private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]()
- private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]()
// Add a shutdown hook to delete the temp dirs when the JVM exits
+ logDebug("Adding shutdown hook") // force eager creation of logger
addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () =>
logInfo("Shutdown hook called")
- shutdownDeletePaths.foreach { dirPath =>
+ // we need to materialize the paths to delete because deleteRecursively removes items from
+ // shutdownDeletePaths as we are traversing through it.
+ shutdownDeletePaths.toArray.foreach { dirPath =>
try {
logInfo("Deleting directory " + dirPath)
Utils.deleteRecursively(new File(dirPath))
@@ -75,14 +77,6 @@ private[spark] object ShutdownHookManager extends Logging {
}
}
- // Register the tachyon path to be deleted via shutdown hook
- def registerShutdownDeleteDir(tachyonfile: TachyonFile) {
- val absolutePath = tachyonfile.getPath()
- shutdownDeleteTachyonPaths.synchronized {
- shutdownDeleteTachyonPaths += absolutePath
- }
- }
-
// Remove the path to be deleted via shutdown hook
def removeShutdownDeleteDir(file: File) {
val absolutePath = file.getAbsolutePath()
@@ -91,14 +85,6 @@ private[spark] object ShutdownHookManager extends Logging {
}
}
- // Remove the tachyon path to be deleted via shutdown hook
- def removeShutdownDeleteDir(tachyonfile: TachyonFile) {
- val absolutePath = tachyonfile.getPath()
- shutdownDeleteTachyonPaths.synchronized {
- shutdownDeleteTachyonPaths.remove(absolutePath)
- }
- }
-
// Is the path already registered to be deleted via a shutdown hook ?
def hasShutdownDeleteDir(file: File): Boolean = {
val absolutePath = file.getAbsolutePath()
@@ -107,14 +93,6 @@ private[spark] object ShutdownHookManager extends Logging {
}
}
- // Is the path already registered to be deleted via a shutdown hook ?
- def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = {
- val absolutePath = file.getPath()
- shutdownDeleteTachyonPaths.synchronized {
- shutdownDeleteTachyonPaths.contains(absolutePath)
- }
- }
-
// Note: if file is child of some registered path, while not equal to it, then return true;
// else false. This is to ensure that two shutdown hooks do not try to delete each others
// paths - resulting in IOException and incomplete cleanup.
@@ -131,22 +109,6 @@ private[spark] object ShutdownHookManager extends Logging {
retval
}
- // Note: if file is child of some registered path, while not equal to it, then return true;
- // else false. This is to ensure that two shutdown hooks do not try to delete each others
- // paths - resulting in Exception and incomplete cleanup.
- def hasRootAsShutdownDeleteDir(file: TachyonFile): Boolean = {
- val absolutePath = file.getPath()
- val retval = shutdownDeleteTachyonPaths.synchronized {
- shutdownDeleteTachyonPaths.exists { path =>
- !absolutePath.equals(path) && absolutePath.startsWith(path)
- }
- }
- if (retval) {
- logInfo("path = " + file + ", already present as root for deletion.")
- }
- retval
- }
-
/**
* Detect whether this thread might be executing a shutdown hook. Will always return true if
* the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g.
@@ -160,7 +122,9 @@ private[spark] object ShutdownHookManager extends Logging {
val hook = new Thread {
override def run() {}
}
+ // scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(hook)
+ // scalastyle:on runtimeaddshutdownhook
Runtime.getRuntime.removeShutdownHook(hook)
} catch {
case ise: IllegalStateException => return true
@@ -204,54 +168,40 @@ private[spark] object ShutdownHookManager extends Logging {
private [util] class SparkShutdownHookManager {
private val hooks = new PriorityQueue[SparkShutdownHook]()
- private var shuttingDown = false
+ @volatile private var shuttingDown = false
/**
- * Install a hook to run at shutdown and run all registered hooks in order. Hadoop 1.x does not
- * have `ShutdownHookManager`, so in that case we just use the JVM's `Runtime` object and hope for
- * the best.
+ * Install a hook to run at shutdown and run all registered hooks in order.
*/
def install(): Unit = {
val hookTask = new Runnable() {
override def run(): Unit = runAll()
}
- Try(Utils.classForName("org.apache.hadoop.util.ShutdownHookManager")) match {
- case Success(shmClass) =>
- val fsPriority = classOf[FileSystem]
- .getField("SHUTDOWN_HOOK_PRIORITY")
- .get(null) // static field, the value is not used
- .asInstanceOf[Int]
- val shm = shmClass.getMethod("get").invoke(null)
- shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int])
- .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30))
-
- case Failure(_) =>
- Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook"));
- }
+ org.apache.hadoop.util.ShutdownHookManager.get().addShutdownHook(
+ hookTask, FileSystem.SHUTDOWN_HOOK_PRIORITY + 30)
}
- def runAll(): Unit = synchronized {
+ def runAll(): Unit = {
shuttingDown = true
- while (!hooks.isEmpty()) {
- Try(Utils.logUncaughtExceptions(hooks.poll().run()))
+ var nextHook: SparkShutdownHook = null
+ while ({ nextHook = hooks.synchronized { hooks.poll() }; nextHook != null }) {
+ Try(Utils.logUncaughtExceptions(nextHook.run()))
}
}
- def add(priority: Int, hook: () => Unit): AnyRef = synchronized {
- checkState()
- val hookRef = new SparkShutdownHook(priority, hook)
- hooks.add(hookRef)
- hookRef
- }
-
- def remove(ref: AnyRef): Boolean = synchronized {
- hooks.remove(ref)
+ def add(priority: Int, hook: () => Unit): AnyRef = {
+ hooks.synchronized {
+ if (shuttingDown) {
+ throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.")
+ }
+ val hookRef = new SparkShutdownHook(priority, hook)
+ hooks.add(hookRef)
+ hookRef
+ }
}
- private def checkState(): Unit = {
- if (shuttingDown) {
- throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.")
- }
+ def remove(ref: AnyRef): Boolean = {
+ hooks.synchronized { hooks.remove(ref) }
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/SignalLogger.scala b/core/src/main/scala/org/apache/spark/util/SignalLogger.scala
deleted file mode 100644
index f77488ef3d44..000000000000
--- a/core/src/main/scala/org/apache/spark/util/SignalLogger.scala
+++ /dev/null
@@ -1,60 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util
-
-import org.apache.commons.lang3.SystemUtils
-import org.slf4j.Logger
-import sun.misc.{Signal, SignalHandler}
-
-/**
- * Used to log signals received. This can be very useful in debugging crashes or kills.
- *
- * Inspired by Colin Patrick McCabe's similar class from Hadoop.
- */
-private[spark] object SignalLogger {
-
- private var registered = false
-
- /** Register a signal handler to log signals on UNIX-like systems. */
- def register(log: Logger): Unit = synchronized {
- if (SystemUtils.IS_OS_UNIX) {
- require(!registered, "Can't re-install the signal handlers")
- registered = true
-
- val signals = Seq("TERM", "HUP", "INT")
- for (signal <- signals) {
- try {
- new SignalLoggerHandler(signal, log)
- } catch {
- case e: Exception => log.warn("Failed to register signal handler " + signal, e)
- }
- }
- log.info("Registered signal handlers for [" + signals.mkString(", ") + "]")
- }
- }
-}
-
-private sealed class SignalLoggerHandler(name: String, log: Logger) extends SignalHandler {
-
- val prevHandler = Signal.handle(new Signal(name), this)
-
- override def handle(signal: Signal): Unit = {
- log.error("RECEIVED SIGNAL " + signal.getNumber() + ": SIG" + signal.getName())
- prevHandler.handle(signal)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/util/SignalUtils.scala b/core/src/main/scala/org/apache/spark/util/SignalUtils.scala
new file mode 100644
index 000000000000..5a24965170ce
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/SignalUtils.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util.Collections
+
+import scala.collection.JavaConverters._
+
+import org.apache.commons.lang3.SystemUtils
+import org.slf4j.Logger
+import sun.misc.{Signal, SignalHandler}
+
+import org.apache.spark.internal.Logging
+
+/**
+ * Contains utilities for working with posix signals.
+ */
+private[spark] object SignalUtils extends Logging {
+
+ /** A flag to make sure we only register the logger once. */
+ private var loggerRegistered = false
+
+ /** Register a signal handler to log signals on UNIX-like systems. */
+ def registerLogger(log: Logger): Unit = synchronized {
+ if (!loggerRegistered) {
+ Seq("TERM", "HUP", "INT").foreach { sig =>
+ SignalUtils.register(sig) {
+ log.error("RECEIVED SIGNAL " + sig)
+ false
+ }
+ }
+ loggerRegistered = true
+ }
+ }
+
+ /**
+ * Adds an action to be run when a given signal is received by this process.
+ *
+ * Note that signals are only supported on unix-like operating systems and work on a best-effort
+ * basis: if a signal is not available or cannot be intercepted, only a warning is emitted.
+ *
+ * All actions for a given signal are run in a separate thread.
+ */
+ def register(signal: String)(action: => Boolean): Unit = synchronized {
+ if (SystemUtils.IS_OS_UNIX) {
+ try {
+ val handler = handlers.getOrElseUpdate(signal, {
+ logInfo("Registered signal handler for " + signal)
+ new ActionHandler(new Signal(signal))
+ })
+ handler.register(action)
+ } catch {
+ case ex: Exception => logWarning(s"Failed to register signal handler for " + signal, ex)
+ }
+ }
+ }
+
+ /**
+ * A handler for the given signal that runs a collection of actions.
+ */
+ private class ActionHandler(signal: Signal) extends SignalHandler {
+
+ /**
+ * List of actions upon the signal; the callbacks should return true if the signal is "handled",
+ * i.e. should not escalate to the next callback.
+ */
+ private val actions = Collections.synchronizedList(new java.util.LinkedList[() => Boolean])
+
+ // original signal handler, before this handler was attached
+ private val prevHandler: SignalHandler = Signal.handle(signal, this)
+
+ /**
+ * Called when this handler's signal is received. Note that if the same signal is received
+ * before this method returns, it is escalated to the previous handler.
+ */
+ override def handle(sig: Signal): Unit = {
+ // register old handler, will receive incoming signals while this handler is running
+ Signal.handle(signal, prevHandler)
+
+ // Run all actions, escalate to parent handler if no action catches the signal
+ // (i.e. all actions return false). Note that calling `map` is to ensure that
+ // all actions are run, `forall` is short-circuited and will stop evaluating
+ // after reaching a first false predicate.
+ val escalate = actions.asScala.map(action => action()).forall(_ == false)
+ if (escalate) {
+ prevHandler.handle(sig)
+ }
+
+ // re-register this handler
+ Signal.handle(signal, this)
+ }
+
+ /**
+ * Adds an action to be run by this handler.
+ * @param action An action to be run when a signal is received. Return true if the signal
+ * should be stopped with this handler, false if it should be escalated.
+ */
+ def register(action: => Boolean): Unit = actions.add(() => action)
+ }
+
+ /** Mapping from signal to their respective handlers. */
+ private val handlers = new scala.collection.mutable.HashMap[String, ActionHandler]
+}
diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
index 23ee4eff0881..3bfdf95db84c 100644
--- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
+++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
@@ -17,20 +17,34 @@
package org.apache.spark.util
-import com.google.common.collect.MapMaker
-
import java.lang.management.ManagementFactory
import java.lang.reflect.{Field, Modifier}
import java.util.{IdentityHashMap, Random}
-import java.util.concurrent.ConcurrentHashMap
import scala.collection.mutable.ArrayBuffer
import scala.runtime.ScalaRunTime
-import org.apache.spark.Logging
+import com.google.common.collect.MapMaker
+
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.internal.Logging
import org.apache.spark.util.collection.OpenHashSet
+/**
+ * A trait that allows a class to give [[SizeEstimator]] more accurate size estimation.
+ * When a class extends it, [[SizeEstimator]] will query the `estimatedSize` first.
+ * If `estimatedSize` does not return [[None]], [[SizeEstimator]] will use the returned size
+ * as the size of the object. Otherwise, [[SizeEstimator]] will do the estimation work.
+ * The difference between a [[KnownSizeEstimation]] and
+ * [[org.apache.spark.util.collection.SizeTracker]] is that, a
+ * [[org.apache.spark.util.collection.SizeTracker]] still uses [[SizeEstimator]] to
+ * estimate the size. However, a [[KnownSizeEstimation]] can provide a better estimation without
+ * using [[SizeEstimator]].
+ */
+private[spark] trait KnownSizeEstimation {
+ def estimatedSize: Long
+}
+
/**
* :: DeveloperApi ::
* Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in
@@ -137,13 +151,12 @@ object SizeEstimator extends Logging {
// TODO: We could use reflection on the VMOption returned ?
getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true")
} catch {
- case e: Exception => {
+ case e: Exception =>
// Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB
val guess = Runtime.getRuntime.maxMemory < (32L*1024*1024*1024)
val guessInWords = if (guess) "yes" else "not"
logWarning("Failed to check whether UseCompressedOops is set; assuming " + guessInWords)
return guess
- }
}
}
@@ -194,15 +207,23 @@ object SizeEstimator extends Logging {
val cls = obj.getClass
if (cls.isArray) {
visitArray(obj, cls, state)
+ } else if (cls.getName.startsWith("scala.reflect")) {
+ // Many objects in the scala.reflect package reference global reflection objects which, in
+ // turn, reference many other large global objects. Do nothing in this case.
} else if (obj.isInstanceOf[ClassLoader] || obj.isInstanceOf[Class[_]]) {
// Hadoop JobConfs created in the interpreter have a ClassLoader, which greatly confuses
// the size estimator since it references the whole REPL. Do nothing in this case. In
// general all ClassLoaders and Classes will be shared between objects anyway.
} else {
- val classInfo = getClassInfo(cls)
- state.size += alignSize(classInfo.shellSize)
- for (field <- classInfo.pointerFields) {
- state.enqueue(field.get(obj))
+ obj match {
+ case s: KnownSizeEstimation =>
+ state.size += s.estimatedSize
+ case _ =>
+ val classInfo = getClassInfo(cls)
+ state.size += alignSize(classInfo.shellSize)
+ for (field <- classInfo.pointerFields) {
+ state.enqueue(field.get(obj))
+ }
}
}
}
@@ -234,7 +255,7 @@ object SizeEstimator extends Logging {
} else {
// Estimate the size of a large array by sampling elements without replacement.
// To exclude the shared objects that the array elements may link, sample twice
- // and use the min one to caculate array size.
+ // and use the min one to calculate array size.
val rand = new Random(42)
val drawn = new OpenHashSet[Int](2 * ARRAY_SAMPLE_SIZE)
val s1 = sampleArray(array, state, rand, drawn, length)
@@ -329,7 +350,7 @@ object SizeEstimator extends Logging {
// 3. consistent fields layouts throughout the hierarchy: This means we should layout
// superclass first. And we can use superclass's shellSize as a starting point to layout the
// other fields in this class.
- // 4. class alignment: HotSpot rounds field blocks up to to HeapOopSize not 4 bytes, confirmed
+ // 4. class alignment: HotSpot rounds field blocks up to HeapOopSize not 4 bytes, confirmed
// with Aleksey. see https://bugs.openjdk.java.net/browse/CODETOOLS-7901322
//
// The real world field layout is much more complicated. There are three kinds of fields
diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala
index 724818724733..95bf3f58bc77 100644
--- a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala
+++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala
@@ -17,7 +17,7 @@
package org.apache.spark.util
-import org.apache.spark.Logging
+import org.apache.spark.internal.Logging
/**
* The default uncaught exception handler for Executors terminates the whole process, to avoid
@@ -29,7 +29,11 @@ private[spark] object SparkUncaughtExceptionHandler
override def uncaughtException(thread: Thread, exception: Throwable) {
try {
- logError("Uncaught exception in thread " + thread, exception)
+ // Make it explicit that uncaught exceptions are thrown when container is shutting down.
+ // It will help users when they analyze the executor logs
+ val inShutdownMsg = if (ShutdownHookManager.inShutdown()) "[Container in shutdown] " else ""
+ val errMsg = "Uncaught exception in thread "
+ logError(inShutdownMsg + errMsg + thread, exception)
// We may have been called from a shutdown hook. If so, we must not call System.exit().
// (If we do, we will deadlock.)
diff --git a/core/src/main/scala/org/apache/spark/util/StatCounter.scala b/core/src/main/scala/org/apache/spark/util/StatCounter.scala
index 8586da1996cf..1e02638591f8 100644
--- a/core/src/main/scala/org/apache/spark/util/StatCounter.scala
+++ b/core/src/main/scala/org/apache/spark/util/StatCounter.scala
@@ -17,11 +17,13 @@
package org.apache.spark.util
+import org.apache.spark.annotation.Since
+
/**
* A class for tracking the statistics of a set of numbers (count, mean and variance) in a
* numerically robust way. Includes support for merging two StatCounters. Based on Welford
- * and Chan's [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance algorithms]]
- * for running variance.
+ * and Chan's
+ * algorithms for running variance.
*
* @constructor Initialize the StatCounter with the given values.
*/
@@ -104,8 +106,14 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
def min: Double = minValue
- /** Return the variance of the values. */
- def variance: Double = {
+ /** Return the population variance of the values. */
+ def variance: Double = popVariance
+
+ /**
+ * Return the population variance of the values.
+ */
+ @Since("2.1.0")
+ def popVariance: Double = {
if (n == 0) {
Double.NaN
} else {
@@ -125,8 +133,14 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
}
}
- /** Return the standard deviation of the values. */
- def stdev: Double = math.sqrt(variance)
+ /** Return the population standard deviation of the values. */
+ def stdev: Double = popStdev
+
+ /**
+ * Return the population standard deviation of the values.
+ */
+ @Since("2.1.0")
+ def popStdev: Double = math.sqrt(popVariance)
/**
* Return the sample standard deviation of the values, which corrects for bias in estimating the
diff --git a/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala b/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala
deleted file mode 100644
index c1b8bf052c0c..000000000000
--- a/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala
+++ /dev/null
@@ -1,33 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util
-
-import java.util.EventListener
-
-import org.apache.spark.TaskContext
-import org.apache.spark.annotation.DeveloperApi
-
-/**
- * :: DeveloperApi ::
- *
- * Listener providing a callback function to invoke when a task's execution completes.
- */
-@DeveloperApi
-trait TaskCompletionListener extends EventListener {
- def onTaskCompletion(context: TaskContext)
-}
diff --git a/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala b/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala
deleted file mode 100644
index f64e069cd172..000000000000
--- a/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util
-
-/**
- * Exception thrown when there is an exception in
- * executing the callback in TaskCompletionListener.
- */
-private[spark]
-class TaskCompletionListenerException(errorMessages: Seq[String]) extends Exception {
-
- override def getMessage: String = {
- if (errorMessages.size == 1) {
- errorMessages.head
- } else {
- errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n")
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala
index d4e0ad93b966..b1217980faf1 100644
--- a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala
+++ b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala
@@ -24,4 +24,8 @@ private[spark] case class ThreadStackTrace(
threadId: Long,
threadName: String,
threadState: Thread.State,
- stackTrace: String)
+ stackTrace: String,
+ blockedByThreadId: Option[Long],
+ blockedByLock: String,
+ holdingLocks: Seq[String])
+
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
index 53283448c87b..81aaf79db0c1 100644
--- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
@@ -19,11 +19,15 @@ package org.apache.spark.util
import java.util.concurrent._
-import scala.concurrent.{ExecutionContext, ExecutionContextExecutor}
+import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor}
+import scala.concurrent.duration.Duration
+import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread}
import scala.util.control.NonFatal
import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder}
+import org.apache.spark.SparkException
+
private[spark] object ThreadUtils {
private val sameThreadExecutionContext =
@@ -56,10 +60,18 @@ private[spark] object ThreadUtils {
* Create a cached thread pool whose max number of threads is `maxThreadNumber`. Thread names
* are formatted as prefix-ID, where ID is a unique, sequentially assigned integer.
*/
- def newDaemonCachedThreadPool(prefix: String, maxThreadNumber: Int): ThreadPoolExecutor = {
+ def newDaemonCachedThreadPool(
+ prefix: String, maxThreadNumber: Int, keepAliveSeconds: Int = 60): ThreadPoolExecutor = {
val threadFactory = namedThreadFactory(prefix)
- new ThreadPoolExecutor(
- 0, maxThreadNumber, 60L, TimeUnit.SECONDS, new SynchronousQueue[Runnable], threadFactory)
+ val threadPool = new ThreadPoolExecutor(
+ maxThreadNumber, // corePoolSize: the max number of threads to create before queuing the tasks
+ maxThreadNumber, // maximumPoolSize: because we use LinkedBlockingDeque, this one is not used
+ keepAliveSeconds,
+ TimeUnit.SECONDS,
+ new LinkedBlockingQueue[Runnable],
+ threadFactory)
+ threadPool.allowCoreThreadTimeOut(true)
+ threadPool
}
/**
@@ -148,4 +160,71 @@ private[spark] object ThreadUtils {
result
}
}
+
+ /**
+ * Construct a new Scala ForkJoinPool with a specified max parallelism and name prefix.
+ */
+ def newForkJoinPool(prefix: String, maxThreadNumber: Int): SForkJoinPool = {
+ // Custom factory to set thread names
+ val factory = new SForkJoinPool.ForkJoinWorkerThreadFactory {
+ override def newThread(pool: SForkJoinPool) =
+ new SForkJoinWorkerThread(pool) {
+ setName(prefix + "-" + super.getName)
+ }
+ }
+ new SForkJoinPool(maxThreadNumber, factory,
+ null, // handler
+ false // asyncMode
+ )
+ }
+
+ // scalastyle:off awaitresult
+ /**
+ * Preferred alternative to `Await.result()`.
+ *
+ * This method wraps and re-throws any exceptions thrown by the underlying `Await` call, ensuring
+ * that this thread's stack trace appears in logs.
+ *
+ * In addition, it calls `Awaitable.result` directly to avoid using `ForkJoinPool`'s
+ * `BlockingContext`. Codes running in the user's thread may be in a thread of Scala ForkJoinPool.
+ * As concurrent executions in ForkJoinPool may see some [[ThreadLocal]] value unexpectedly, this
+ * method basically prevents ForkJoinPool from running other tasks in the current waiting thread.
+ * In general, we should use this method because many places in Spark use [[ThreadLocal]] and it's
+ * hard to debug when [[ThreadLocal]]s leak to other tasks.
+ */
+ @throws(classOf[SparkException])
+ def awaitResult[T](awaitable: Awaitable[T], atMost: Duration): T = {
+ try {
+ // `awaitPermission` is not actually used anywhere so it's safe to pass in null here.
+ // See SPARK-13747.
+ val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait]
+ awaitable.result(atMost)(awaitPermission)
+ } catch {
+ // TimeoutException is thrown in the current thread, so not need to warp the exception.
+ case NonFatal(t) if !t.isInstanceOf[TimeoutException] =>
+ throw new SparkException("Exception thrown in awaitResult: ", t)
+ }
+ }
+ // scalastyle:on awaitresult
+
+ // scalastyle:off awaitready
+ /**
+ * Preferred alternative to `Await.ready()`.
+ *
+ * @see [[awaitResult]]
+ */
+ @throws(classOf[SparkException])
+ def awaitReady[T](awaitable: Awaitable[T], atMost: Duration): awaitable.type = {
+ try {
+ // `awaitPermission` is not actually used anywhere so it's safe to pass in null here.
+ // See SPARK-13747.
+ val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait]
+ awaitable.ready(atMost)(awaitPermission)
+ } catch {
+ // TimeoutException is thrown in the current thread, so not need to warp the exception.
+ case NonFatal(t) if !t.isInstanceOf[TimeoutException] =>
+ throw new SparkException("Exception thrown in awaitResult: ", t)
+ }
+ }
+ // scalastyle:on awaitready
}
diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
index d7e5143c3095..32af0127bbf3 100644
--- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
@@ -17,14 +17,14 @@
package org.apache.spark.util
-import java.util.Set
import java.util.Map.Entry
+import java.util.Set
import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConverters._
import scala.collection.mutable
-import org.apache.spark.Logging
+import org.apache.spark.internal.Logging
private[spark] case class TimeStampedValue[V](value: V, timestamp: Long)
diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala
deleted file mode 100644
index 65efeb1f4c19..000000000000
--- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala
+++ /dev/null
@@ -1,86 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util
-
-import java.util.concurrent.ConcurrentHashMap
-
-import scala.collection.JavaConverters._
-import scala.collection.mutable.Set
-
-private[spark] class TimeStampedHashSet[A] extends Set[A] {
- val internalMap = new ConcurrentHashMap[A, Long]()
-
- def contains(key: A): Boolean = {
- internalMap.contains(key)
- }
-
- def iterator: Iterator[A] = {
- val jIterator = internalMap.entrySet().iterator()
- jIterator.asScala.map(_.getKey)
- }
-
- override def + (elem: A): Set[A] = {
- val newSet = new TimeStampedHashSet[A]
- newSet ++= this
- newSet += elem
- newSet
- }
-
- override def - (elem: A): Set[A] = {
- val newSet = new TimeStampedHashSet[A]
- newSet ++= this
- newSet -= elem
- newSet
- }
-
- override def += (key: A): this.type = {
- internalMap.put(key, currentTime)
- this
- }
-
- override def -= (key: A): this.type = {
- internalMap.remove(key)
- this
- }
-
- override def empty: Set[A] = new TimeStampedHashSet[A]()
-
- override def size(): Int = internalMap.size()
-
- override def foreach[U](f: (A) => U): Unit = {
- val iterator = internalMap.entrySet().iterator()
- while(iterator.hasNext) {
- f(iterator.next.getKey)
- }
- }
-
- /**
- * Removes old values that have timestamp earlier than `threshTime`
- */
- def clearOldValues(threshTime: Long) {
- val iterator = internalMap.entrySet().iterator()
- while(iterator.hasNext) {
- val entry = iterator.next()
- if (entry.getValue < threshTime) {
- iterator.remove()
- }
- }
- }
-
- private def currentTime: Long = System.currentTimeMillis()
-}
diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala
deleted file mode 100644
index 310c0c109416..000000000000
--- a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala
+++ /dev/null
@@ -1,171 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util
-
-import java.lang.ref.WeakReference
-import java.util.concurrent.atomic.AtomicInteger
-
-import scala.collection.mutable
-import scala.language.implicitConversions
-
-import org.apache.spark.Logging
-
-/**
- * A wrapper of TimeStampedHashMap that ensures the values are weakly referenced and timestamped.
- *
- * If the value is garbage collected and the weak reference is null, get() will return a
- * non-existent value. These entries are removed from the map periodically (every N inserts), as
- * their values are no longer strongly reachable. Further, key-value pairs whose timestamps are
- * older than a particular threshold can be removed using the clearOldValues method.
- *
- * TimeStampedWeakValueHashMap exposes a scala.collection.mutable.Map interface, which allows it
- * to be a drop-in replacement for Scala HashMaps. Internally, it uses a Java ConcurrentHashMap,
- * so all operations on this HashMap are thread-safe.
- *
- * @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed.
- */
-private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boolean = false)
- extends mutable.Map[A, B]() with Logging {
-
- import TimeStampedWeakValueHashMap._
-
- private val internalMap = new TimeStampedHashMap[A, WeakReference[B]](updateTimeStampOnGet)
- private val insertCount = new AtomicInteger(0)
-
- /** Return a map consisting only of entries whose values are still strongly reachable. */
- private def nonNullReferenceMap = internalMap.filter { case (_, ref) => ref.get != null }
-
- def get(key: A): Option[B] = internalMap.get(key)
-
- def iterator: Iterator[(A, B)] = nonNullReferenceMap.iterator
-
- override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = {
- val newMap = new TimeStampedWeakValueHashMap[A, B1]
- val oldMap = nonNullReferenceMap.asInstanceOf[mutable.Map[A, WeakReference[B1]]]
- newMap.internalMap.putAll(oldMap.toMap)
- newMap.internalMap += kv
- newMap
- }
-
- override def - (key: A): mutable.Map[A, B] = {
- val newMap = new TimeStampedWeakValueHashMap[A, B]
- newMap.internalMap.putAll(nonNullReferenceMap.toMap)
- newMap.internalMap -= key
- newMap
- }
-
- override def += (kv: (A, B)): this.type = {
- internalMap += kv
- if (insertCount.incrementAndGet() % CLEAR_NULL_VALUES_INTERVAL == 0) {
- clearNullValues()
- }
- this
- }
-
- override def -= (key: A): this.type = {
- internalMap -= key
- this
- }
-
- override def update(key: A, value: B): Unit = this += ((key, value))
-
- override def apply(key: A): B = internalMap.apply(key)
-
- override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = nonNullReferenceMap.filter(p)
-
- override def empty: mutable.Map[A, B] = new TimeStampedWeakValueHashMap[A, B]()
-
- override def size: Int = internalMap.size
-
- override def foreach[U](f: ((A, B)) => U): Unit = nonNullReferenceMap.foreach(f)
-
- def putIfAbsent(key: A, value: B): Option[B] = internalMap.putIfAbsent(key, value)
-
- def toMap: Map[A, B] = iterator.toMap
-
- /** Remove old key-value pairs with timestamps earlier than `threshTime`. */
- def clearOldValues(threshTime: Long): Unit = internalMap.clearOldValues(threshTime)
-
- /** Remove entries with values that are no longer strongly reachable. */
- def clearNullValues() {
- val it = internalMap.getEntrySet.iterator
- while (it.hasNext) {
- val entry = it.next()
- if (entry.getValue.value.get == null) {
- logDebug("Removing key " + entry.getKey + " because it is no longer strongly reachable.")
- it.remove()
- }
- }
- }
-
- // For testing
-
- def getTimestamp(key: A): Option[Long] = {
- internalMap.getTimeStampedValue(key).map(_.timestamp)
- }
-
- def getReference(key: A): Option[WeakReference[B]] = {
- internalMap.getTimeStampedValue(key).map(_.value)
- }
-}
-
-/**
- * Helper methods for converting to and from WeakReferences.
- */
-private object TimeStampedWeakValueHashMap {
-
- // Number of inserts after which entries with null references are removed
- val CLEAR_NULL_VALUES_INTERVAL = 100
-
- /* Implicit conversion methods to WeakReferences. */
-
- implicit def toWeakReference[V](v: V): WeakReference[V] = new WeakReference[V](v)
-
- implicit def toWeakReferenceTuple[K, V](kv: (K, V)): (K, WeakReference[V]) = {
- kv match { case (k, v) => (k, toWeakReference(v)) }
- }
-
- implicit def toWeakReferenceFunction[K, V, R](p: ((K, V)) => R): ((K, WeakReference[V])) => R = {
- (kv: (K, WeakReference[V])) => p(kv)
- }
-
- /* Implicit conversion methods from WeakReferences. */
-
- implicit def fromWeakReference[V](ref: WeakReference[V]): V = ref.get
-
- implicit def fromWeakReferenceOption[V](v: Option[WeakReference[V]]): Option[V] = {
- v match {
- case Some(ref) => Option(fromWeakReference(ref))
- case None => None
- }
- }
-
- implicit def fromWeakReferenceTuple[K, V](kv: (K, WeakReference[V])): (K, V) = {
- kv match { case (k, v) => (k, fromWeakReference(v)) }
- }
-
- implicit def fromWeakReferenceIterator[K, V](
- it: Iterator[(K, WeakReference[V])]): Iterator[(K, V)] = {
- it.map(fromWeakReferenceTuple)
- }
-
- implicit def fromWeakReferenceMap[K, V](
- map: mutable.Map[K, WeakReference[V]]) : mutable.Map[K, V] = {
- mutable.Map(map.mapValues(fromWeakReference).toSeq: _*)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala
new file mode 100644
index 000000000000..27922b31949b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import javax.annotation.concurrent.GuardedBy
+
+/**
+ * A special Thread that provides "runUninterruptibly" to allow running codes without being
+ * interrupted by `Thread.interrupt()`. If `Thread.interrupt()` is called during runUninterruptibly
+ * is running, it won't set the interrupted status. Instead, setting the interrupted status will be
+ * deferred until it's returning from "runUninterruptibly".
+ *
+ * Note: "runUninterruptibly" should be called only in `this` thread.
+ */
+private[spark] class UninterruptibleThread(
+ target: Runnable,
+ name: String) extends Thread(target, name) {
+
+ def this(name: String) {
+ this(null, name)
+ }
+
+ /** A monitor to protect "uninterruptible" and "interrupted" */
+ private val uninterruptibleLock = new Object
+
+ /**
+ * Indicates if `this` thread are in the uninterruptible status. If so, interrupting
+ * "this" will be deferred until `this` enters into the interruptible status.
+ */
+ @GuardedBy("uninterruptibleLock")
+ private var uninterruptible = false
+
+ /**
+ * Indicates if we should interrupt `this` when we are leaving the uninterruptible zone.
+ */
+ @GuardedBy("uninterruptibleLock")
+ private var shouldInterruptThread = false
+
+ /**
+ * Run `f` uninterruptibly in `this` thread. The thread won't be interrupted before returning
+ * from `f`.
+ *
+ * If this method finds that `interrupt` is called before calling `f` and it's not inside another
+ * `runUninterruptibly`, it will throw `InterruptedException`.
+ *
+ * Note: this method should be called only in `this` thread.
+ */
+ def runUninterruptibly[T](f: => T): T = {
+ if (Thread.currentThread() != this) {
+ throw new IllegalStateException(s"Call runUninterruptibly in a wrong thread. " +
+ s"Expected: $this but was ${Thread.currentThread()}")
+ }
+
+ if (uninterruptibleLock.synchronized { uninterruptible }) {
+ // We are already in the uninterruptible status. So just run "f" and return
+ return f
+ }
+
+ uninterruptibleLock.synchronized {
+ // Clear the interrupted status if it's set.
+ if (Thread.interrupted() || shouldInterruptThread) {
+ shouldInterruptThread = false
+ // Since it's interrupted, we don't need to run `f` which may be a long computation.
+ // Throw InterruptedException as we don't have a T to return.
+ throw new InterruptedException()
+ }
+ uninterruptible = true
+ }
+ try {
+ f
+ } finally {
+ uninterruptibleLock.synchronized {
+ uninterruptible = false
+ if (shouldInterruptThread) {
+ // Recover the interrupted status
+ super.interrupt()
+ shouldInterruptThread = false
+ }
+ }
+ }
+ }
+
+ /**
+ * Interrupt `this` thread if possible. If `this` is in the uninterruptible status, it won't be
+ * interrupted until it enters into the interruptible status.
+ */
+ override def interrupt(): Unit = {
+ uninterruptibleLock.synchronized {
+ if (uninterruptible) {
+ shouldInterruptThread = true
+ } else {
+ super.interrupt()
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 5a976ee839b1..626b65679a27 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -18,22 +18,35 @@
package org.apache.spark.util
import java.io._
-import java.lang.management.ManagementFactory
+import java.lang.{Byte => JByte}
+import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo}
+import java.math.{MathContext, RoundingMode}
import java.net._
import java.nio.ByteBuffer
-import java.util.{Properties, Locale, Random, UUID}
+import java.nio.channels.{Channels, FileChannel}
+import java.nio.charset.StandardCharsets
+import java.nio.file.{Files, Paths}
+import java.security.SecureRandom
+import java.util.{Locale, Properties, Random, UUID}
import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicBoolean
+import java.util.zip.GZIPInputStream
import javax.net.ssl.HttpsURLConnection
+import scala.annotation.tailrec
import scala.collection.JavaConverters._
import scala.collection.Map
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
import scala.reflect.ClassTag
-import scala.util.{Failure, Success, Try}
+import scala.util.Try
import scala.util.control.{ControlThrowable, NonFatal}
+import scala.util.matching.Regex
-import com.google.common.io.{ByteStreams, Files}
+import _root_.io.netty.channel.unix.Errors.NativeIoException
+import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
+import com.google.common.hash.HashCodes
+import com.google.common.io.{ByteStreams, Files => GFiles}
import com.google.common.net.InetAddresses
import org.apache.commons.lang3.SystemUtils
import org.apache.hadoop.conf.Configuration
@@ -42,12 +55,12 @@ import org.apache.hadoop.security.UserGroupInformation
import org.apache.log4j.PropertyConfigurator
import org.eclipse.jetty.util.MultiException
import org.json4s._
-
-import tachyon.TachyonURI
-import tachyon.client.{TachyonFS, TachyonFile}
+import org.slf4j.Logger
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
@@ -57,6 +70,7 @@ private[spark] case class CallSite(shortForm: String, longForm: String)
private[spark] object CallSite {
val SHORT_FORM = "callSite.short"
val LONG_FORM = "callSite.long"
+ val empty = CallSite("", "")
}
/**
@@ -74,6 +88,52 @@ private[spark] object Utils extends Logging {
private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
@volatile private var localRootDirs: Array[String] = null
+ /**
+ * The performance overhead of creating and logging strings for wide schemas can be large. To
+ * limit the impact, we bound the number of fields to include by default. This can be overridden
+ * by setting the 'spark.debug.maxToStringFields' conf in SparkEnv.
+ */
+ val DEFAULT_MAX_TO_STRING_FIELDS = 25
+
+ private def maxNumToStringFields = {
+ if (SparkEnv.get != null) {
+ SparkEnv.get.conf.getInt("spark.debug.maxToStringFields", DEFAULT_MAX_TO_STRING_FIELDS)
+ } else {
+ DEFAULT_MAX_TO_STRING_FIELDS
+ }
+ }
+
+ /** Whether we have warned about plan string truncation yet. */
+ private val truncationWarningPrinted = new AtomicBoolean(false)
+
+ /**
+ * Format a sequence with semantics similar to calling .mkString(). Any elements beyond
+ * maxNumToStringFields will be dropped and replaced by a "... N more fields" placeholder.
+ *
+ * @return the trimmed and formatted string.
+ */
+ def truncatedString[T](
+ seq: Seq[T],
+ start: String,
+ sep: String,
+ end: String,
+ maxNumFields: Int = maxNumToStringFields): String = {
+ if (seq.length > maxNumFields) {
+ if (truncationWarningPrinted.compareAndSet(false, true)) {
+ logWarning(
+ "Truncated the string representation of a plan since it was too large. This " +
+ "behavior can be adjusted by setting 'spark.debug.maxToStringFields' in SparkEnv.conf.")
+ }
+ val numFields = math.max(0, maxNumFields - 1)
+ seq.take(numFields).mkString(
+ start, sep, sep + "... " + (seq.length - numFields) + " more fields" + end)
+ } else {
+ seq.mkString(start, sep, end)
+ }
+ }
+
+ /** Shorthand for calling truncatedString() without start or end strings. */
+ def truncatedString[T](seq: Seq[T], sep: String): String = truncatedString(seq, "", sep, "")
/** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = {
@@ -177,13 +237,30 @@ private[spark] object Utils extends Logging {
/**
* Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]]
*/
- def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput): Unit = {
+ def writeByteBuffer(bb: ByteBuffer, out: DataOutput): Unit = {
if (bb.hasArray) {
out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
} else {
+ val originalPosition = bb.position()
val bbval = new Array[Byte](bb.remaining())
bb.get(bbval)
out.write(bbval)
+ bb.position(originalPosition)
+ }
+ }
+
+ /**
+ * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.OutputStream]]
+ */
+ def writeByteBuffer(bb: ByteBuffer, out: OutputStream): Unit = {
+ if (bb.hasArray) {
+ out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
+ } else {
+ val originalPosition = bb.position()
+ val bbval = new Array[Byte](bb.remaining())
+ bb.get(bbval)
+ out.write(bbval)
+ bb.position(originalPosition)
}
}
@@ -239,45 +316,27 @@ private[spark] object Utils extends Logging {
dir
}
- /** Copy all data from an InputStream to an OutputStream. NIO way of file stream to file stream
- * copying is disabled by default unless explicitly set transferToEnabled as true,
- * the parameter transferToEnabled should be configured by spark.file.transferTo = [true|false].
- */
- def copyStream(in: InputStream,
- out: OutputStream,
- closeStreams: Boolean = false,
- transferToEnabled: Boolean = false): Long =
- {
- var count = 0L
+ /**
+ * Copy all data from an InputStream to an OutputStream. NIO way of file stream to file stream
+ * copying is disabled by default unless explicitly set transferToEnabled as true,
+ * the parameter transferToEnabled should be configured by spark.file.transferTo = [true|false].
+ */
+ def copyStream(
+ in: InputStream,
+ out: OutputStream,
+ closeStreams: Boolean = false,
+ transferToEnabled: Boolean = false): Long = {
tryWithSafeFinally {
if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream]
&& transferToEnabled) {
// When both streams are File stream, use transferTo to improve copy performance.
val inChannel = in.asInstanceOf[FileInputStream].getChannel()
val outChannel = out.asInstanceOf[FileOutputStream].getChannel()
- val initialPos = outChannel.position()
val size = inChannel.size()
-
- // In case transferTo method transferred less data than we have required.
- while (count < size) {
- count += inChannel.transferTo(count, size - count, outChannel)
- }
-
- // Check the position after transferTo loop to see if it is in the right position and
- // give user information if not.
- // Position will not be increased to the expected length after calling transferTo in
- // kernel version 2.6.32, this issue can be seen in
- // https://bugs.openjdk.java.net/browse/JDK-7052359
- // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948).
- val finalPos = outChannel.position()
- assert(finalPos == initialPos + size,
- s"""
- |Current position $finalPos do not equal to expected position ${initialPos + size}
- |after transferTo, please check your kernel version to see if it is 2.6.32,
- |this is a kernel bug which will lead to unexpected behavior when using transferTo.
- |You can set spark.file.transferTo = false to disable this NIO feature.
- """.stripMargin)
+ copyFileStreamNIO(inChannel, outChannel, 0, size)
+ size
} else {
+ var count = 0L
val buf = new Array[Byte](8192)
var n = 0
while (n != -1) {
@@ -287,8 +346,8 @@ private[spark] object Utils extends Logging {
count += n
}
}
+ count
}
- count
} {
if (closeStreams) {
try {
@@ -300,6 +359,37 @@ private[spark] object Utils extends Logging {
}
}
+ def copyFileStreamNIO(
+ input: FileChannel,
+ output: FileChannel,
+ startPosition: Long,
+ bytesToCopy: Long): Unit = {
+ val initialPos = output.position()
+ var count = 0L
+ // In case transferTo method transferred less data than we have required.
+ while (count < bytesToCopy) {
+ count += input.transferTo(count + startPosition, bytesToCopy - count, output)
+ }
+ assert(count == bytesToCopy,
+ s"request to copy $bytesToCopy bytes, but actually copied $count bytes.")
+
+ // Check the position after transferTo loop to see if it is in the right position and
+ // give user information if not.
+ // Position will not be increased to the expected length after calling transferTo in
+ // kernel version 2.6.32, this issue can be seen in
+ // https://bugs.openjdk.java.net/browse/JDK-7052359
+ // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948).
+ val finalPos = output.position()
+ val expectedPos = initialPos + bytesToCopy
+ assert(finalPos == expectedPos,
+ s"""
+ |Current position $finalPos do not equal to expected position $expectedPos
+ |after transferTo, please check your kernel version to see if it is 2.6.32,
+ |this is a kernel bug which will lead to unexpected behavior when using transferTo.
+ |You can set spark.file.transferTo = false to disable this NIO feature.
+ """.stripMargin)
+ }
+
/**
* Construct a URI container information used for authentication.
* This also sets the default authenticator to properly negotiation the
@@ -317,6 +407,30 @@ private[spark] object Utils extends Logging {
}
/**
+ * A file name may contain some invalid URI characters, such as " ". This method will convert the
+ * file name to a raw path accepted by `java.net.URI(String)`.
+ *
+ * Note: the file name must not contain "/" or "\"
+ */
+ def encodeFileNameToURIRawPath(fileName: String): String = {
+ require(!fileName.contains("/") && !fileName.contains("\\"))
+ // `file` and `localhost` are not used. Just to prevent URI from parsing `fileName` as
+ // scheme or host. The prefix "/" is required because URI doesn't accept a relative path.
+ // We should remove it after we get the raw path.
+ new URI("file", null, "localhost", -1, "/" + fileName, null, null).getRawPath.substring(1)
+ }
+
+ /**
+ * Get the file name from uri's raw path and decode it. If the raw path of uri ends with "/",
+ * return the name before the last "/".
+ */
+ def decodeFileNameInURI(uri: URI): String = {
+ val rawPath = uri.getRawPath
+ val rawFileName = rawPath.split("/").last
+ new URI("file:///" + rawFileName).getPath.substring(1)
+ }
+
+ /**
* Download a file or directory to target directory. Supports fetching the file in a variety of
* ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based
* on the URL parameter. Fetching directories is only supported from Hadoop-compatible
@@ -337,7 +451,7 @@ private[spark] object Utils extends Logging {
hadoopConf: Configuration,
timestamp: Long,
useCache: Boolean) {
- val fileName = url.split("/").last
+ val fileName = decodeFileNameInURI(new URI(url))
val targetFile = new File(targetDir, fileName)
val fetchCacheEnabled = conf.getBoolean("spark.files.useFetchCache", defaultValue = true)
if (useCache && fetchCacheEnabled) {
@@ -479,7 +593,7 @@ private[spark] object Utils extends Logging {
// The file does not exist in the target directory. Copy or move it there.
if (removeSourceFile) {
- Files.move(sourceFile, destFile)
+ Files.move(sourceFile.toPath, destFile.toPath)
} else {
logInfo(s"Copying ${sourceFile.getAbsolutePath} to ${destFile.getAbsolutePath}")
copyRecursive(sourceFile, destFile)
@@ -497,7 +611,7 @@ private[spark] object Utils extends Logging {
case (f1, f2) => filesEqualRecursive(f1, f2)
}
} else if (file1.isFile && file2.isFile) {
- Files.equal(file1, file2)
+ GFiles.equal(file1, file2)
} else {
false
}
@@ -511,7 +625,7 @@ private[spark] object Utils extends Logging {
val subfiles = source.listFiles()
subfiles.foreach(f => copyRecursive(f, new File(dest, f.getName)))
} else {
- Files.copy(source, dest)
+ Files.copy(source.toPath, dest.toPath)
}
}
@@ -535,6 +649,14 @@ private[spark] object Utils extends Logging {
val uri = new URI(url)
val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false)
Option(uri.getScheme).getOrElse("file") match {
+ case "spark" =>
+ if (SparkEnv.get == null) {
+ throw new IllegalStateException(
+ "Cannot retrieve files with 'spark' scheme without an active SparkEnv.")
+ }
+ val source = SparkEnv.get.rpcEnv.openChannel(url)
+ val is = Channels.newInputStream(source)
+ downloadFile(url, is, targetFile, fileOverwrite)
case "http" | "https" | "ftp" =>
var uc: URLConnection = null
if (securityMgr.isAuthenticationEnabled()) {
@@ -599,6 +721,26 @@ private[spark] object Utils extends Logging {
}
}
+ /**
+ * Validate that a given URI is actually a valid URL as well.
+ * @param uri The URI to validate
+ */
+ @throws[MalformedURLException]("when the URI is an invalid URL")
+ def validateURL(uri: URI): Unit = {
+ Option(uri.getScheme).getOrElse("file") match {
+ case "http" | "https" | "ftp" =>
+ try {
+ uri.toURL
+ } catch {
+ case e: MalformedURLException =>
+ val ex = new MalformedURLException(s"URI (${uri.toString}) is not a valid URL.")
+ ex.initCause(e)
+ throw ex
+ }
+ case _ => // will not be turned into a URL anyway
+ }
+ }
+
/**
* Get the path of a temporary directory. Spark's local directories can be configured through
* multiple settings, which are used with the following precedence:
@@ -612,14 +754,16 @@ private[spark] object Utils extends Logging {
* always return a single directory.
*/
def getLocalDir(conf: SparkConf): String = {
- getOrCreateLocalRootDirs(conf)(0)
+ getOrCreateLocalRootDirs(conf).headOption.getOrElse {
+ val configuredLocalDirs = getConfiguredLocalDirs(conf)
+ throw new IOException(
+ s"Failed to get a temp directory under [${configuredLocalDirs.mkString(",")}].")
+ }
}
private[spark] def isRunningInYarnContainer(conf: SparkConf): Boolean = {
// These environment variables are set by YARN.
- // For Hadoop 0.23.X, we check for YARN_LOCAL_DIRS (we use this below in getYarnLocalDirs())
- // For Hadoop 2.X, we check for CONTAINER_ID.
- conf.getenv("CONTAINER_ID") != null || conf.getenv("YARN_LOCAL_DIRS") != null
+ conf.getenv("CONTAINER_ID") != null
}
/**
@@ -695,17 +839,12 @@ private[spark] object Utils extends Logging {
logError(s"Failed to create local root dir in $root. Ignoring this directory.")
None
}
- }.toArray
+ }
}
/** Get the Yarn approved local directories. */
private def getYarnLocalDirs(conf: SparkConf): String = {
- // Hadoop 0.23 and 2.x have different Environment variable names for the
- // local dirs, so lets check both. We assume one of the 2 is set.
- // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X
- val localDirs = Option(conf.getenv("YARN_LOCAL_DIRS"))
- .getOrElse(Option(conf.getenv("LOCAL_DIRS"))
- .getOrElse(""))
+ val localDirs = Option(conf.getenv("LOCAL_DIRS")).getOrElse("")
if (localDirs.isEmpty) {
throw new Exception("Yarn Local dirs can't be empty")
@@ -733,7 +872,7 @@ private[spark] object Utils extends Logging {
*/
def randomizeInPlace[T](arr: Array[T], rand: Random = new Random): Array[T] = {
for (i <- (arr.length - 1) to 1 by -1) {
- val j = rand.nextInt(i)
+ val j = rand.nextInt(i + 1)
val tmp = arr(j)
arr(j) = arr(i)
arr(i) = tmp
@@ -899,28 +1038,11 @@ private[spark] object Utils extends Logging {
}
}
- /**
- * Delete a file or directory and its contents recursively.
- */
- def deleteRecursively(dir: TachyonFile, client: TachyonFS) {
- if (!client.delete(new TachyonURI(dir.getPath()), true)) {
- throw new IOException("Failed to delete the tachyon dir: " + dir)
- }
- }
-
/**
* Check to see if file is a symbolic link.
*/
def isSymlink(file: File): Boolean = {
- if (file == null) throw new NullPointerException("File must not be null")
- if (isWindows) return false
- val fileInCanonicalDir = if (file.getParent() == null) {
- file
- } else {
- new File(file.getParentFile().getCanonicalFile(), file.getName())
- }
-
- !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile())
+ return Files.isSymbolicLink(Paths.get(file.toURI))
}
/**
@@ -1007,26 +1129,39 @@ private[spark] object Utils extends Logging {
/**
* Convert a quantity in bytes to a human-readable string such as "4.0 MB".
*/
- def bytesToString(size: Long): String = {
+ def bytesToString(size: Long): String = bytesToString(BigInt(size))
+
+ def bytesToString(size: BigInt): String = {
+ val EB = 1L << 60
+ val PB = 1L << 50
val TB = 1L << 40
val GB = 1L << 30
val MB = 1L << 20
val KB = 1L << 10
- val (value, unit) = {
- if (size >= 2*TB) {
- (size.asInstanceOf[Double] / TB, "TB")
- } else if (size >= 2*GB) {
- (size.asInstanceOf[Double] / GB, "GB")
- } else if (size >= 2*MB) {
- (size.asInstanceOf[Double] / MB, "MB")
- } else if (size >= 2*KB) {
- (size.asInstanceOf[Double] / KB, "KB")
- } else {
- (size.asInstanceOf[Double], "B")
+ if (size >= BigInt(1L << 11) * EB) {
+ // The number is too large, show it in scientific notation.
+ BigDecimal(size, new MathContext(3, RoundingMode.HALF_UP)).toString() + " B"
+ } else {
+ val (value, unit) = {
+ if (size >= 2 * EB) {
+ (BigDecimal(size) / EB, "EB")
+ } else if (size >= 2 * PB) {
+ (BigDecimal(size) / PB, "PB")
+ } else if (size >= 2 * TB) {
+ (BigDecimal(size) / TB, "TB")
+ } else if (size >= 2 * GB) {
+ (BigDecimal(size) / GB, "GB")
+ } else if (size >= 2 * MB) {
+ (BigDecimal(size) / MB, "MB")
+ } else if (size >= 2 * KB) {
+ (BigDecimal(size) / KB, "KB")
+ } else {
+ (BigDecimal(size), "B")
+ }
}
+ "%.1f %s".formatLocal(Locale.US, value, unit)
}
- "%.1f %s".formatLocal(Locale.US, value, unit)
}
/**
@@ -1087,9 +1222,9 @@ private[spark] object Utils extends Logging {
extraEnvironment: Map[String, String] = Map.empty,
redirectStderr: Boolean = true): String = {
val process = executeCommand(command, workingDir, extraEnvironment, redirectStderr)
- val output = new StringBuffer
+ val output = new StringBuilder
val threadName = "read stdout for " + command(0)
- def appendToOutput(s: String): Unit = output.append(s)
+ def appendToOutput(s: String): Unit = output.append(s).append("\n")
val stdoutThread = processStreamByLine(threadName, process.getInputStream, appendToOutput)
val exitCode = process.waitFor()
stdoutThread.join() // Wait for it to finish reading output
@@ -1135,7 +1270,7 @@ private[spark] object Utils extends Logging {
}
/**
- * Execute a block of code that evaluates to Unit, stop SparkContext is there is any uncaught
+ * Execute a block of code that evaluates to Unit, stop SparkContext if there is any uncaught
* exception
*
* NOTE: This method is to be called by the driver-side components to avoid stopping the
@@ -1151,7 +1286,7 @@ private[spark] object Utils extends Logging {
val currentThreadName = Thread.currentThread().getName
if (sc != null) {
logError(s"uncaught error in thread $currentThreadName, stopping SparkContext", t)
- sc.stop()
+ sc.stopInNewThread()
}
if (!NonFatal(t)) {
logError(s"throw uncaught fatal error in thread $currentThreadName", t)
@@ -1160,21 +1295,6 @@ private[spark] object Utils extends Logging {
}
}
- /**
- * Execute a block of code that evaluates to Unit, re-throwing any non-fatal uncaught
- * exceptions as IOException. This is used when implementing Externalizable and Serializable's
- * read and write methods, since Java's serializer will not report non-IOExceptions properly;
- * see SPARK-4080 for more context.
- */
- def tryOrIOException(block: => Unit) {
- try {
- block
- } catch {
- case e: IOException => throw e
- case NonFatal(t) => throw new IOException(t)
- }
- }
-
/**
* Execute a block of code that returns a value, re-throwing any non-fatal uncaught
* exceptions as IOException. This is used when implementing Externalizable and Serializable's
@@ -1185,8 +1305,12 @@ private[spark] object Utils extends Logging {
try {
block
} catch {
- case e: IOException => throw e
- case NonFatal(t) => throw new IOException(t)
+ case e: IOException =>
+ logError("Exception encountered", e)
+ throw e
+ case NonFatal(e) =>
+ logError("Exception encountered", e)
+ throw new IOException(e)
}
}
@@ -1211,7 +1335,6 @@ private[spark] object Utils extends Logging {
* exception from the original `out.write` call.
*/
def tryWithSafeFinally[T](block: => T)(finallyBlock: => Unit): T = {
- // It would be nice to find a method on Try that did this
var originalThrowable: Throwable = null
try {
block
@@ -1225,14 +1348,55 @@ private[spark] object Utils extends Logging {
try {
finallyBlock
} catch {
- case t: Throwable =>
- if (originalThrowable != null) {
- originalThrowable.addSuppressed(t)
- logWarning(s"Suppressing exception in finally: " + t.getMessage, t)
- throw originalThrowable
- } else {
- throw t
- }
+ case t: Throwable if (originalThrowable != null && originalThrowable != t) =>
+ originalThrowable.addSuppressed(t)
+ logWarning(s"Suppressing exception in finally: ${t.getMessage}", t)
+ throw originalThrowable
+ }
+ }
+ }
+
+ /**
+ * Execute a block of code and call the failure callbacks in the catch block. If exceptions occur
+ * in either the catch or the finally block, they are appended to the list of suppressed
+ * exceptions in original exception which is then rethrown.
+ *
+ * This is primarily an issue with `catch { abort() }` or `finally { out.close() }` blocks,
+ * where the abort/close needs to be called to clean up `out`, but if an exception happened
+ * in `out.write`, it's likely `out` may be corrupted and `abort` or `out.close` will
+ * fail as well. This would then suppress the original/likely more meaningful
+ * exception from the original `out.write` call.
+ */
+ def tryWithSafeFinallyAndFailureCallbacks[T](block: => T)
+ (catchBlock: => Unit = (), finallyBlock: => Unit = ()): T = {
+ var originalThrowable: Throwable = null
+ try {
+ block
+ } catch {
+ case cause: Throwable =>
+ // Purposefully not using NonFatal, because even fatal exceptions
+ // we don't want to have our finallyBlock suppress
+ originalThrowable = cause
+ try {
+ logError("Aborting task", originalThrowable)
+ TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(originalThrowable)
+ catchBlock
+ } catch {
+ case t: Throwable =>
+ if (originalThrowable != t) {
+ originalThrowable.addSuppressed(t)
+ logWarning(s"Suppressing exception in catch: ${t.getMessage}", t)
+ }
+ }
+ throw originalThrowable
+ } finally {
+ try {
+ finallyBlock
+ } catch {
+ case t: Throwable if (originalThrowable != null && originalThrowable != t) =>
+ originalThrowable.addSuppressed(t)
+ logWarning(s"Suppressing exception in finally: ${t.getMessage}", t)
+ throw originalThrowable
}
}
}
@@ -1286,8 +1450,12 @@ private[spark] object Utils extends Logging {
}
callStack(0) = ste.toString // Put last Spark method on top of the stack trace.
} else {
- firstUserLine = ste.getLineNumber
- firstUserFile = ste.getFileName
+ if (ste.getFileName != null) {
+ firstUserFile = ste.getFileName
+ if (ste.getLineNumber >= 0) {
+ firstUserLine = ste.getLineNumber
+ }
+ }
callStack += ste.toString
insideSpark = false
}
@@ -1311,14 +1479,77 @@ private[spark] object Utils extends Logging {
CallSite(shortForm, longForm)
}
+ private val UNCOMPRESSED_LOG_FILE_LENGTH_CACHE_SIZE_CONF =
+ "spark.worker.ui.compressedLogFileLengthCacheSize"
+ private val DEFAULT_UNCOMPRESSED_LOG_FILE_LENGTH_CACHE_SIZE = 100
+ private var compressedLogFileLengthCache: LoadingCache[String, java.lang.Long] = null
+ private def getCompressedLogFileLengthCache(
+ sparkConf: SparkConf): LoadingCache[String, java.lang.Long] = this.synchronized {
+ if (compressedLogFileLengthCache == null) {
+ val compressedLogFileLengthCacheSize = sparkConf.getInt(
+ UNCOMPRESSED_LOG_FILE_LENGTH_CACHE_SIZE_CONF,
+ DEFAULT_UNCOMPRESSED_LOG_FILE_LENGTH_CACHE_SIZE)
+ compressedLogFileLengthCache = CacheBuilder.newBuilder()
+ .maximumSize(compressedLogFileLengthCacheSize)
+ .build[String, java.lang.Long](new CacheLoader[String, java.lang.Long]() {
+ override def load(path: String): java.lang.Long = {
+ Utils.getCompressedFileLength(new File(path))
+ }
+ })
+ }
+ compressedLogFileLengthCache
+ }
+
+ /**
+ * Return the file length, if the file is compressed it returns the uncompressed file length.
+ * It also caches the uncompressed file size to avoid repeated decompression. The cache size is
+ * read from workerConf.
+ */
+ def getFileLength(file: File, workConf: SparkConf): Long = {
+ if (file.getName.endsWith(".gz")) {
+ getCompressedLogFileLengthCache(workConf).get(file.getAbsolutePath)
+ } else {
+ file.length
+ }
+ }
+
+ /** Return uncompressed file length of a compressed file. */
+ private def getCompressedFileLength(file: File): Long = {
+ var gzInputStream: GZIPInputStream = null
+ try {
+ // Uncompress .gz file to determine file size.
+ var fileSize = 0L
+ gzInputStream = new GZIPInputStream(new FileInputStream(file))
+ val bufSize = 1024
+ val buf = new Array[Byte](bufSize)
+ var numBytes = ByteStreams.read(gzInputStream, buf, 0, bufSize)
+ while (numBytes > 0) {
+ fileSize += numBytes
+ numBytes = ByteStreams.read(gzInputStream, buf, 0, bufSize)
+ }
+ fileSize
+ } catch {
+ case e: Throwable =>
+ logError(s"Cannot get file length of ${file}", e)
+ throw e
+ } finally {
+ if (gzInputStream != null) {
+ gzInputStream.close()
+ }
+ }
+ }
+
/** Return a string containing part of a file from byte 'start' to 'end'. */
- def offsetBytes(path: String, start: Long, end: Long): String = {
+ def offsetBytes(path: String, length: Long, start: Long, end: Long): String = {
val file = new File(path)
- val length = file.length()
val effectiveEnd = math.min(length, end)
val effectiveStart = math.max(0, start)
val buff = new Array[Byte]((effectiveEnd-effectiveStart).toInt)
- val stream = new FileInputStream(file)
+ val stream = if (path.endsWith(".gz")) {
+ new GZIPInputStream(new FileInputStream(file))
+ } else {
+ new FileInputStream(file)
+ }
try {
ByteStreams.skipFully(stream, effectiveStart)
@@ -1334,8 +1565,8 @@ private[spark] object Utils extends Logging {
* and `endIndex` is based on the cumulative size of all the files take in
* the given order. See figure below for more details.
*/
- def offsetBytes(files: Seq[File], start: Long, end: Long): String = {
- val fileLengths = files.map { _.length }
+ def offsetBytes(files: Seq[File], fileLengths: Seq[Long], start: Long, end: Long): String = {
+ assert(files.length == fileLengths.length)
val startIndex = math.max(start, 0)
val endIndex = math.min(end, fileLengths.sum)
val fileToLength = files.zip(fileLengths).toMap
@@ -1343,7 +1574,7 @@ private[spark] object Utils extends Logging {
val stringBuffer = new StringBuffer((endIndex - startIndex).toInt)
var sum = 0L
- for (file <- files) {
+ files.zip(fileLengths).foreach { case (file, fileLength) =>
val startIndexOfFile = sum
val endIndexOfFile = sum + fileToLength(file)
logDebug(s"Processing file $file, " +
@@ -1362,19 +1593,19 @@ private[spark] object Utils extends Logging {
if (startIndex <= startIndexOfFile && endIndex >= endIndexOfFile) {
// Case C: read the whole file
- stringBuffer.append(offsetBytes(file.getAbsolutePath, 0, fileToLength(file)))
+ stringBuffer.append(offsetBytes(file.getAbsolutePath, fileLength, 0, fileToLength(file)))
} else if (startIndex > startIndexOfFile && startIndex < endIndexOfFile) {
// Case A and B: read from [start of required range] to [end of file / end of range]
val effectiveStartIndex = startIndex - startIndexOfFile
val effectiveEndIndex = math.min(endIndex - startIndexOfFile, fileToLength(file))
stringBuffer.append(Utils.offsetBytes(
- file.getAbsolutePath, effectiveStartIndex, effectiveEndIndex))
+ file.getAbsolutePath, fileLength, effectiveStartIndex, effectiveEndIndex))
} else if (endIndex > startIndexOfFile && endIndex < endIndexOfFile) {
// Case D: read from [start of file] to [end of require range]
val effectiveStartIndex = math.max(startIndex - startIndexOfFile, 0)
val effectiveEndIndex = endIndex - startIndexOfFile
stringBuffer.append(Utils.offsetBytes(
- file.getAbsolutePath, effectiveStartIndex, effectiveEndIndex))
+ file.getAbsolutePath, fileLength, effectiveStartIndex, effectiveEndIndex))
}
sum += fileToLength(file)
logDebug(s"After processing file $file, string built is ${stringBuffer.toString}")
@@ -1461,7 +1692,7 @@ private[spark] object Utils extends Logging {
rawMod + (if (rawMod < 0) mod else 0)
}
- // Handles idiosyncracies with hash (add more as required)
+ // Handles idiosyncrasies with hash (add more as required)
// This method should be kept in sync with
// org.apache.spark.network.util.JavaUtils#nonNegativeHash().
def nonNegativeHash(obj: AnyRef): Int = {
@@ -1478,8 +1709,8 @@ private[spark] object Utils extends Logging {
}
/**
- * NaN-safe version of [[java.lang.Double.compare()]] which allows NaN values to be compared
- * according to semantics where NaN == NaN and NaN > any non-NaN double.
+ * NaN-safe version of `java.lang.Double.compare()` which allows NaN values to be compared
+ * according to semantics where NaN == NaN and NaN is greater than any non-NaN double.
*/
def nanSafeCompareDoubles(x: Double, y: Double): Int = {
val xIsNan: Boolean = java.lang.Double.isNaN(x)
@@ -1492,8 +1723,8 @@ private[spark] object Utils extends Logging {
}
/**
- * NaN-safe version of [[java.lang.Float.compare()]] which allows NaN values to be compared
- * according to semantics where NaN == NaN and NaN > any non-NaN float.
+ * NaN-safe version of `java.lang.Float.compare()` which allows NaN values to be compared
+ * according to semantics where NaN == NaN and NaN is greater than any non-NaN float.
*/
def nanSafeCompareFloats(x: Float, y: Float): Int = {
val xIsNan: Boolean = java.lang.Float.isNaN(x)
@@ -1505,9 +1736,11 @@ private[spark] object Utils extends Logging {
else -1
}
- /** Returns the system properties map that is thread-safe to iterator over. It gets the
- * properties which have been set explicitly, as well as those for which only a default value
- * has been defined. */
+ /**
+ * Returns the system properties map that is thread-safe to iterator over. It gets the
+ * properties which have been set explicitly, as well as those for which only a default value
+ * has been defined.
+ */
def getSystemProperties: Map[String, String] = {
System.getProperties.stringPropertyNames().asScala
.map(key => (key, System.getProperty(key))).toMap
@@ -1527,11 +1760,12 @@ private[spark] object Utils extends Logging {
/**
* Timing method based on iterations that permit JVM JIT optimization.
+ *
* @param numIters number of iterations
* @param f function to be executed. If prepare is not None, the running time of each call to f
* must be an order of magnitude longer than one millisecond for accurate timing.
* @param prepare function to be executed before each call to f. Its running time doesn't count.
- * @return the total time across all iterations (not couting preparation time)
+ * @return the total time across all iterations (not counting preparation time)
*/
def timeIt(numIters: Int)(f: => Unit, prepare: Option[() => Unit] = None): Long = {
if (prepare.isEmpty) {
@@ -1567,30 +1801,35 @@ private[spark] object Utils extends Logging {
}
/**
- * Creates a symlink. Note jdk1.7 has Files.createSymbolicLink but not used here
- * for jdk1.6 support. Supports windows by doing copy, everything else uses "ln -sf".
+ * Generate a zipWithIndex iterator, avoid index value overflowing problem
+ * in scala's zipWithIndex
+ */
+ def getIteratorZipWithIndex[T](iterator: Iterator[T], startIndex: Long): Iterator[(T, Long)] = {
+ new Iterator[(T, Long)] {
+ require(startIndex >= 0, "startIndex should be >= 0.")
+ var index: Long = startIndex - 1L
+ def hasNext: Boolean = iterator.hasNext
+ def next(): (T, Long) = {
+ index += 1L
+ (iterator.next(), index)
+ }
+ }
+ }
+
+ /**
+ * Creates a symlink.
+ *
* @param src absolute path to the source
* @param dst relative path for the destination
*/
- def symlink(src: File, dst: File) {
+ def symlink(src: File, dst: File): Unit = {
if (!src.isAbsolute()) {
throw new IOException("Source must be absolute")
}
if (dst.isAbsolute()) {
throw new IOException("Destination must be relative")
}
- var cmdSuffix = ""
- val linkCmd = if (isWindows) {
- // refer to http://technet.microsoft.com/en-us/library/cc771254.aspx
- cmdSuffix = " /s /e /k /h /y /i"
- "cmd /c xcopy "
- } else {
- cmdSuffix = ""
- "ln -sf "
- }
- import scala.sys.process._
- (linkCmd + src.getAbsolutePath() + " " + dst.getPath() + cmdSuffix) lines_!
- ProcessLogger(line => logInfo(line))
+ Files.createSymbolicLink(dst.toPath, src.toPath)
}
@@ -1663,26 +1902,30 @@ private[spark] object Utils extends Logging {
}
/**
- * Wait for a process to terminate for at most the specified duration.
- * Return whether the process actually terminated after the given timeout.
- */
- def waitForProcess(process: Process, timeoutMs: Long): Boolean = {
- var terminated = false
- val startTime = System.currentTimeMillis
- while (!terminated) {
+ * Terminates a process waiting for at most the specified duration.
+ *
+ * @return the process exit value if it was successfully terminated, else None
+ */
+ def terminateProcess(process: Process, timeoutMs: Long): Option[Int] = {
+ // Politely destroy first
+ process.destroy()
+ if (process.waitFor(timeoutMs, TimeUnit.MILLISECONDS)) {
+ // Successful exit
+ Option(process.exitValue())
+ } else {
try {
- process.exitValue()
- terminated = true
+ process.destroyForcibly()
} catch {
- case e: IllegalThreadStateException =>
- // Process not terminated yet
- if (System.currentTimeMillis - startTime > timeoutMs) {
- return false
- }
- Thread.sleep(100)
+ case NonFatal(e) => logWarning("Exception when attempting to kill process", e)
+ }
+ // Wait, again, although this really should return almost immediately
+ if (process.waitFor(timeoutMs, TimeUnit.MILLISECONDS)) {
+ Option(process.exitValue())
+ } else {
+ logWarning("Timed out waiting to forcibly kill process")
+ None
}
}
- true
}
/**
@@ -1690,7 +1933,7 @@ private[spark] object Utils extends Logging {
* If the process does not terminate within the specified timeout, return None.
*/
def getStderr(process: Process, timeoutMs: Long): Option[String] = {
- val terminated = Utils.waitForProcess(process, timeoutMs)
+ val terminated = process.waitFor(timeoutMs, TimeUnit.MILLISECONDS)
if (terminated) {
Some(Source.fromInputStream(process.getErrorStream).getLines().mkString("\n"))
} else {
@@ -1732,22 +1975,17 @@ private[spark] object Utils extends Logging {
/** Returns true if the given exception was fatal. See docs for scala.util.control.NonFatal. */
def isFatalError(e: Throwable): Boolean = {
e match {
- case NonFatal(_) | _: InterruptedException | _: NotImplementedError | _: ControlThrowable =>
+ case NonFatal(_) |
+ _: InterruptedException |
+ _: NotImplementedError |
+ _: ControlThrowable |
+ _: LinkageError =>
false
case _ =>
true
}
}
- lazy val isInInterpreter: Boolean = {
- try {
- val interpClass = classForName("org.apache.spark.repl.Main")
- interpClass.getMethod("interp").invoke(null) != null
- } catch {
- case _: ClassNotFoundException => false
- }
- }
-
/**
* Return a well-formed URI for the file described by a user input string.
*
@@ -1778,7 +2016,7 @@ private[spark] object Utils extends Logging {
if (paths == null || paths.trim.isEmpty) {
""
} else {
- paths.split(",").map { p => Utils.resolveURI(p) }.mkString(",")
+ paths.split(",").filter(_.trim.nonEmpty).map { p => Utils.resolveURI(p) }.mkString(",")
}
}
@@ -1818,13 +2056,27 @@ private[spark] object Utils extends Logging {
path
}
+ /**
+ * Updates Spark config with properties from a set of Properties.
+ * Provided properties have the highest priority.
+ */
+ def updateSparkConfigFromProperties(
+ conf: SparkConf,
+ properties: Map[String, String]) : Unit = {
+ properties.filter { case (k, v) =>
+ k.startsWith("spark.")
+ }.foreach { case (k, v) =>
+ conf.set(k, v)
+ }
+ }
+
/** Load properties present in the given file. */
def getPropertiesFromFile(filename: String): Map[String, String] = {
val file = new File(filename)
require(file.exists(), s"Properties file $file does not exist")
require(file.isFile(), s"Properties file $file is not a normal file")
- val inReader = new InputStreamReader(new FileInputStream(file), "UTF-8")
+ val inReader = new InputStreamReader(new FileInputStream(file), StandardCharsets.UTF_8)
try {
val properties = new Properties()
properties.load(inReader)
@@ -1863,18 +2115,62 @@ private[spark] object Utils extends Logging {
}
}
+ private implicit class Lock(lock: LockInfo) {
+ def lockString: String = {
+ lock match {
+ case monitor: MonitorInfo =>
+ s"Monitor(${lock.getClassName}@${lock.getIdentityHashCode}})"
+ case _ =>
+ s"Lock(${lock.getClassName}@${lock.getIdentityHashCode}})"
+ }
+ }
+ }
+
/** Return a thread dump of all threads' stacktraces. Used to capture dumps for the web UI */
def getThreadDump(): Array[ThreadStackTrace] = {
// We need to filter out null values here because dumpAllThreads() may return null array
// elements for threads that are dead / don't exist.
val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null)
- threadInfos.sortBy(_.getThreadId).map { case threadInfo =>
- val stackTrace = threadInfo.getStackTrace.map(_.toString).mkString("\n")
- ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName,
- threadInfo.getThreadState, stackTrace)
+ threadInfos.sortBy(_.getThreadId).map(threadInfoToThreadStackTrace)
+ }
+
+ def getThreadDumpForThread(threadId: Long): Option[ThreadStackTrace] = {
+ if (threadId <= 0) {
+ None
+ } else {
+ // The Int.MaxValue here requests the entire untruncated stack trace of the thread:
+ val threadInfo =
+ Option(ManagementFactory.getThreadMXBean.getThreadInfo(threadId, Int.MaxValue))
+ threadInfo.map(threadInfoToThreadStackTrace)
}
}
+ private def threadInfoToThreadStackTrace(threadInfo: ThreadInfo): ThreadStackTrace = {
+ val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame -> m).toMap
+ val stackTrace = threadInfo.getStackTrace.map { frame =>
+ monitors.get(frame) match {
+ case Some(monitor) =>
+ monitor.getLockedStackFrame.toString + s" => holding ${monitor.lockString}"
+ case None =>
+ frame.toString
+ }
+ }.mkString("\n")
+
+ // use a set to dedup re-entrant locks that are held at multiple places
+ val heldLocks =
+ (threadInfo.getLockedSynchronizers ++ threadInfo.getLockedMonitors).map(_.lockString).toSet
+
+ ThreadStackTrace(
+ threadId = threadInfo.getThreadId,
+ threadName = threadInfo.getThreadName,
+ threadState = threadInfo.getThreadState,
+ stackTrace = stackTrace,
+ blockedByThreadId =
+ if (threadInfo.getLockOwnerId < 0) None else Some(threadInfo.getLockOwnerId),
+ blockedByLock = Option(threadInfo.getLockInfo).map(_.lockString).getOrElse(""),
+ holdingLocks = heldLocks.toSeq)
+ }
+
/**
* Convert all spark properties set in the given SparkConf to a sequence of java options.
*/
@@ -1897,6 +2193,14 @@ private[spark] object Utils extends Logging {
}
}
+ /**
+ * Returns the user port to try when trying to bind a service. Handles wrapping and skipping
+ * privileged ports.
+ */
+ def userPort(base: Int, offset: Int): Int = {
+ (base + offset - 1024) % (65536 - 1024) + 1024
+ }
+
/**
* Attempt to start a service on the given port, or fail after a number of attempts.
* Each subsequent attempt uses 1 + the port used in the previous attempt (unless the port is 0).
@@ -1924,8 +2228,7 @@ private[spark] object Utils extends Logging {
val tryPort = if (startPort == 0) {
startPort
} else {
- // If the new port wraps around, do not try a privilege port
- ((startPort + offset - 1024) % (65536 - 1024)) + 1024
+ userPort(startPort, offset)
}
try {
val (service, port) = startService(tryPort)
@@ -1934,15 +2237,32 @@ private[spark] object Utils extends Logging {
} catch {
case e: Exception if isBindCollision(e) =>
if (offset >= maxRetries) {
- val exceptionMessage =
- s"${e.getMessage}: Service$serviceString failed after $maxRetries retries!"
+ val exceptionMessage = if (startPort == 0) {
+ s"${e.getMessage}: Service$serviceString failed after " +
+ s"$maxRetries retries (on a random free port)! " +
+ s"Consider explicitly setting the appropriate binding address for " +
+ s"the service$serviceString (for example spark.driver.bindAddress " +
+ s"for SparkDriver) to the correct binding address."
+ } else {
+ s"${e.getMessage}: Service$serviceString failed after " +
+ s"$maxRetries retries (starting from $startPort)! Consider explicitly setting " +
+ s"the appropriate port for the service$serviceString (for example spark.ui.port " +
+ s"for SparkUI) to an available port or increasing spark.port.maxRetries."
+ }
val exception = new BindException(exceptionMessage)
// restore original stack trace
exception.setStackTrace(e.getStackTrace)
throw exception
}
- logWarning(s"Service$serviceString could not bind on port $tryPort. " +
- s"Attempting port ${tryPort + 1}.")
+ if (startPort == 0) {
+ // As startPort 0 is for a random free port, it is most possibly binding address is
+ // not correct.
+ logWarning(s"Service$serviceString could not bind on a random free port. " +
+ "You may check whether configuring an appropriate binding address.")
+ } else {
+ logWarning(s"Service$serviceString could not bind on port $tryPort. " +
+ s"Attempting port ${tryPort + 1}.")
+ }
}
}
// Should never happen
@@ -1961,6 +2281,9 @@ private[spark] object Utils extends Logging {
isBindCollision(e.getCause)
case e: MultiException =>
e.getThrowables.asScala.exists(isBindCollision)
+ case e: NativeIoException =>
+ (e.getMessage != null && e.getMessage.startsWith("bind() failed: ")) ||
+ isBindCollision(e.getCause)
case e: Exception => isBindCollision(e.getCause)
case _ => false
}
@@ -2071,8 +2394,9 @@ private[spark] object Utils extends Logging {
* A spark url (`spark://host:port`) is a special URI that its scheme is `spark` and only contains
* host and port.
*
- * @throws SparkException if `sparkUrl` is invalid.
+ * @throws org.apache.spark.SparkException if sparkUrl is invalid.
*/
+ @throws(classOf[SparkException])
def extractHostPortFromSparkUrl(sparkUrl: String): (String, Int) = {
try {
val uri = new java.net.URI(sparkUrl)
@@ -2103,6 +2427,25 @@ private[spark] object Utils extends Logging {
.getOrElse(UserGroupInformation.getCurrentUser().getShortUserName())
}
+ val EMPTY_USER_GROUPS = Set[String]()
+
+ // Returns the groups to which the current user belongs.
+ def getCurrentUserGroups(sparkConf: SparkConf, username: String): Set[String] = {
+ val groupProviderClassName = sparkConf.get("spark.user.groups.mapping",
+ "org.apache.spark.security.ShellBasedGroupsMappingProvider")
+ if (groupProviderClassName != "") {
+ try {
+ val groupMappingServiceProvider = classForName(groupProviderClassName).newInstance.
+ asInstanceOf[org.apache.spark.security.GroupMappingServiceProvider]
+ val currentUserGroups = groupMappingServiceProvider.getGroups(username)
+ return currentUserGroups
+ } catch {
+ case e: Exception => logError(s"Error getting groups for user=$username", e)
+ }
+ }
+ EMPTY_USER_GROUPS
+ }
+
/**
* Split the comma delimited string of master URLs into a list.
* For instance, "spark://abc,def" becomes [spark://abc, spark://def].
@@ -2140,6 +2483,7 @@ private[spark] object Utils extends Logging {
/**
* Return whether the specified file is a parent directory of the child file.
*/
+ @tailrec
def isInDirectory(parent: File, child: File): Boolean = {
if (child == null || parent == null) {
return false
@@ -2153,21 +2497,282 @@ private[spark] object Utils extends Logging {
isInDirectory(parent, child.getParentFile)
}
+
+ /**
+ *
+ * @return whether it is local mode
+ */
+ def isLocalMaster(conf: SparkConf): Boolean = {
+ val master = conf.get("spark.master", "")
+ master == "local" || master.startsWith("local[")
+ }
+
/**
- * Return whether dynamic allocation is enabled in the given conf
- * Dynamic allocation and explicitly setting the number of executors are inherently
- * incompatible. In environments where dynamic allocation is turned on by default,
- * the latter should override the former (SPARK-9092).
+ * Return whether dynamic allocation is enabled in the given conf.
*/
def isDynamicAllocationEnabled(conf: SparkConf): Boolean = {
- conf.getBoolean("spark.dynamicAllocation.enabled", false) &&
- conf.getInt("spark.executor.instances", 0) == 0
+ val dynamicAllocationEnabled = conf.getBoolean("spark.dynamicAllocation.enabled", false)
+ dynamicAllocationEnabled &&
+ (!isLocalMaster(conf) || conf.getBoolean("spark.dynamicAllocation.testing", false))
+ }
+
+ /**
+ * Return the initial number of executors for dynamic allocation.
+ */
+ def getDynamicAllocationInitialExecutors(conf: SparkConf): Int = {
+ if (conf.get(DYN_ALLOCATION_INITIAL_EXECUTORS) < conf.get(DYN_ALLOCATION_MIN_EXECUTORS)) {
+ logWarning(s"${DYN_ALLOCATION_INITIAL_EXECUTORS.key} less than " +
+ s"${DYN_ALLOCATION_MIN_EXECUTORS.key} is invalid, ignoring its setting, " +
+ "please update your configs.")
+ }
+
+ if (conf.get(EXECUTOR_INSTANCES).getOrElse(0) < conf.get(DYN_ALLOCATION_MIN_EXECUTORS)) {
+ logWarning(s"${EXECUTOR_INSTANCES.key} less than " +
+ s"${DYN_ALLOCATION_MIN_EXECUTORS.key} is invalid, ignoring its setting, " +
+ "please update your configs.")
+ }
+
+ val initialExecutors = Seq(
+ conf.get(DYN_ALLOCATION_MIN_EXECUTORS),
+ conf.get(DYN_ALLOCATION_INITIAL_EXECUTORS),
+ conf.get(EXECUTOR_INSTANCES).getOrElse(0)).max
+
+ logInfo(s"Using initial executors = $initialExecutors, max of " +
+ s"${DYN_ALLOCATION_INITIAL_EXECUTORS.key}, ${DYN_ALLOCATION_MIN_EXECUTORS.key} and " +
+ s"${EXECUTOR_INSTANCES.key}")
+ initialExecutors
}
def tryWithResource[R <: Closeable, T](createResource: => R)(f: R => T): T = {
val resource = createResource
try f.apply(resource) finally resource.close()
}
+
+ /**
+ * Returns a path of temporary file which is in the same directory with `path`.
+ */
+ def tempFileWith(path: File): File = {
+ new File(path.getAbsolutePath + "." + UUID.randomUUID())
+ }
+
+ /**
+ * Returns the name of this JVM process. This is OS dependent but typically (OSX, Linux, Windows),
+ * this is formatted as PID@hostname.
+ */
+ def getProcessName(): String = {
+ ManagementFactory.getRuntimeMXBean().getName()
+ }
+
+ /**
+ * Utility function that should be called early in `main()` for daemons to set up some common
+ * diagnostic state.
+ */
+ def initDaemon(log: Logger): Unit = {
+ log.info(s"Started daemon with process name: ${Utils.getProcessName()}")
+ SignalUtils.registerLogger(log)
+ }
+
+ /**
+ * Unions two comma-separated lists of files and filters out empty strings.
+ */
+ def unionFileLists(leftList: Option[String], rightList: Option[String]): Set[String] = {
+ var allFiles = Set[String]()
+ leftList.foreach { value => allFiles ++= value.split(",") }
+ rightList.foreach { value => allFiles ++= value.split(",") }
+ allFiles.filter { _.nonEmpty }
+ }
+
+ /**
+ * Return the jar files pointed by the "spark.jars" property. Spark internally will distribute
+ * these jars through file server. In the YARN mode, it will return an empty list, since YARN
+ * has its own mechanism to distribute jars.
+ */
+ def getUserJars(conf: SparkConf): Seq[String] = {
+ val sparkJars = conf.getOption("spark.jars")
+ sparkJars.map(_.split(",")).map(_.filter(_.nonEmpty)).toSeq.flatten
+ }
+
+ /**
+ * Return the local jar files which will be added to REPL's classpath. These jar files are
+ * specified by --jars (spark.jars) or --packages, remote jars will be downloaded to local by
+ * SparkSubmit at first.
+ */
+ def getLocalUserJarsForShell(conf: SparkConf): Seq[String] = {
+ val localJars = conf.getOption("spark.repl.local.jars")
+ localJars.map(_.split(",")).map(_.filter(_.nonEmpty)).toSeq.flatten
+ }
+
+ private[spark] val REDACTION_REPLACEMENT_TEXT = "*********(redacted)"
+
+ /**
+ * Redact the sensitive values in the given map. If a map key matches the redaction pattern then
+ * its value is replaced with a dummy text.
+ */
+ def redact(conf: SparkConf, kvs: Seq[(String, String)]): Seq[(String, String)] = {
+ val redactionPattern = conf.get(SECRET_REDACTION_PATTERN)
+ redact(redactionPattern, kvs)
+ }
+
+ /**
+ * Redact the sensitive information in the given string.
+ */
+ def redact(conf: SparkConf, text: String): String = {
+ if (text == null || text.isEmpty || conf == null || !conf.contains(STRING_REDACTION_PATTERN)) {
+ text
+ } else {
+ val regex = conf.get(STRING_REDACTION_PATTERN).get
+ regex.replaceAllIn(text, REDACTION_REPLACEMENT_TEXT)
+ }
+ }
+
+ /**
+ * Redact the sensitive values in the given map. If a map key matches the redaction pattern then
+ * its value is replaced with a dummy text.
+ */
+ def redact(regex: Option[Regex], kvs: Seq[(String, String)]): Seq[(String, String)] = {
+ regex match {
+ case None => kvs
+ case Some(r) => redact(r, kvs)
+ }
+ }
+
+ private def redact(redactionPattern: Regex, kvs: Seq[(String, String)]): Seq[(String, String)] = {
+ // If the sensitive information regex matches with either the key or the value, redact the value
+ // While the original intent was to only redact the value if the key matched with the regex,
+ // we've found that especially in verbose mode, the value of the property may contain sensitive
+ // information like so:
+ // "sun.java.command":"org.apache.spark.deploy.SparkSubmit ... \
+ // --conf spark.executorEnv.HADOOP_CREDSTORE_PASSWORD=secret_password ...
+ //
+ // And, in such cases, simply searching for the sensitive information regex in the key name is
+ // not sufficient. The values themselves have to be searched as well and redacted if matched.
+ // This does mean we may be accounting more false positives - for example, if the value of an
+ // arbitrary property contained the term 'password', we may redact the value from the UI and
+ // logs. In order to work around it, user would have to make the spark.redaction.regex property
+ // more specific.
+ kvs.map { case (key, value) =>
+ redactionPattern.findFirstIn(key)
+ .orElse(redactionPattern.findFirstIn(value))
+ .map { _ => (key, REDACTION_REPLACEMENT_TEXT) }
+ .getOrElse((key, value))
+ }
+ }
+
+ /**
+ * Looks up the redaction regex from within the key value pairs and uses it to redact the rest
+ * of the key value pairs. No care is taken to make sure the redaction property itself is not
+ * redacted. So theoretically, the property itself could be configured to redact its own value
+ * when printing.
+ */
+ def redact(kvs: Map[String, String]): Seq[(String, String)] = {
+ val redactionPattern = kvs.getOrElse(
+ SECRET_REDACTION_PATTERN.key,
+ SECRET_REDACTION_PATTERN.defaultValueString
+ ).r
+ redact(redactionPattern, kvs.toArray)
+ }
+
+ def createSecret(conf: SparkConf): String = {
+ val bits = conf.get(AUTH_SECRET_BIT_LENGTH)
+ val rnd = new SecureRandom()
+ val secretBytes = new Array[Byte](bits / JByte.SIZE)
+ rnd.nextBytes(secretBytes)
+ HashCodes.fromBytes(secretBytes).toString()
+ }
+
+}
+
+private[util] object CallerContext extends Logging {
+ val callerContextSupported: Boolean = {
+ SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", false) && {
+ try {
+ Utils.classForName("org.apache.hadoop.ipc.CallerContext")
+ Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder")
+ true
+ } catch {
+ case _: ClassNotFoundException =>
+ false
+ case NonFatal(e) =>
+ logWarning("Fail to load the CallerContext class", e)
+ false
+ }
+ }
+ }
+}
+
+/**
+ * An utility class used to set up Spark caller contexts to HDFS and Yarn. The `context` will be
+ * constructed by parameters passed in.
+ * When Spark applications run on Yarn and HDFS, its caller contexts will be written into Yarn RM
+ * audit log and hdfs-audit.log. That can help users to better diagnose and understand how
+ * specific applications impacting parts of the Hadoop system and potential problems they may be
+ * creating (e.g. overloading NN). As HDFS mentioned in HDFS-9184, for a given HDFS operation, it's
+ * very helpful to track which upper level job issues it.
+ *
+ * @param from who sets up the caller context (TASK, CLIENT, APPMASTER)
+ *
+ * The parameters below are optional:
+ * @param upstreamCallerContext caller context the upstream application passes in
+ * @param appId id of the app this task belongs to
+ * @param appAttemptId attempt id of the app this task belongs to
+ * @param jobId id of the job this task belongs to
+ * @param stageId id of the stage this task belongs to
+ * @param stageAttemptId attempt id of the stage this task belongs to
+ * @param taskId task id
+ * @param taskAttemptNumber task attempt id
+ */
+private[spark] class CallerContext(
+ from: String,
+ upstreamCallerContext: Option[String] = None,
+ appId: Option[String] = None,
+ appAttemptId: Option[String] = None,
+ jobId: Option[Int] = None,
+ stageId: Option[Int] = None,
+ stageAttemptId: Option[Int] = None,
+ taskId: Option[Long] = None,
+ taskAttemptNumber: Option[Int] = None) extends Logging {
+
+ private val context = prepareContext("SPARK_" +
+ from +
+ appId.map("_" + _).getOrElse("") +
+ appAttemptId.map("_" + _).getOrElse("") +
+ jobId.map("_JId_" + _).getOrElse("") +
+ stageId.map("_SId_" + _).getOrElse("") +
+ stageAttemptId.map("_" + _).getOrElse("") +
+ taskId.map("_TId_" + _).getOrElse("") +
+ taskAttemptNumber.map("_" + _).getOrElse("") +
+ upstreamCallerContext.map("_" + _).getOrElse(""))
+
+ private def prepareContext(context: String): String = {
+ // The default max size of Hadoop caller context is 128
+ lazy val len = SparkHadoopUtil.get.conf.getInt("hadoop.caller.context.max.size", 128)
+ if (context == null || context.length <= len) {
+ context
+ } else {
+ val finalContext = context.substring(0, len)
+ logWarning(s"Truncated Spark caller context from $context to $finalContext")
+ finalContext
+ }
+ }
+
+ /**
+ * Set up the caller context [[context]] by invoking Hadoop CallerContext API of
+ * [[org.apache.hadoop.ipc.CallerContext]], which was added in hadoop 2.8.
+ */
+ def setCurrentContext(): Unit = {
+ if (CallerContext.callerContextSupported) {
+ try {
+ val callerContext = Utils.classForName("org.apache.hadoop.ipc.CallerContext")
+ val builder = Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder")
+ val builderInst = builder.getConstructor(classOf[String]).newInstance(context)
+ val hdfsContext = builder.getMethod("build").invoke(builderInst)
+ callerContext.getMethod("setCurrent", callerContext).invoke(null, hdfsContext)
+ } catch {
+ case NonFatal(e) =>
+ logWarning("Fail to set Spark caller context", e)
+ }
+ }
+ }
}
/**
@@ -2207,29 +2812,24 @@ private[spark] class RedirectThread(
* the toString method.
*/
private[spark] class CircularBuffer(sizeInBytes: Int = 10240) extends java.io.OutputStream {
- var pos: Int = 0
- var buffer = new Array[Int](sizeInBytes)
+ private var pos: Int = 0
+ private var isBufferFull = false
+ private val buffer = new Array[Byte](sizeInBytes)
- def write(i: Int): Unit = {
- buffer(pos) = i
+ def write(input: Int): Unit = {
+ buffer(pos) = input.toByte
pos = (pos + 1) % buffer.length
+ isBufferFull = isBufferFull || (pos == 0)
}
override def toString: String = {
- val (end, start) = buffer.splitAt(pos)
- val input = new java.io.InputStream {
- val iterator = (start ++ end).iterator
-
- def read(): Int = if (iterator.hasNext) iterator.next() else -1
+ if (!isBufferFull) {
+ return new String(buffer, 0, pos, StandardCharsets.UTF_8)
}
- val reader = new BufferedReader(new InputStreamReader(input))
- val stringBuilder = new StringBuilder
- var line = reader.readLine()
- while (line != null) {
- stringBuilder.append(line)
- stringBuilder.append("\n")
- line = reader.readLine()
- }
- stringBuilder.toString()
+
+ val nonCircularBuffer = new Array[Byte](sizeInBytes)
+ System.arraycopy(buffer, pos, nonCircularBuffer, 0, buffer.length - pos)
+ System.arraycopy(buffer, 0, nonCircularBuffer, buffer.length - pos, pos)
+ new String(nonCircularBuffer, StandardCharsets.UTF_8)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala
deleted file mode 100644
index 2ed827eab46d..000000000000
--- a/core/src/main/scala/org/apache/spark/util/Vector.scala
+++ /dev/null
@@ -1,158 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util
-
-import scala.language.implicitConversions
-import scala.util.Random
-
-import org.apache.spark.util.random.XORShiftRandom
-
-@deprecated("Use Vectors.dense from Spark's mllib.linalg package instead.", "1.0.0")
-class Vector(val elements: Array[Double]) extends Serializable {
- def length: Int = elements.length
-
- def apply(index: Int): Double = elements(index)
-
- def + (other: Vector): Vector = {
- if (length != other.length) {
- throw new IllegalArgumentException("Vectors of different length")
- }
- Vector(length, i => this(i) + other(i))
- }
-
- def add(other: Vector): Vector = this + other
-
- def - (other: Vector): Vector = {
- if (length != other.length) {
- throw new IllegalArgumentException("Vectors of different length")
- }
- Vector(length, i => this(i) - other(i))
- }
-
- def subtract(other: Vector): Vector = this - other
-
- def dot(other: Vector): Double = {
- if (length != other.length) {
- throw new IllegalArgumentException("Vectors of different length")
- }
- var ans = 0.0
- var i = 0
- while (i < length) {
- ans += this(i) * other(i)
- i += 1
- }
- ans
- }
-
- /**
- * return (this + plus) dot other, but without creating any intermediate storage
- * @param plus
- * @param other
- * @return
- */
- def plusDot(plus: Vector, other: Vector): Double = {
- if (length != other.length) {
- throw new IllegalArgumentException("Vectors of different length")
- }
- if (length != plus.length) {
- throw new IllegalArgumentException("Vectors of different length")
- }
- var ans = 0.0
- var i = 0
- while (i < length) {
- ans += (this(i) + plus(i)) * other(i)
- i += 1
- }
- ans
- }
-
- def += (other: Vector): Vector = {
- if (length != other.length) {
- throw new IllegalArgumentException("Vectors of different length")
- }
- var i = 0
- while (i < length) {
- elements(i) += other(i)
- i += 1
- }
- this
- }
-
- def addInPlace(other: Vector): Vector = this +=other
-
- def * (scale: Double): Vector = Vector(length, i => this(i) * scale)
-
- def multiply (d: Double): Vector = this * d
-
- def / (d: Double): Vector = this * (1 / d)
-
- def divide (d: Double): Vector = this / d
-
- def unary_- : Vector = this * -1
-
- def sum: Double = elements.reduceLeft(_ + _)
-
- def squaredDist(other: Vector): Double = {
- var ans = 0.0
- var i = 0
- while (i < length) {
- ans += (this(i) - other(i)) * (this(i) - other(i))
- i += 1
- }
- ans
- }
-
- def dist(other: Vector): Double = math.sqrt(squaredDist(other))
-
- override def toString: String = elements.mkString("(", ", ", ")")
-}
-
-object Vector {
- def apply(elements: Array[Double]): Vector = new Vector(elements)
-
- def apply(elements: Double*): Vector = new Vector(elements.toArray)
-
- def apply(length: Int, initializer: Int => Double): Vector = {
- val elements: Array[Double] = Array.tabulate(length)(initializer)
- new Vector(elements)
- }
-
- def zeros(length: Int): Vector = new Vector(new Array[Double](length))
-
- def ones(length: Int): Vector = Vector(length, _ => 1)
-
- /**
- * Creates this [[org.apache.spark.util.Vector]] of given length containing random numbers
- * between 0.0 and 1.0. Optional scala.util.Random number generator can be provided.
- */
- def random(length: Int, random: Random = new XORShiftRandom()): Vector =
- Vector(length, _ => random.nextDouble())
-
- class Multiplier(num: Double) {
- def * (vec: Vector): Vector = vec * num
- }
-
- implicit def doubleToMultiplier(num: Double): Multiplier = new Multiplier(num)
-
- implicit object VectorAccumParam extends org.apache.spark.AccumulatorParam[Vector] {
- def addInPlace(t1: Vector, t2: Vector): Vector = t1 + t2
-
- def zero(initialValue: Vector): Vector = Vector.zeros(initialValue.length)
- }
-
-}
diff --git a/core/src/main/scala/org/apache/spark/util/VersionUtils.scala b/core/src/main/scala/org/apache/spark/util/VersionUtils.scala
new file mode 100644
index 000000000000..828153b86842
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/VersionUtils.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+/**
+ * Utilities for working with Spark version strings
+ */
+private[spark] object VersionUtils {
+
+ private val majorMinorRegex = """^(\d+)\.(\d+)(\..*)?$""".r
+
+ /**
+ * Given a Spark version string, return the major version number.
+ * E.g., for 2.0.1-SNAPSHOT, return 2.
+ */
+ def majorVersion(sparkVersion: String): Int = majorMinorVersion(sparkVersion)._1
+
+ /**
+ * Given a Spark version string, return the minor version number.
+ * E.g., for 2.0.1-SNAPSHOT, return 0.
+ */
+ def minorVersion(sparkVersion: String): Int = majorMinorVersion(sparkVersion)._2
+
+ /**
+ * Given a Spark version string, return the (major version number, minor version number).
+ * E.g., for 2.0.1-SNAPSHOT, return (2, 0).
+ */
+ def majorMinorVersion(sparkVersion: String): (Int, Int) = {
+ majorMinorRegex.findFirstMatchIn(sparkVersion) match {
+ case Some(m) =>
+ (m.group(1).toInt, m.group(2).toInt)
+ case None =>
+ throw new IllegalArgumentException(s"Spark tried to parse '$sparkVersion' as a Spark" +
+ s" version string, but it could not find the major and minor version numbers.")
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala
index 4c1e16155462..bcb95b416dd2 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala
@@ -17,7 +17,7 @@
package org.apache.spark.util.collection
-import java.util.{Arrays, Comparator}
+import java.util.Comparator
import com.google.common.hash.Hashing
@@ -140,16 +140,16 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64)
var i = 1
while (true) {
val curKey = data(2 * pos)
- if (k.eq(curKey) || k.equals(curKey)) {
- val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
- data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
- return newValue
- } else if (curKey.eq(null)) {
+ if (curKey.eq(null)) {
val newValue = updateFunc(false, null.asInstanceOf[V])
data(2 * pos) = k
data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
incrementSize()
return newValue
+ } else if (k.eq(curKey) || k.equals(curKey)) {
+ val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
+ data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
+ return newValue
} else {
val delta = i
pos = (pos + delta) & mask
diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
index 85c5bdbfcebc..e63e0e3e1f68 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
@@ -17,21 +17,16 @@
package org.apache.spark.util.collection
-import java.io.{Externalizable, ObjectInput, ObjectOutput}
-
-import org.apache.spark.util.{Utils => UUtils}
-
+import java.util.Arrays
/**
* A simple, fixed-size bit set implementation. This implementation is fast because it avoids
* safety/bound checking.
*/
-class BitSet(private[this] var numBits: Int) extends Externalizable {
-
- private var words = new Array[Long](bit2words(numBits))
- private def numWords = words.length
+class BitSet(numBits: Int) extends Serializable {
- def this() = this(0)
+ private val words = new Array[Long](bit2words(numBits))
+ private val numWords = words.length
/**
* Compute the capacity (number of bits) that can be represented
@@ -42,21 +37,14 @@ class BitSet(private[this] var numBits: Int) extends Externalizable {
/**
* Clear all set bits.
*/
- def clear(): Unit = {
- var i = 0
- while (i < numWords) {
- words(i) = 0L
- i += 1
- }
- }
+ def clear(): Unit = Arrays.fill(words, 0)
/**
* Set all the bits up to a given index
*/
- def setUntil(bitIndex: Int) {
+ def setUntil(bitIndex: Int): Unit = {
val wordIndex = bitIndex >> 6 // divide by 64
- var i = 0
- while(i < wordIndex) { words(i) = -1; i += 1 }
+ Arrays.fill(words, 0, wordIndex, -1)
if(wordIndex < words.length) {
// Set the remaining bits (note that the mask could still be zero)
val mask = ~(-1L << (bitIndex & 0x3f))
@@ -64,6 +52,19 @@ class BitSet(private[this] var numBits: Int) extends Externalizable {
}
}
+ /**
+ * Clear all the bits up to a given index
+ */
+ def clearUntil(bitIndex: Int): Unit = {
+ val wordIndex = bitIndex >> 6 // divide by 64
+ Arrays.fill(words, 0, wordIndex, 0)
+ if(wordIndex < words.length) {
+ // Clear the remaining bits
+ val mask = -1L << (bitIndex & 0x3f)
+ words(wordIndex) &= mask
+ }
+ }
+
/**
* Compute the bit-wise AND of the two sets returning the
* result.
@@ -237,19 +238,4 @@ class BitSet(private[this] var numBits: Int) extends Externalizable {
/** Return the number of longs it would take to hold numBits. */
private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1
-
- override def writeExternal(out: ObjectOutput): Unit = UUtils.tryOrIOException {
- out.writeInt(numBits)
- words.foreach(out.writeLong(_))
- }
-
- override def readExternal(in: ObjectInput): Unit = UUtils.tryOrIOException {
- numBits = in.readInt()
- words = new Array[Long](bit2words(numBits))
- var index = 0
- while (index < words.length) {
- words(index) = in.readLong()
- index += 1
- }
- }
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index f6d81ee5bf05..8aafda5e45d5 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -26,14 +26,15 @@ import scala.collection.mutable.ArrayBuffer
import com.google.common.io.ByteStreams
-import org.apache.spark.{Logging, SparkEnv, TaskContext}
+import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.internal.Logging
import org.apache.spark.memory.TaskMemoryManager
-import org.apache.spark.serializer.{DeserializationStream, Serializer}
+import org.apache.spark.serializer.{DeserializationStream, Serializer, SerializerManager}
import org.apache.spark.storage.{BlockId, BlockManager}
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator
-import org.apache.spark.executor.ShuffleWriteMetrics
/**
* :: DeveloperApi ::
@@ -58,11 +59,12 @@ class ExternalAppendOnlyMap[K, V, C](
mergeCombiners: (C, C) => C,
serializer: Serializer = SparkEnv.get.serializer,
blockManager: BlockManager = SparkEnv.get.blockManager,
- context: TaskContext = TaskContext.get())
- extends Iterable[(K, C)]
+ context: TaskContext = TaskContext.get(),
+ serializerManager: SerializerManager = SparkEnv.get.serializerManager)
+ extends Spillable[SizeTracker](context.taskMemoryManager())
with Serializable
with Logging
- with Spillable[SizeTracker] {
+ with Iterable[(K, C)] {
if (context == null) {
throw new IllegalStateException(
@@ -79,9 +81,7 @@ class ExternalAppendOnlyMap[K, V, C](
this(createCombiner, mergeValue, mergeCombiners, serializer, blockManager, TaskContext.get())
}
- override protected[this] def taskMemoryManager: TaskMemoryManager = context.taskMemoryManager()
-
- private var currentMap = new SizeTrackingAppendOnlyMap[K, C]
+ @volatile private var currentMap = new SizeTrackingAppendOnlyMap[K, C]
private val spilledMaps = new ArrayBuffer[DiskMapIterator]
private val sparkConf = SparkEnv.get.conf
private val diskBlockManager = blockManager.diskBlockManager
@@ -105,8 +105,8 @@ class ExternalAppendOnlyMap[K, V, C](
private val fileBufferSize =
sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
- // Write metrics for current spill
- private var curWriteMetrics: ShuffleWriteMetrics = _
+ // Write metrics
+ private val writeMetrics: ShuffleWriteMetrics = new ShuffleWriteMetrics()
// Peak size of the in-memory map observed so far, in bytes
private var _peakMemoryUsedBytes: Long = 0L
@@ -115,6 +115,8 @@ class ExternalAppendOnlyMap[K, V, C](
private val keyComparator = new HashComparator[K]
private val ser = serializer.newInstance()
+ @volatile private var readingIterator: SpillableIterator = null
+
/**
* Number of files this map has spilled so far.
* Exposed for testing.
@@ -180,9 +182,38 @@ class ExternalAppendOnlyMap[K, V, C](
* Sort the existing contents of the in-memory map and spill them to a temporary file on disk.
*/
override protected[this] def spill(collection: SizeTracker): Unit = {
+ val inMemoryIterator = currentMap.destructiveSortedIterator(keyComparator)
+ val diskMapIterator = spillMemoryIteratorToDisk(inMemoryIterator)
+ spilledMaps += diskMapIterator
+ }
+
+ /**
+ * Force to spilling the current in-memory collection to disk to release memory,
+ * It will be called by TaskMemoryManager when there is not enough memory for the task.
+ */
+ override protected[this] def forceSpill(): Boolean = {
+ if (readingIterator != null) {
+ val isSpilled = readingIterator.spill()
+ if (isSpilled) {
+ currentMap = null
+ }
+ isSpilled
+ } else if (currentMap.size > 0) {
+ spill(currentMap)
+ currentMap = new SizeTrackingAppendOnlyMap[K, C]
+ true
+ } else {
+ false
+ }
+ }
+
+ /**
+ * Spill the in-memory Iterator to a temporary file on disk.
+ */
+ private[this] def spillMemoryIteratorToDisk(inMemoryIterator: Iterator[(K, C)])
+ : DiskMapIterator = {
val (blockId, file) = diskBlockManager.createTempLocalBlock()
- curWriteMetrics = new ShuffleWriteMetrics()
- var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
+ val writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetrics)
var objectsWritten = 0
// List of batch sizes (bytes) in the order they are written to disk
@@ -190,43 +221,35 @@ class ExternalAppendOnlyMap[K, V, C](
// Flush the disk writer's contents to disk, and update relevant variables
def flush(): Unit = {
- val w = writer
- writer = null
- w.commitAndClose()
- _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten
- batchSizes.append(curWriteMetrics.shuffleBytesWritten)
+ val segment = writer.commitAndGet()
+ batchSizes += segment.length
+ _diskBytesSpilled += segment.length
objectsWritten = 0
}
var success = false
try {
- val it = currentMap.destructiveSortedIterator(keyComparator)
- while (it.hasNext) {
- val kv = it.next()
+ while (inMemoryIterator.hasNext) {
+ val kv = inMemoryIterator.next()
writer.write(kv._1, kv._2)
objectsWritten += 1
if (objectsWritten == serializerBatchSize) {
flush()
- curWriteMetrics = new ShuffleWriteMetrics()
- writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
}
}
if (objectsWritten > 0) {
flush()
- } else if (writer != null) {
- val w = writer
- writer = null
- w.revertPartialWritesAndClose()
+ writer.close()
+ } else {
+ writer.revertPartialWritesAndClose()
}
success = true
} finally {
if (!success) {
// This code path only happens if an exception was thrown above before we set success;
// close our stuff and let the exception be thrown further
- if (writer != null) {
- writer.revertPartialWritesAndClose()
- }
+ writer.revertPartialWritesAndClose()
if (file.exists()) {
if (!file.delete()) {
logWarning(s"Error deleting ${file}")
@@ -235,7 +258,17 @@ class ExternalAppendOnlyMap[K, V, C](
}
}
- spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes))
+ new DiskMapIterator(file, blockId, batchSizes)
+ }
+
+ /**
+ * Returns a destructive iterator for iterating over the entries of this map.
+ * If this iterator is forced spill to disk to release memory when there is not enough memory,
+ * it returns pairs from an on-disk map.
+ */
+ def destructiveIterator(inMemoryIterator: Iterator[(K, C)]): Iterator[(K, C)] = {
+ readingIterator = new SpillableIterator(inMemoryIterator)
+ readingIterator
}
/**
@@ -248,15 +281,18 @@ class ExternalAppendOnlyMap[K, V, C](
"ExternalAppendOnlyMap.iterator is destructive and should only be called once.")
}
if (spilledMaps.isEmpty) {
- CompletionIterator[(K, C), Iterator[(K, C)]](currentMap.iterator, freeCurrentMap())
+ CompletionIterator[(K, C), Iterator[(K, C)]](
+ destructiveIterator(currentMap.iterator), freeCurrentMap())
} else {
new ExternalIterator()
}
}
private def freeCurrentMap(): Unit = {
- currentMap = null // So that the memory can be garbage-collected
- releaseMemory()
+ if (currentMap != null) {
+ currentMap = null // So that the memory can be garbage-collected
+ releaseMemory()
+ }
}
/**
@@ -270,8 +306,8 @@ class ExternalAppendOnlyMap[K, V, C](
// Input streams are derived both from the in-memory map and spilled maps on disk
// The in-memory map is sorted in place, while the spilled maps are already in sorted order
- private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](
- currentMap.destructiveSortedIterator(keyComparator), freeCurrentMap())
+ private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](destructiveIterator(
+ currentMap.destructiveSortedIterator(keyComparator)), freeCurrentMap())
private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered)
inputStreams.foreach { it =>
@@ -338,14 +374,14 @@ class ExternalAppendOnlyMap[K, V, C](
/**
* Return true if there exists an input stream that still has unvisited pairs.
*/
- override def hasNext: Boolean = mergeHeap.length > 0
+ override def hasNext: Boolean = mergeHeap.nonEmpty
/**
* Select a key with the minimum hash, then combine all values with the same key from all
* input streams.
*/
override def next(): (K, C) = {
- if (mergeHeap.length == 0) {
+ if (mergeHeap.isEmpty) {
throw new NoSuchElementException
}
// Select a key from the StreamBuffer that holds the lowest key hash
@@ -360,7 +396,7 @@ class ExternalAppendOnlyMap[K, V, C](
// For all other streams that may have this key (i.e. have the same minimum key hash),
// merge in the corresponding value (if any) from that stream
val mergedBuffers = ArrayBuffer[StreamBuffer](minBuffer)
- while (mergeHeap.length > 0 && mergeHeap.head.minKeyHash == minHash) {
+ while (mergeHeap.nonEmpty && mergeHeap.head.minKeyHash == minHash) {
val newBuffer = mergeHeap.dequeue()
minCombiner = mergeIfKeyExists(minKey, minCombiner, newBuffer)
mergedBuffers += newBuffer
@@ -457,8 +493,8 @@ class ExternalAppendOnlyMap[K, V, C](
", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))
val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
- val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream)
- ser.deserializeStream(compressedStream)
+ val wrappedStream = serializerManager.wrapStream(blockId, bufferedStream)
+ ser.deserializeStream(wrappedStream)
} else {
// No more batches left
cleanup()
@@ -530,8 +566,56 @@ class ExternalAppendOnlyMap[K, V, C](
context.addTaskCompletionListener(context => cleanup())
}
+ private[this] class SpillableIterator(var upstream: Iterator[(K, C)])
+ extends Iterator[(K, C)] {
+
+ private val SPILL_LOCK = new Object()
+
+ private var nextUpstream: Iterator[(K, C)] = null
+
+ private var cur: (K, C) = readNext()
+
+ private var hasSpilled: Boolean = false
+
+ def spill(): Boolean = SPILL_LOCK.synchronized {
+ if (hasSpilled) {
+ false
+ } else {
+ logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " +
+ s"it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory")
+ nextUpstream = spillMemoryIteratorToDisk(upstream)
+ hasSpilled = true
+ true
+ }
+ }
+
+ def readNext(): (K, C) = SPILL_LOCK.synchronized {
+ if (nextUpstream != null) {
+ upstream = nextUpstream
+ nextUpstream = null
+ }
+ if (upstream.hasNext) {
+ upstream.next()
+ } else {
+ null
+ }
+ }
+
+ override def hasNext(): Boolean = cur != null
+
+ override def next(): (K, C) = {
+ val r = cur
+ cur = readNext()
+ r
+ }
+ }
+
/** Convenience function to hash the given (K, C) pair by the key. */
private def hashKey(kc: (K, C)): Int = ExternalAppendOnlyMap.hash(kc._1)
+
+ override def toString(): String = {
+ this.getClass.getName + "@" + java.lang.Integer.toHexString(this.hashCode())
+ }
}
private[spark] object ExternalAppendOnlyMap {
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index a44e72b7c16d..176f84fa2a0d 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -20,16 +20,15 @@ package org.apache.spark.util.collection
import java.io._
import java.util.Comparator
-import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
-import com.google.common.annotations.VisibleForTesting
import com.google.common.io.ByteStreams
import org.apache.spark._
-import org.apache.spark.memory.TaskMemoryManager
-import org.apache.spark.serializer._
import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.internal.Logging
+import org.apache.spark.serializer._
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
/**
@@ -68,35 +67,33 @@ import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
*
* At a high level, this class works internally as follows:
*
- * - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if
- * we want to combine by key, or a PartitionedPairBuffer if we don't.
- * Inside these buffers, we sort elements by partition ID and then possibly also by key.
- * To avoid calling the partitioner multiple times with each key, we store the partition ID
- * alongside each record.
+ * - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if
+ * we want to combine by key, or a PartitionedPairBuffer if we don't.
+ * Inside these buffers, we sort elements by partition ID and then possibly also by key.
+ * To avoid calling the partitioner multiple times with each key, we store the partition ID
+ * alongside each record.
*
- * - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first
- * by partition ID and possibly second by key or by hash code of the key, if we want to do
- * aggregation. For each file, we track how many objects were in each partition in memory, so we
- * don't have to write out the partition ID for every element.
+ * - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first
+ * by partition ID and possibly second by key or by hash code of the key, if we want to do
+ * aggregation. For each file, we track how many objects were in each partition in memory, so we
+ * don't have to write out the partition ID for every element.
*
- * - When the user requests an iterator or file output, the spilled files are merged, along with
- * any remaining in-memory data, using the same sort order defined above (unless both sorting
- * and aggregation are disabled). If we need to aggregate by key, we either use a total ordering
- * from the ordering parameter, or read the keys with the same hash code and compare them with
- * each other for equality to merge values.
+ * - When the user requests an iterator or file output, the spilled files are merged, along with
+ * any remaining in-memory data, using the same sort order defined above (unless both sorting
+ * and aggregation are disabled). If we need to aggregate by key, we either use a total ordering
+ * from the ordering parameter, or read the keys with the same hash code and compare them with
+ * each other for equality to merge values.
*
- * - Users are expected to call stop() at the end to delete all the intermediate files.
+ * - Users are expected to call stop() at the end to delete all the intermediate files.
*/
private[spark] class ExternalSorter[K, V, C](
context: TaskContext,
aggregator: Option[Aggregator[K, V, C]] = None,
partitioner: Option[Partitioner] = None,
ordering: Option[Ordering[K]] = None,
- serializer: Option[Serializer] = None)
- extends Logging
- with Spillable[WritablePartitionedPairCollection[K, C]] {
-
- override protected[this] def taskMemoryManager: TaskMemoryManager = context.taskMemoryManager()
+ serializer: Serializer = SparkEnv.get.serializer)
+ extends Spillable[WritablePartitionedPairCollection[K, C]](context.taskMemoryManager())
+ with Logging {
private val conf = SparkEnv.get.conf
@@ -108,8 +105,8 @@ private[spark] class ExternalSorter[K, V, C](
private val blockManager = SparkEnv.get.blockManager
private val diskBlockManager = blockManager.diskBlockManager
- private val ser = Serializer.getSerializer(serializer)
- private val serInstance = ser.newInstance()
+ private val serializerManager = SparkEnv.get.serializerManager
+ private val serInstance = serializer.newInstance()
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
@@ -126,8 +123,8 @@ private[spark] class ExternalSorter[K, V, C](
// Data structures to store in-memory objects before we spill. Depending on whether we have an
// Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we
// store them in an array buffer.
- private var map = new PartitionedAppendOnlyMap[K, C]
- private var buffer = new PartitionedPairBuffer[K, C]
+ @volatile private var map = new PartitionedAppendOnlyMap[K, C]
+ @volatile private var buffer = new PartitionedPairBuffer[K, C]
// Total spilling statistics
private var _diskBytesSpilled = 0L
@@ -137,6 +134,10 @@ private[spark] class ExternalSorter[K, V, C](
private var _peakMemoryUsedBytes: Long = 0L
def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes
+ @volatile private var isShuffleSort: Boolean = true
+ private val forceSpillFiles = new ArrayBuffer[SpilledFile]
+ @volatile private var readingIterator: SpillableIterator = null
+
// A comparator for keys K that orders them within a partition to allow aggregation or sorting.
// Can be a partial ordering by hash code if a total ordering is not provided through by the
// user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some
@@ -235,6 +236,34 @@ private[spark] class ExternalSorter[K, V, C](
* @param collection whichever collection we're using (map or buffer)
*/
override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
+ val inMemoryIterator = collection.destructiveSortedWritablePartitionedIterator(comparator)
+ val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
+ spills += spillFile
+ }
+
+ /**
+ * Force to spilling the current in-memory collection to disk to release memory,
+ * It will be called by TaskMemoryManager when there is not enough memory for the task.
+ */
+ override protected[this] def forceSpill(): Boolean = {
+ if (isShuffleSort) {
+ false
+ } else {
+ assert(readingIterator != null)
+ val isSpilled = readingIterator.spill()
+ if (isSpilled) {
+ map = null
+ buffer = null
+ }
+ isSpilled
+ }
+ }
+
+ /**
+ * Spill contents of in-memory iterator to a temporary file on disk.
+ */
+ private[this] def spillMemoryIteratorToDisk(inMemoryIterator: WritablePartitionedIterator)
+ : SpilledFile = {
// Because these files may be read during shuffle, their compression must be controlled by
// spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
// createTempShuffleBlock here; see SPARK-3426 for more context.
@@ -242,14 +271,9 @@ private[spark] class ExternalSorter[K, V, C](
// These variables are reset after each flush
var objectsWritten: Long = 0
- var spillMetrics: ShuffleWriteMetrics = null
- var writer: DiskBlockObjectWriter = null
- def openWriter(): Unit = {
- assert (writer == null && spillMetrics == null)
- spillMetrics = new ShuffleWriteMetrics
- writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics)
- }
- openWriter()
+ val spillMetrics: ShuffleWriteMetrics = new ShuffleWriteMetrics
+ val writer: DiskBlockObjectWriter =
+ blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics)
// List of batch sizes (bytes) in the order they are written to disk
val batchSizes = new ArrayBuffer[Long]
@@ -258,48 +282,41 @@ private[spark] class ExternalSorter[K, V, C](
val elementsPerPartition = new Array[Long](numPartitions)
// Flush the disk writer's contents to disk, and update relevant variables.
- // The writer is closed at the end of this process, and cannot be reused.
+ // The writer is committed at the end of this process.
def flush(): Unit = {
- val w = writer
- writer = null
- w.commitAndClose()
- _diskBytesSpilled += spillMetrics.shuffleBytesWritten
- batchSizes.append(spillMetrics.shuffleBytesWritten)
- spillMetrics = null
+ val segment = writer.commitAndGet()
+ batchSizes += segment.length
+ _diskBytesSpilled += segment.length
objectsWritten = 0
}
var success = false
try {
- val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
- while (it.hasNext) {
- val partitionId = it.nextPartition()
+ while (inMemoryIterator.hasNext) {
+ val partitionId = inMemoryIterator.nextPartition()
require(partitionId >= 0 && partitionId < numPartitions,
s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})")
- it.writeNext(writer)
+ inMemoryIterator.writeNext(writer)
elementsPerPartition(partitionId) += 1
objectsWritten += 1
if (objectsWritten == serializerBatchSize) {
flush()
- openWriter()
}
}
if (objectsWritten > 0) {
flush()
- } else if (writer != null) {
- val w = writer
- writer = null
- w.revertPartialWritesAndClose()
+ } else {
+ writer.revertPartialWritesAndClose()
}
success = true
} finally {
- if (!success) {
+ if (success) {
+ writer.close()
+ } else {
// This code path only happens if an exception was thrown above before we set success;
// close our stuff and let the exception be thrown further
- if (writer != null) {
- writer.revertPartialWritesAndClose()
- }
+ writer.revertPartialWritesAndClose()
if (file.exists()) {
if (!file.delete()) {
logWarning(s"Error deleting ${file}")
@@ -308,7 +325,7 @@ private[spark] class ExternalSorter[K, V, C](
}
}
- spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition))
+ SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)
}
/**
@@ -504,8 +521,9 @@ private[spark] class ExternalSorter[K, V, C](
", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))
val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
- val compressedStream = blockManager.wrapForCompression(spill.blockId, bufferedStream)
- serInstance.deserializeStream(compressedStream)
+
+ val wrappedStream = serializerManager.wrapStream(spill.blockId, bufferedStream)
+ serInstance.deserializeStream(wrappedStream)
} else {
// No more batches left
cleanup()
@@ -593,12 +611,28 @@ private[spark] class ExternalSorter[K, V, C](
val ds = deserializeStream
deserializeStream = null
fileStream = null
- ds.close()
+ if (ds != null) {
+ ds.close()
+ }
// NOTE: We don't do file.delete() here because that is done in ExternalSorter.stop().
// This should also be fixed in ExternalAppendOnlyMap.
}
}
+ /**
+ * Returns a destructive iterator for iterating over the entries of this map.
+ * If this iterator is forced spill to disk to release memory when there is not enough memory,
+ * it returns pairs from an on-disk map.
+ */
+ def destructiveIterator(memoryIterator: Iterator[((Int, K), C)]): Iterator[((Int, K), C)] = {
+ if (isShuffleSort) {
+ memoryIterator
+ } else {
+ readingIterator = new SpillableIterator(memoryIterator)
+ readingIterator
+ }
+ }
+
/**
* Return an iterator over all the data written to this object, grouped by partition and
* aggregated by the requested aggregator. For each partition we then have an iterator over its
@@ -608,8 +642,8 @@ private[spark] class ExternalSorter[K, V, C](
*
* For now, we just merge all the spilled files in once pass, but this can be modified to
* support hierarchical merging.
+ * Exposed for testing.
*/
- @VisibleForTesting
def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
val usingMap = aggregator.isDefined
val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
@@ -618,28 +652,32 @@ private[spark] class ExternalSorter[K, V, C](
// we don't even need to sort by anything other than partition ID
if (!ordering.isDefined) {
// The user hasn't requested sorted keys, so only sort by partition ID, not key
- groupByPartition(collection.partitionedDestructiveSortedIterator(None))
+ groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None)))
} else {
// We do need to sort by both partition ID and key
- groupByPartition(collection.partitionedDestructiveSortedIterator(Some(keyComparator)))
+ groupByPartition(destructiveIterator(
+ collection.partitionedDestructiveSortedIterator(Some(keyComparator))))
}
} else {
// Merge spilled and in-memory data
- merge(spills, collection.partitionedDestructiveSortedIterator(comparator))
+ merge(spills, destructiveIterator(
+ collection.partitionedDestructiveSortedIterator(comparator)))
}
}
/**
* Return an iterator over all the data written to this object, aggregated by our aggregator.
*/
- def iterator: Iterator[Product2[K, C]] = partitionedIterator.flatMap(pair => pair._2)
+ def iterator: Iterator[Product2[K, C]] = {
+ isShuffleSort = false
+ partitionedIterator.flatMap(pair => pair._2)
+ }
/**
* Write all the data added into this ExternalSorter into a file in the disk store. This is
* called by the SortShuffleWriter.
*
* @param blockId block ID to write to. The index file will be blockId.name + ".index".
- * @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
*/
def writePartitionedFile(
@@ -648,52 +686,52 @@ private[spark] class ExternalSorter[K, V, C](
// Track location of each range in the output file
val lengths = new Array[Long](numPartitions)
+ val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
+ context.taskMetrics().shuffleWriteMetrics)
if (spills.isEmpty) {
// Case where we only have in-memory data
val collection = if (aggregator.isDefined) map else buffer
val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
while (it.hasNext) {
- val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
- context.taskMetrics.shuffleWriteMetrics.get)
val partitionId = it.nextPartition()
while (it.hasNext && it.nextPartition() == partitionId) {
it.writeNext(writer)
}
- writer.commitAndClose()
- val segment = writer.fileSegment()
+ val segment = writer.commitAndGet()
lengths(partitionId) = segment.length
}
} else {
// We must perform merge-sort; get an iterator by partition and write everything directly.
for ((id, elements) <- this.partitionedIterator) {
if (elements.hasNext) {
- val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
- context.taskMetrics.shuffleWriteMetrics.get)
for (elem <- elements) {
writer.write(elem._1, elem._2)
}
- writer.commitAndClose()
- val segment = writer.fileSegment()
+ val segment = writer.commitAndGet()
lengths(id) = segment.length
}
}
}
+ writer.close()
context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
- context.internalMetricsToAccumulators(
- InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemoryUsedBytes)
+ context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
lengths
}
def stop(): Unit = {
- map = null // So that the memory can be garbage-collected
- buffer = null // So that the memory can be garbage-collected
spills.foreach(s => s.file.delete())
spills.clear()
- releaseMemory()
+ forceSpillFiles.foreach(s => s.file.delete())
+ forceSpillFiles.clear()
+ if (map != null || buffer != null) {
+ map = null // So that the memory can be garbage-collected
+ buffer = null // So that the memory can be garbage-collected
+ releaseMemory()
+ }
}
/**
@@ -727,4 +765,66 @@ private[spark] class ExternalSorter[K, V, C](
(elem._1._2, elem._2)
}
}
+
+ private[this] class SpillableIterator(var upstream: Iterator[((Int, K), C)])
+ extends Iterator[((Int, K), C)] {
+
+ private val SPILL_LOCK = new Object()
+
+ private var nextUpstream: Iterator[((Int, K), C)] = null
+
+ private var cur: ((Int, K), C) = readNext()
+
+ private var hasSpilled: Boolean = false
+
+ def spill(): Boolean = SPILL_LOCK.synchronized {
+ if (hasSpilled) {
+ false
+ } else {
+ val inMemoryIterator = new WritablePartitionedIterator {
+ private[this] var cur = if (upstream.hasNext) upstream.next() else null
+
+ def writeNext(writer: DiskBlockObjectWriter): Unit = {
+ writer.write(cur._1._2, cur._2)
+ cur = if (upstream.hasNext) upstream.next() else null
+ }
+
+ def hasNext(): Boolean = cur != null
+
+ def nextPartition(): Int = cur._1._1
+ }
+ logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " +
+ s" it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory")
+ val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
+ forceSpillFiles += spillFile
+ val spillReader = new SpillReader(spillFile)
+ nextUpstream = (0 until numPartitions).iterator.flatMap { p =>
+ val iterator = spillReader.readNextPartition()
+ iterator.map(cur => ((p, cur._1), cur._2))
+ }
+ hasSpilled = true
+ true
+ }
+ }
+
+ def readNext(): ((Int, K), C) = SPILL_LOCK.synchronized {
+ if (nextUpstream != null) {
+ upstream = nextUpstream
+ nextUpstream = null
+ }
+ if (upstream.hasNext) {
+ upstream.next()
+ } else {
+ null
+ }
+ }
+
+ override def hasNext(): Boolean = cur != null
+
+ override def next(): ((Int, K), C) = {
+ val r = cur
+ cur = readNext()
+ r
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/MedianHeap.scala b/core/src/main/scala/org/apache/spark/util/collection/MedianHeap.scala
new file mode 100644
index 000000000000..6e57c3c5bee8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/MedianHeap.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.collection.mutable.PriorityQueue
+
+/**
+ * MedianHeap is designed to be used to quickly track the median of a group of numbers
+ * that may contain duplicates. Inserting a new number has O(log n) time complexity and
+ * determining the median has O(1) time complexity.
+ * The basic idea is to maintain two heaps: a smallerHalf and a largerHalf. The smallerHalf
+ * stores the smaller half of all numbers while the largerHalf stores the larger half.
+ * The sizes of two heaps need to be balanced each time when a new number is inserted so
+ * that their sizes will not be different by more than 1. Therefore each time when
+ * findMedian() is called we check if two heaps have the same size. If they do, we should
+ * return the average of the two top values of heaps. Otherwise we return the top of the
+ * heap which has one more element.
+ */
+private[spark] class MedianHeap(implicit val ord: Ordering[Double]) {
+
+ /**
+ * Stores all the numbers less than the current median in a smallerHalf,
+ * i.e median is the maximum, at the root.
+ */
+ private[this] var smallerHalf = PriorityQueue.empty[Double](ord)
+
+ /**
+ * Stores all the numbers greater than the current median in a largerHalf,
+ * i.e median is the minimum, at the root.
+ */
+ private[this] var largerHalf = PriorityQueue.empty[Double](ord.reverse)
+
+ def isEmpty(): Boolean = {
+ smallerHalf.isEmpty && largerHalf.isEmpty
+ }
+
+ def size(): Int = {
+ smallerHalf.size + largerHalf.size
+ }
+
+ def insert(x: Double): Unit = {
+ // If both heaps are empty, we arbitrarily insert it into a heap, let's say, the largerHalf.
+ if (isEmpty) {
+ largerHalf.enqueue(x)
+ } else {
+ // If the number is larger than current median, it should be inserted into largerHalf,
+ // otherwise smallerHalf.
+ if (x > median) {
+ largerHalf.enqueue(x)
+ } else {
+ smallerHalf.enqueue(x)
+ }
+ }
+ rebalance()
+ }
+
+ private[this] def rebalance(): Unit = {
+ if (largerHalf.size - smallerHalf.size > 1) {
+ smallerHalf.enqueue(largerHalf.dequeue())
+ }
+ if (smallerHalf.size - largerHalf.size > 1) {
+ largerHalf.enqueue(smallerHalf.dequeue)
+ }
+ }
+
+ def median: Double = {
+ if (isEmpty) {
+ throw new NoSuchElementException("MedianHeap is empty.")
+ }
+ if (largerHalf.size == smallerHalf.size) {
+ (largerHalf.head + smallerHalf.head) / 2.0
+ } else if (largerHalf.size > smallerHalf.size) {
+ largerHalf.head
+ } else {
+ smallerHalf.head
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
index efc2482c74dd..10ab0b3f8996 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
@@ -19,17 +19,16 @@ package org.apache.spark.util.collection
import scala.reflect.ClassTag
-import org.apache.spark.annotation.DeveloperApi
-
/**
- * :: DeveloperApi ::
* A fast hash map implementation for nullable keys. This hash map supports insertions and updates,
* but not deletions. This map is about 5X faster than java.util.HashMap, while using much less
* space overhead.
*
* Under the hood, it uses our OpenHashSet implementation.
+ *
+ * NOTE: when using numeric type as the value type, the user of this class should be careful to
+ * distinguish between the 0/0.0/0L and non-exist value
*/
-@DeveloperApi
private[spark]
class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag](
initialCapacity: Int)
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
index 60bf4dd7469f..60f6f537c1d5 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
@@ -18,6 +18,7 @@
package org.apache.spark.util.collection
import scala.reflect._
+
import com.google.common.hash.Hashing
import org.apache.spark.annotation.Private
@@ -47,7 +48,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
require(initialCapacity <= OpenHashSet.MAX_CAPACITY,
s"Can't make capacity bigger than ${OpenHashSet.MAX_CAPACITY} elements")
- require(initialCapacity >= 1, "Invalid initial capacity")
+ require(initialCapacity >= 0, "Invalid initial capacity")
require(loadFactor < 1.0, "Load factor must be less than 1.0")
require(loadFactor > 0.0, "Load factor must be greater than 0.0")
@@ -270,8 +271,12 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
private def hashcode(h: Int): Int = Hashing.murmur3_32().hashInt(h).asInt()
private def nextPowerOf2(n: Int): Int = {
- val highBit = Integer.highestOneBit(n)
- if (highBit == n) n else highBit << 1
+ if (n == 0) {
+ 1
+ } else {
+ val highBit = Integer.highestOneBit(n)
+ if (highBit == n) n else highBit << 1
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
index 9e002621a690..8183f825592c 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
@@ -17,14 +17,16 @@
package org.apache.spark.util.collection
-import org.apache.spark.memory.TaskMemoryManager
-import org.apache.spark.{Logging, SparkEnv}
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.Logging
+import org.apache.spark.memory.{MemoryConsumer, MemoryMode, TaskMemoryManager}
/**
* Spills contents of an in-memory collection to disk when the memory threshold
* has been exceeded.
*/
-private[spark] trait Spillable[C] extends Logging {
+private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager)
+ extends MemoryConsumer(taskMemoryManager) with Logging {
/**
* Spills the current in-memory collection to disk, and releases the memory.
*
@@ -32,6 +34,12 @@ private[spark] trait Spillable[C] extends Logging {
*/
protected def spill(collection: C): Unit
+ /**
+ * Force to spilling the current in-memory collection to disk to release memory,
+ * It will be called by TaskMemoryManager when there is not enough memory for the task.
+ */
+ protected def forceSpill(): Boolean
+
// Number of elements read from input since last spill
protected def elementsRead: Long = _elementsRead
@@ -39,9 +47,6 @@ private[spark] trait Spillable[C] extends Logging {
// It's used for checking spilling frequency
protected def addElementsRead(): Unit = { _elementsRead += 1 }
- // Memory manager that can be used to acquire/release memory
- protected[this] def taskMemoryManager: TaskMemoryManager
-
// Initial threshold for the size of a collection before we start tracking its memory usage
// For testing only
private[this] val initialMemoryThreshold: Long =
@@ -54,13 +59,13 @@ private[spark] trait Spillable[C] extends Logging {
// Threshold for this collection's size in bytes before we start tracking its memory usage
// To avoid a large number of small spills, initialize this to a value orders of magnitude > 0
- private[this] var myMemoryThreshold = initialMemoryThreshold
+ @volatile private[this] var myMemoryThreshold = initialMemoryThreshold
// Number of elements read from input since last spill
private[this] var _elementsRead = 0L
// Number of bytes spilled in total
- private[this] var _memoryBytesSpilled = 0L
+ @volatile private[this] var _memoryBytesSpilled = 0L
// Number of spills
private[this] var _spillCount = 0
@@ -78,7 +83,7 @@ private[spark] trait Spillable[C] extends Logging {
if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
// Claim up to double our current memory from the shuffle memory pool
val amountToRequest = 2 * currentMemory - myMemoryThreshold
- val granted = taskMemoryManager.acquireExecutionMemory(amountToRequest, null)
+ val granted = acquireMemory(amountToRequest)
myMemoryThreshold += granted
// If we were granted too little memory to grow further (either tryToAcquire returned 0,
// or we already had more memory than myMemoryThreshold), spill the current collection
@@ -97,6 +102,26 @@ private[spark] trait Spillable[C] extends Logging {
shouldSpill
}
+ /**
+ * Spill some data to disk to release memory, which will be called by TaskMemoryManager
+ * when there is not enough memory for the task.
+ */
+ override def spill(size: Long, trigger: MemoryConsumer): Long = {
+ if (trigger != this && taskMemoryManager.getTungstenMemoryMode == MemoryMode.ON_HEAP) {
+ val isSpilled = forceSpill()
+ if (!isSpilled) {
+ 0L
+ } else {
+ val freeMemory = myMemoryThreshold - initialMemoryThreshold
+ _memoryBytesSpilled += freeMemory
+ releaseMemory()
+ freeMemory
+ }
+ } else {
+ 0L
+ }
+ }
+
/**
* @return number of bytes spilled in total
*/
@@ -106,8 +131,7 @@ private[spark] trait Spillable[C] extends Logging {
* Release our memory back to the execution pool so that other tasks can grab it.
*/
def releaseMemory(): Unit = {
- // The amount we requested does not include the initial memory tracking threshold
- taskMemoryManager.releaseExecutionMemory(myMemoryThreshold - initialMemoryThreshold, null)
+ freeMemory(myMemoryThreshold - initialMemoryThreshold)
myMemoryThreshold = initialMemoryThreshold
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
index 38848e9018c6..5232c2bd8d6f 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
@@ -23,9 +23,10 @@ import org.apache.spark.storage.DiskBlockObjectWriter
/**
* A common interface for size-tracking collections of key-value pairs that
- * - Have an associated partition for each key-value pair.
- * - Support a memory-efficient sorted iterator
- * - Support a WritablePartitionedIterator for writing the contents directly as bytes.
+ *
+ * - Have an associated partition for each key-value pair.
+ * - Support a memory-efficient sorted iterator
+ * - Support a WritablePartitionedIterator for writing the contents directly as bytes.
*/
private[spark] trait WritablePartitionedPairCollection[K, V] {
/**
diff --git a/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala
deleted file mode 100644
index daac6f971eb2..000000000000
--- a/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala
+++ /dev/null
@@ -1,94 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.io
-
-import java.io.OutputStream
-
-import scala.collection.mutable.ArrayBuffer
-
-
-/**
- * An OutputStream that writes to fixed-size chunks of byte arrays.
- *
- * @param chunkSize size of each chunk, in bytes.
- */
-private[spark]
-class ByteArrayChunkOutputStream(chunkSize: Int) extends OutputStream {
-
- private val chunks = new ArrayBuffer[Array[Byte]]
-
- /** Index of the last chunk. Starting with -1 when the chunks array is empty. */
- private var lastChunkIndex = -1
-
- /**
- * Next position to write in the last chunk.
- *
- * If this equals chunkSize, it means for next write we need to allocate a new chunk.
- * This can also never be 0.
- */
- private var position = chunkSize
-
- override def write(b: Int): Unit = {
- allocateNewChunkIfNeeded()
- chunks(lastChunkIndex)(position) = b.toByte
- position += 1
- }
-
- override def write(bytes: Array[Byte], off: Int, len: Int): Unit = {
- var written = 0
- while (written < len) {
- allocateNewChunkIfNeeded()
- val thisBatch = math.min(chunkSize - position, len - written)
- System.arraycopy(bytes, written + off, chunks(lastChunkIndex), position, thisBatch)
- written += thisBatch
- position += thisBatch
- }
- }
-
- @inline
- private def allocateNewChunkIfNeeded(): Unit = {
- if (position == chunkSize) {
- chunks += new Array[Byte](chunkSize)
- lastChunkIndex += 1
- position = 0
- }
- }
-
- def toArrays: Array[Array[Byte]] = {
- if (lastChunkIndex == -1) {
- new Array[Array[Byte]](0)
- } else {
- // Copy the first n-1 chunks to the output, and then create an array that fits the last chunk.
- // An alternative would have been returning an array of ByteBuffers, with the last buffer
- // bounded to only the last chunk's position. However, given our use case in Spark (to put
- // the chunks in block manager), only limiting the view bound of the buffer would still
- // require the block manager to store the whole chunk.
- val ret = new Array[Array[Byte]](chunks.size)
- for (i <- 0 until chunks.size - 1) {
- ret(i) = chunks(i)
- }
- if (position == chunkSize) {
- ret(lastChunkIndex) = chunks(lastChunkIndex)
- } else {
- ret(lastChunkIndex) = new Array[Byte](position)
- System.arraycopy(chunks(lastChunkIndex), 0, ret(lastChunkIndex), 0, position)
- }
- ret
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
new file mode 100644
index 000000000000..2f905c8af0f6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.io
+
+import java.io.InputStream
+import java.nio.ByteBuffer
+import java.nio.channels.WritableByteChannel
+
+import com.google.common.primitives.UnsignedBytes
+import io.netty.buffer.{ByteBuf, Unpooled}
+
+import org.apache.spark.network.util.ByteArrayWritableChannel
+import org.apache.spark.storage.StorageUtils
+
+/**
+ * Read-only byte buffer which is physically stored as multiple chunks rather than a single
+ * contiguous array.
+ *
+ * @param chunks an array of [[ByteBuffer]]s. Each buffer in this array must have position == 0.
+ * Ownership of these buffers is transferred to the ChunkedByteBuffer, so if these
+ * buffers may also be used elsewhere then the caller is responsible for copying
+ * them as needed.
+ */
+private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
+ require(chunks != null, "chunks must not be null")
+ require(chunks.forall(_.position() == 0), "chunks' positions must be 0")
+
+ private[this] var disposed: Boolean = false
+
+ /**
+ * This size of this buffer, in bytes.
+ */
+ val size: Long = chunks.map(_.limit().asInstanceOf[Long]).sum
+
+ def this(byteBuffer: ByteBuffer) = {
+ this(Array(byteBuffer))
+ }
+
+ /**
+ * Write this buffer to a channel.
+ */
+ def writeFully(channel: WritableByteChannel): Unit = {
+ for (bytes <- getChunks()) {
+ while (bytes.remaining > 0) {
+ channel.write(bytes)
+ }
+ }
+ }
+
+ /**
+ * Wrap this buffer to view it as a Netty ByteBuf.
+ */
+ def toNetty: ByteBuf = {
+ Unpooled.wrappedBuffer(getChunks(): _*)
+ }
+
+ /**
+ * Copy this buffer into a new byte array.
+ *
+ * @throws UnsupportedOperationException if this buffer's size exceeds the maximum array size.
+ */
+ def toArray: Array[Byte] = {
+ if (size >= Integer.MAX_VALUE) {
+ throw new UnsupportedOperationException(
+ s"cannot call toArray because buffer size ($size bytes) exceeds maximum array size")
+ }
+ val byteChannel = new ByteArrayWritableChannel(size.toInt)
+ writeFully(byteChannel)
+ byteChannel.close()
+ byteChannel.getData
+ }
+
+ /**
+ * Convert this buffer to a ByteBuffer. If this buffer is backed by a single chunk, its underlying
+ * data will not be copied. Instead, it will be duplicated. If this buffer is backed by multiple
+ * chunks, the data underlying this buffer will be copied into a new byte buffer. As a result, it
+ * is suggested to use this method only if the caller does not need to manage the memory
+ * underlying this buffer.
+ *
+ * @throws UnsupportedOperationException if this buffer's size exceeds the max ByteBuffer size.
+ */
+ def toByteBuffer: ByteBuffer = {
+ if (chunks.length == 1) {
+ chunks.head.duplicate()
+ } else {
+ ByteBuffer.wrap(toArray)
+ }
+ }
+
+ /**
+ * Creates an input stream to read data from this ChunkedByteBuffer.
+ *
+ * @param dispose if true, [[dispose()]] will be called at the end of the stream
+ * in order to close any memory-mapped files which back this buffer.
+ */
+ def toInputStream(dispose: Boolean = false): InputStream = {
+ new ChunkedByteBufferInputStream(this, dispose)
+ }
+
+ /**
+ * Get duplicates of the ByteBuffers backing this ChunkedByteBuffer.
+ */
+ def getChunks(): Array[ByteBuffer] = {
+ chunks.map(_.duplicate())
+ }
+
+ /**
+ * Make a copy of this ChunkedByteBuffer, copying all of the backing data into new buffers.
+ * The new buffer will share no resources with the original buffer.
+ *
+ * @param allocator a method for allocating byte buffers
+ */
+ def copy(allocator: Int => ByteBuffer): ChunkedByteBuffer = {
+ val copiedChunks = getChunks().map { chunk =>
+ val newChunk = allocator(chunk.limit())
+ newChunk.put(chunk)
+ newChunk.flip()
+ newChunk
+ }
+ new ChunkedByteBuffer(copiedChunks)
+ }
+
+ /**
+ * Attempt to clean up any ByteBuffer in this ChunkedByteBuffer which is direct or memory-mapped.
+ * See [[StorageUtils.dispose]] for more information.
+ */
+ def dispose(): Unit = {
+ if (!disposed) {
+ chunks.foreach(StorageUtils.dispose)
+ disposed = true
+ }
+ }
+
+}
+
+/**
+ * Reads data from a ChunkedByteBuffer.
+ *
+ * @param dispose if true, `ChunkedByteBuffer.dispose()` will be called at the end of the stream
+ * in order to close any memory-mapped files which back the buffer.
+ */
+private[spark] class ChunkedByteBufferInputStream(
+ var chunkedByteBuffer: ChunkedByteBuffer,
+ dispose: Boolean)
+ extends InputStream {
+
+ private[this] var chunks = chunkedByteBuffer.getChunks().iterator
+ private[this] var currentChunk: ByteBuffer = {
+ if (chunks.hasNext) {
+ chunks.next()
+ } else {
+ null
+ }
+ }
+
+ override def read(): Int = {
+ if (currentChunk != null && !currentChunk.hasRemaining && chunks.hasNext) {
+ currentChunk = chunks.next()
+ }
+ if (currentChunk != null && currentChunk.hasRemaining) {
+ UnsignedBytes.toInt(currentChunk.get())
+ } else {
+ close()
+ -1
+ }
+ }
+
+ override def read(dest: Array[Byte], offset: Int, length: Int): Int = {
+ if (currentChunk != null && !currentChunk.hasRemaining && chunks.hasNext) {
+ currentChunk = chunks.next()
+ }
+ if (currentChunk != null && currentChunk.hasRemaining) {
+ val amountToGet = math.min(currentChunk.remaining(), length)
+ currentChunk.get(dest, offset, amountToGet)
+ amountToGet
+ } else {
+ close()
+ -1
+ }
+ }
+
+ override def skip(bytes: Long): Long = {
+ if (currentChunk != null) {
+ val amountToSkip = math.min(bytes, currentChunk.remaining).toInt
+ currentChunk.position(currentChunk.position + amountToSkip)
+ if (currentChunk.remaining() == 0) {
+ if (chunks.hasNext) {
+ currentChunk = chunks.next()
+ } else {
+ close()
+ }
+ }
+ amountToSkip
+ } else {
+ 0L
+ }
+ }
+
+ override def close(): Unit = {
+ if (chunkedByteBuffer != null && dispose) {
+ chunkedByteBuffer.dispose()
+ }
+ chunkedByteBuffer = null
+ chunks = null
+ currentChunk = null
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala
new file mode 100644
index 000000000000..a625b3289538
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.io
+
+import java.io.OutputStream
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.storage.StorageUtils
+
+/**
+ * An OutputStream that writes to fixed-size chunks of byte arrays.
+ *
+ * @param chunkSize size of each chunk, in bytes.
+ */
+private[spark] class ChunkedByteBufferOutputStream(
+ chunkSize: Int,
+ allocator: Int => ByteBuffer)
+ extends OutputStream {
+
+ private[this] var toChunkedByteBufferWasCalled = false
+
+ private val chunks = new ArrayBuffer[ByteBuffer]
+
+ /** Index of the last chunk. Starting with -1 when the chunks array is empty. */
+ private[this] var lastChunkIndex = -1
+
+ /**
+ * Next position to write in the last chunk.
+ *
+ * If this equals chunkSize, it means for next write we need to allocate a new chunk.
+ * This can also never be 0.
+ */
+ private[this] var position = chunkSize
+ private[this] var _size = 0
+ private[this] var closed: Boolean = false
+
+ def size: Long = _size
+
+ override def close(): Unit = {
+ if (!closed) {
+ super.close()
+ closed = true
+ }
+ }
+
+ override def write(b: Int): Unit = {
+ require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream")
+ allocateNewChunkIfNeeded()
+ chunks(lastChunkIndex).put(b.toByte)
+ position += 1
+ _size += 1
+ }
+
+ override def write(bytes: Array[Byte], off: Int, len: Int): Unit = {
+ require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream")
+ var written = 0
+ while (written < len) {
+ allocateNewChunkIfNeeded()
+ val thisBatch = math.min(chunkSize - position, len - written)
+ chunks(lastChunkIndex).put(bytes, written + off, thisBatch)
+ written += thisBatch
+ position += thisBatch
+ }
+ _size += len
+ }
+
+ @inline
+ private def allocateNewChunkIfNeeded(): Unit = {
+ if (position == chunkSize) {
+ chunks += allocator(chunkSize)
+ lastChunkIndex += 1
+ position = 0
+ }
+ }
+
+ def toChunkedByteBuffer: ChunkedByteBuffer = {
+ require(closed, "cannot call toChunkedByteBuffer() unless close() has been called")
+ require(!toChunkedByteBufferWasCalled, "toChunkedByteBuffer() can only be called once")
+ toChunkedByteBufferWasCalled = true
+ if (lastChunkIndex == -1) {
+ new ChunkedByteBuffer(Array.empty[ByteBuffer])
+ } else {
+ // Copy the first n-1 chunks to the output, and then create an array that fits the last chunk.
+ // An alternative would have been returning an array of ByteBuffers, with the last buffer
+ // bounded to only the last chunk's position. However, given our use case in Spark (to put
+ // the chunks in block manager), only limiting the view bound of the buffer would still
+ // require the block manager to store the whole chunk.
+ val ret = new Array[ByteBuffer](chunks.size)
+ for (i <- 0 until chunks.size - 1) {
+ ret(i) = chunks(i)
+ ret(i).flip()
+ }
+ if (position == chunkSize) {
+ ret(lastChunkIndex) = chunks(lastChunkIndex)
+ ret(lastChunkIndex).flip()
+ } else {
+ ret(lastChunkIndex) = allocator(position)
+ chunks(lastChunkIndex).flip()
+ ret(lastChunkIndex).put(chunks(lastChunkIndex))
+ ret(lastChunkIndex).flip()
+ StorageUtils.dispose(chunks(lastChunkIndex))
+ }
+ new ChunkedByteBuffer(ret)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala
index 14b6ba4af489..fdb1495899bc 100644
--- a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala
+++ b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala
@@ -17,9 +17,10 @@
package org.apache.spark.util.logging
-import java.io.{File, FileOutputStream, InputStream}
+import java.io.{File, FileOutputStream, InputStream, IOException}
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.Logging
import org.apache.spark.util.{IntParam, Utils}
/**
@@ -29,7 +30,6 @@ private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSi
extends Logging {
@volatile private var outputStream: FileOutputStream = null
@volatile private var markedForStop = false // has the appender been asked to stopped
- @volatile private var stopped = false // has the appender stopped
// Thread that reads the input stream and writes to file
private val writingThread = new Thread("File appending thread for " + file) {
@@ -47,11 +47,7 @@ private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSi
* or because of any error in appending
*/
def awaitTermination() {
- synchronized {
- if (!stopped) {
- wait()
- }
- }
+ writingThread.join()
}
/** Stop the appender */
@@ -63,24 +59,28 @@ private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSi
protected def appendStreamToFile() {
try {
logDebug("Started appending thread")
- openFile()
- val buf = new Array[Byte](bufferSize)
- var n = 0
- while (!markedForStop && n != -1) {
- n = inputStream.read(buf)
- if (n != -1) {
- appendToFile(buf, n)
+ Utils.tryWithSafeFinally {
+ openFile()
+ val buf = new Array[Byte](bufferSize)
+ var n = 0
+ while (!markedForStop && n != -1) {
+ try {
+ n = inputStream.read(buf)
+ } catch {
+ // An InputStream can throw IOException during read if the stream is closed
+ // asynchronously, so once appender has been flagged to stop these will be ignored
+ case _: IOException if markedForStop => // do nothing and proceed to stop appending
+ }
+ if (n > 0) {
+ appendToFile(buf, n)
+ }
}
+ } {
+ closeFile()
}
} catch {
case e: Exception =>
logError(s"Error writing stream to file $file", e)
- } finally {
- closeFile()
- synchronized {
- stopped = true
- notifyAll()
- }
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala
index 1e8476c4a047..5d8cec8447b5 100644
--- a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala
+++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala
@@ -17,11 +17,13 @@
package org.apache.spark.util.logging
-import java.io.{File, FileFilter, InputStream}
+import java.io._
+import java.util.zip.GZIPOutputStream
import com.google.common.io.Files
+import org.apache.commons.io.IOUtils
+
import org.apache.spark.SparkConf
-import RollingFileAppender._
/**
* Continuously appends data from input stream into the given file, and rolls
@@ -39,10 +41,13 @@ private[spark] class RollingFileAppender(
activeFile: File,
val rollingPolicy: RollingPolicy,
conf: SparkConf,
- bufferSize: Int = DEFAULT_BUFFER_SIZE
+ bufferSize: Int = RollingFileAppender.DEFAULT_BUFFER_SIZE
) extends FileAppender(inputStream, activeFile, bufferSize) {
+ import RollingFileAppender._
+
private val maxRetainedFiles = conf.getInt(RETAINED_FILES_PROPERTY, -1)
+ private val enableCompression = conf.getBoolean(ENABLE_COMPRESSION, false)
/** Stop the appender */
override def stop() {
@@ -74,6 +79,33 @@ private[spark] class RollingFileAppender(
}
}
+ // Roll the log file and compress if enableCompression is true.
+ private def rotateFile(activeFile: File, rolloverFile: File): Unit = {
+ if (enableCompression) {
+ val gzFile = new File(rolloverFile.getAbsolutePath + GZIP_LOG_SUFFIX)
+ var gzOutputStream: GZIPOutputStream = null
+ var inputStream: InputStream = null
+ try {
+ inputStream = new FileInputStream(activeFile)
+ gzOutputStream = new GZIPOutputStream(new FileOutputStream(gzFile))
+ IOUtils.copy(inputStream, gzOutputStream)
+ inputStream.close()
+ gzOutputStream.close()
+ activeFile.delete()
+ } finally {
+ IOUtils.closeQuietly(inputStream)
+ IOUtils.closeQuietly(gzOutputStream)
+ }
+ } else {
+ Files.move(activeFile, rolloverFile)
+ }
+ }
+
+ // Check if the rollover file already exists.
+ private def rolloverFileExist(file: File): Boolean = {
+ file.exists || new File(file.getAbsolutePath + GZIP_LOG_SUFFIX).exists
+ }
+
/** Move the active log file to a new rollover file */
private def moveFile() {
val rolloverSuffix = rollingPolicy.generateRolledOverFileSuffix()
@@ -81,8 +113,8 @@ private[spark] class RollingFileAppender(
activeFile.getParentFile, activeFile.getName + rolloverSuffix).getAbsoluteFile
logDebug(s"Attempting to rollover file $activeFile to file $rolloverFile")
if (activeFile.exists) {
- if (!rolloverFile.exists) {
- Files.move(activeFile, rolloverFile)
+ if (!rolloverFileExist(rolloverFile)) {
+ rotateFile(activeFile, rolloverFile)
logInfo(s"Rolled over $activeFile to $rolloverFile")
} else {
// In case the rollover file name clashes, make a unique file name.
@@ -95,11 +127,11 @@ private[spark] class RollingFileAppender(
altRolloverFile = new File(activeFile.getParent,
s"${activeFile.getName}$rolloverSuffix--$i").getAbsoluteFile
i += 1
- } while (i < 10000 && altRolloverFile.exists)
+ } while (i < 10000 && rolloverFileExist(altRolloverFile))
logWarning(s"Rollover file $rolloverFile already exists, " +
s"rolled over $activeFile to file $altRolloverFile")
- Files.move(activeFile, altRolloverFile)
+ rotateFile(activeFile, altRolloverFile)
}
} else {
logWarning(s"File $activeFile does not exist")
@@ -115,7 +147,7 @@ private[spark] class RollingFileAppender(
}
}).sorted
val filesToBeDeleted = rolledoverFiles.take(
- math.max(0, rolledoverFiles.size - maxRetainedFiles))
+ math.max(0, rolledoverFiles.length - maxRetainedFiles))
filesToBeDeleted.foreach { file =>
logInfo(s"Deleting file executor log file ${file.getAbsolutePath}")
file.delete()
@@ -140,6 +172,9 @@ private[spark] object RollingFileAppender {
val SIZE_DEFAULT = (1024 * 1024).toString
val RETAINED_FILES_PROPERTY = "spark.executor.logs.rolling.maxRetainedFiles"
val DEFAULT_BUFFER_SIZE = 8192
+ val ENABLE_COMPRESSION = "spark.executor.logs.rolling.enableCompression"
+
+ val GZIP_LOG_SUFFIX = ".gz"
/**
* Get the sorted list of rolled over files. This assumes that the all the rolled
@@ -156,6 +191,6 @@ private[spark] object RollingFileAppender {
val file = new File(directory, activeFileName).getAbsoluteFile
if (file.exists) Some(file) else None
}
- rolledOverFiles ++ activeFile
+ rolledOverFiles.sortBy(_.getName.stripSuffix(GZIP_LOG_SUFFIX)) ++ activeFile
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala
index d7b7219e179d..1f263df57c85 100644
--- a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala
+++ b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala
@@ -18,9 +18,9 @@
package org.apache.spark.util.logging
import java.text.SimpleDateFormat
-import java.util.Calendar
+import java.util.{Calendar, Locale}
-import org.apache.spark.Logging
+import org.apache.spark.internal.Logging
/**
* Defines the policy based on which [[org.apache.spark.util.logging.RollingFileAppender]] will
@@ -32,10 +32,10 @@ private[spark] trait RollingPolicy {
def shouldRollover(bytesToBeWritten: Long): Boolean
/** Notify that rollover has occurred */
- def rolledOver()
+ def rolledOver(): Unit
/** Notify that bytes have been written */
- def bytesWritten(bytes: Long)
+ def bytesWritten(bytes: Long): Unit
/** Get the desired name of the rollover file */
def generateRolledOverFileSuffix(): String
@@ -59,7 +59,7 @@ private[spark] class TimeBasedRollingPolicy(
}
@volatile private var nextRolloverTime = calculateNextRolloverTime()
- private val formatter = new SimpleDateFormat(rollingFileSuffixPattern)
+ private val formatter = new SimpleDateFormat(rollingFileSuffixPattern, Locale.US)
/** Should rollover if current time has exceeded next rollover time */
def shouldRollover(bytesToBeWritten: Long): Boolean = {
@@ -109,11 +109,11 @@ private[spark] class SizeBasedRollingPolicy(
}
@volatile private var bytesWrittenSinceRollover = 0L
- val formatter = new SimpleDateFormat("--yyyy-MM-dd--HH-mm-ss--SSSS")
+ val formatter = new SimpleDateFormat("--yyyy-MM-dd--HH-mm-ss--SSSS", Locale.US)
/** Should rollover if the next set of bytes is going to exceed the size limit */
def shouldRollover(bytesToBeWritten: Long): Boolean = {
- logInfo(s"$bytesToBeWritten + $bytesWrittenSinceRollover > $rolloverSizeBytes")
+ logDebug(s"$bytesToBeWritten + $bytesWrittenSinceRollover > $rolloverSizeBytes")
bytesToBeWritten + bytesWrittenSinceRollover > rolloverSizeBytes
}
diff --git a/core/src/main/scala/org/apache/spark/util/package-info.java b/core/src/main/scala/org/apache/spark/util/package-info.java
index 819f54ee41a7..4c5d33d88d2b 100644
--- a/core/src/main/scala/org/apache/spark/util/package-info.java
+++ b/core/src/main/scala/org/apache/spark/util/package-info.java
@@ -18,4 +18,4 @@
/**
* Spark utilities.
*/
-package org.apache.spark.util;
\ No newline at end of file
+package org.apache.spark.util;
diff --git a/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala b/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala
index 70f3dd62b9b1..41f28f6e511e 100644
--- a/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala
@@ -26,5 +26,5 @@ import org.apache.spark.annotation.DeveloperApi
@DeveloperApi
trait Pseudorandom {
/** Set random seed. */
- def setSeed(seed: Long)
+ def setSeed(seed: Long): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
index c156b03cdb7c..ea99a7e5b484 100644
--- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
@@ -20,7 +20,6 @@ package org.apache.spark.util.random
import java.util.Random
import scala.reflect.ClassTag
-import scala.collection.mutable.ArrayBuffer
import org.apache.commons.math3.distribution.PoissonDistribution
@@ -39,7 +38,14 @@ import org.apache.spark.annotation.DeveloperApi
trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable {
/** take a random sample */
- def sample(items: Iterator[T]): Iterator[U]
+ def sample(items: Iterator[T]): Iterator[U] =
+ items.filter(_ => sample > 0).asInstanceOf[Iterator[U]]
+
+ /**
+ * Whether to sample the next item or not.
+ * Return how many times the next item will be sampled. Return 0 if it is not sampled.
+ */
+ def sample(): Int
/** return a copy of the RandomSampler object */
override def clone: RandomSampler[T, U] =
@@ -54,7 +60,7 @@ object RandomSampler {
/**
* Default maximum gap-sampling fraction.
* For sampling fractions <= this value, the gap sampling optimization will be applied.
- * Above this value, it is assumed that "tradtional" Bernoulli sampling is faster. The
+ * Above this value, it is assumed that "traditional" Bernoulli sampling is faster. The
* optimal value for this will depend on the RNG. More expensive RNGs will tend to make
* the optimal value higher. The most reliable way to determine this value for a new RNG
* is to experiment. When tuning for a new RNG, I would expect a value of 0.5 to be close
@@ -107,21 +113,13 @@ class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = fals
override def setSeed(seed: Long): Unit = rng.setSeed(seed)
- override def sample(items: Iterator[T]): Iterator[T] = {
+ override def sample(): Int = {
if (ub - lb <= 0.0) {
- if (complement) items else Iterator.empty
+ if (complement) 1 else 0
} else {
- if (complement) {
- items.filter { item => {
- val x = rng.nextDouble()
- (x < lb) || (x >= ub)
- }}
- } else {
- items.filter { item => {
- val x = rng.nextDouble()
- (x >= lb) && (x < ub)
- }}
- }
+ val x = rng.nextDouble()
+ val n = if ((x >= lb) && (x < ub)) 1 else 0
+ if (complement) 1 - n else n
}
}
@@ -155,15 +153,22 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T
override def setSeed(seed: Long): Unit = rng.setSeed(seed)
- override def sample(items: Iterator[T]): Iterator[T] = {
+ private lazy val gapSampling: GapSampling =
+ new GapSampling(fraction, rng, RandomSampler.rngEpsilon)
+
+ override def sample(): Int = {
if (fraction <= 0.0) {
- Iterator.empty
+ 0
} else if (fraction >= 1.0) {
- items
+ 1
} else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
- new GapSamplingIterator(items, fraction, rng, RandomSampler.rngEpsilon)
+ gapSampling.sample()
} else {
- items.filter { _ => rng.nextDouble() <= fraction }
+ if (rng.nextDouble() <= fraction) {
+ 1
+ } else {
+ 0
+ }
}
}
@@ -180,7 +185,7 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T
* @tparam T item type
*/
@DeveloperApi
-class PoissonSampler[T: ClassTag](
+class PoissonSampler[T](
fraction: Double,
useGapSamplingIfPossible: Boolean) extends RandomSampler[T, T] {
@@ -201,15 +206,29 @@ class PoissonSampler[T: ClassTag](
rngGap.setSeed(seed)
}
- override def sample(items: Iterator[T]): Iterator[T] = {
+ private lazy val gapSamplingReplacement =
+ new GapSamplingReplacement(fraction, rngGap, RandomSampler.rngEpsilon)
+
+ override def sample(): Int = {
if (fraction <= 0.0) {
- Iterator.empty
+ 0
} else if (useGapSamplingIfPossible &&
fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
- new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon)
+ gapSamplingReplacement.sample()
+ } else {
+ rng.sample()
+ }
+ }
+
+ override def sample(items: Iterator[T]): Iterator[T] = {
+ if (fraction <= 0.0) {
+ Iterator.empty
} else {
+ val useGapSampling = useGapSamplingIfPossible &&
+ fraction <= RandomSampler.defaultMaxGapSamplingFraction
+
items.flatMap { item =>
- val count = rng.sample()
+ val count = if (useGapSampling) gapSamplingReplacement.sample() else rng.sample()
if (count == 0) Iterator.empty else Iterator.fill(count)(item)
}
}
@@ -220,50 +239,36 @@ class PoissonSampler[T: ClassTag](
private[spark]
-class GapSamplingIterator[T: ClassTag](
- var data: Iterator[T],
+class GapSampling(
f: Double,
rng: Random = RandomSampler.newDefaultRNG,
- epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] {
+ epsilon: Double = RandomSampler.rngEpsilon) extends Serializable {
require(f > 0.0 && f < 1.0, s"Sampling fraction ($f) must reside on open interval (0, 1)")
require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0")
- /** implement efficient linear-sequence drop until Scala includes fix for jira SI-8835. */
- private val iterDrop: Int => Unit = {
- val arrayClass = Array.empty[T].iterator.getClass
- val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass
- data.getClass match {
- case `arrayClass` =>
- (n: Int) => { data = data.drop(n) }
- case `arrayBufferClass` =>
- (n: Int) => { data = data.drop(n) }
- case _ =>
- (n: Int) => {
- var j = 0
- while (j < n && data.hasNext) {
- data.next()
- j += 1
- }
- }
- }
- }
-
- override def hasNext: Boolean = data.hasNext
+ private val lnq = math.log1p(-f)
- override def next(): T = {
- val r = data.next()
- advance()
- r
+ /** Return 1 if the next item should be sampled. Otherwise, return 0. */
+ def sample(): Int = {
+ if (countForDropping > 0) {
+ countForDropping -= 1
+ 0
+ } else {
+ advance()
+ 1
+ }
}
- private val lnq = math.log1p(-f)
+ private var countForDropping: Int = 0
- /** skip elements that won't be sampled, according to geometric dist P(k) = (f)(1-f)^k. */
+ /**
+ * Decide the number of elements that won't be sampled,
+ * according to geometric dist P(k) = (f)(1-f)^k.
+ */
private def advance(): Unit = {
val u = math.max(rng.nextDouble(), epsilon)
- val k = (math.log(u) / lnq).toInt
- iterDrop(k)
+ countForDropping = (math.log(u) / lnq).toInt
}
/** advance to first sample as part of object construction. */
@@ -273,73 +278,24 @@ class GapSamplingIterator[T: ClassTag](
// work reliably.
}
+
private[spark]
-class GapSamplingReplacementIterator[T: ClassTag](
- var data: Iterator[T],
- f: Double,
- rng: Random = RandomSampler.newDefaultRNG,
- epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] {
+class GapSamplingReplacement(
+ val f: Double,
+ val rng: Random = RandomSampler.newDefaultRNG,
+ epsilon: Double = RandomSampler.rngEpsilon) extends Serializable {
require(f > 0.0, s"Sampling fraction ($f) must be > 0")
require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0")
- /** implement efficient linear-sequence drop until scala includes fix for jira SI-8835. */
- private val iterDrop: Int => Unit = {
- val arrayClass = Array.empty[T].iterator.getClass
- val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass
- data.getClass match {
- case `arrayClass` =>
- (n: Int) => { data = data.drop(n) }
- case `arrayBufferClass` =>
- (n: Int) => { data = data.drop(n) }
- case _ =>
- (n: Int) => {
- var j = 0
- while (j < n && data.hasNext) {
- data.next()
- j += 1
- }
- }
- }
- }
-
- /** current sampling value, and its replication factor, as we are sampling with replacement. */
- private var v: T = _
- private var rep: Int = 0
-
- override def hasNext: Boolean = data.hasNext || rep > 0
-
- override def next(): T = {
- val r = v
- rep -= 1
- if (rep <= 0) advance()
- r
- }
-
- /**
- * Skip elements with replication factor zero (i.e. elements that won't be sampled).
- * Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is
- * q is the probabililty of Poisson(0; f)
- */
- private def advance(): Unit = {
- val u = math.max(rng.nextDouble(), epsilon)
- val k = (math.log(u) / (-f)).toInt
- iterDrop(k)
- // set the value and replication factor for the next value
- if (data.hasNext) {
- v = data.next()
- rep = poissonGE1
- }
- }
-
- private val q = math.exp(-f)
+ protected val q = math.exp(-f)
/**
* Sample from Poisson distribution, conditioned such that the sampled value is >= 1.
* This is an adaptation from the algorithm for Generating Poisson distributed random variables:
* http://en.wikipedia.org/wiki/Poisson_distribution
*/
- private def poissonGE1: Int = {
+ protected def poissonGE1: Int = {
// simulate that the standard poisson sampling
// gave us at least one iteration, for a sample of >= 1
var pp = q + ((1.0 - q) * rng.nextDouble())
@@ -353,6 +309,28 @@ class GapSamplingReplacementIterator[T: ClassTag](
}
r
}
+ private var countForDropping: Int = 0
+
+ def sample(): Int = {
+ if (countForDropping > 0) {
+ countForDropping -= 1
+ 0
+ } else {
+ val r = poissonGE1
+ advance()
+ r
+ }
+ }
+
+ /**
+ * Skip elements with replication factor zero (i.e. elements that won't be sampled).
+ * Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is
+ * q is the probability of Poisson(0; f)
+ */
+ private def advance(): Unit = {
+ val u = math.max(rng.nextDouble(), epsilon)
+ countForDropping = (math.log(u) / (-f)).toInt
+ }
/** advance to first sample as part of object construction. */
advance()
diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
index c9a864ae6277..a7e0075debed 100644
--- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
@@ -34,7 +34,7 @@ private[spark] object SamplingUtils {
input: Iterator[T],
k: Int,
seed: Long = Random.nextLong())
- : (Array[T], Int) = {
+ : (Array[T], Long) = {
val reservoir = new Array[T](k)
// Put the first k elements in the reservoir.
var i = 0
@@ -52,31 +52,37 @@ private[spark] object SamplingUtils {
(trimReservoir, i)
} else {
// If input size > k, continue the sampling process.
+ var l = i.toLong
val rand = new XORShiftRandom(seed)
while (input.hasNext) {
val item = input.next()
- val replacementIndex = rand.nextInt(i)
+ l += 1
+ // There are k elements in the reservoir, and the l-th element has been
+ // consumed. It should be chosen with probability k/l. The expression
+ // below is a random long chosen uniformly from [0,l)
+ val replacementIndex = (rand.nextDouble() * l).toLong
if (replacementIndex < k) {
- reservoir(replacementIndex) = item
+ reservoir(replacementIndex.toInt) = item
}
- i += 1
}
- (reservoir, i)
+ (reservoir, l)
}
}
/**
- * Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of
- * the time.
+ * Returns a sampling rate that guarantees a sample of size greater than or equal to
+ * sampleSizeLowerBound 99.99% of the time.
*
* How the sampling rate is determined:
+ *
* Let p = num / total, where num is the sample size and total is the total number of
- * datapoints in the RDD. We're trying to compute q > p such that
+ * datapoints in the RDD. We're trying to compute q {@literal >} p such that
* - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q),
- * where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to total),
- * i.e. the failure rate of not having a sufficiently large sample < 0.0001.
+ * where we want to guarantee
+ * Pr[s {@literal <} num] {@literal <} 0.0001 for s = sum(prob_i for i from 0 to total),
+ * i.e. the failure rate of not having a sufficiently large sample {@literal <} 0.0001.
* Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for
- * num > 12, but we need a slightly larger q (9 empirically determined).
+ * num {@literal >} 12, but we need a slightly larger q (9 empirically determined).
* - when sampling without replacement, we're drawing each datapoint with prob_i
* ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success
* rate, where success rate is defined the same as in sampling with replacement.
@@ -107,14 +113,14 @@ private[spark] object SamplingUtils {
private[spark] object PoissonBounds {
/**
- * Returns a lambda such that Pr[X > s] is very small, where X ~ Pois(lambda).
+ * Returns a lambda such that Pr[X {@literal >} s] is very small, where X ~ Pois(lambda).
*/
def getLowerBound(s: Double): Double = {
math.max(s - numStd(s) * math.sqrt(s), 1e-15)
}
/**
- * Returns a lambda such that Pr[X < s] is very small, where X ~ Pois(lambda).
+ * Returns a lambda such that Pr[X {@literal <} s] is very small, where X ~ Pois(lambda).
*
* @param s sample size
*/
diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala
index effe6fa2adcf..ce46fc8f201b 100644
--- a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala
@@ -24,7 +24,7 @@ import scala.reflect.ClassTag
import org.apache.commons.math3.distribution.PoissonDistribution
-import org.apache.spark.Logging
+import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
/**
@@ -35,13 +35,14 @@ import org.apache.spark.rdd.RDD
* high probability. This is achieved by maintaining a waitlist of size O(log(s)), where s is the
* desired sample size for each stratum.
*
- * Like in simple random sampling, we generate a random value for each item from the
- * uniform distribution [0.0, 1.0]. All items with values <= min(values of items in the waitlist)
- * are accepted into the sample instantly. The threshold for instant accept is designed so that
- * s - numAccepted = O(sqrt(s)), where s is again the desired sample size. Thus, by maintaining a
- * waitlist size = O(sqrt(s)), we will be able to create a sample of the exact size s by adding
- * a portion of the waitlist to the set of items that are instantly accepted. The exact threshold
- * is computed by sorting the values in the waitlist and picking the value at (s - numAccepted).
+ * Like in simple random sampling, we generate a random value for each item from the uniform
+ * distribution [0.0, 1.0]. All items with values less than or equal to min(values of items in the
+ * waitlist) are accepted into the sample instantly. The threshold for instant accept is designed
+ * so that s - numAccepted = O(sqrt(s)), where s is again the desired sample size. Thus, by
+ * maintaining a waitlist size = O(sqrt(s)), we will be able to create a sample of the exact size
+ * s by adding a portion of the waitlist to the set of items that are instantly accepted. The exact
+ * threshold is computed by sorting the values in the waitlist and picking the value at
+ * (s - numAccepted).
*
* Note that since we use the same seed for the RNG when computing the thresholds and the actual
* sample, our computed thresholds are guaranteed to produce the desired sample size.
@@ -160,12 +161,20 @@ private[spark] object StratifiedSamplingUtils extends Logging {
*
* To do so, we compute sampleSize = math.ceil(size * samplingRate) for each stratum and compare
* it to the number of items that were accepted instantly and the number of items in the waitlist
- * for that stratum. Most of the time, numAccepted <= sampleSize <= (numAccepted + numWaitlisted),
+ * for that stratum.
+ *
+ * Most of the time,
+ * {{{
+ * numAccepted <= sampleSize <= (numAccepted + numWaitlisted)
+ * }}}
* which means we need to sort the elements in the waitlist by their associated values in order
- * to find the value T s.t. |{elements in the stratum whose associated values <= T}| = sampleSize.
- * Note that all elements in the waitlist have values >= bound for instant accept, so a T value
- * in the waitlist range would allow all elements that were instantly accepted on the first pass
- * to be included in the sample.
+ * to find the value T s.t.
+ * {{{
+ * |{elements in the stratum whose associated values <= T}| = sampleSize
+ * }}}.
+ * Note that all elements in the waitlist have values greater than or equal to bound for instant
+ * accept, so a T value in the waitlist range would allow all elements that were instantly
+ * accepted on the first pass to be included in the sample.
*/
def computeThresholdByKey[K](finalResult: Map[K, AcceptanceResult],
fractions: Map[K, Double]): Map[K, Double] = {
diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
index 85fb923cd9bc..e8cdb6e98bf3 100644
--- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
@@ -60,9 +60,11 @@ private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) {
private[spark] object XORShiftRandom {
/** Hash seeds to have 0/1 bits throughout. */
- private def hashSeed(seed: Long): Long = {
+ private[random] def hashSeed(seed: Long): Long = {
val bytes = ByteBuffer.allocate(java.lang.Long.SIZE).putLong(seed).array()
- MurmurHash3.bytesHash(bytes)
+ val lowBits = MurmurHash3.bytesHash(bytes)
+ val highBits = MurmurHash3.bytesHash(bytes, lowBits)
+ (highBits.toLong << 32) | (lowBits.toLong & 0xFFFFFFFFL)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/util/random/package-info.java b/core/src/main/scala/org/apache/spark/util/random/package-info.java
index 62c3762dd11b..e4f0c0febbbb 100644
--- a/core/src/main/scala/org/apache/spark/util/random/package-info.java
+++ b/core/src/main/scala/org/apache/spark/util/random/package-info.java
@@ -18,4 +18,4 @@
/**
* Utilities for random number generation.
*/
-package org.apache.spark.util.random;
\ No newline at end of file
+package org.apache.spark.util.random;
diff --git a/core/src/main/scala/org/apache/spark/util/taskListeners.scala b/core/src/main/scala/org/apache/spark/util/taskListeners.scala
new file mode 100644
index 000000000000..51feccfb8342
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/taskListeners.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util.EventListener
+
+import org.apache.spark.TaskContext
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * :: DeveloperApi ::
+ *
+ * Listener providing a callback function to invoke when a task's execution completes.
+ */
+@DeveloperApi
+trait TaskCompletionListener extends EventListener {
+ def onTaskCompletion(context: TaskContext): Unit
+}
+
+
+/**
+ * :: DeveloperApi ::
+ *
+ * Listener providing a callback function to invoke when a task's execution encounters an error.
+ * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times.
+ */
+@DeveloperApi
+trait TaskFailureListener extends EventListener {
+ def onTaskFailure(context: TaskContext, error: Throwable): Unit
+}
+
+
+/**
+ * Exception thrown when there is an exception in executing the callback in TaskCompletionListener.
+ */
+private[spark]
+class TaskCompletionListenerException(
+ errorMessages: Seq[String],
+ val previousError: Option[Throwable] = None)
+ extends RuntimeException {
+
+ override def getMessage: String = {
+ val listenerErrorMessage =
+ if (errorMessages.size == 1) {
+ errorMessages.head
+ } else {
+ errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n")
+ }
+ val previousErrorMessage = previousError.map { e =>
+ "\n\nPrevious exception in task: " + e.getMessage + "\n" +
+ e.getStackTrace.mkString("\t", "\n\t", "")
+ }.getOrElse("")
+ listenerErrorMessage + previousErrorMessage
+ }
+}
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
deleted file mode 100644
index fd8f7f39b7cc..000000000000
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ /dev/null
@@ -1,1812 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark;
-
-import java.io.*;
-import java.nio.channels.FileChannel;
-import java.nio.ByteBuffer;
-import java.net.URI;
-import java.util.*;
-import java.util.concurrent.*;
-
-import scala.Tuple2;
-import scala.Tuple3;
-import scala.Tuple4;
-import scala.collection.JavaConverters;
-
-import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.Iterables;
-import com.google.common.collect.Iterators;
-import com.google.common.collect.Lists;
-import com.google.common.collect.Maps;
-import com.google.common.base.Throwables;
-import com.google.common.base.Optional;
-import com.google.common.base.Charsets;
-import com.google.common.io.Files;
-import org.apache.hadoop.io.IntWritable;
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.io.compress.DefaultCodec;
-import org.apache.hadoop.mapred.SequenceFileInputFormat;
-import org.apache.hadoop.mapred.SequenceFileOutputFormat;
-import org.apache.hadoop.mapreduce.Job;
-import org.junit.After;
-import org.junit.Assert;
-import org.junit.Before;
-import org.junit.Test;
-
-import org.apache.spark.api.java.*;
-import org.apache.spark.api.java.function.*;
-import org.apache.spark.input.PortableDataStream;
-import org.apache.spark.partial.BoundedDouble;
-import org.apache.spark.partial.PartialResult;
-import org.apache.spark.rdd.RDD;
-import org.apache.spark.serializer.KryoSerializer;
-import org.apache.spark.storage.StorageLevel;
-import org.apache.spark.util.StatCounter;
-
-// The test suite itself is Serializable so that anonymous Function implementations can be
-// serialized, as an alternative to converting these anonymous classes to static inner classes;
-// see http://stackoverflow.com/questions/758570/.
-public class JavaAPISuite implements Serializable {
- private transient JavaSparkContext sc;
- private transient File tempDir;
-
- @Before
- public void setUp() {
- sc = new JavaSparkContext("local", "JavaAPISuite");
- tempDir = Files.createTempDir();
- tempDir.deleteOnExit();
- }
-
- @After
- public void tearDown() {
- sc.stop();
- sc = null;
- }
-
- @SuppressWarnings("unchecked")
- @Test
- public void sparkContextUnion() {
- // Union of non-specialized JavaRDDs
- List strings = Arrays.asList("Hello", "World");
- JavaRDD s1 = sc.parallelize(strings);
- JavaRDD s2 = sc.parallelize(strings);
- // Varargs
- JavaRDD sUnion = sc.union(s1, s2);
- Assert.assertEquals(4, sUnion.count());
- // List
- List> list = new ArrayList<>();
- list.add(s2);
- sUnion = sc.union(s1, list);
- Assert.assertEquals(4, sUnion.count());
-
- // Union of JavaDoubleRDDs
- List doubles = Arrays.asList(1.0, 2.0);
- JavaDoubleRDD d1 = sc.parallelizeDoubles(doubles);
- JavaDoubleRDD d2 = sc.parallelizeDoubles(doubles);
- JavaDoubleRDD dUnion = sc.union(d1, d2);
- Assert.assertEquals(4, dUnion.count());
-
- // Union of JavaPairRDDs
- List> pairs = new ArrayList<>();
- pairs.add(new Tuple2<>(1, 2));
- pairs.add(new Tuple2<>(3, 4));
- JavaPairRDD p1 = sc.parallelizePairs(pairs);
- JavaPairRDD p2 = sc.parallelizePairs(pairs);
- JavaPairRDD pUnion = sc.union(p1, p2);
- Assert.assertEquals(4, pUnion.count());
- }
-
- @SuppressWarnings("unchecked")
- @Test
- public void intersection() {
- List ints1 = Arrays.asList(1, 10, 2, 3, 4, 5);
- List ints2 = Arrays.asList(1, 6, 2, 3, 7, 8);
- JavaRDD s1 = sc.parallelize(ints1);
- JavaRDD s2 = sc.parallelize(ints2);
-
- JavaRDD intersections = s1.intersection(s2);
- Assert.assertEquals(3, intersections.count());
-
- JavaRDD empty = sc.emptyRDD();
- JavaRDD emptyIntersection = empty.intersection(s2);
- Assert.assertEquals(0, emptyIntersection.count());
-
- List doubles = Arrays.asList(1.0, 2.0);
- JavaDoubleRDD d1 = sc.parallelizeDoubles(doubles);
- JavaDoubleRDD d2 = sc.parallelizeDoubles(doubles);
- JavaDoubleRDD dIntersection = d1.intersection(d2);
- Assert.assertEquals(2, dIntersection.count());
-
- List> pairs = new ArrayList<>();
- pairs.add(new Tuple2<>(1, 2));
- pairs.add(new Tuple2<>(3, 4));
- JavaPairRDD p1 = sc.parallelizePairs(pairs);
- JavaPairRDD p2 = sc.parallelizePairs(pairs);
- JavaPairRDD pIntersection = p1.intersection(p2);
- Assert.assertEquals(2, pIntersection.count());
- }
-
- @Test
- public void sample() {
- List ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
- JavaRDD rdd = sc.parallelize(ints);
- JavaRDD sample20 = rdd.sample(true, 0.2, 3);
- Assert.assertEquals(2, sample20.count());
- JavaRDD sample20WithoutReplacement = rdd.sample(false, 0.2, 5);
- Assert.assertEquals(2, sample20WithoutReplacement.count());
- }
-
- @Test
- public void randomSplit() {
- List ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
- JavaRDD rdd = sc.parallelize(ints);
- JavaRDD[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 31);
- Assert.assertEquals(3, splits.length);
- Assert.assertEquals(1, splits[0].count());
- Assert.assertEquals(2, splits[1].count());
- Assert.assertEquals(7, splits[2].count());
- }
-
- @Test
- public void sortByKey() {
- List> pairs = new ArrayList<>();
- pairs.add(new Tuple2<>(0, 4));
- pairs.add(new Tuple2<>(3, 2));
- pairs.add(new Tuple2<>(-1, 1));
-
- JavaPairRDD rdd = sc.parallelizePairs(pairs);
-
- // Default comparator
- JavaPairRDD sortedRDD = rdd.sortByKey();
- Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first());
- List> sortedPairs = sortedRDD.collect();
- Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1));
- Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2));
-
- // Custom comparator
- sortedRDD = rdd.sortByKey(Collections.reverseOrder(), false);
- Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first());
- sortedPairs = sortedRDD.collect();
- Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1));
- Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2));
- }
-
- @SuppressWarnings("unchecked")
- @Test
- public void repartitionAndSortWithinPartitions() {
- List> pairs = new ArrayList<>();
- pairs.add(new Tuple2<>(0, 5));
- pairs.add(new Tuple2<>(3, 8));
- pairs.add(new Tuple2<>(2, 6));
- pairs.add(new Tuple2<>(0, 8));
- pairs.add(new Tuple2<>(3, 8));
- pairs.add(new Tuple2<>(1, 3));
-
- JavaPairRDD rdd = sc.parallelizePairs(pairs);
-
- Partitioner partitioner = new Partitioner() {
- @Override
- public int numPartitions() {
- return 2;
- }
- @Override
- public int getPartition(Object key) {
- return (Integer) key % 2;
- }
- };
-
- JavaPairRDD repartitioned =
- rdd.repartitionAndSortWithinPartitions(partitioner);
- Assert.assertTrue(repartitioned.partitioner().isPresent());
- Assert.assertEquals(repartitioned.partitioner().get(), partitioner);
- List>> partitions = repartitioned.glom().collect();
- Assert.assertEquals(partitions.get(0),
- Arrays.asList(new Tuple2<>(0, 5), new Tuple2<>(0, 8), new Tuple2<>(2, 6)));
- Assert.assertEquals(partitions.get(1),
- Arrays.asList(new Tuple2<>(1, 3), new Tuple2<>(3, 8), new Tuple2<>(3, 8)));
- }
-
- @Test
- public void emptyRDD() {
- JavaRDD