diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index f13c24ae5e017..1321b83181150 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -46,18 +46,22 @@ public final class Platform { private static final boolean unaligned; static { boolean _unaligned; - // use reflection to access unaligned field - try { - Class bitsClass = - Class.forName("java.nio.Bits", false, ClassLoader.getSystemClassLoader()); - Method unalignedMethod = bitsClass.getDeclaredMethod("unaligned"); - unalignedMethod.setAccessible(true); - _unaligned = Boolean.TRUE.equals(unalignedMethod.invoke(null)); - } catch (Throwable t) { - // We at least know x86 and x64 support unaligned access. - String arch = System.getProperty("os.arch", ""); - //noinspection DynamicRegexReplaceableByCompiledPattern - _unaligned = arch.matches("^(i[3-6]86|x86(_64)?|x64|amd64|aarch64)$"); + String arch = System.getProperty("os.arch", ""); + if (arch.equals("ppc64le") || arch.equals("ppc64")) { + // Since java.nio.Bits.unaligned() doesn't return true on ppc (See JDK-8165231), but ppc64 and ppc64le support it + _unaligned = true; + } else { + try { + Class bitsClass = + Class.forName("java.nio.Bits", false, ClassLoader.getSystemClassLoader()); + Method unalignedMethod = bitsClass.getDeclaredMethod("unaligned"); + unalignedMethod.setAccessible(true); + _unaligned = Boolean.TRUE.equals(unalignedMethod.invoke(null)); + } catch (Throwable t) { + // We at least know x86 and x64 support unaligned access. + //noinspection DynamicRegexReplaceableByCompiledPattern + _unaligned = arch.matches("^(i[3-6]86|x86(_64)?|x64|amd64|aarch64)$"); + } } unaligned = _unaligned; } diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index fc1f3a80239ba..48cf4b9455e4d 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -60,8 +60,6 @@ protected long getUsed() { /** * Force spill during building. - * - * For testing. */ public void spill() throws IOException { spill(Long.MAX_VALUE, this); diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 39fb3b249d731..aa0b373231327 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -155,11 +155,7 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { for (MemoryConsumer c: consumers) { if (c != consumer && c.getUsed() > 0 && c.getMode() == mode) { long key = c.getUsed(); - List list = sortedConsumers.get(key); - if (list == null) { - list = new ArrayList<>(1); - sortedConsumers.put(key, list); - } + List list = sortedConsumers.computeIfAbsent(key, k -> new ArrayList<>(1)); list.add(c); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 4a15559e55cbd..323a5d3c52831 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -52,8 +52,7 @@ * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path * writes incoming records to separate files, one file per reduce partition, then concatenates these * per-partition files to form a single output file, regions of which are served to reducers. - * Records are not buffered in memory. This is essentially identical to - * {@link org.apache.spark.shuffle.hash.HashShuffleWriter}, except that it writes output in a format + * Records are not buffered in memory. It writes output in a format * that can be served / consumed via {@link org.apache.spark.shuffle.IndexShuffleBlockResolver}. *

* This write path is inefficient for shuffles with large numbers of reduce partitions because it @@ -61,7 +60,7 @@ * {@link SortShuffleManager} only selects this write path when *

diff --git a/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java b/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java index 9307eb93a5b20..b38639e854815 100644 --- a/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java +++ b/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java @@ -19,6 +19,7 @@ import org.apache.spark.util.EnumUtil; +import java.util.Collections; import java.util.HashSet; import java.util.Set; @@ -30,9 +31,7 @@ public enum TaskSorting { private final Set alternateNames; TaskSorting(String... names) { alternateNames = new HashSet<>(); - for (String n: names) { - alternateNames.add(n); - } + Collections.addAll(alternateNames, names); } public static TaskSorting fromString(String str) { diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala index e4b9f8111efca..9112d93a86b2a 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala @@ -71,13 +71,12 @@ private[spark] trait ExecutorAllocationClient { /** * Request that the cluster manager kill every executor on the specified host. - * Results in a call to killExecutors for each executor on the host, with the replace - * and force arguments set to true. + * * @return whether the request is acknowledged by the cluster manager. */ def killExecutorsOnHost(host: String): Boolean - /** + /** * Request that the cluster manager kill the specified executor. * @return whether the request is acknowledged by the cluster manager. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 0614d80b60e1c..0144fd1056bac 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -190,6 +190,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S .orNull numExecutors = Option(numExecutors) .getOrElse(sparkProperties.get("spark.executor.instances").orNull) + queue = Option(queue).orElse(sparkProperties.get("spark.yarn.queue")).orNull keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index 2394cf361c33a..7efa9416362a0 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -121,6 +121,13 @@ abstract class FileCommitProtocol { def deleteWithJob(fs: FileSystem, path: Path, recursive: Boolean): Boolean = { fs.delete(path, recursive) } + + /** + * Called on the driver after a task commits. This can be used to access task commit messages + * before the job has finished. These same task commit messages will be passed to commitJob() + * if the entire job succeeds. + */ + def onTaskCommit(taskCommit: TaskCommitMessage): Unit = {} } diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 2e991ce394c42..c216fe477fd15 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -71,8 +71,7 @@ private[spark] object CompressionCodec { val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf]) Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec]) } catch { - case e: ClassNotFoundException => None - case e: IllegalArgumentException => None + case _: ClassNotFoundException | _: IllegalArgumentException => None } codec.getOrElse(throw new IllegalArgumentException(s"Codec [$codecName] is not available. " + s"Consider setting $configKey=$FALLBACK_COMPRESSION_CODEC")) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 46ef23f316a61..7fd2918960cd0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -149,7 +149,7 @@ private[spark] abstract class Task[T]( def preferredLocations: Seq[TaskLocation] = Nil - // Map output tracker epoch. Will be set by TaskScheduler. + // Map output tracker epoch. Will be set by TaskSetManager. var epoch: Long = -1 // Task context, to be initialized in run(). diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index 008b0387899f6..01bbda0b5e6b3 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -77,7 +77,7 @@ abstract class Serializer { * position = 0 * serOut.write(obj1) * serOut.flush() - * position = # of bytes writen to stream so far + * position = # of bytes written to stream so far * obj1Bytes = output[0:position-1] * serOut.write(obj2) * serOut.flush() diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 8b2e26cdd94fb..ba3e0e395e958 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -95,8 +95,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( // Sort the output if there is a sort ordering defined. dep.keyOrdering match { case Some(keyOrd: Ordering[K]) => - // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled, - // the ExternalSorter won't spill to disk. + // Create an ExternalSorter to sort the data. val sorter = new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer) sorter.insertAll(aggregatedIter) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 91858f0912b65..15540485170d0 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -61,7 +61,7 @@ private[spark] class IndexShuffleBlockResolver( /** * Remove data file and index file that contain the output data from one map. - * */ + */ def removeDataByMap(shuffleId: Int, mapId: Int): Unit = { var file = getDataFile(shuffleId, mapId) if (file.exists()) { @@ -132,7 +132,7 @@ private[spark] class IndexShuffleBlockResolver( * replace them with new ones. * * Note: the `lengths` will be updated to match the existing index file if use the existing ones. - * */ + */ def writeIndexFileAndCommit( shuffleId: Int, mapId: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 5e977a16febe1..bfb4dc698e325 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -82,13 +82,13 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) /** - * Register a shuffle with the manager and obtain a handle for it to pass to tasks. + * Obtains a [[ShuffleHandle]] to pass to tasks. */ override def registerShuffle[K, V, C]( shuffleId: Int, numMaps: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf, dependency)) { + if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) { // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't // need map-side aggregation, then write numPartitions files directly and just concatenate // them at the end. This avoids doing serialization and deserialization twice to merge diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index fcda9fa65303a..46a078b2f9f93 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -49,7 +49,6 @@ import org.apache.spark.unsafe.Platform import org.apache.spark.util._ import org.apache.spark.util.io.ChunkedByteBuffer - /* Class for returning a fetched block and associated metrics. */ private[spark] class BlockResult( val data: Iterator[Any], @@ -1258,7 +1257,6 @@ private[spark] class BlockManager( replication = 1) val numPeersToReplicateTo = level.replication - 1 - val startTime = System.nanoTime var peersReplicatedTo = mutable.HashSet.empty ++ existingReplicas @@ -1313,7 +1311,6 @@ private[spark] class BlockManager( numPeersToReplicateTo - peersReplicatedTo.size) } } - logDebug(s"Replicating $blockId of ${data.size} bytes to " + s"${peersReplicatedTo.size} peer(s) took ${(System.nanoTime - startTime) / 1e6} ms") if (peersReplicatedTo.size < numPeersToReplicateTo) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala index bb8a684b4c7a8..353eac60df171 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala @@ -53,6 +53,46 @@ trait BlockReplicationPolicy { numReplicas: Int): List[BlockManagerId] } +object BlockReplicationUtils { + // scalastyle:off line.size.limit + /** + * Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while + * minimizing space usage. Please see + * here. + * + * @param n total number of indices + * @param m number of samples needed + * @param r random number generator + * @return list of m random unique indices + */ + // scalastyle:on line.size.limit + private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = { + val indices = (n - m + 1 to n).foldLeft(mutable.LinkedHashSet.empty[Int]) {case (set, i) => + val t = r.nextInt(i) + 1 + if (set.contains(t)) set + i else set + t + } + indices.map(_ - 1).toList + } + + /** + * Get a random sample of size m from the elems + * + * @param elems + * @param m number of samples needed + * @param r random number generator + * @tparam T + * @return a random list of size m. If there are fewer than m elements in elems, we just + * randomly shuffle elems + */ + def getRandomSample[T](elems: Seq[T], m: Int, r: Random): List[T] = { + if (elems.size > m) { + getSampleIds(elems.size, m, r).map(elems(_)) + } else { + r.shuffle(elems).toList + } + } +} + @DeveloperApi class RandomBlockReplicationPolicy extends BlockReplicationPolicy @@ -67,6 +107,7 @@ class RandomBlockReplicationPolicy * @param peersReplicatedTo Set of peers already replicated to * @param blockId BlockId of the block being replicated. This can be used as a source of * randomness if needed. + * @param numReplicas Number of peers we need to replicate to * @return A prioritized list of peers. Lower the index of a peer, higher its priority */ override def prioritize( @@ -78,7 +119,7 @@ class RandomBlockReplicationPolicy val random = new Random(blockId.hashCode) logDebug(s"Input peers : ${peers.mkString(", ")}") val prioritizedPeers = if (peers.size > numReplicas) { - getSampleIds(peers.size, numReplicas, random).map(peers(_)) + BlockReplicationUtils.getRandomSample(peers, numReplicas, random) } else { if (peers.size < numReplicas) { logWarning(s"Expecting ${numReplicas} replicas with only ${peers.size} peer/s.") @@ -88,26 +129,96 @@ class RandomBlockReplicationPolicy logDebug(s"Prioritized peers : ${prioritizedPeers.mkString(", ")}") prioritizedPeers } +} + +@DeveloperApi +class BasicBlockReplicationPolicy + extends BlockReplicationPolicy + with Logging { - // scalastyle:off line.size.limit /** - * Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while - * minimizing space usage. Please see - * here. + * Method to prioritize a bunch of candidate peers of a block manager. This implementation + * replicates the behavior of block replication in HDFS. For a given number of replicas needed, + * we choose a peer within the rack, one outside and remaining blockmanagers are chosen at + * random, in that order till we meet the number of replicas needed. + * This works best with a total replication factor of 3, like HDFS. * - * @param n total number of indices - * @param m number of samples needed - * @param r random number generator - * @return list of m random unique indices + * @param blockManagerId Id of the current BlockManager for self identification + * @param peers A list of peers of a BlockManager + * @param peersReplicatedTo Set of peers already replicated to + * @param blockId BlockId of the block being replicated. This can be used as a source of + * randomness if needed. + * @param numReplicas Number of peers we need to replicate to + * @return A prioritized list of peers. Lower the index of a peer, higher its priority */ - // scalastyle:on line.size.limit - private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = { - val indices = (n - m + 1 to n).foldLeft(Set.empty[Int]) {case (set, i) => - val t = r.nextInt(i) + 1 - if (set.contains(t)) set + i else set + t + override def prioritize( + blockManagerId: BlockManagerId, + peers: Seq[BlockManagerId], + peersReplicatedTo: mutable.HashSet[BlockManagerId], + blockId: BlockId, + numReplicas: Int): List[BlockManagerId] = { + + logDebug(s"Input peers : $peers") + logDebug(s"BlockManagerId : $blockManagerId") + + val random = new Random(blockId.hashCode) + + // if block doesn't have topology info, we can't do much, so we randomly shuffle + // if there is, we see what's needed from peersReplicatedTo and based on numReplicas, + // we choose whats needed + if (blockManagerId.topologyInfo.isEmpty || numReplicas == 0) { + // no topology info for the block. The best we can do is randomly choose peers + BlockReplicationUtils.getRandomSample(peers, numReplicas, random) + } else { + // we have topology information, we see what is left to be done from peersReplicatedTo + val doneWithinRack = peersReplicatedTo.exists(_.topologyInfo == blockManagerId.topologyInfo) + val doneOutsideRack = peersReplicatedTo.exists { p => + p.topologyInfo.isDefined && p.topologyInfo != blockManagerId.topologyInfo + } + + if (doneOutsideRack && doneWithinRack) { + // we are done, we just return a random sample + BlockReplicationUtils.getRandomSample(peers, numReplicas, random) + } else { + // we separate peers within and outside rack + val (inRackPeers, outOfRackPeers) = peers + .filter(_.host != blockManagerId.host) + .partition(_.topologyInfo == blockManagerId.topologyInfo) + + val peerWithinRack = if (doneWithinRack) { + // we are done with in-rack replication, so don't need anymore peers + Seq.empty + } else { + if (inRackPeers.isEmpty) { + Seq.empty + } else { + Seq(inRackPeers(random.nextInt(inRackPeers.size))) + } + } + + val peerOutsideRack = if (doneOutsideRack || numReplicas - peerWithinRack.size <= 0) { + Seq.empty + } else { + if (outOfRackPeers.isEmpty) { + Seq.empty + } else { + Seq(outOfRackPeers(random.nextInt(outOfRackPeers.size))) + } + } + + val priorityPeers = peerWithinRack ++ peerOutsideRack + val numRemainingPeers = numReplicas - priorityPeers.size + val remainingPeers = if (numRemainingPeers > 0) { + val rPeers = peers.filter(p => !priorityPeers.contains(p)) + BlockReplicationUtils.getRandomSample(rPeers, numRemainingPeers, random) + } else { + Seq.empty + } + + (priorityPeers ++ remainingPeers).toList + } + } - // we shuffle the result to ensure a random arrangement within the sample - // to avoid any bias from set implementations - r.shuffle(indices.map(_ - 1).toList) } + } diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index a9480cc220c8d..8b75f5d8fe1a8 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -124,7 +124,7 @@ private[spark] abstract class WebUI( /** Bind to the HTTP server behind this web interface. */ def bind(): Unit = { - assert(!serverInfo.isDefined, s"Attempted to bind $className more than once!") + assert(serverInfo.isEmpty, s"Attempted to bind $className more than once!") try { val host = Option(conf.getenv("SPARK_LOCAL_IP")).getOrElse("0.0.0.0") serverInfo = Some(startJettyServer(host, port, sslOptions, handlers, conf, name)) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index c6a07445f2a35..dbcc6402bc309 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -49,7 +49,7 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage }.map { thread => val threadId = thread.threadId val blockedBy = thread.blockedByThreadId match { - case Some(blockedByThreadId) => + case Some(_) =>
Blocked by Thread {thread.blockedByThreadId} {thread.blockedByLock} diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 2d1691e55c428..d849ce76a9e3c 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -48,7 +48,6 @@ private[ui] class ExecutorsPage( parent: ExecutorsTab, threadDumpEnabled: Boolean) extends WebUIPage("") { - private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { val content = @@ -59,7 +58,7 @@ private[ui] class ExecutorsPage( ++ } -
; + UIUtils.headerSparkPage("Executors", content, parent, useDataTables = true) } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 8ae712f8ed323..03851293eb2f1 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -64,7 +64,7 @@ private[ui] case class ExecutorTaskSummary( @DeveloperApi class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: SparkConf) extends SparkListener { - var executorToTaskSummary = LinkedHashMap[String, ExecutorTaskSummary]() + val executorToTaskSummary = LinkedHashMap[String, ExecutorTaskSummary]() var executorEvents = new ListBuffer[SparkListenerEvent]() private val maxTimelineExecutors = conf.getInt("spark.ui.timeline.executors.maximum", 1000) @@ -137,7 +137,7 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: Spar // could have failed half-way through. The correct fix would be to keep track of the // metrics added by each attempt, but this is much more complicated. return - case e: ExceptionFailure => + case _: ExceptionFailure => taskSummary.tasksFailed += 1 case _ => taskSummary.tasksComplete += 1 diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index fe6ca1099e6b0..2b0816e35747d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -34,9 +34,9 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { listener.synchronized { val activeStages = listener.activeStages.values.toSeq val pendingStages = listener.pendingStages.values.toSeq - val completedStages = listener.completedStages.reverse.toSeq + val completedStages = listener.completedStages.reverse val numCompletedStages = listener.numCompletedStages - val failedStages = listener.failedStages.reverse.toSeq + val failedStages = listener.failedStages.reverse val numFailedStages = listener.numFailedStages val subPath = "stages" diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 52f41298a1729..382a6f979f2e6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -133,9 +133,9 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage {executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")} {UIUtils.formatDuration(v.taskTime)} - {v.failedTasks + v.succeededTasks + v.reasonToNumKilled.map(_._2).sum} + {v.failedTasks + v.succeededTasks + v.reasonToNumKilled.values.sum} {v.failedTasks} - {v.reasonToNumKilled.map(_._2).sum} + {v.reasonToNumKilled.values.sum} {v.succeededTasks} {if (stageData.hasInput) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 1cf03e1541d14..f78db5ab80d15 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -226,7 +226,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { trimJobsIfNecessary(completedJobs) jobData.status = JobExecutionStatus.SUCCEEDED numCompletedJobs += 1 - case JobFailed(exception) => + case JobFailed(_) => failedJobs += jobData trimJobsIfNecessary(failedJobs) jobData.status = JobExecutionStatus.FAILED @@ -284,7 +284,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { ) { jobData.numActiveStages -= 1 if (stage.failureReason.isEmpty) { - if (!stage.submissionTime.isEmpty) { + if (stage.submissionTime.isDefined) { jobData.completedStageIndices.add(stage.stageId) } } else { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index ff17775008acc..19325a2dc9169 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -142,7 +142,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val allAccumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables val externalAccumulables = allAccumulables.values.filter { acc => !acc.internal } - val hasAccumulators = externalAccumulables.size > 0 + val hasAccumulators = externalAccumulables.nonEmpty val summary =
@@ -339,7 +339,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.metrics.isDefined) val summaryTable: Option[Seq[Node]] = - if (validTasks.size == 0) { + if (validTasks.isEmpty) { None } else { @@ -786,8 +786,8 @@ private[ui] object StagePage { info: TaskInfo, metrics: TaskMetricsUIData, currentTime: Long): Long = { if (info.finished) { val totalExecutionTime = info.finishTime - info.launchTime - val executorOverhead = (metrics.executorDeserializeTime + - metrics.resultSerializationTime) + val executorOverhead = metrics.executorDeserializeTime + + metrics.resultSerializationTime math.max( 0, totalExecutionTime - metrics.executorRunTime - executorOverhead - @@ -872,7 +872,7 @@ private[ui] class TaskDataSource( // so that we can avoid creating duplicate contents during sorting the data private val data = tasks.map(taskRow).sorted(ordering(sortColumn, desc)) - private var _slicedTaskIds: Set[Long] = null + private var _slicedTaskIds: Set[Long] = _ override def dataSize: Int = data.size diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index f4caad0f58715..256b726fa7eea 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -412,7 +412,7 @@ private[ui] class StageDataSource( // so that we can avoid creating duplicate contents during sorting the data private val data = stages.map(stageRow).sorted(ordering(sortColumn, desc)) - private var _slicedStageIds: Set[Int] = null + private var _slicedStageIds: Set[Int] = _ override def dataSize: Int = data.size diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 76d7c6d414bcf..aa84788f1df88 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -151,7 +151,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { /** Render a stream block */ private def streamBlockTableRow(block: (BlockId, Seq[BlockUIData])): Seq[Node] = { val replications = block._2 - assert(replications.size > 0) // This must be true because it's the result of "groupBy" + assert(replications.nonEmpty) // This must be true because it's the result of "groupBy" if (replications.size == 1) { streamBlockTableSubrow(block._1, replications.head, replications.size, true) } else { diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 00e0cf257cd4a..7479de55140ea 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -279,7 +279,7 @@ private[spark] object AccumulatorContext { /** - * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for 64-bit integers. + * An [[AccumulatorV2 accumulator]] for computing sum, count, and average of 64-bit integers. * * @since 2.0.0 */ diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index a591b98bca488..7c2ec01a03d04 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -148,6 +148,17 @@ class SparkSubmitSuite appArgs.childArgs should be (Seq("--master", "local", "some", "--weird", "args")) } + test("print the right queue name") { + val clArgs = Seq( + "--name", "myApp", + "--class", "Foo", + "--conf", "spark.yarn.queue=thequeue", + "userjar.jar") + val appArgs = new SparkSubmitArguments(clArgs) + appArgs.queue should be ("thequeue") + appArgs.toString should include ("thequeue") + } + test("specify deploy mode through configuration") { val clArgs = Seq( "--master", "yarn", diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index d5715f8469f71..13020acdd3dbe 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager +import org.apache.spark.internal.Logging import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService @@ -36,6 +37,7 @@ import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{KryoSerializer, SerializerManager} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.StorageLevel._ +import org.apache.spark.util.Utils trait BlockManagerReplicationBehavior extends SparkFunSuite with Matchers @@ -43,6 +45,7 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite with LocalSparkContext { val conf: SparkConf + protected var rpcEnv: RpcEnv = null protected var master: BlockManagerMaster = null protected lazy val securityMgr = new SecurityManager(conf) @@ -55,7 +58,6 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite protected val allStores = new ArrayBuffer[BlockManager] // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test - protected lazy val serializer = new KryoSerializer(conf) // Implicitly convert strings to BlockIds for test clarity. @@ -471,7 +473,7 @@ class BlockManagerProactiveReplicationSuite extends BlockManagerReplicationBehav conf.set("spark.storage.replication.proactive", "true") conf.set("spark.storage.exceptionOnPinLeak", "true") - (2 to 5).foreach{ i => + (2 to 5).foreach { i => test(s"proactive block replication - $i replicas - ${i - 1} block manager deletions") { testProactiveReplication(i) } @@ -524,3 +526,30 @@ class BlockManagerProactiveReplicationSuite extends BlockManagerReplicationBehav } } } + +class DummyTopologyMapper(conf: SparkConf) extends TopologyMapper(conf) with Logging { + // number of racks to test with + val numRacks = 3 + + /** + * Gets the topology information given the host name + * + * @param hostname Hostname + * @return random topology + */ + override def getTopologyForHost(hostname: String): Option[String] = { + Some(s"/Rack-${Utils.random.nextInt(numRacks)}") + } +} + +class BlockManagerBasicStrategyReplicationSuite extends BlockManagerReplicationBehavior { + val conf: SparkConf = new SparkConf(false).set("spark.app.id", "test") + conf.set("spark.kryoserializer.buffer", "1m") + conf.set( + "spark.storage.replication.policy", + classOf[BasicBlockReplicationPolicy].getName) + conf.set( + "spark.storage.replication.topologyMapper", + classOf[DummyTopologyMapper].getName) +} + diff --git a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala index 800c3899f1a72..ecad0f5352e59 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala @@ -18,34 +18,34 @@ package org.apache.spark.storage import scala.collection.mutable +import scala.util.Random import org.scalatest.{BeforeAndAfter, Matchers} import org.apache.spark.{LocalSparkContext, SparkFunSuite} -class BlockReplicationPolicySuite extends SparkFunSuite +class RandomBlockReplicationPolicyBehavior extends SparkFunSuite with Matchers with BeforeAndAfter with LocalSparkContext { // Implicitly convert strings to BlockIds for test clarity. - private implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) + protected implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) + val replicationPolicy: BlockReplicationPolicy = new RandomBlockReplicationPolicy + + val blockId = "test-block" /** * Test if we get the required number of peers when using random sampling from - * RandomBlockReplicationPolicy + * BlockReplicationPolicy */ - test(s"block replication - random block replication policy") { + test("block replication - random block replication policy") { val numBlockManagers = 10 val storeSize = 1000 - val blockManagers = (1 to numBlockManagers).map { i => - BlockManagerId(s"store-$i", "localhost", 1000 + i, None) - } + val blockManagers = generateBlockManagerIds(numBlockManagers, Seq("/Rack-1")) val candidateBlockManager = BlockManagerId("test-store", "localhost", 1000, None) - val replicationPolicy = new RandomBlockReplicationPolicy - val blockId = "test-block" - (1 to 10).foreach {numReplicas => + (1 to 10).foreach { numReplicas => logDebug(s"Num replicas : $numReplicas") val randomPeers = replicationPolicy.prioritize( candidateBlockManager, @@ -68,7 +68,60 @@ class BlockReplicationPolicySuite extends SparkFunSuite logDebug(s"Random peers : ${secondPass.mkString(", ")}") assert(secondPass.toSet.size === numReplicas) } + } + + protected def generateBlockManagerIds(count: Int, racks: Seq[String]): Seq[BlockManagerId] = { + (1 to count).map{i => + BlockManagerId(s"Exec-$i", s"Host-$i", 10000 + i, Some(racks(Random.nextInt(racks.size)))) + } + } +} + +class TopologyAwareBlockReplicationPolicyBehavior extends RandomBlockReplicationPolicyBehavior { + override val replicationPolicy = new BasicBlockReplicationPolicy + + test("All peers in the same rack") { + val racks = Seq("/default-rack") + val numBlockManager = 10 + (1 to 10).foreach {numReplicas => + val peers = generateBlockManagerIds(numBlockManager, racks) + val blockManager = BlockManagerId("Driver", "Host-driver", 10001, Some(racks.head)) + + val prioritizedPeers = replicationPolicy.prioritize( + blockManager, + peers, + mutable.HashSet.empty, + blockId, + numReplicas + ) + assert(prioritizedPeers.toSet.size == numReplicas) + assert(prioritizedPeers.forall(p => p.host != blockManager.host)) + } } + test("Peers in 2 racks") { + val racks = Seq("/Rack-1", "/Rack-2") + (1 to 10).foreach {numReplicas => + val peers = generateBlockManagerIds(10, racks) + val blockManager = BlockManagerId("Driver", "Host-driver", 9001, Some(racks.head)) + + val prioritizedPeers = replicationPolicy.prioritize( + blockManager, + peers, + mutable.HashSet.empty, + blockId, + numReplicas + ) + + assert(prioritizedPeers.toSet.size == numReplicas) + val priorityPeers = prioritizedPeers.take(2) + assert(priorityPeers.forall(p => p.host != blockManager.host)) + if(numReplicas > 1) { + // both these conditions should be satisfied when numReplicas > 1 + assert(priorityPeers.exists(p => p.topologyInfo == blockManager.topologyInfo)) + assert(priorityPeers.exists(p => p.topologyInfo != blockManager.topologyInfo)) + } + } + } } diff --git a/docs/configuration.md b/docs/configuration.md index 4729f1b0404c1..a9753925407d7 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1137,6 +1137,15 @@ Apart from these, the following properties are also available, and may be useful mapping has high overhead for blocks close to or below the page size of the operating system. + + spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version + 1 + + The file output committer algorithm version, valid algorithm version number: 1 or 2. + Version 2 may have better performance, but version 1 may handle failures better in certain situations, + as per MAPREDUCE-4815. + + ### Networking diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index ff07ad11943bd..b5cf9f1644986 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -717,11 +717,11 @@ However, to run this query for days, it's necessary for the system to bound the intermediate in-memory state it accumulates. This means the system needs to know when an old aggregate can be dropped from the in-memory state because the application is not going to receive late data for that aggregate any more. To enable this, in Spark 2.1, we have introduced -**watermarking**, which let's the engine automatically track the current event time in the data and +**watermarking**, which lets the engine automatically track the current event time in the data and attempt to clean up old state accordingly. You can define the watermark of a query by -specifying the event time column and the threshold on how late the data is expected be in terms of +specifying the event time column and the threshold on how late the data is expected to be in terms of event time. For a specific window starting at time `T`, the engine will maintain state and allow late -data to be update the state until `(max event time seen by the engine - late threshold > T)`. +data to update the state until `(max event time seen by the engine - late threshold > T)`. In other words, late data within the threshold will be aggregated, but data later than the threshold will be dropped. Let's understand this with an example. We can easily define watermarking on the previous example using `withWatermark()` as shown below. @@ -792,7 +792,7 @@ This watermark lets the engine maintain intermediate state for additional 10 min data to be counted. For example, the data `(12:09, cat)` is out of order and late, and it falls in windows `12:05 - 12:15` and `12:10 - 12:20`. Since, it is still ahead of the watermark `12:04` in the trigger, the engine still maintains the intermediate counts as state and correctly updates the -counts of the related windows. However, when the watermark is updated to 12:11, the intermediate +counts of the related windows. However, when the watermark is updated to `12:11`, the intermediate state for window `(12:00 - 12:10)` is cleared, and all subsequent data (e.g. `(12:04, donkey)`) is considered "too late" and therefore ignored. Note that after every trigger, the updated counts (i.e. purple rows) are written to sink as the trigger output, as dictated by @@ -825,7 +825,7 @@ section for detailed explanation of the semantics of each output mode. same column as the timestamp column used in the aggregate. For example, `df.withWatermark("time", "1 min").groupBy("time2").count()` is invalid in Append output mode, as watermark is defined on a different column -as the aggregation column. +from the aggregation column. - `withWatermark` must be called before the aggregation for the watermark details to be used. For example, `df.groupBy("time").count().withWatermark("time", "1 min")` is invalid in Append @@ -909,7 +909,7 @@ track of all the data received in the stream. This is therefore fundamentally ha efficiently. ## Starting Streaming Queries -Once you have defined the final result DataFrame/Dataset, all that is left is for you start the streaming computation. To do that, you have to use the `DataStreamWriter` +Once you have defined the final result DataFrame/Dataset, all that is left is for you to start the streaming computation. To do that, you have to use the `DataStreamWriter` ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamWriter)/[Java](api/java/org/apache/spark/sql/streaming/DataStreamWriter.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamWriter) docs) returned through `Dataset.writeStream()`. You will have to specify one or more of the following in this interface. @@ -1396,15 +1396,15 @@ You can directly get the current status and metrics of an active query using `lastProgress()` returns a `StreamingQueryProgress` object in [Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryProgress) and [Java](api/java/org/apache/spark/sql/streaming/StreamingQueryProgress.html) -and an dictionary with the same fields in Python. It has all the information about +and a dictionary with the same fields in Python. It has all the information about the progress made in the last trigger of the stream - what data was processed, what were the processing rates, latencies, etc. There is also `streamingQuery.recentProgress` which returns an array of last few progresses. -In addition, `streamingQuery.status()` returns `StreamingQueryStatus` object +In addition, `streamingQuery.status()` returns a `StreamingQueryStatus` object in [Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryStatus) and [Java](api/java/org/apache/spark/sql/streaming/StreamingQueryStatus.html) -and an dictionary with the same fields in Python. It gives information about +and a dictionary with the same fields in Python. It gives information about what the query is immediately doing - is a trigger active, is data being processed, etc. Here are a few examples. diff --git a/examples/src/main/r/RSparkSQLExample.R b/examples/src/main/r/RSparkSQLExample.R index e647f0e1e9f17..3734568d872d0 100644 --- a/examples/src/main/r/RSparkSQLExample.R +++ b/examples/src/main/r/RSparkSQLExample.R @@ -15,6 +15,9 @@ # limitations under the License. # +# To run this example use +# ./bin/spark-submit examples/src/main/r/RSparkSQLExample.R + library(SparkR) # $example on:init_session$ diff --git a/examples/src/main/r/dataframe.R b/examples/src/main/r/dataframe.R index 82b85f2f590f6..311350497f873 100644 --- a/examples/src/main/r/dataframe.R +++ b/examples/src/main/r/dataframe.R @@ -15,6 +15,9 @@ # limitations under the License. # +# To run this example use +# ./bin/spark-submit examples/src/main/r/dataframe.R + library(SparkR) # Initialize SparkSession diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index e07c9a4717c3a..0658bddf16961 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.util.Utils /** - * An example of how to use [[org.apache.spark.sql.DataFrame]] for ML. Run with + * An example of how to use [[DataFrame]] for ML. Run with * {{{ * ./bin/run-example ml.DataFrameExample [options] * }}} diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala index a7243ccbf28cc..d3c84b77d26ac 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types.{StructField, StructType} /** * API for correlation functions in MLlib, compatible with Dataframes and Datasets. * - * The functions in this package generalize the functions in [[org.apache.spark.sql.Dataset.stat]] + * The functions in this package generalize the functions in [[org.apache.spark.sql.Dataset#stat]] * to spark.ml's Vector types. */ @Since("2.2.0") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 70438eb5912b8..920033a9a8480 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.trees.CurrentOrigin /** * Collection of rules related to hints. The only hint currently available is broadcast join hint. * - * Note that this is separatedly into two rules because in the future we might introduce new hint + * Note that this is separately into two rules because in the future we might introduce new hint * rules that have different ordering requirements from broadcast. */ object ResolveHints { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 93fc565a53419..ec003cdc17b89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -229,9 +229,9 @@ case class ExpressionEncoder[T]( // serializer expressions are used to encode an object to a row, while the object is usually an // intermediate value produced inside an operator, not from the output of the child operator. This // is quite different from normal expressions, and `AttributeReference` doesn't work here - // (intermediate value is not an attribute). We assume that all serializer expressions use a same - // `BoundReference` to refer to the object, and throw exception if they don't. - assert(serializer.forall(_.references.isEmpty), "serializer cannot reference to any attributes.") + // (intermediate value is not an attribute). We assume that all serializer expressions use the + // same `BoundReference` to refer to the object, and throw exception if they don't. + assert(serializer.forall(_.references.isEmpty), "serializer cannot reference any attributes.") assert(serializer.flatMap { ser => val boundRefs = ser.collect { case b: BoundReference => b } assert(boundRefs.nonEmpty, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b93a5d0b7a0e5..1db26d9c415a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -491,7 +491,7 @@ abstract class BinaryExpression extends Expression { * A [[BinaryExpression]] that is an operator, with two properties: * * 1. The string representation is "x symbol y", rather than "funcName(x, y)". - * 2. Two inputs are expected to the be same type. If the two inputs have different types, + * 2. Two inputs are expected to be of the same type. If the two inputs have different types, * the analyzer will find the tightest common type and do the proper type casting. */ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 4870093e9250f..f2b252259b89d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -113,7 +113,7 @@ case class Abs(child: Expression) protected override def nullSafeEval(input: Any): Any = numeric.abs(input) } -abstract class BinaryArithmetic extends BinaryOperator { +abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { override def dataType: DataType = left.dataType @@ -146,7 +146,7 @@ object BinaryArithmetic { > SELECT 1 _FUNC_ 2; 3 """) -case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { +case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -182,8 +182,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic wit > SELECT 2 _FUNC_ 1; 1 """) -case class Subtract(left: Expression, right: Expression) - extends BinaryArithmetic with NullIntolerant { +case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -219,8 +218,7 @@ case class Subtract(left: Expression, right: Expression) > SELECT 2 _FUNC_ 3; 6 """) -case class Multiply(left: Expression, right: Expression) - extends BinaryArithmetic with NullIntolerant { +case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = NumericType @@ -243,8 +241,7 @@ case class Multiply(left: Expression, right: Expression) 1.0 """) // scalastyle:on line.size.limit -case class Divide(left: Expression, right: Expression) - extends BinaryArithmetic with NullIntolerant { +case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) @@ -324,8 +321,7 @@ case class Divide(left: Expression, right: Expression) > SELECT 2 _FUNC_ 1.8; 0.2 """) -case class Remainder(left: Expression, right: Expression) - extends BinaryArithmetic with NullIntolerant { +case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = NumericType @@ -412,7 +408,7 @@ case class Remainder(left: Expression, right: Expression) > SELECT _FUNC_(-10, 3); 2 """) -case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { +case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { override def toString: String = s"pmod($left, $right)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 0c256c3d890f1..de1594d119e17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -104,7 +104,7 @@ trait ExtractValue extends Expression * For example, when get field `yEAr` from ``, we should pass in `yEAr`. */ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None) - extends UnaryExpression with ExtractValue { + extends UnaryExpression with ExtractValue with NullIntolerant { lazy val childSchema = child.dataType.asInstanceOf[StructType] @@ -152,7 +152,7 @@ case class GetArrayStructFields( field: StructField, ordinal: Int, numFields: Int, - containsNull: Boolean) extends UnaryExpression with ExtractValue { + containsNull: Boolean) extends UnaryExpression with ExtractValue with NullIntolerant { override def dataType: DataType = ArrayType(field.dataType, containsNull) override def toString: String = s"$child.${field.name}" @@ -213,7 +213,7 @@ case class GetArrayStructFields( * We need to do type checking here as `ordinal` expression maybe unresolved. */ case class GetArrayItem(child: Expression, ordinal: Expression) - extends BinaryExpression with ExpectsInputTypes with ExtractValue { + extends BinaryExpression with ExpectsInputTypes with ExtractValue with NullIntolerant { // We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType) @@ -260,7 +260,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) * We need to do type checking here as `key` expression maybe unresolved. */ case class GetMapValue(child: Expression, key: Expression) - extends BinaryExpression with ImplicitCastInputTypes with ExtractValue { + extends BinaryExpression with ImplicitCastInputTypes with ExtractValue with NullIntolerant { private def keyType = child.dataType.asInstanceOf[MapType].keyType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 1b00c9e79da22..4c8b177237d23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -138,5 +138,5 @@ package object expressions { * input will result in null output). We will use this information during constructing IsNotNull * constraints. */ - trait NullIntolerant + trait NullIntolerant extends Expression } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 4896a6225aa80..b23da537be721 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -trait StringRegexExpression extends ImplicitCastInputTypes { - self: BinaryExpression => +abstract class StringRegexExpression extends BinaryExpression + with ImplicitCastInputTypes with NullIntolerant { def escape(v: String): String def matches(regex: Pattern, str: String): Boolean @@ -69,8 +69,7 @@ trait StringRegexExpression extends ImplicitCastInputTypes { */ @ExpressionDescription( usage = "str _FUNC_ pattern - Returns true if `str` matches `pattern`, or false otherwise.") -case class Like(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression { +case class Like(left: Expression, right: Expression) extends StringRegexExpression { override def escape(v: String): String = StringUtils.escapeLikeRegex(v) @@ -122,8 +121,7 @@ case class Like(left: Expression, right: Expression) @ExpressionDescription( usage = "str _FUNC_ regexp - Returns true if `str` matches `regexp`, or false otherwise.") -case class RLike(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression { +case class RLike(left: Expression, right: Expression) extends StringRegexExpression { override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 908aa44f81c97..5598a146997ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -297,8 +297,8 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx } /** A base trait for functions that compare two strings, returning a boolean. */ -trait StringPredicate extends Predicate with ImplicitCastInputTypes { - self: BinaryExpression => +abstract class StringPredicate extends BinaryExpression + with Predicate with ImplicitCastInputTypes with NullIntolerant { def compare(l: UTF8String, r: UTF8String): Boolean @@ -313,8 +313,7 @@ trait StringPredicate extends Predicate with ImplicitCastInputTypes { /** * A function that returns true if the string `left` contains the string `right`. */ -case class Contains(left: Expression, right: Expression) - extends BinaryExpression with StringPredicate { +case class Contains(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") @@ -324,8 +323,7 @@ case class Contains(left: Expression, right: Expression) /** * A function that returns true if the string `left` starts with the string `right`. */ -case class StartsWith(left: Expression, right: Expression) - extends BinaryExpression with StringPredicate { +case class StartsWith(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") @@ -335,8 +333,7 @@ case class StartsWith(left: Expression, right: Expression) /** * A function that returns true if the string `left` ends with the string `right`. */ -case class EndsWith(left: Expression, right: Expression) - extends BinaryExpression with StringPredicate { +case class EndsWith(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") @@ -1122,7 +1119,7 @@ case class StringSpace(child: Expression) """) // scalastyle:on line.size.limit case class Substring(str: Expression, pos: Expression, len: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(str: Expression, pos: Expression) = { this(str, pos, Literal(Integer.MAX_VALUE)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 07d294b108548..b2a3888ff7b08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -695,7 +695,7 @@ case class DenseRank(children: Seq[Expression]) extends RankLike { * * This documentation has been based upon similar documentation for the Hive and Presto projects. * - * @param children to base the rank on; a change in the value of one the children will trigger a + * @param children to base the rank on; a change in the value of one of the children will trigger a * change in rank. This is an internal parameter and will be assigned by the * Analyser. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 21d1cd5932620..33039127f16ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -347,35 +347,30 @@ object LikeSimplification extends Rule[LogicalPlan] { * Null value propagation from bottom to top of the expression tree. */ case class NullPropagation(conf: CatalystConf) extends Rule[LogicalPlan] { - private def nonNullLiteral(e: Expression): Boolean = e match { - case Literal(null, _) => false - case _ => true + private def isNullLiteral(e: Expression): Boolean = e match { + case Literal(null, _) => true + case _ => false } def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { case e @ WindowExpression(Cast(Literal(0L, _), _, _), _) => Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone)) - case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) => + case e @ AggregateExpression(Count(exprs), _, _, _) if exprs.forall(isNullLiteral) => Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone)) - case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) - case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) - case e @ GetArrayItem(Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ GetArrayItem(_, Literal(null, _)) => Literal.create(null, e.dataType) - case e @ GetMapValue(Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ GetMapValue(_, Literal(null, _)) => Literal.create(null, e.dataType) - case e @ GetStructField(Literal(null, _), _, _) => Literal.create(null, e.dataType) - case e @ GetArrayStructFields(Literal(null, _), _, _, _, _) => - Literal.create(null, e.dataType) - case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) - case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) => // This rule should be only triggered when isDistinct field is false. ae.copy(aggregateFunction = Count(Literal(1))) + case IsNull(c) if !c.nullable => Literal.create(false, BooleanType) + case IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) + + case EqualNullSafe(Literal(null, _), r) => IsNull(r) + case EqualNullSafe(l, Literal(null, _)) => IsNull(l) + // For Coalesce, remove null literals. case e @ Coalesce(children) => - val newChildren = children.filter(nonNullLiteral) + val newChildren = children.filterNot(isNullLiteral) if (newChildren.isEmpty) { Literal.create(null, e.dataType) } else if (newChildren.length == 1) { @@ -384,33 +379,13 @@ case class NullPropagation(conf: CatalystConf) extends Rule[LogicalPlan] { Coalesce(newChildren) } - case e @ Substring(Literal(null, _), _, _) => Literal.create(null, e.dataType) - case e @ Substring(_, Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ Substring(_, _, Literal(null, _)) => Literal.create(null, e.dataType) - - // Put exceptional cases above if any - case e @ BinaryArithmetic(Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ BinaryArithmetic(_, Literal(null, _)) => Literal.create(null, e.dataType) - - case e @ BinaryComparison(Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ BinaryComparison(_, Literal(null, _)) => Literal.create(null, e.dataType) - - case e: StringRegexExpression => e.children match { - case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) - case _ => e - } - - case e: StringPredicate => e.children match { - case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) - case _ => e - } - - // If the value expression is NULL then transform the In expression to - // Literal(null) - case In(Literal(null, _), list) => Literal.create(null, BooleanType) + // If the value expression is NULL then transform the In expression to null literal. + case In(Literal(null, _), _) => Literal.create(null, BooleanType) + // Non-leaf NullIntolerant expressions will return null, if at least one of its children is + // a null literal. + case e: NullIntolerant if e.children.exists(isNullLiteral) => + Literal.create(null, e.dataType) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 174d546e22809..257dbfac8c3e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -65,7 +65,7 @@ object EliminateSerialization extends Rule[LogicalPlan] { /** * Combines two adjacent [[TypedFilter]]s, which operate on same type object in condition, into one, - * mering the filter functions into one conjunctive function. + * merging the filter functions into one conjunctive function. */ object CombineTypedFilters extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index cd238e05d4102..162051a8c0e4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -492,7 +492,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Add an [[Aggregate]] to a logical plan. + * Add an [[Aggregate]] or [[GroupingSets]] to a logical plan. */ private def withAggregation( ctx: AggregationContext, @@ -519,7 +519,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Add a Hint to a logical plan. + * Add a [[Hint]] to a logical plan. */ private def withHints( ctx: HintContext, @@ -545,7 +545,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Create a single relation referenced in a FROM claused. This method is used when a part of the + * Create a single relation referenced in a FROM clause. This method is used when a part of the * join condition is nested, for example: * {{{ * select * from t1 join (t2 cross join t3) on col1 = col2 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 9fd95a4b368ce..2d8ec2053a4cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -230,14 +230,15 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT def producedAttributes: AttributeSet = AttributeSet.empty /** - * Attributes that are referenced by expressions but not provided by this nodes children. + * Attributes that are referenced by expressions but not provided by this node's children. * Subclasses should override this method if they produce attributes internally as it is used by * assertions designed to prevent the construction of invalid plans. */ def missingInput: AttributeSet = references -- inputSet -- producedAttributes /** - * Runs [[transform]] with `rule` on all expressions present in this query operator. + * Runs [[transformExpressionsDown]] with `rule` on all expressions present + * in this query operator. * Users should not expect a specific directionality. If a specific directionality is needed, * transformExpressionsDown or transformExpressionsUp should be used. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index e22b429aec68b..f71a976bd7a24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -32,7 +32,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { private var _analyzed: Boolean = false /** - * Marks this plan as already analyzed. This should only be called by CheckAnalysis. + * Marks this plan as already analyzed. This should only be called by [[CheckAnalysis]]. */ private[catalyst] def setAnalyzed(): Unit = { _analyzed = true } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index c2e62e739776f..d1c6b50536cd2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -26,7 +26,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval /** - * Test basic expression parsing. If a type of expression is supported it should be tested here. + * Test basic expression parsing. + * If the type of an expression is supported it should be tested here. * * Please note that some of the expressions test don't have to be sound expressions, only their * structure needs to be valid. Unsound expressions should be caught by the Analyzer or diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index ae0703513cf42..43de2de7e7094 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -84,8 +84,8 @@ class TypedColumn[-T, U]( } /** - * Gives the TypedColumn a name (alias). - * If the current TypedColumn has metadata associated with it, this metadata will be propagated + * Gives the [[TypedColumn]] a name (alias). + * If the current `TypedColumn` has metadata associated with it, this metadata will be propagated * to the new column. * * @group expr_ops @@ -99,16 +99,14 @@ class TypedColumn[-T, U]( /** * A column that will be computed based on the data in a `DataFrame`. * - * A new column is constructed based on the input columns present in a dataframe: + * A new column can be constructed based on the input columns present in a DataFrame: * * {{{ - * df("columnName") // On a specific DataFrame. + * df("columnName") // On a specific `df` DataFrame. * col("columnName") // A generic column no yet associated with a DataFrame. * col("columnName.field") // Extracting a struct field * col("`a.column.with.dots`") // Escape `.` in column names. * $"columnName" // Scala short hand for a named column. - * expr("a + 1") // A column that is constructed from a parsed SQL Expression. - * lit("abc") // A column that produces a literal (constant) value. * }}} * * [[Column]] objects can be composed to form complex expressions: @@ -118,7 +116,7 @@ class TypedColumn[-T, U]( * $"a" === $"b" * }}} * - * @note The internal Catalyst expression can be accessed via "expr", but this method is for + * @note The internal Catalyst expression can be accessed via [[expr]], but this method is for * debugging purposes only and can change in any future Spark releases. * * @groupname java_expr_ops Java-specific expression operators @@ -1100,7 +1098,7 @@ class Column(val expr: Expression) extends Logging { def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast, Set.empty) } /** - * Prints the expression to the console for debugging purpose. + * Prints the expression to the console for debugging purposes. * * @group df_ops * @since 1.3.0 @@ -1154,8 +1152,8 @@ class Column(val expr: Expression) extends Logging { * {{{ * val w = Window.partitionBy("name").orderBy("id") * df.select( - * sum("price").over(w.rangeBetween(Long.MinValue, 2)), - * avg("price").over(w.rowsBetween(0, 4)) + * sum("price").over(w.rangeBetween(Window.unboundedPreceding, 2)), + * avg("price").over(w.rowsBetween(Window.currentRow, 4)) * ) * }}} * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala index 18bccee98f610..582d4a3670b8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala @@ -24,7 +24,8 @@ import org.apache.spark.annotation.InterfaceStability * * To use this, import implicit conversions in SQL: * {{{ - * import sqlContext.implicits._ + * val spark: SparkSession = ... + * import spark.implicits._ * }}} * * @since 1.6.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index a97297892b5e0..b60499253c42f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -60,7 +60,7 @@ import org.apache.spark.util.Utils * The builder can also be used to create a new session: * * {{{ - * SparkSession.builder() + * SparkSession.builder * .master("local") * .appName("Word Count") * .config("spark.some.config.option", "some-value") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 28156b277f597..239151495f4bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -171,8 +171,20 @@ case class FileSourceScanExec( false } - @transient private lazy val selectedPartitions = - relation.location.listFiles(partitionFilters, dataFilters) + @transient private lazy val selectedPartitions: Seq[PartitionDirectory] = { + val startTime = System.nanoTime() + val ret = relation.location.listFiles(partitionFilters, dataFilters) + val timeTaken = (System.nanoTime() - startTime) / 1000 / 1000 + + metrics("numFiles").add(ret.map(_.files.size.toLong).sum) + metrics("metadataTime").add(timeTaken) + + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, + metrics("numFiles") :: metrics("metadataTime") :: Nil) + + ret + } override val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = { val bucketSpec = if (relation.sparkSession.sessionState.conf.bucketingEnabled) { @@ -293,6 +305,8 @@ case class FileSourceScanExec( override lazy val metrics = Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of files"), + "metadataTime" -> SQLMetrics.createMetric(sparkContext, "metadata time (ms)"), "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala index e5a6a5f60b8a6..470c736da98b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types.StringType /** * A command for users to list the databases/schemas. - * If a databasePattern is supplied then the databases that only matches the + * If a databasePattern is supplied then the databases that only match the * pattern would be listed. * The syntax of using this command in SQL is: * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 7957224ce48b5..bda64d4b91bbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -80,6 +80,9 @@ object FileFormatWriter extends Logging { """.stripMargin) } + /** The result of a successful write task. */ + private case class WriteTaskResult(commitMsg: TaskCommitMessage, updatedPartitions: Set[String]) + /** * Basic work flow of this command is: * 1. Driver side setup, including output committer initialization and data source specific @@ -172,8 +175,9 @@ object FileFormatWriter extends Logging { global = false, child = queryExecution.executedPlan).execute() } - - val ret = sparkSession.sparkContext.runJob(rdd, + val ret = new Array[WriteTaskResult](rdd.partitions.length) + sparkSession.sparkContext.runJob( + rdd, (taskContext: TaskContext, iter: Iterator[InternalRow]) => { executeTask( description = description, @@ -182,10 +186,16 @@ object FileFormatWriter extends Logging { sparkAttemptNumber = taskContext.attemptNumber(), committer, iterator = iter) + }, + 0 until rdd.partitions.length, + (index, res: WriteTaskResult) => { + committer.onTaskCommit(res.commitMsg) + ret(index) = res }) - val commitMsgs = ret.map(_._1) - val updatedPartitions = ret.flatMap(_._2).distinct.map(PartitioningUtils.parsePathFragment) + val commitMsgs = ret.map(_.commitMsg) + val updatedPartitions = ret.flatMap(_.updatedPartitions) + .distinct.map(PartitioningUtils.parsePathFragment) committer.commitJob(job, commitMsgs) logInfo(s"Job ${job.getJobID} committed.") @@ -205,7 +215,7 @@ object FileFormatWriter extends Logging { sparkPartitionId: Int, sparkAttemptNumber: Int, committer: FileCommitProtocol, - iterator: Iterator[InternalRow]): (TaskCommitMessage, Set[String]) = { + iterator: Iterator[InternalRow]): WriteTaskResult = { val jobId = SparkHadoopWriterUtils.createJobID(new Date, sparkStageId) val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) @@ -238,7 +248,7 @@ object FileFormatWriter extends Logging { // Execute the task to write rows out and commit the task. val outputPartitions = writeTask.execute(iterator) writeTask.releaseResources() - (committer.commitTask(taskAttemptContext), outputPartitions) + WriteTaskResult(committer.commitTask(taskAttemptContext), outputPartitions) })(catchBlock = { // If there is an error, release resource and then abort the task try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala index 75ffe90f2bb70..311942f6dbd84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.types.StructType * monotonically increasing notion of progress that can be represented as an [[Offset]]. Spark * will regularly query each [[Source]] to see if any more data is available. */ -trait Source { +trait Source { /** Returns the schema of the data from this source */ def schema: StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index f3cf3052ea3ea..00053485e614c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -113,7 +113,7 @@ object Window { * Creates a [[WindowSpec]] with the frame boundaries defined, * from `start` (inclusive) to `end` (inclusive). * - * Both `start` and `end` are relative positions from the current row. For example, "0" means + * Both `start` and `end` are positions relative to the current row. For example, "0" means * "current row", while "-1" means the row before the current row, and "5" means the fifth row * after the current row. * @@ -131,9 +131,9 @@ object Window { * import org.apache.spark.sql.expressions.Window * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) * .toDF("id", "category") - * df.withColumn("sum", - * sum('id) over Window.partitionBy('category).orderBy('id).rowsBetween(0,1)) - * .show() + * val byCategoryOrderedById = + * Window.partitionBy('category).orderBy('id).rowsBetween(Window.currentRow, 1) + * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() * * +---+--------+---+ * | id|category|sum| @@ -150,7 +150,7 @@ object Window { * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value (`Window.unboundedPreceding`). * @param end boundary end, inclusive. The frame is unbounded if this is the - * maximum long value (`Window.unboundedFollowing`). + * maximum long value (`Window.unboundedFollowing`). * @since 2.1.0 */ // Note: when updating the doc for this method, also update WindowSpec.rowsBetween. @@ -162,7 +162,7 @@ object Window { * Creates a [[WindowSpec]] with the frame boundaries defined, * from `start` (inclusive) to `end` (inclusive). * - * Both `start` and `end` are relative from the current row. For example, "0" means "current row", + * Both `start` and `end` are relative to the current row. For example, "0" means "current row", * while "-1" means one off before the current row, and "5" means the five off after the * current row. * @@ -183,9 +183,9 @@ object Window { * import org.apache.spark.sql.expressions.Window * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) * .toDF("id", "category") - * df.withColumn("sum", - * sum('id) over Window.partitionBy('category).orderBy('id).rangeBetween(0,1)) - * .show() + * val byCategoryOrderedById = + * Window.partitionBy('category).orderBy('id).rowsBetween(Window.currentRow, 1) + * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() * * +---+--------+---+ * | id|category|sum| @@ -202,7 +202,7 @@ object Window { * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value (`Window.unboundedPreceding`). * @param end boundary end, inclusive. The frame is unbounded if this is the - * maximum long value (`Window.unboundedFollowing`). + * maximum long value (`Window.unboundedFollowing`). * @since 2.1.0 */ // Note: when updating the doc for this method, also update WindowSpec.rangeBetween. @@ -221,7 +221,8 @@ object Window { * * {{{ * // PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW - * Window.partitionBy("country").orderBy("date").rowsBetween(Long.MinValue, 0) + * Window.partitionBy("country").orderBy("date") + * .rowsBetween(Window.unboundedPreceding, Window.currentRow) * * // PARTITION BY country ORDER BY date ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index de7d7a1772753..6279d48c94de5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -86,7 +86,7 @@ class WindowSpec private[sql]( * after the current row. * * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`, - * and `[Window.currentRow` to specify special boundary values, rather than using integral + * and `Window.currentRow` to specify special boundary values, rather than using integral * values directly. * * A row based boundary is based on the position of the row within the partition. @@ -99,9 +99,9 @@ class WindowSpec private[sql]( * import org.apache.spark.sql.expressions.Window * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) * .toDF("id", "category") - * df.withColumn("sum", - * sum('id) over Window.partitionBy('category).orderBy('id).rowsBetween(0,1)) - * .show() + * val byCategoryOrderedById = + * Window.partitionBy('category).orderBy('id).rowsBetween(Window.currentRow, 1) + * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() * * +---+--------+---+ * | id|category|sum| @@ -118,7 +118,7 @@ class WindowSpec private[sql]( * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value (`Window.unboundedPreceding`). * @param end boundary end, inclusive. The frame is unbounded if this is the - * maximum long value (`Window.unboundedFollowing`). + * maximum long value (`Window.unboundedFollowing`). * @since 1.4.0 */ // Note: when updating the doc for this method, also update Window.rowsBetween. @@ -134,7 +134,7 @@ class WindowSpec private[sql]( * current row. * * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`, - * and `[Window.currentRow` to specify special boundary values, rather than using integral + * and `Window.currentRow` to specify special boundary values, rather than using integral * values directly. * * A range based boundary is based on the actual value of the ORDER BY @@ -150,9 +150,9 @@ class WindowSpec private[sql]( * import org.apache.spark.sql.expressions.Window * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) * .toDF("id", "category") - * df.withColumn("sum", - * sum('id) over Window.partitionBy('category).orderBy('id).rangeBetween(0,1)) - * .show() + * val byCategoryOrderedById = + * Window.partitionBy('category).orderBy('id).rangeBetween(Window.currentRow, 1) + * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() * * +---+--------+---+ * | id|category|sum| @@ -169,7 +169,7 @@ class WindowSpec private[sql]( * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value (`Window.unboundedPreceding`). * @param end boundary end, inclusive. The frame is unbounded if this is the - * maximum long value (`Window.unboundedFollowing`). + * maximum long value (`Window.unboundedFollowing`). * @since 1.4.0 */ // Note: when updating the doc for this method, also update Window.rangeBetween. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0f9203065ef05..f07e04368389f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2968,7 +2968,7 @@ object functions { * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string - * @param options options to control how the json is parsed. accepts the same options and the + * @param options options to control how the json is parsed. Accepts the same options as the * json data source. * * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 8287776f8f558..7c71e7280c6d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -18,9 +18,12 @@ package org.apache.spark.sql.test import java.io.File +import java.util.concurrent.ConcurrentLinkedQueue import org.scalatest.BeforeAndAfter +import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage +import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.sources._ @@ -41,7 +44,6 @@ object LastOptions { } } - /** Dummy provider. */ class DefaultSource extends RelationProvider @@ -107,6 +109,20 @@ class DefaultSourceWithoutUserSpecifiedSchema } } +object MessageCapturingCommitProtocol { + val commitMessages = new ConcurrentLinkedQueue[TaskCommitMessage]() +} + +class MessageCapturingCommitProtocol(jobId: String, path: String) + extends HadoopMapReduceCommitProtocol(jobId, path) { + + // captures commit messages for testing + override def onTaskCommit(msg: TaskCommitMessage): Unit = { + MessageCapturingCommitProtocol.commitMessages.offer(msg) + } +} + + class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with BeforeAndAfter { import testImplicits._ @@ -291,6 +307,19 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be Option(dir).map(spark.read.format("org.apache.spark.sql.test").load) } + test("write path implements onTaskCommit API correctly") { + withSQLConf( + "spark.sql.sources.commitProtocolClass" -> + classOf[MessageCapturingCommitProtocol].getCanonicalName) { + withTempDir { dir => + val path = dir.getCanonicalPath + MessageCapturingCommitProtocol.commitMessages.clear() + spark.range(10).repartition(10).write.mode("overwrite").parquet(path) + assert(MessageCapturingCommitProtocol.commitMessages.size() == 10) + } + } + } + test("read a data source that does not extend SchemaRelationProvider") { val dfReader = spark.read .option("from", "1") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 8048c2ba2c2e4..2f3dfa05e9ef7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState} /** - * Builder that produces a Hive aware [[SessionState]]. + * Builder that produces a Hive-aware `SessionState`. */ @Experimental @InterfaceStability.Unstable diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala index 8e1a090618433..639ac6de4f5d3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala @@ -66,7 +66,7 @@ private[streaming] class InputInfoTracker(ssc: StreamingContext) extends Logging new mutable.HashMap[Int, StreamInputInfo]()) if (inputInfos.contains(inputInfo.inputStreamId)) { - throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId} for batch" + + throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId} for batch " + s"$batchTime is already added into InputInfoTracker, this is an illegal state") } inputInfos += ((inputInfo.inputStreamId, inputInfo))