From e6f9191a4eb2fc444ea6a908ffc49959583db9e4 Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 16 Jul 2015 16:21:05 -0700 Subject: [PATCH 001/340] [WIP] External Group By initial impl --- .../apache/spark/rdd/PairRDDFunctions.scala | 13 +- .../spark/serializer/KryoSerializer.scala | 3 +- .../spark/storage/BlockObjectWriter.scala | 2 + .../spark/util/collection/CompactBuffer.scala | 1 + .../spark/util/collection/ExternalList.scala | 311 ++++++++++++++++++ .../SizeTrackingCompactBuffer.scala | 31 ++ .../spark/util/collection/Spillable.scala | 4 +- .../util/collection/ExternalListSuite.scala | 55 ++++ 8 files changed, 412 insertions(+), 8 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala create mode 100644 core/src/main/scala/org/apache/spark/util/collection/SizeTrackingCompactBuffer.scala create mode 100644 core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 91a6a2d039852..ef74e85203a7a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -45,7 +45,7 @@ import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.util.collection.{ExternalList, SizeTrackingCompactBuffer, CompactBuffer} import org.apache.spark.util.random.StratifiedSamplingUtils /** @@ -463,10 +463,13 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) // groupByKey shouldn't use map side combine because map side combine does not // reduce the amount of data shuffled and requires all map side data be inserted // into a hash table, leading to more objects in the old gen. - val createCombiner = (v: V) => CompactBuffer(v) - val mergeValue = (buf: CompactBuffer[V], v: V) => buf += v - val mergeCombiners = (c1: CompactBuffer[V], c2: CompactBuffer[V]) => c1 ++= c2 - val bufs = combineByKey[CompactBuffer[V]]( + val createCombiner = (v: V) => ExternalList(v) + val mergeValue = (buf: ExternalList[V], v: V) => buf += v + val mergeCombiners = (c1: ExternalList[V], c2: ExternalList[V]) => { + c2.foreach(c => c1 += c) + c1 + } + val bufs = combineByKey[ExternalList[V]]( createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine = false) bufs.asInstanceOf[RDD[(K, Iterable[V])]] } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index ed35cffe968f8..3b878f47ce5e0 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -37,7 +37,7 @@ import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf} -import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.util.collection.{ExternalList, ExternalListSerializer, CompactBuffer} /** * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. @@ -100,6 +100,7 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer()) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) + kryo.register(classOf[ExternalList[Any]], new ExternalListSerializer[Any]()) try { // Use the default classloader when calling the user registrator. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 7eeabd1e0489c..6d6199716f406 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -58,6 +58,7 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) extends Ou */ def write(key: Any, value: Any) + /** * Notify the writer that a record worth of bytes has been written with OutputStream#write. */ @@ -253,4 +254,5 @@ private[spark] class DiskBlockObjectWriter( objOut.flush() bs.flush() } + } diff --git a/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala index 4d43d8d5cc8d8..5bed400cf96ef 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala @@ -141,6 +141,7 @@ private[spark] class CompactBuffer[T: ClassTag] extends Seq[T] with Serializable newArrayLen = Int.MaxValue - 2 } } + require(newArrayLen != null) val newArray = new Array[T](newArrayLen) if (otherElements != null) { System.arraycopy(otherElements, 0, newArray, 0, otherElements.length) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala new file mode 100644 index 0000000000000..eceaa604d0cdb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util.collection + +import java.io._ + +import com.esotericsoftware.kryo.io.{Output, Input} +import com.esotericsoftware.kryo.{Kryo, Serializer => KSerializer} +import com.google.common.io.ByteStreams +import org.apache.spark.SparkEnv +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.serializer.{DeserializationStream, Serializer} +import org.apache.spark.storage.{BlockId, BlockManager} + +import scala.collection.generic.Growable +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +/** + * List that can spill some of its contents to disk if its contents cannot be held in memory. + * Implementation is based heavily on `org.apache.spark.util.collection.ExternalAppendOnlyMap}` + */ +@SerialVersionUID(1L) +private[spark] class ExternalList[T: ClassTag] + extends Growable[T] + with Iterable[T] + with Spillable[SizeTrackingCompactBuffer[T]] + with Serializable { + + private val sparkConf = SparkEnv.get.conf + private val blockManager: BlockManager = SparkEnv.get.blockManager + private val diskBlockManager = blockManager.diskBlockManager + private val fileBufferSize = + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 + private val serializerBatchSize = sparkConf.getLong("spark.shuffle.spill.batchSize", 10000) + private val serializer: Serializer = SparkEnv.get.serializer + private val ser = serializer.newInstance() + private val spilledLists = new ArrayBuffer[Iterable[T]] + + private var curWriteMetrics: ShuffleWriteMetrics = _ + // Number of bytes spilled in total + private var _diskBytesSpilled = 0L + // Write metrics for current spill + private var list = new SizeTrackingCompactBuffer[T]() + private var numItems = 0 + + override def size() = numItems + + override def +=(value: T) = { + list += value + if (maybeSpill(list, list.estimateSize())) { + list = new SizeTrackingCompactBuffer + } + numItems += 1 + this + } + + override def clear(): Unit = { + spilledLists.clear() + list = new SizeTrackingCompactBuffer[T]() + } + + /** + * Spills the current in-memory collection to disk, and releases the memory. + * Logic is very similar to `ExternalAppendOnlyMap` - with the difference that + * we must hold iterables, not iterators, as this lists' iterator may be requested + * multiple times. + * + * @param collection collection to spill to disk + */ + override protected def spill(collection: SizeTrackingCompactBuffer[T]): Unit = { + val (blockId, file) = diskBlockManager.createTempLocalBlock() + curWriteMetrics = new ShuffleWriteMetrics() + var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) + var objectsWritten = 0 + + // List of batch sizes (bytes) in the order they are written to disk + val batchSizes = new ArrayBuffer[Long] + + // Flush the disk writer's contents to disk, and update relevant variables + def flush(): Unit = { + val w = writer + writer = null + w.commitAndClose() + _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten + batchSizes.append(curWriteMetrics.shuffleBytesWritten) + objectsWritten = 0 + } + + var success = false + try { + val it = list.iterator + while (it.hasNext) { + val kv = it.next() + writer.write(0, kv) + objectsWritten += 1 + + if (objectsWritten == serializerBatchSize) { + flush() + curWriteMetrics = new ShuffleWriteMetrics() + writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) + } + } + if (objectsWritten > 0) { + flush() + } else if (writer != null) { + val w = writer + writer = null + w.revertPartialWritesAndClose() + } + success = true + } finally { + if (!success) { + // This code path only happens if an exception was thrown above before we set success; + // close our stuff and let the exception be thrown further + if (writer != null) { + writer.revertPartialWritesAndClose() + } + if (file.exists()) { + file.delete() + } + } + } + + spilledLists += new DiskListIterable(file, blockId, batchSizes) + } + + override def iterator: Iterator[T] = { + val myIt = list.iterator + val allIts = spilledLists.map(_.iterator) ++ Seq(myIt) + allIts.foldLeft(Iterator[T]())(_ ++ _) + } + + private class DiskListIterable(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]) + extends Iterable[T] { + override def iterator: Iterator[T] = { + new DiskListIterator(file, blockId, batchSizes) + } + } + + private class DiskListIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]) + extends Iterator[T] { + private val batchOffsets = batchSizes.scanLeft(0L)(_ + _) // Size will be batchSize.length + 1 + assert(file.length() == batchOffsets.last, + "File length is not equal to the last batch offset:\n" + + s" file length = ${file.length}\n" + + s" last batch offset = ${batchOffsets.last}\n" + + s" all batch offsets = ${batchOffsets.mkString(",")}" + ) + + private var batchIndex = 0 // Which batch we're in + private var fileStream: FileInputStream = null + + // An intermediate stream that reads from exactly one batch + // This guards against pre-fetching and other arbitrary behavior of higher level streams + private var deserializeStream = nextBatchStream() + private var nextItem: Option[T] = None + private var objectsRead = 0 + + /** + * Construct a stream that reads only from the next batch. + */ + private def nextBatchStream(): DeserializationStream = { + // Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether + // we're still in a valid batch. + if (batchIndex < batchOffsets.length - 1) { + if (deserializeStream != null) { + deserializeStream.close() + fileStream.close() + deserializeStream = null + fileStream = null + } + + val start = batchOffsets(batchIndex) + fileStream = new FileInputStream(file) + fileStream.getChannel.position(start) + batchIndex += 1 + + val end = batchOffsets(batchIndex) + + assert(end >= start, "start = " + start + ", end = " + end + + ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) + + val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) + val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream) + ser.deserializeStream(compressedStream) + } else { + // No more batches left + cleanup() + null + } + } + + private def readNextItem(): Option[T] = { + try { + // Ignore the key because we only wrote 0S + deserializeStream.readKey() + val t = deserializeStream.readValue() + objectsRead += 1 + if (objectsRead == serializerBatchSize) { + objectsRead = 0 + deserializeStream = nextBatchStream() + } + Some(t) + } catch { + case e: EOFException => + cleanup() + None + } + } + + override def hasNext: Boolean = { + if (!nextItem.isDefined) { + if (deserializeStream == null) { + return false + } + nextItem = readNextItem() + } + nextItem.isDefined + } + + override def next(): T = { + val item = nextItem match { + case None => readNextItem() + case Some(theItem) => nextItem + } + if (!item.isDefined) { + throw new NoSuchElementException + } + nextItem = None + item match { + case Some(value) => value + case None => null.asInstanceOf[T] + } + } + + private def cleanup() { + batchIndex = batchOffsets.length // Prevent reading any other batch + val ds = deserializeStream + deserializeStream = null + fileStream = null + ds.close() + file.delete() + } + } + + @throws(classOf[IOException]) + private def writeObject(stream: ObjectOutputStream): Unit = { + stream.writeInt(this.size) + val it = this.iterator + while (it.hasNext) { + stream.writeObject(it.next) + } + } + + @throws(classOf[IOException]) + private def readObject(stream: ObjectInputStream): Unit = { + val listSize = stream.readInt() + list = new SizeTrackingCompactBuffer[T] + for(i <- 0L until listSize) { + val newItem = stream.readObject().asInstanceOf[T] + require(newItem != null) + this.+=(newItem) + } + } +} + +private[spark] object ExternalList { + def apply[T: ClassTag](): ExternalList[T] = new ExternalList[T] + + def apply[T: ClassTag](value: T): ExternalList[T] = { + val buf = new ExternalList[T] + buf += value + buf + } +} + +private[spark] class ExternalListSerializer[T: ClassTag] extends KSerializer[ExternalList[T]] { + override def write(kryo: Kryo, output: Output, list: ExternalList[T]): Unit = { + output.writeInt(list.size) + val it = list.iterator + while (it.hasNext) { + kryo.writeClassAndObject(output, it.next()) + } + } + + override def read(kryo: Kryo, input: Input, clazz: Class[ExternalList[T]]): ExternalList[T] = { + val listToRead = new ExternalList[T] + val listSize = input.readInt() + for (i <- 0L until listSize) { + val newItem = kryo.readClassAndObject(input).asInstanceOf[T] + listToRead += newItem + } + listToRead + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingCompactBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingCompactBuffer.scala new file mode 100644 index 0000000000000..00de7913f1491 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingCompactBuffer.scala @@ -0,0 +1,31 @@ +package org.apache.spark.util.collection + +import scala.reflect.ClassTag + +/** + * CompactBuffer that keeps track of its size via SizeTracker. + */ +private[spark] class SizeTrackingCompactBuffer[T: ClassTag] extends CompactBuffer[T] + with SizeTracker { + + override def +=(t: T) = { + super.+=(t) + super.afterUpdate() + this + } + + override def ++=(t: TraversableOnce[T]) = { + super.++=(t) + super.afterUpdate() + this + } +} + +private[spark] object SizeTrackingCompactBuffer { + def apply[T: ClassTag](): SizeTrackingCompactBuffer[T] = new SizeTrackingCompactBuffer[T] + + def apply[T: ClassTag](value: T): SizeTrackingCompactBuffer[T] = { + val buf = new SizeTrackingCompactBuffer[T] + buf += value + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 747ecf075a397..3b044c9b87de0 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -40,11 +40,11 @@ private[spark] trait Spillable[C] extends Logging { protected def addElementsRead(): Unit = { _elementsRead += 1 } // Memory manager that can be used to acquire/release memory - private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager + private[this] def shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager // Initial threshold for the size of a collection before we start tracking its memory usage // Exposed for testing - private[this] val initialMemoryThreshold: Long = + private[this] def initialMemoryThreshold: Long = SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024) // Threshold for this collection's size in bytes before we start tracking its memory usage diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala new file mode 100644 index 0000000000000..c5080fb97c489 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util.collection + +import org.apache.spark.serializer.{KryoSerializer, JavaSerializer, SerializerInstance} +import org.apache.spark.{SparkContext, SparkConf, SparkFunSuite} + +import scala.reflect.ClassTag + +class ExternalListSuite extends SparkFunSuite { + + val conf = new SparkConf(false) + conf.set("spark.kryoserializer.buffer.max", "2048m") + conf.set("spark.shuffle.spill.initialMemoryThreshold", "1") + conf.set("spark.shuffle.memoryFraction", "0.1") + conf.setMaster("local[8]") + conf.setAppName("test") + val sparkContext = new SparkContext(conf) + test("Serializing and deserializing a spilled list should produce the same values") { + var serializer = new KryoSerializer(new SparkConf()).newInstance() + val list = new ExternalList[Int] + for (i <- 0 to 10000000) { + list += i + } +// testSerialization(serializer, list) + serializer = new JavaSerializer(conf).newInstance() + testSerialization(serializer, list) + } + + private def testSerialization[T: ClassTag](serializer: SerializerInstance, list: ExternalList[T]): Unit = { + val bytes = serializer.serialize(list) + val readList = serializer.deserialize(bytes).asInstanceOf[ExternalList[Int]] + val originalIt = list.iterator + val readIt = readList.iterator + while (originalIt.hasNext) { + assert (originalIt.next == readIt.next) + } + assert (!readIt.hasNext) + } + +} From 3a1624e1eeba45df0c1a807b0dbeea946b021f9c Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 16 Jul 2015 18:14:50 -0700 Subject: [PATCH 002/340] Fix java serialization for ExternalList. --- .../spark/util/collection/ExternalList.scala | 42 ++++++++++++------- .../util/collection/ExternalListSuite.scala | 21 +++++++--- 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala index eceaa604d0cdb..be9153cf802aa 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala @@ -21,6 +21,7 @@ import java.io._ import com.esotericsoftware.kryo.io.{Output, Input} import com.esotericsoftware.kryo.{Kryo, Serializer => KSerializer} import com.google.common.io.ByteStreams +import org.apache.spark.util.collection.ExternalList._ import org.apache.spark.SparkEnv import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.serializer.{DeserializationStream, Serializer} @@ -35,27 +36,21 @@ import scala.reflect.ClassTag * Implementation is based heavily on `org.apache.spark.util.collection.ExternalAppendOnlyMap}` */ @SerialVersionUID(1L) -private[spark] class ExternalList[T: ClassTag] +private[spark] class ExternalList[T](implicit private var tag: ClassTag[T]) extends Growable[T] with Iterable[T] with Spillable[SizeTrackingCompactBuffer[T]] with Serializable { - private val sparkConf = SparkEnv.get.conf - private val blockManager: BlockManager = SparkEnv.get.blockManager - private val diskBlockManager = blockManager.diskBlockManager - private val fileBufferSize = - // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided - sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 - private val serializerBatchSize = sparkConf.getLong("spark.shuffle.spill.batchSize", 10000) - private val serializer: Serializer = SparkEnv.get.serializer - private val ser = serializer.newInstance() - private val spilledLists = new ArrayBuffer[Iterable[T]] + // Lazy vals so that this isn't created multiple times but still can be re-instantiated properly + // after serialization + private lazy val ser = serializer.newInstance() + private lazy val spilledLists = new ArrayBuffer[Iterable[T]] + // Write metrics for current spill private var curWriteMetrics: ShuffleWriteMetrics = _ // Number of bytes spilled in total private var _diskBytesSpilled = 0L - // Write metrics for current spill private var list = new SizeTrackingCompactBuffer[T]() private var numItems = 0 @@ -86,7 +81,8 @@ private[spark] class ExternalList[T: ClassTag] override protected def spill(collection: SizeTrackingCompactBuffer[T]): Unit = { val (blockId, file) = diskBlockManager.createTempLocalBlock() curWriteMetrics = new ShuffleWriteMetrics() - var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) + var writer = blockManager.getDiskWriter(blockId, file, ser, + fileBufferSize, curWriteMetrics) var objectsWritten = 0 // List of batch sizes (bytes) in the order they are written to disk @@ -113,7 +109,8 @@ private[spark] class ExternalList[T: ClassTag] if (objectsWritten == serializerBatchSize) { flush() curWriteMetrics = new ShuffleWriteMetrics() - writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) + writer = blockManager.getDiskWriter(blockId, file, ser, + fileBufferSize, curWriteMetrics) } } if (objectsWritten > 0) { @@ -261,6 +258,7 @@ private[spark] class ExternalList[T: ClassTag] @throws(classOf[IOException]) private def writeObject(stream: ObjectOutputStream): Unit = { + stream.writeObject(tag) stream.writeInt(this.size) val it = this.iterator while (it.hasNext) { @@ -270,17 +268,31 @@ private[spark] class ExternalList[T: ClassTag] @throws(classOf[IOException]) private def readObject(stream: ObjectInputStream): Unit = { + tag = stream.readObject().asInstanceOf[ClassTag[T]] val listSize = stream.readInt() list = new SizeTrackingCompactBuffer[T] for(i <- 0L until listSize) { val newItem = stream.readObject().asInstanceOf[T] - require(newItem != null) this.+=(newItem) } } } +/** + * Companion object for constants and singleton-references that we don't want to lose when + * Java-serializing + */ private[spark] object ExternalList { + // Defs so that they're not simply erased upon Java serialization + private val sparkConf = SparkEnv.get.conf + private val blockManager: BlockManager = SparkEnv.get.blockManager + private val diskBlockManager = blockManager.diskBlockManager + private val fileBufferSize = + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 + private val serializerBatchSize = sparkConf.getLong("spark.shuffle.spill.batchSize", 10000) + private val serializer: Serializer = SparkEnv.get.serializer + def apply[T: ClassTag](): ExternalList[T] = new ExternalList[T] def apply[T: ClassTag](value: T): ExternalList[T] = { diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala index c5080fb97c489..e04a2368fe3c9 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala @@ -26,22 +26,33 @@ class ExternalListSuite extends SparkFunSuite { val conf = new SparkConf(false) conf.set("spark.kryoserializer.buffer.max", "2048m") conf.set("spark.shuffle.spill.initialMemoryThreshold", "1") - conf.set("spark.shuffle.memoryFraction", "0.1") + conf.set("spark.shuffle.memoryFraction", "0.035") conf.setMaster("local[8]") conf.setAppName("test") val sparkContext = new SparkContext(conf) + test("Serializing and deserializing a spilled list should produce the same values") { var serializer = new KryoSerializer(new SparkConf()).newInstance() - val list = new ExternalList[Int] - for (i <- 0 to 10000000) { + var list = new ExternalList[Int] + // Test big list for Kryo because it's fast enough to handle it + // and we want to test the case where the list would spill to disk + for (i <- 0 to 8000000) { list += i } -// testSerialization(serializer, list) + testSerialization(serializer, list) serializer = new JavaSerializer(conf).newInstance() + list = new ExternalList[Int] + // Test smaller list for Java serialization since serializing with Java is + // really slow, and we already test serialization causing spilling in the Kryo case + for (i <- 0 to 100) { + list += i + } testSerialization(serializer, list) } - private def testSerialization[T: ClassTag](serializer: SerializerInstance, list: ExternalList[T]): Unit = { + private def testSerialization[T: ClassTag]( + serializer: SerializerInstance, + list: ExternalList[T]): Unit = { val bytes = serializer.serialize(list) val readList = serializer.deserialize(bytes).asInstanceOf[ExternalList[Int]] val originalIt = list.iterator From 8b3a51d432c70eafb290a67cb5c52e1126322229 Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 16 Jul 2015 18:17:45 -0700 Subject: [PATCH 003/340] Remove confusing comment --- .../scala/org/apache/spark/util/collection/ExternalList.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala index be9153cf802aa..dad0044bef93f 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala @@ -283,7 +283,6 @@ private[spark] class ExternalList[T](implicit private var tag: ClassTag[T]) * Java-serializing */ private[spark] object ExternalList { - // Defs so that they're not simply erased upon Java serialization private val sparkConf = SparkEnv.get.conf private val blockManager: BlockManager = SparkEnv.get.blockManager private val diskBlockManager = blockManager.diskBlockManager From a93445e5f9bd81ba1c04590e5ecf5307a4b22dfc Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 30 Jul 2015 09:39:13 -0700 Subject: [PATCH 004/340] Refactor logic common to both ExternalAppendOnlyMap and ExternalList --- .../collection/ExternalAppendOnlyMap.scala | 205 ++-------------- .../spark/util/collection/ExternalList.scala | 205 ++-------------- .../spark/util/collection/Spillable.scala | 20 +- .../util/collection/SpillableCollection.scala | 230 ++++++++++++++++++ .../util/collection/ExternalListSuite.scala | 6 +- 5 files changed, 276 insertions(+), 390 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 1e4531ef395ae..f96d431687eb5 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -29,7 +29,7 @@ import com.google.common.io.ByteStreams import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.{DeserializationStream, Serializer} -import org.apache.spark.storage.{BlockId, BlockManager} +import org.apache.spark.storage.{BlockObjectWriter, BlockId, BlockManager} import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator import org.apache.spark.executor.ShuffleWriteMetrics @@ -69,36 +69,11 @@ class ExternalAppendOnlyMap[K, V, C]( extends Iterable[(K, C)] with Serializable with Logging - with Spillable[SizeTracker] { + with SpillableCollection[(K, C), SizeTrackingAppendOnlyMap[K, C]] { private var currentMap = new SizeTrackingAppendOnlyMap[K, C] private val spilledMaps = new ArrayBuffer[DiskMapIterator] - private val sparkConf = SparkEnv.get.conf - private val diskBlockManager = blockManager.diskBlockManager - - /** - * Size of object batches when reading/writing from serializers. - * - * Objects are written in batches, with each batch using its own serialization stream. This - * cuts down on the size of reference-tracking maps constructed when deserializing a stream. - * - * NOTE: Setting this too low can cause excessive copying when serializing, since some serializers - * grow internal data structures by growing + copying every time the number of objects doubles. - */ - private val serializerBatchSize = sparkConf.getLong("spark.shuffle.spill.batchSize", 10000) - - // Number of bytes spilled in total - private var _diskBytesSpilled = 0L - - // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided - private val fileBufferSize = - sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 - - // Write metrics for current spill - private var curWriteMetrics: ShuffleWriteMetrics = _ - private val keyComparator = new HashComparator[K] - private val ser = serializer.newInstance() /** * Insert the given key and value into the map. @@ -147,66 +122,6 @@ class ExternalAppendOnlyMap[K, V, C]( insertAll(entries.iterator) } - /** - * Sort the existing contents of the in-memory map and spill them to a temporary file on disk. - */ - override protected[this] def spill(collection: SizeTracker): Unit = { - val (blockId, file) = diskBlockManager.createTempLocalBlock() - curWriteMetrics = new ShuffleWriteMetrics() - var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) - var objectsWritten = 0 - - // List of batch sizes (bytes) in the order they are written to disk - val batchSizes = new ArrayBuffer[Long] - - // Flush the disk writer's contents to disk, and update relevant variables - def flush(): Unit = { - val w = writer - writer = null - w.commitAndClose() - _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten - batchSizes.append(curWriteMetrics.shuffleBytesWritten) - objectsWritten = 0 - } - - var success = false - try { - val it = currentMap.destructiveSortedIterator(keyComparator) - while (it.hasNext) { - val kv = it.next() - writer.write(kv._1, kv._2) - objectsWritten += 1 - - if (objectsWritten == serializerBatchSize) { - flush() - curWriteMetrics = new ShuffleWriteMetrics() - writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) - } - } - if (objectsWritten > 0) { - flush() - } else if (writer != null) { - val w = writer - writer = null - w.revertPartialWritesAndClose() - } - success = true - } finally { - if (!success) { - // This code path only happens if an exception was thrown above before we set success; - // close our stuff and let the exception be thrown further - if (writer != null) { - writer.revertPartialWritesAndClose() - } - if (file.exists()) { - file.delete() - } - } - } - - spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) - } - def diskBytesSpilled: Long = _diskBytesSpilled /** @@ -374,115 +289,27 @@ class ExternalAppendOnlyMap[K, V, C]( * An iterator that returns (K, C) pairs in sorted order from an on-disk map */ private class DiskMapIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]) - extends Iterator[(K, C)] + extends DiskIterator(file, blockId, batchSizes) { - private val batchOffsets = batchSizes.scanLeft(0L)(_ + _) // Size will be batchSize.length + 1 - assert(file.length() == batchOffsets.last, - "File length is not equal to the last batch offset:\n" + - s" file length = ${file.length}\n" + - s" last batch offset = ${batchOffsets.last}\n" + - s" all batch offsets = ${batchOffsets.mkString(",")}" - ) - - private var batchIndex = 0 // Which batch we're in - private var fileStream: FileInputStream = null - - // An intermediate stream that reads from exactly one batch - // This guards against pre-fetching and other arbitrary behavior of higher level streams - private var deserializeStream = nextBatchStream() - private var nextItem: (K, C) = null - private var objectsRead = 0 - - /** - * Construct a stream that reads only from the next batch. - */ - private def nextBatchStream(): DeserializationStream = { - // Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether - // we're still in a valid batch. - if (batchIndex < batchOffsets.length - 1) { - if (deserializeStream != null) { - deserializeStream.close() - fileStream.close() - deserializeStream = null - fileStream = null - } - - val start = batchOffsets(batchIndex) - fileStream = new FileInputStream(file) - fileStream.getChannel.position(start) - batchIndex += 1 - - val end = batchOffsets(batchIndex) - - assert(end >= start, "start = " + start + ", end = " + end + - ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) - - val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) - val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream) - ser.deserializeStream(compressedStream) - } else { - // No more batches left - cleanup() - null - } - } - - /** - * Return the next (K, C) pair from the deserialization stream. - * - * If the current batch is drained, construct a stream for the next batch and read from it. - * If no more pairs are left, return null. - */ - private def readNextItem(): (K, C) = { - try { - val k = deserializeStream.readKey().asInstanceOf[K] - val c = deserializeStream.readValue().asInstanceOf[C] - val item = (k, c) - objectsRead += 1 - if (objectsRead == serializerBatchSize) { - objectsRead = 0 - deserializeStream = nextBatchStream() - } - item - } catch { - case e: EOFException => - cleanup() - null - } + override protected def readNextItemFromStream(deserializeStream: DeserializationStream): (K, C) = { + val k = deserializeStream.readKey().asInstanceOf[K] + val v = deserializeStream.readValue().asInstanceOf[C] + (k, v) } + } - override def hasNext: Boolean = { - if (nextItem == null) { - if (deserializeStream == null) { - return false - } - nextItem = readNextItem() - } - nextItem != null - } + /** Convenience function to hash the given (K, C) pair by the key. */ + private def hashKey(kc: (K, C)): Int = ExternalAppendOnlyMap.hash(kc._1) - override def next(): (K, C) = { - val item = if (nextItem == null) readNextItem() else nextItem - if (item == null) { - throw new NoSuchElementException - } - nextItem = null - item - } + override protected def getIteratorForCurrentSpillable(): Iterator[(K, C)] = currentMap.destructiveSortedIterator(keyComparator) - // TODO: Ensure this gets called even if the iterator isn't drained. - private def cleanup() { - batchIndex = batchOffsets.length // Prevent reading any other batch - val ds = deserializeStream - deserializeStream = null - fileStream = null - ds.close() - file.delete() - } + override protected def recordNextSpilledPart(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]): Unit = { + spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) } - /** Convenience function to hash the given (K, C) pair by the key. */ - private def hashKey(kc: (K, C)): Int = ExternalAppendOnlyMap.hash(kc._1) + override protected def writeNextObject(c: (K, C), writer: BlockObjectWriter): Unit = { + writer.write(c._1, c._2) + } } private[spark] object ExternalAppendOnlyMap { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala index dad0044bef93f..1846e8ea93e29 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala @@ -20,12 +20,8 @@ import java.io._ import com.esotericsoftware.kryo.io.{Output, Input} import com.esotericsoftware.kryo.{Kryo, Serializer => KSerializer} -import com.google.common.io.ByteStreams -import org.apache.spark.util.collection.ExternalList._ -import org.apache.spark.SparkEnv -import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.serializer.{DeserializationStream, Serializer} -import org.apache.spark.storage.{BlockId, BlockManager} +import org.apache.spark.serializer.DeserializationStream +import org.apache.spark.storage.{BlockObjectWriter, BlockId} import scala.collection.generic.Growable import scala.collection.mutable.ArrayBuffer @@ -39,18 +35,13 @@ import scala.reflect.ClassTag private[spark] class ExternalList[T](implicit private var tag: ClassTag[T]) extends Growable[T] with Iterable[T] - with Spillable[SizeTrackingCompactBuffer[T]] + with SpillableCollection[T, SizeTrackingCompactBuffer[T]] with Serializable { // Lazy vals so that this isn't created multiple times but still can be re-instantiated properly // after serialization - private lazy val ser = serializer.newInstance() private lazy val spilledLists = new ArrayBuffer[Iterable[T]] - // Write metrics for current spill - private var curWriteMetrics: ShuffleWriteMetrics = _ - // Number of bytes spilled in total - private var _diskBytesSpilled = 0L private var list = new SizeTrackingCompactBuffer[T]() private var numItems = 0 @@ -70,73 +61,6 @@ private[spark] class ExternalList[T](implicit private var tag: ClassTag[T]) list = new SizeTrackingCompactBuffer[T]() } - /** - * Spills the current in-memory collection to disk, and releases the memory. - * Logic is very similar to `ExternalAppendOnlyMap` - with the difference that - * we must hold iterables, not iterators, as this lists' iterator may be requested - * multiple times. - * - * @param collection collection to spill to disk - */ - override protected def spill(collection: SizeTrackingCompactBuffer[T]): Unit = { - val (blockId, file) = diskBlockManager.createTempLocalBlock() - curWriteMetrics = new ShuffleWriteMetrics() - var writer = blockManager.getDiskWriter(blockId, file, ser, - fileBufferSize, curWriteMetrics) - var objectsWritten = 0 - - // List of batch sizes (bytes) in the order they are written to disk - val batchSizes = new ArrayBuffer[Long] - - // Flush the disk writer's contents to disk, and update relevant variables - def flush(): Unit = { - val w = writer - writer = null - w.commitAndClose() - _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten - batchSizes.append(curWriteMetrics.shuffleBytesWritten) - objectsWritten = 0 - } - - var success = false - try { - val it = list.iterator - while (it.hasNext) { - val kv = it.next() - writer.write(0, kv) - objectsWritten += 1 - - if (objectsWritten == serializerBatchSize) { - flush() - curWriteMetrics = new ShuffleWriteMetrics() - writer = blockManager.getDiskWriter(blockId, file, ser, - fileBufferSize, curWriteMetrics) - } - } - if (objectsWritten > 0) { - flush() - } else if (writer != null) { - val w = writer - writer = null - w.revertPartialWritesAndClose() - } - success = true - } finally { - if (!success) { - // This code path only happens if an exception was thrown above before we set success; - // close our stuff and let the exception be thrown further - if (writer != null) { - writer.revertPartialWritesAndClose() - } - if (file.exists()) { - file.delete() - } - } - } - - spilledLists += new DiskListIterable(file, blockId, batchSizes) - } - override def iterator: Iterator[T] = { val myIt = list.iterator val allIts = spilledLists.map(_.iterator) ++ Seq(myIt) @@ -151,108 +75,10 @@ private[spark] class ExternalList[T](implicit private var tag: ClassTag[T]) } private class DiskListIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]) - extends Iterator[T] { - private val batchOffsets = batchSizes.scanLeft(0L)(_ + _) // Size will be batchSize.length + 1 - assert(file.length() == batchOffsets.last, - "File length is not equal to the last batch offset:\n" + - s" file length = ${file.length}\n" + - s" last batch offset = ${batchOffsets.last}\n" + - s" all batch offsets = ${batchOffsets.mkString(",")}" - ) - - private var batchIndex = 0 // Which batch we're in - private var fileStream: FileInputStream = null - - // An intermediate stream that reads from exactly one batch - // This guards against pre-fetching and other arbitrary behavior of higher level streams - private var deserializeStream = nextBatchStream() - private var nextItem: Option[T] = None - private var objectsRead = 0 - - /** - * Construct a stream that reads only from the next batch. - */ - private def nextBatchStream(): DeserializationStream = { - // Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether - // we're still in a valid batch. - if (batchIndex < batchOffsets.length - 1) { - if (deserializeStream != null) { - deserializeStream.close() - fileStream.close() - deserializeStream = null - fileStream = null - } - - val start = batchOffsets(batchIndex) - fileStream = new FileInputStream(file) - fileStream.getChannel.position(start) - batchIndex += 1 - - val end = batchOffsets(batchIndex) - - assert(end >= start, "start = " + start + ", end = " + end + - ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) - - val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) - val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream) - ser.deserializeStream(compressedStream) - } else { - // No more batches left - cleanup() - null - } - } - - private def readNextItem(): Option[T] = { - try { - // Ignore the key because we only wrote 0S - deserializeStream.readKey() - val t = deserializeStream.readValue() - objectsRead += 1 - if (objectsRead == serializerBatchSize) { - objectsRead = 0 - deserializeStream = nextBatchStream() - } - Some(t) - } catch { - case e: EOFException => - cleanup() - None - } - } - - override def hasNext: Boolean = { - if (!nextItem.isDefined) { - if (deserializeStream == null) { - return false - } - nextItem = readNextItem() - } - nextItem.isDefined - } - - override def next(): T = { - val item = nextItem match { - case None => readNextItem() - case Some(theItem) => nextItem - } - if (!item.isDefined) { - throw new NoSuchElementException - } - nextItem = None - item match { - case Some(value) => value - case None => null.asInstanceOf[T] - } - } - - private def cleanup() { - batchIndex = batchOffsets.length // Prevent reading any other batch - val ds = deserializeStream - deserializeStream = null - fileStream = null - ds.close() - file.delete() + extends DiskIterator(file, blockId, batchSizes) { + override protected def readNextItemFromStream(deserializeStream: DeserializationStream): T = { + deserializeStream.readKey[Int]() + deserializeStream.readValue[T]() } } @@ -276,6 +102,14 @@ private[spark] class ExternalList[T](implicit private var tag: ClassTag[T]) this.+=(newItem) } } + + override protected def getIteratorForCurrentSpillable(): Iterator[T] = list.iterator + override protected def recordNextSpilledPart(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]): Unit = { + spilledLists += new DiskListIterable(file, blockId, batchSizes) + } + override protected def writeNextObject(c: T, writer: BlockObjectWriter): Unit = { + writer.write(0, c) + } } /** @@ -283,15 +117,6 @@ private[spark] class ExternalList[T](implicit private var tag: ClassTag[T]) * Java-serializing */ private[spark] object ExternalList { - private val sparkConf = SparkEnv.get.conf - private val blockManager: BlockManager = SparkEnv.get.blockManager - private val diskBlockManager = blockManager.diskBlockManager - private val fileBufferSize = - // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided - sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 - private val serializerBatchSize = sparkConf.getLong("spark.shuffle.spill.batchSize", 10000) - private val serializer: Serializer = SparkEnv.get.serializer - def apply[T: ClassTag](): ExternalList[T] = new ExternalList[T] def apply[T: ClassTag](value: T): ExternalList[T] = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 3b044c9b87de0..a710d618f3d23 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -19,6 +19,7 @@ package org.apache.spark.util.collection import org.apache.spark.Logging import org.apache.spark.SparkEnv +import org.apache.spark.util.collection.Spillable._ /** * Spills contents of an in-memory collection to disk when the memory threshold @@ -39,14 +40,6 @@ private[spark] trait Spillable[C] extends Logging { // It's used for checking spilling frequency protected def addElementsRead(): Unit = { _elementsRead += 1 } - // Memory manager that can be used to acquire/release memory - private[this] def shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager - - // Initial threshold for the size of a collection before we start tracking its memory usage - // Exposed for testing - private[this] def initialMemoryThreshold: Long = - SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024) - // Threshold for this collection's size in bytes before we start tracking its memory usage // To avoid a large number of small spills, initialize this to a value orders of magnitude > 0 private[this] var myMemoryThreshold = initialMemoryThreshold @@ -117,4 +110,15 @@ private[spark] trait Spillable[C] extends Logging { .format(threadId, org.apache.spark.util.Utils.bytesToString(size), _spillCount, if (_spillCount > 1) "s" else "")) } + +} + +private object Spillable { + // Memory manager that can be used to acquire/release memory + protected val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager + + // Initial threshold for the size of a collection before we start tracking its memory usage + // Exposed for testing + protected val initialMemoryThreshold: Long = + SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024) } diff --git a/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala new file mode 100644 index 0000000000000..fd6d422836053 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import com.google.common.io.ByteStreams +import org.apache.spark.util.collection.SpillableCollection._ +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.serializer.{DeserializationStream, Serializer} +import org.apache.spark.storage.{DiskBlockManager, BlockId, BlockObjectWriter, BlockManager} + +import java.io.{EOFException, BufferedInputStream, FileInputStream, File} +import scala.collection.mutable.ArrayBuffer + +/** + * Collection that can spill to disk. Takes type parameters T, the iterable type, and + * C, the type of the elements returned by T's iterator. + */ +private[spark] trait SpillableCollection[C, T <: Iterable[C]] extends Spillable[T] { + // Write metrics for current spill + private var curWriteMetrics: ShuffleWriteMetrics = _ + // Number of bytes spilled in total + protected var _diskBytesSpilled = 0L + private lazy val ser = serializer.newInstance() + + override protected def spill(collection: T): Unit = { + val (blockId, file) = diskBlockManager.createTempLocalBlock() + curWriteMetrics = new ShuffleWriteMetrics() + var writer = blockManager.getDiskWriter(blockId, file, ser, + fileBufferSize, curWriteMetrics) + var objectsWritten = 0 + + // List of batch sizes (bytes) in the order they are written to disk + val batchSizes = new ArrayBuffer[Long] + + // Flush the disk writer's contents to disk, and update relevant variables + def flush(): Unit = { + val w = writer + writer = null + w.commitAndClose() + _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten + batchSizes.append(curWriteMetrics.shuffleBytesWritten) + objectsWritten = 0 + } + + var success = false + try { + val it = getIteratorForCurrentSpillable() + while (it.hasNext) { + val kv = it.next() + writeNextObject(kv, writer) + objectsWritten += 1 + + if (objectsWritten == serializerBatchSize) { + flush() + curWriteMetrics = new ShuffleWriteMetrics() + writer = blockManager.getDiskWriter(blockId, file, ser, + fileBufferSize, curWriteMetrics) + } + } + if (objectsWritten > 0) { + flush() + } else if (writer != null) { + val w = writer + writer = null + w.revertPartialWritesAndClose() + } + success = true + } finally { + if (!success) { + // This code path only happens if an exception was thrown above before we set success; + // close our stuff and let the exception be thrown further + if (writer != null) { + writer.revertPartialWritesAndClose() + } + if (file.exists()) { + file.delete() + } + } + } + + recordNextSpilledPart(file, blockId, batchSizes) + } + + protected def getIteratorForCurrentSpillable(): Iterator[C] + protected def writeNextObject(c: C, writer: BlockObjectWriter): Unit + protected def recordNextSpilledPart(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]) + + /** + * Iterator backed by elements from batches on disk. + */ + protected abstract class DiskIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]) + extends Iterator[C] { + private val batchOffsets = batchSizes.scanLeft(0L)(_ + _) // Size will be batchSize.length + 1 + assert(file.length() == batchOffsets.last, + "File length is not equal to the last batch offset:\n" + + s" file length = ${file.length}\n" + + s" last batch offset = ${batchOffsets.last}\n" + + s" all batch offsets = ${batchOffsets.mkString(",")}" + ) + + private var batchIndex = 0 // Which batch we're in + private var fileStream: FileInputStream = null + + // An intermediate stream that reads from exactly one batch + // This guards against pre-fetching and other arbitrary behavior of higher level streams + private var deserializeStream = nextBatchStream() + private var nextItem: Option[C] = None + private var objectsRead = 0 + + /** + * Construct a stream that reads only from the next batch. + */ + protected def nextBatchStream(): DeserializationStream = { + // Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether + // we're still in a valid batch. + if (batchIndex < batchOffsets.length - 1) { + if (deserializeStream != null) { + deserializeStream.close() + fileStream.close() + deserializeStream = null + fileStream = null + } + + val start = batchOffsets(batchIndex) + fileStream = new FileInputStream(file) + fileStream.getChannel.position(start) + batchIndex += 1 + + val end = batchOffsets(batchIndex) + + assert(end >= start, "start = " + start + ", end = " + end + + ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) + + val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) + val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream) + ser.deserializeStream(compressedStream) + } else { + // No more batches left + cleanup() + null + } + } + + /** + * Return the next item from the deserialization stream. + * + * If the current batch is drained, construct a stream for the next batch and read from it. + * If no more items are left, return null. + */ + protected def readNextItem(): Option[C] = { + try { + val item = readNextItemFromStream(deserializeStream) + objectsRead += 1 + if (objectsRead == serializerBatchSize) { + objectsRead = 0 + deserializeStream = nextBatchStream() + } + Some(item) + } catch { + case e: EOFException => + cleanup() + None + } + } + + private def cleanup() { + batchIndex = batchOffsets.length // Prevent reading any other batch + val ds = deserializeStream + deserializeStream = null + fileStream = null + ds.close() + file.delete() + } + + override def hasNext: Boolean = { + if (!nextItem.isDefined) { + if (deserializeStream == null) { + return false + } + nextItem = readNextItem() + } + nextItem.isDefined + } + + override def next(): C = { + val item = nextItem match { + case None => readNextItem() + case Some(theItem) => nextItem + } + if (!item.isDefined) { + throw new NoSuchElementException + } + nextItem = None + item match { + case Some(value) => value + case None => null.asInstanceOf[C] + } + } + + protected def readNextItemFromStream(deserializeStream: DeserializationStream): C + } +} + +// Visible and modifiable only for testing +private[collection] object SpillableCollection { + private def sparkConf(): SparkConf = SparkEnv.get.conf + private def blockManager(): BlockManager = SparkEnv.get.blockManager + private def diskBlockManager(): DiskBlockManager = blockManager.diskBlockManager + private def fileBufferSize(): Int = + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 + private def serializerBatchSize(): Long = sparkConf.getLong("spark.shuffle.spill.batchSize", 10000) + private def serializer(): Serializer = SparkEnv.get.serializer +} diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala index e04a2368fe3c9..d53164c550968 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala @@ -24,7 +24,7 @@ import scala.reflect.ClassTag class ExternalListSuite extends SparkFunSuite { val conf = new SparkConf(false) - conf.set("spark.kryoserializer.buffer.max", "2048m") + conf.set("spark.kryoserializer.buffer.max", "2046m") conf.set("spark.shuffle.spill.initialMemoryThreshold", "1") conf.set("spark.shuffle.memoryFraction", "0.035") conf.setMaster("local[8]") @@ -32,7 +32,7 @@ class ExternalListSuite extends SparkFunSuite { val sparkContext = new SparkContext(conf) test("Serializing and deserializing a spilled list should produce the same values") { - var serializer = new KryoSerializer(new SparkConf()).newInstance() + var serializer = new KryoSerializer(conf).newInstance() var list = new ExternalList[Int] // Test big list for Kryo because it's fast enough to handle it // and we want to test the case where the list would spill to disk @@ -44,7 +44,7 @@ class ExternalListSuite extends SparkFunSuite { list = new ExternalList[Int] // Test smaller list for Java serialization since serializing with Java is // really slow, and we already test serialization causing spilling in the Kryo case - for (i <- 0 to 100) { + for (i <- 0 to 1000) { list += i } testSerialization(serializer, list) From 86d42c4cdff36396131e09905a6d784385ea064c Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 30 Jul 2015 13:52:59 -0700 Subject: [PATCH 005/340] Added a formal group by unit test for spilling. Also organized imports, and fixed a bug where the external list cannot be iterated through twice because the file is cleaned up. --- .../apache/spark/rdd/PairRDDFunctions.scala | 21 +++++++++--- .../collection/ExternalAppendOnlyMap.scala | 5 ++- .../spark/util/collection/ExternalList.scala | 11 +++++-- .../util/collection/SpillableCollection.scala | 26 +++++++++++---- .../util/collection/ExternalListSuite.scala | 32 ++++++++++++++++--- 5 files changed, 73 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index ef74e85203a7a..3533a38fcd278 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -45,7 +45,7 @@ import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.util.collection.{ExternalList, SizeTrackingCompactBuffer, CompactBuffer} +import org.apache.spark.util.collection.{ExternalSorter, ExternalList, SizeTrackingCompactBuffer, CompactBuffer} import org.apache.spark.util.random.StratifiedSamplingUtils /** @@ -469,9 +469,22 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) c2.foreach(c => c1 += c) c1 } - val bufs = combineByKey[ExternalList[V]]( - createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine = false) - bufs.asInstanceOf[RDD[(K, Iterable[V])]] + val aggregator = new Aggregator[K, V, ExternalList[V]](createCombiner, mergeValue, mergeCombiners) + val shuffledRdd = if (self.partitioner != partitioner) { + self.partitionBy(partitioner) + } else { + self + } + def groupOnPartition(iterator: Iterator[(K, V)]): Iterator[(K, Iterable[V])] = { + val sorter = new ExternalSorter[K, V, ExternalList[V]](aggregator = Some(aggregator)) + sorter.insertAll(iterator) + sorter.iterator.map(keyAndGroup => (keyAndGroup._1, keyAndGroup._2.asInstanceOf[Iterable[V]])) + } + + shuffledRdd.mapPartitions(groupOnPartition(_), preservesPartitioning = true) + //val bufs = combineByKey[ExternalList[V]]( + // createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine = false) + //bufs.asInstanceOf[RDD[(K, Iterable[V])]] } /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index f96d431687eb5..4b919b1871fcd 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -24,14 +24,11 @@ import scala.collection.BufferedIterator import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import com.google.common.io.ByteStreams - import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.{DeserializationStream, Serializer} import org.apache.spark.storage.{BlockObjectWriter, BlockId, BlockManager} import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator -import org.apache.spark.executor.ShuffleWriteMetrics /** * :: DeveloperApi :: @@ -296,6 +293,8 @@ class ExternalAppendOnlyMap[K, V, C]( val v = deserializeStream.readValue().asInstanceOf[C] (k, v) } + + override protected def shouldCleanupFileAfterOneIteration(): Boolean = true } /** Convenience function to hash the given (K, C) pair by the key. */ diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala index 1846e8ea93e29..cf4d737b1d400 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala @@ -18,14 +18,16 @@ package org.apache.spark.util.collection import java.io._ +import scala.reflect.ClassTag +import scala.collection.generic.Growable +import scala.collection.mutable.ArrayBuffer + import com.esotericsoftware.kryo.io.{Output, Input} import com.esotericsoftware.kryo.{Kryo, Serializer => KSerializer} + import org.apache.spark.serializer.DeserializationStream import org.apache.spark.storage.{BlockObjectWriter, BlockId} -import scala.collection.generic.Growable -import scala.collection.mutable.ArrayBuffer -import scala.reflect.ClassTag /** * List that can spill some of its contents to disk if its contents cannot be held in memory. @@ -80,6 +82,9 @@ private[spark] class ExternalList[T](implicit private var tag: ClassTag[T]) deserializeStream.readKey[Int]() deserializeStream.readValue[T]() } + + // Need to be able to iterate multiple times, so don't clean up the file every time + override protected def shouldCleanupFileAfterOneIteration(): Boolean = false } @throws(classOf[IOException]) diff --git a/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala index fd6d422836053..2aa1e006e24c6 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala @@ -17,17 +17,20 @@ package org.apache.spark.util.collection +import java.io.{EOFException, BufferedInputStream, FileInputStream, File} + +import scala.collection.mutable.ArrayBuffer + import com.google.common.io.ByteStreams -import org.apache.spark.util.collection.SpillableCollection._ + import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.serializer.{DeserializationStream, Serializer} import org.apache.spark.storage.{DiskBlockManager, BlockId, BlockObjectWriter, BlockManager} - -import java.io.{EOFException, BufferedInputStream, FileInputStream, File} -import scala.collection.mutable.ArrayBuffer +import org.apache.spark.util.collection.SpillableCollection._ /** + * * Collection that can spill to disk. Takes type parameters T, the iterable type, and * C, the type of the elements returned by T's iterator. */ @@ -180,12 +183,20 @@ private[spark] trait SpillableCollection[C, T <: Iterable[C]] extends Spillable[ } private def cleanup() { - batchIndex = batchOffsets.length // Prevent reading any other batch + batchIndex = batchOffsets.length // Prevent reading any other batch val ds = deserializeStream deserializeStream = null + if (ds != null) { + ds.close() + } + val fs = fileStream fileStream = null - ds.close() - file.delete() + if (fs != null) { + fs.close() + } + if (shouldCleanupFileAfterOneIteration()) { + file.delete() + } } override def hasNext: Boolean = { @@ -214,6 +225,7 @@ private[spark] trait SpillableCollection[C, T <: Iterable[C]] extends Spillable[ } protected def readNextItemFromStream(deserializeStream: DeserializationStream): C + protected def shouldCleanupFileAfterOneIteration(): Boolean } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala index d53164c550968..025f1bbd9c45a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala @@ -16,19 +16,24 @@ */ package org.apache.spark.util.collection -import org.apache.spark.serializer.{KryoSerializer, JavaSerializer, SerializerInstance} -import org.apache.spark.{SparkContext, SparkConf, SparkFunSuite} - import scala.reflect.ClassTag -class ExternalListSuite extends SparkFunSuite { +import org.apache.spark.{SparkContext, SparkConf, SparkFunSuite} +import org.apache.spark.serializer.{KryoSerializer, JavaSerializer, SerializerInstance} +import org.junit.Assert.assertEquals + +class ExternalListSuite extends SparkFunSuite with Serializable { val conf = new SparkConf(false) conf.set("spark.kryoserializer.buffer.max", "2046m") conf.set("spark.shuffle.spill.initialMemoryThreshold", "1") + conf.set("spark.shuffle.spill.batchSize", "10") conf.set("spark.shuffle.memoryFraction", "0.035") + conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") + conf.set("spark.task.maxFailures", "1") conf.setMaster("local[8]") conf.setAppName("test") + val sparkContext = new SparkContext(conf) test("Serializing and deserializing a spilled list should produce the same values") { @@ -36,7 +41,7 @@ class ExternalListSuite extends SparkFunSuite { var list = new ExternalList[Int] // Test big list for Kryo because it's fast enough to handle it // and we want to test the case where the list would spill to disk - for (i <- 0 to 8000000) { + for (i <- 0 to 5000000) { list += i } testSerialization(serializer, list) @@ -50,6 +55,23 @@ class ExternalListSuite extends SparkFunSuite { testSerialization(serializer, list) } + test("Group by key with spilling list") { + val totalRddSize = 7200000 + val numBuckets = 5 + val rawLargeRdd = sparkContext.parallelize(1 to totalRddSize) + val groupedRdd = rawLargeRdd.map(x => (x % numBuckets, x)).groupByKey + def validateList(kv: (Int, Iterable[Int])): Unit = { + var numItems = 0 + for (valsInBucket <- kv._2) { + numItems += 1 + // Can't use scala assertions because including assert statements makes closures not serializable. + assertEquals(s"Value $valsInBucket should not be in bucket ${kv._1}", kv._1, valsInBucket % numBuckets) + } + assertEquals(s"Number of items in bucket ${kv._1} is incorrect.", totalRddSize / numBuckets, numItems) + } + groupedRdd.foreach(validateList(_)) + } + private def testSerialization[T: ClassTag]( serializer: SerializerInstance, list: ExternalList[T]): Unit = { From d891b75d25e152b2304548944c5e642329a148cb Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 30 Jul 2015 14:57:16 -0700 Subject: [PATCH 006/340] Fix merge conflict compiler errors --- .../collection/ExternalAppendOnlyMap.scala | 31 ++++--------------- .../spark/util/collection/ExternalList.scala | 4 +-- .../util/collection/SpillableCollection.scala | 4 +-- 3 files changed, 10 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 85c5fe693dc38..b33f12eee9bdd 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -27,7 +27,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.{DeserializationStream, Serializer} -import org.apache.spark.storage.{BlockObjectWriter, BlockId, BlockManager} +import org.apache.spark.storage.{DiskBlockObjectWriter, BlockId, BlockManager} import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator /** @@ -297,37 +297,18 @@ class ExternalAppendOnlyMap[K, V, C]( override protected def shouldCleanupFileAfterOneIteration(): Boolean = true } + /** Convenience function to hash the given (K, C) pair by the key. */ private def hashKey(kc: (K, C)): Int = ExternalAppendOnlyMap.hash(kc._1) override protected def getIteratorForCurrentSpillable(): Iterator[(K, C)] = currentMap.destructiveSortedIterator(keyComparator) - private def cleanup() { - batchIndex = batchOffsets.length // Prevent reading any other batch - val ds = deserializeStream - if (ds != null) { - ds.close() - deserializeStream = null - } - if (fileStream != null) { - fileStream.close() - fileStream = null - } - if (file.exists()) { - file.delete() - } - } - - val context = TaskContext.get() - // context is null in some tests of ExternalAppendOnlyMapSuite because these tests don't run in - // a TaskContext. - if (context != null) { - context.addTaskCompletionListener(context => cleanup()) - } + override protected def writeNextObject(c: (K, C), writer: DiskBlockObjectWriter): Unit = { + writer.write(c._1, c._2) } - override protected def writeNextObject(c: (K, C), writer: BlockObjectWriter): Unit = { - writer.write(c._1, c._2) + override protected def recordNextSpilledPart(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]): Unit = { + spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala index cf4d737b1d400..110b304cc0f4d 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala @@ -26,7 +26,7 @@ import com.esotericsoftware.kryo.io.{Output, Input} import com.esotericsoftware.kryo.{Kryo, Serializer => KSerializer} import org.apache.spark.serializer.DeserializationStream -import org.apache.spark.storage.{BlockObjectWriter, BlockId} +import org.apache.spark.storage.{DiskBlockObjectWriter, BlockId} /** @@ -112,7 +112,7 @@ private[spark] class ExternalList[T](implicit private var tag: ClassTag[T]) override protected def recordNextSpilledPart(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]): Unit = { spilledLists += new DiskListIterable(file, blockId, batchSizes) } - override protected def writeNextObject(c: T, writer: BlockObjectWriter): Unit = { + override protected def writeNextObject(c: T, writer: DiskBlockObjectWriter): Unit = { writer.write(0, c) } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala index 2aa1e006e24c6..75adef126e7cb 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala @@ -26,7 +26,7 @@ import com.google.common.io.ByteStreams import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.serializer.{DeserializationStream, Serializer} -import org.apache.spark.storage.{DiskBlockManager, BlockId, BlockObjectWriter, BlockManager} +import org.apache.spark.storage.{DiskBlockManager, BlockId, DiskBlockObjectWriter, BlockManager} import org.apache.spark.util.collection.SpillableCollection._ /** @@ -101,7 +101,7 @@ private[spark] trait SpillableCollection[C, T <: Iterable[C]] extends Spillable[ } protected def getIteratorForCurrentSpillable(): Iterator[C] - protected def writeNextObject(c: C, writer: BlockObjectWriter): Unit + protected def writeNextObject(c: C, writer: DiskBlockObjectWriter): Unit protected def recordNextSpilledPart(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]) /** From 0dbd6963d589a8f6ad344273f3da7df680ada515 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 30 Jul 2015 15:39:46 -0700 Subject: [PATCH 007/340] [SPARK-9479] [STREAMING] [TESTS] Fix ReceiverTrackerSuite failure for maven build and other potential test failures in Streaming See https://issues.apache.org/jira/browse/SPARK-9479 for the failure cause. The PR includes the following changes: 1. Make ReceiverTrackerSuite create StreamingContext in the test body. 2. Fix places that don't stop StreamingContext. I verified no SparkContext was stopped in the shutdown hook locally after this fix. 3. Fix an issue that `ReceiverTracker.endpoint` may be null. 4. Make sure stopping SparkContext in non-main thread won't fail other tests. Author: zsxwing Closes #7797 from zsxwing/fix-ReceiverTrackerSuite and squashes the following commits: 3a4bb98 [zsxwing] Fix another potential NPE d7497df [zsxwing] Fix ReceiverTrackerSuite; make sure StreamingContext in tests is closed --- .../StreamingLogisticRegressionSuite.scala | 21 +++++-- .../clustering/StreamingKMeansSuite.scala | 17 ++++-- .../StreamingLinearRegressionSuite.scala | 21 +++++-- .../streaming/scheduler/ReceiverTracker.scala | 12 +++- .../apache/spark/streaming/JavaAPISuite.java | 1 + .../streaming/BasicOperationsSuite.scala | 58 ++++++++++--------- .../spark/streaming/InputStreamsSuite.scala | 38 ++++++------ .../spark/streaming/MasterFailureTest.scala | 8 ++- .../streaming/StreamingContextSuite.scala | 22 +++++-- .../streaming/StreamingListenerSuite.scala | 13 ++++- .../scheduler/ReceiverTrackerSuite.scala | 56 +++++++++--------- .../StreamingJobProgressListenerSuite.scala | 19 ++++-- 12 files changed, 183 insertions(+), 103 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index fd653296c9d97..d7b291d5a6330 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -24,13 +24,22 @@ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.streaming.dstream.DStream -import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase { // use longer wait time to ensure job completion override def maxWaitTimeMillis: Int = 30000 + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + // Test if we can accurately learn B for Y = logistic(BX) on streaming data test("parameter accuracy") { @@ -50,7 +59,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase } // apply model training to input stream - val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) inputDStream.count() }) @@ -84,7 +93,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase // apply model training to input stream, storing the intermediate results // (we add a count to ensure the result is a DStream) - val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - B))) inputDStream.count() @@ -118,7 +127,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase } // apply model predictions to test stream - val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { model.predictOnValues(inputDStream.map(x => (x.label, x.features))) }) @@ -147,7 +156,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase } // train and predict - val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) model.predictOnValues(inputDStream.map(x => (x.label, x.features))) }) @@ -167,7 +176,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase .setNumIterations(10) val numBatches = 10 val emptyInput = Seq.empty[Seq[LabeledPoint]] - val ssc = setupStreams(emptyInput, + ssc = setupStreams(emptyInput, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) model.predictOnValues(inputDStream.map(x => (x.label, x.features))) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index ac01622b8a089..3645d29dccdb2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.random.XORShiftRandom @@ -28,6 +28,15 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { override def maxWaitTimeMillis: Int = 30000 + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + test("accuracy for single center and equivalence to grand average") { // set parameters val numBatches = 10 @@ -46,7 +55,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42) // setup and run the model training - val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { model.trainOn(inputDStream) inputDStream.count() }) @@ -82,7 +91,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42) // setup and run the model training - val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { kMeans.trainOn(inputDStream) inputDStream.count() }) @@ -114,7 +123,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42, Array(Vectors.dense(0.0))) // setup and run the model training - val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { kMeans.trainOn(inputDStream) inputDStream.count() }) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index a2a4c5f6b8b70..34c07ed170816 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -22,14 +22,23 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LinearDataGenerator +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} import org.apache.spark.streaming.dstream.DStream -import org.apache.spark.streaming.TestSuiteBase class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { // use longer wait time to ensure job completion override def maxWaitTimeMillis: Int = 20000 + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + // Assert that two values are equal within tolerance epsilon def assertEqual(v1: Double, v2: Double, epsilon: Double) { def errorMessage = v1.toString + " did not equal " + v2.toString @@ -62,7 +71,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { } // apply model training to input stream - val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) inputDStream.count() }) @@ -98,7 +107,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { // apply model training to input stream, storing the intermediate results // (we add a count to ensure the result is a DStream) - val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - 10.0))) inputDStream.count() @@ -129,7 +138,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { } // apply model predictions to test stream - val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { model.predictOnValues(inputDStream.map(x => (x.label, x.features))) }) // collect the output as (true, estimated) tuples @@ -156,7 +165,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { } // train and predict - val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) model.predictOnValues(inputDStream.map(x => (x.label, x.features))) }) @@ -177,7 +186,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { val numBatches = 10 val nPoints = 100 val emptyInput = Seq.empty[Seq[LabeledPoint]] - val ssc = setupStreams(emptyInput, + ssc = setupStreams(emptyInput, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) model.predictOnValues(inputDStream.map(x => (x.label, x.features))) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 6270137951b5a..e076fb5ea174b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -223,7 +223,11 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Signal the receivers to delete old block data if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { logInfo(s"Cleanup old received batch data: $cleanupThreshTime") - endpoint.send(CleanupOldBlocks(cleanupThreshTime)) + synchronized { + if (isTrackerStarted) { + endpoint.send(CleanupOldBlocks(cleanupThreshTime)) + } + } } } @@ -285,8 +289,10 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } /** Update a receiver's maximum ingestion rate */ - def sendRateUpdate(streamUID: Int, newRate: Long): Unit = { - endpoint.send(UpdateReceiverRateLimit(streamUID, newRate)) + def sendRateUpdate(streamUID: Int, newRate: Long): Unit = synchronized { + if (isTrackerStarted) { + endpoint.send(UpdateReceiverRateLimit(streamUID, newRate)) + } } /** Add new blocks for the given stream */ diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index a34f23475804a..e0718f73aa13f 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -1735,6 +1735,7 @@ public Integer call(String s) throws Exception { @SuppressWarnings("unchecked") @Test public void testContextGetOrCreate() throws InterruptedException { + ssc.stop(); final SparkConf conf = new SparkConf() .setMaster("local[2]") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 08faeaa58f419..255376807c957 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -81,39 +81,41 @@ class BasicOperationsSuite extends TestSuiteBase { test("repartition (more partitions)") { val input = Seq(1 to 100, 101 to 200, 201 to 300) val operation = (r: DStream[Int]) => r.repartition(5) - val ssc = setupStreams(input, operation, 2) - val output = runStreamsWithPartitions(ssc, 3, 3) - assert(output.size === 3) - val first = output(0) - val second = output(1) - val third = output(2) - - assert(first.size === 5) - assert(second.size === 5) - assert(third.size === 5) - - assert(first.flatten.toSet.equals((1 to 100).toSet) ) - assert(second.flatten.toSet.equals((101 to 200).toSet)) - assert(third.flatten.toSet.equals((201 to 300).toSet)) + withStreamingContext(setupStreams(input, operation, 2)) { ssc => + val output = runStreamsWithPartitions(ssc, 3, 3) + assert(output.size === 3) + val first = output(0) + val second = output(1) + val third = output(2) + + assert(first.size === 5) + assert(second.size === 5) + assert(third.size === 5) + + assert(first.flatten.toSet.equals((1 to 100).toSet)) + assert(second.flatten.toSet.equals((101 to 200).toSet)) + assert(third.flatten.toSet.equals((201 to 300).toSet)) + } } test("repartition (fewer partitions)") { val input = Seq(1 to 100, 101 to 200, 201 to 300) val operation = (r: DStream[Int]) => r.repartition(2) - val ssc = setupStreams(input, operation, 5) - val output = runStreamsWithPartitions(ssc, 3, 3) - assert(output.size === 3) - val first = output(0) - val second = output(1) - val third = output(2) - - assert(first.size === 2) - assert(second.size === 2) - assert(third.size === 2) - - assert(first.flatten.toSet.equals((1 to 100).toSet)) - assert(second.flatten.toSet.equals( (101 to 200).toSet)) - assert(third.flatten.toSet.equals((201 to 300).toSet)) + withStreamingContext(setupStreams(input, operation, 5)) { ssc => + val output = runStreamsWithPartitions(ssc, 3, 3) + assert(output.size === 3) + val first = output(0) + val second = output(1) + val third = output(2) + + assert(first.size === 2) + assert(second.size === 2) + assert(third.size === 2) + + assert(first.flatten.toSet.equals((1 to 100).toSet)) + assert(second.flatten.toSet.equals((101 to 200).toSet)) + assert(third.flatten.toSet.equals((201 to 300).toSet)) + } } test("groupByKey") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index b74d67c63a788..ec2852d9a0206 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -325,27 +325,31 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } test("test track the number of input stream") { - val ssc = new StreamingContext(conf, batchDuration) + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => - class TestInputDStream extends InputDStream[String](ssc) { - def start() { } - def stop() { } - def compute(validTime: Time): Option[RDD[String]] = None - } + class TestInputDStream extends InputDStream[String](ssc) { + def start() {} - class TestReceiverInputDStream extends ReceiverInputDStream[String](ssc) { - def getReceiver: Receiver[String] = null - } + def stop() {} + + def compute(validTime: Time): Option[RDD[String]] = None + } + + class TestReceiverInputDStream extends ReceiverInputDStream[String](ssc) { + def getReceiver: Receiver[String] = null + } - // Register input streams - val receiverInputStreams = Array(new TestReceiverInputDStream, new TestReceiverInputDStream) - val inputStreams = Array(new TestInputDStream, new TestInputDStream, new TestInputDStream) + // Register input streams + val receiverInputStreams = Array(new TestReceiverInputDStream, new TestReceiverInputDStream) + val inputStreams = Array(new TestInputDStream, new TestInputDStream, new TestInputDStream) - assert(ssc.graph.getInputStreams().length == receiverInputStreams.length + inputStreams.length) - assert(ssc.graph.getReceiverInputStreams().length == receiverInputStreams.length) - assert(ssc.graph.getReceiverInputStreams() === receiverInputStreams) - assert(ssc.graph.getInputStreams().map(_.id) === Array.tabulate(5)(i => i)) - assert(receiverInputStreams.map(_.id) === Array(0, 1)) + assert(ssc.graph.getInputStreams().length == + receiverInputStreams.length + inputStreams.length) + assert(ssc.graph.getReceiverInputStreams().length == receiverInputStreams.length) + assert(ssc.graph.getReceiverInputStreams() === receiverInputStreams) + assert(ssc.graph.getInputStreams().map(_.id) === Array.tabulate(5)(i => i)) + assert(receiverInputStreams.map(_.id) === Array(0, 1)) + } } def testFileStream(newFilesOnly: Boolean) { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index 6e9d4431090a2..0e64b57e0ffd8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -244,7 +244,13 @@ object MasterFailureTest extends Logging { } catch { case e: Exception => logError("Error running streaming context", e) } - if (killingThread.isAlive) killingThread.interrupt() + if (killingThread.isAlive) { + killingThread.interrupt() + // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is + // stopped before running the next test. Otherwise, it's possible that we set SparkEnv.env + // to null after the next test creates the new SparkContext and fail the test. + killingThread.join() + } ssc.stop() logInfo("Has been killed = " + killed) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 4bba9691f8aa5..84a5fbb3d95eb 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -120,7 +120,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) myConf.set("spark.streaming.checkpoint.directory", checkpointDirectory) - val ssc = new StreamingContext(myConf, batchDuration) + ssc = new StreamingContext(myConf, batchDuration) assert(ssc.checkpointDir != null) } @@ -369,16 +369,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo } assert(exception.isInstanceOf[TestFailedDueToTimeoutException], "Did not wait for stop") + var t: Thread = null // test whether wait exits if context is stopped failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown - new Thread() { + t = new Thread() { override def run() { Thread.sleep(500) ssc.stop() } - }.start() + } + t.start() ssc.awaitTermination() } + // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is stopped + // before running the next test. Otherwise, it's possible that we set SparkEnv.env to null after + // the next test creates the new SparkContext and fail the test. + t.join() } test("awaitTermination after stop") { @@ -430,16 +436,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo assert(ssc.awaitTerminationOrTimeout(500) === false) } + var t: Thread = null // test whether awaitTerminationOrTimeout() return true if context is stopped failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown - new Thread() { + t = new Thread() { override def run() { Thread.sleep(500) ssc.stop() } - }.start() + } + t.start() assert(ssc.awaitTerminationOrTimeout(10000) === true) } + // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is stopped + // before running the next test. Otherwise, it's possible that we set SparkEnv.env to null after + // the next test creates the new SparkContext and fail the test. + t.join() } test("getOrCreate") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 4bc1dd4a30fc4..d840c349bbbc4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -36,13 +36,22 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { val input = (1 to 4).map(Seq(_)).toSeq val operation = (d: DStream[Int]) => d.map(x => x) + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + // To make sure that the processing start and end times in collected // information are different for successive batches override def batchDuration: Duration = Milliseconds(100) override def actuallyWait: Boolean = true test("batch info reporting") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val collector = new BatchInfoCollector ssc.addStreamingListener(collector) runStreams(ssc, input.size, input.size) @@ -107,7 +116,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { } test("receiver info reporting") { - val ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) + ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) inputStream.foreachRDD(_.count) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index aff8b53f752fa..afad5f16dbc71 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -29,36 +29,40 @@ import org.apache.spark.storage.StorageLevel /** Testsuite for receiver scheduling */ class ReceiverTrackerSuite extends TestSuiteBase { val sparkConf = new SparkConf().setMaster("local[8]").setAppName("test") - val ssc = new StreamingContext(sparkConf, Milliseconds(100)) - ignore("Receiver tracker - propagates rate limit") { - object ReceiverStartedWaiter extends StreamingListener { - @volatile - var started = false + test("Receiver tracker - propagates rate limit") { + withStreamingContext(new StreamingContext(sparkConf, Milliseconds(100))) { ssc => + object ReceiverStartedWaiter extends StreamingListener { + @volatile + var started = false - override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { - started = true + override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { + started = true + } } - } - - ssc.addStreamingListener(ReceiverStartedWaiter) - ssc.scheduler.listenerBus.start(ssc.sc) - SingletonTestRateReceiver.reset() - - val newRateLimit = 100L - val inputDStream = new RateLimitInputDStream(ssc) - val tracker = new ReceiverTracker(ssc) - tracker.start() - // we wait until the Receiver has registered with the tracker, - // otherwise our rate update is lost - eventually(timeout(5 seconds)) { - assert(ReceiverStartedWaiter.started) - } - tracker.sendRateUpdate(inputDStream.id, newRateLimit) - // this is an async message, we need to wait a bit for it to be processed - eventually(timeout(3 seconds)) { - assert(inputDStream.getCurrentRateLimit.get === newRateLimit) + ssc.addStreamingListener(ReceiverStartedWaiter) + ssc.scheduler.listenerBus.start(ssc.sc) + SingletonTestRateReceiver.reset() + + val newRateLimit = 100L + val inputDStream = new RateLimitInputDStream(ssc) + val tracker = new ReceiverTracker(ssc) + tracker.start() + try { + // we wait until the Receiver has registered with the tracker, + // otherwise our rate update is lost + eventually(timeout(5 seconds)) { + assert(ReceiverStartedWaiter.started) + } + tracker.sendRateUpdate(inputDStream.id, newRateLimit) + // this is an async message, we need to wait a bit for it to be processed + eventually(timeout(3 seconds)) { + assert(inputDStream.getCurrentRateLimit.get === newRateLimit) + } + } finally { + tracker.stop(false) + } } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index 0891309f956d2..995f1197ccdfd 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -22,15 +22,24 @@ import java.util.Properties import org.scalatest.Matchers import org.apache.spark.scheduler.SparkListenerJobStart +import org.apache.spark.streaming._ import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.scheduler._ -import org.apache.spark.streaming.{Duration, Time, Milliseconds, TestSuiteBase} class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val input = (1 to 4).map(Seq(_)).toSeq val operation = (d: DStream[Int]) => d.map(x => x) + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + private def createJobStart( batchTime: Time, outputOpId: Int, jobId: Int): SparkListenerJobStart = { val properties = new Properties() @@ -46,7 +55,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { test("onBatchSubmitted, onBatchStarted, onBatchCompleted, " + "onReceiverStarted, onReceiverError, onReceiverStopped") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val listener = new StreamingJobProgressListener(ssc) val streamIdToInputInfo = Map( @@ -141,7 +150,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { } test("Remove the old completed batches when exceeding the limit") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) val listener = new StreamingJobProgressListener(ssc) @@ -158,7 +167,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { } test("out-of-order onJobStart and onBatchXXX") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) val listener = new StreamingJobProgressListener(ssc) @@ -209,7 +218,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { } test("detect memory leak") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val listener = new StreamingJobProgressListener(ssc) val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) From 7f7a319c4ce07f07a6bd68100cf0a4f1da66269e Mon Sep 17 00:00:00 2001 From: martinzapletal Date: Thu, 30 Jul 2015 15:57:14 -0700 Subject: [PATCH 008/340] [SPARK-8671] [ML] Added isotonic regression to the pipeline API. Author: martinzapletal Closes #7517 from zapletal-martin/SPARK-8671-isotonic-regression-api and squashes the following commits: 8c435c1 [martinzapletal] Review https://github.com/apache/spark/pull/7517 feedback update. bebbb86 [martinzapletal] Merge remote-tracking branch 'upstream/master' into SPARK-8671-isotonic-regression-api b68efc0 [martinzapletal] Added tests for param validation. 07c12bd [martinzapletal] Comments and refactoring. 834fcf7 [martinzapletal] Merge remote-tracking branch 'upstream/master' into SPARK-8671-isotonic-regression-api b611fee [martinzapletal] SPARK-8671. Added first version of isotonic regression to pipeline API --- .../ml/regression/IsotonicRegression.scala | 144 +++++++++++++++++ .../regression/IsotonicRegressionSuite.scala | 148 ++++++++++++++++++ 2 files changed, 292 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala new file mode 100644 index 0000000000000..4ece8cf8cf0b6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.param.{Param, ParamMap, BooleanParam} +import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} +import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.{DoubleType, DataType} +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.storage.StorageLevel + +/** + * Params for isotonic regression. + */ +private[regression] trait IsotonicRegressionParams extends PredictorParams { + + /** + * Param for weight column name. + * TODO: Move weightCol to sharedParams. + * + * @group param + */ + final val weightCol: Param[String] = + new Param[String](this, "weightCol", "weight column name") + + /** @group getParam */ + final def getWeightCol: String = $(weightCol) + + /** + * Param for isotonic parameter. + * Isotonic (increasing) or antitonic (decreasing) sequence. + * @group param + */ + final val isotonic: BooleanParam = + new BooleanParam(this, "isotonic", "isotonic (increasing) or antitonic (decreasing) sequence") + + /** @group getParam */ + final def getIsotonicParam: Boolean = $(isotonic) +} + +/** + * :: Experimental :: + * Isotonic regression. + * + * Currently implemented using parallelized pool adjacent violators algorithm. + * Only univariate (single feature) algorithm supported. + * + * Uses [[org.apache.spark.mllib.regression.IsotonicRegression]]. + */ +@Experimental +class IsotonicRegression(override val uid: String) + extends Regressor[Double, IsotonicRegression, IsotonicRegressionModel] + with IsotonicRegressionParams { + + def this() = this(Identifiable.randomUID("isoReg")) + + /** + * Set the isotonic parameter. + * Default is true. + * @group setParam + */ + def setIsotonicParam(value: Boolean): this.type = set(isotonic, value) + setDefault(isotonic -> true) + + /** + * Set weight column param. + * Default is weight. + * @group setParam + */ + def setWeightParam(value: String): this.type = set(weightCol, value) + setDefault(weightCol -> "weight") + + override private[ml] def featuresDataType: DataType = DoubleType + + override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) + + private[this] def extractWeightedLabeledPoints( + dataset: DataFrame): RDD[(Double, Double, Double)] = { + + dataset.select($(labelCol), $(featuresCol), $(weightCol)) + .map { case Row(label: Double, features: Double, weights: Double) => + (label, features, weights) + } + } + + override protected def train(dataset: DataFrame): IsotonicRegressionModel = { + SchemaUtils.checkColumnType(dataset.schema, $(weightCol), DoubleType) + // Extract columns from data. If dataset is persisted, do not persist oldDataset. + val instances = extractWeightedLabeledPoints(dataset) + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + + val isotonicRegression = new MLlibIsotonicRegression().setIsotonic($(isotonic)) + val parentModel = isotonicRegression.run(instances) + + new IsotonicRegressionModel(uid, parentModel) + } +} + +/** + * :: Experimental :: + * Model fitted by IsotonicRegression. + * Predicts using a piecewise linear function. + * + * For detailed rules see [[org.apache.spark.mllib.regression.IsotonicRegressionModel.predict()]]. + * + * @param parentModel A [[org.apache.spark.mllib.regression.IsotonicRegressionModel]] + * model trained by [[org.apache.spark.mllib.regression.IsotonicRegression]]. + */ +class IsotonicRegressionModel private[ml] ( + override val uid: String, + private[ml] val parentModel: MLlibIsotonicRegressionModel) + extends RegressionModel[Double, IsotonicRegressionModel] + with IsotonicRegressionParams { + + override def featuresDataType: DataType = DoubleType + + override protected def predict(features: Double): Double = { + parentModel.predict(features) + } + + override def copy(extra: ParamMap): IsotonicRegressionModel = { + copyValues(new IsotonicRegressionModel(uid, parentModel), extra) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala new file mode 100644 index 0000000000000..66e4b170bae80 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} +import org.apache.spark.sql.{DataFrame, Row} + +class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { + private val schema = StructType( + Array( + StructField("label", DoubleType), + StructField("features", DoubleType), + StructField("weight", DoubleType))) + + private val predictionSchema = StructType(Array(StructField("features", DoubleType))) + + private def generateIsotonicInput(labels: Seq[Double]): DataFrame = { + val data = Seq.tabulate(labels.size)(i => Row(labels(i), i.toDouble, 1d)) + val parallelData = sc.parallelize(data) + + sqlContext.createDataFrame(parallelData, schema) + } + + private def generatePredictionInput(features: Seq[Double]): DataFrame = { + val data = Seq.tabulate(features.size)(i => Row(features(i))) + + val parallelData = sc.parallelize(data) + sqlContext.createDataFrame(parallelData, predictionSchema) + } + + test("isotonic regression predictions") { + val dataset = generateIsotonicInput(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18)) + val trainer = new IsotonicRegression().setIsotonicParam(true) + + val model = trainer.fit(dataset) + + val predictions = model + .transform(dataset) + .select("prediction").map { + case Row(pred) => pred + }.collect() + + assert(predictions === Array(1, 2, 2, 2, 6, 16.5, 16.5, 17, 18)) + + assert(model.parentModel.boundaries === Array(0, 1, 3, 4, 5, 6, 7, 8)) + assert(model.parentModel.predictions === Array(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0)) + assert(model.parentModel.isotonic) + } + + test("antitonic regression predictions") { + val dataset = generateIsotonicInput(Seq(7, 5, 3, 5, 1)) + val trainer = new IsotonicRegression().setIsotonicParam(false) + + val model = trainer.fit(dataset) + val features = generatePredictionInput(Seq(-2.0, -1.0, 0.5, 0.75, 1.0, 2.0, 9.0)) + + val predictions = model + .transform(features) + .select("prediction").map { + case Row(pred) => pred + }.collect() + + assert(predictions === Array(7, 7, 6, 5.5, 5, 4, 1)) + } + + test("params validation") { + val dataset = generateIsotonicInput(Seq(1, 2, 3)) + val ir = new IsotonicRegression + ParamsSuite.checkParams(ir) + val model = ir.fit(dataset) + ParamsSuite.checkParams(model) + } + + test("default params") { + val dataset = generateIsotonicInput(Seq(1, 2, 3)) + val ir = new IsotonicRegression() + assert(ir.getLabelCol === "label") + assert(ir.getFeaturesCol === "features") + assert(ir.getWeightCol === "weight") + assert(ir.getPredictionCol === "prediction") + assert(ir.getIsotonicParam === true) + + val model = ir.fit(dataset) + model.transform(dataset) + .select("label", "features", "prediction", "weight") + .collect() + + assert(model.getLabelCol === "label") + assert(model.getFeaturesCol === "features") + assert(model.getWeightCol === "weight") + assert(model.getPredictionCol === "prediction") + assert(model.getIsotonicParam === true) + assert(model.hasParent) + } + + test("set parameters") { + val isotonicRegression = new IsotonicRegression() + .setIsotonicParam(false) + .setWeightParam("w") + .setFeaturesCol("f") + .setLabelCol("l") + .setPredictionCol("p") + + assert(isotonicRegression.getIsotonicParam === false) + assert(isotonicRegression.getWeightCol === "w") + assert(isotonicRegression.getFeaturesCol === "f") + assert(isotonicRegression.getLabelCol === "l") + assert(isotonicRegression.getPredictionCol === "p") + } + + test("missing column") { + val dataset = generateIsotonicInput(Seq(1, 2, 3)) + + intercept[IllegalArgumentException] { + new IsotonicRegression().setWeightParam("w").fit(dataset) + } + + intercept[IllegalArgumentException] { + new IsotonicRegression().setFeaturesCol("f").fit(dataset) + } + + intercept[IllegalArgumentException] { + new IsotonicRegression().setLabelCol("l").fit(dataset) + } + + intercept[IllegalArgumentException] { + new IsotonicRegression().fit(dataset).setFeaturesCol("f").transform(dataset) + } + } +} From be7be6d4c7d978c20e601d1f5f56ecb3479814cb Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 30 Jul 2015 16:04:23 -0700 Subject: [PATCH 009/340] [SPARK-6684] [MLLIB] [ML] Add checkpointing to GBTs Add checkpointing to GradientBoostedTrees, GBTClassifier, GBTRegressor CC: mengxr Author: Joseph K. Bradley Closes #7804 from jkbradley/gbt-checkpoint3 and squashes the following commits: 3fbd7ba [Joseph K. Bradley] tiny fix b3e160c [Joseph K. Bradley] unset checkpoint dir after test 9cc3a04 [Joseph K. Bradley] added checkpointing to GBTs --- .../spark/mllib/clustering/LDAOptimizer.scala | 1 + .../mllib/tree/GradientBoostedTrees.scala | 48 +++++------ .../tree/configuration/BoostingStrategy.scala | 3 +- .../classification/GBTClassifierSuite.scala | 20 +++++ .../ml/regression/GBTRegressorSuite.scala | 20 ++++- .../tree/GradientBoostedTreesSuite.scala | 79 +++++++++++-------- 6 files changed, 114 insertions(+), 57 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 9dbec41efeada..d6f8b29a43dfd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -144,6 +144,7 @@ final class EMLDAOptimizer extends LDAOptimizer { this.checkpointInterval = lda.getCheckpointInterval this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( checkpointInterval, graph.vertices.sparkContext) + this.graphCheckpointer.update(this.graph) this.globalTopicTotals = computeGlobalTopicTotals() this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index a835f96d5d0e3..9ce6faa137c41 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.tree import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.BoostingStrategy import org.apache.spark.mllib.tree.configuration.Algo._ @@ -184,22 +185,28 @@ object GradientBoostedTrees extends Logging { false } + // Prepare periodic checkpointers + val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( + treeStrategy.getCheckpointInterval, input.sparkContext) + val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( + treeStrategy.getCheckpointInterval, input.sparkContext) + timer.stop("init") logDebug("##########") logDebug("Building tree 0") logDebug("##########") - var data = input // Initialize tree timer.start("building tree 0") - val firstTreeModel = new DecisionTree(treeStrategy).run(data) + val firstTreeModel = new DecisionTree(treeStrategy).run(input) val firstTreeWeight = 1.0 baseLearners(0) = firstTreeModel baseLearnerWeights(0) = firstTreeWeight var predError: RDD[(Double, Double)] = GradientBoostedTreesModel. computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) + predErrorCheckpointer.update(predError) logDebug("error of gbt = " + predError.values.mean()) // Note: A model of type regression is used since we require raw prediction @@ -207,35 +214,34 @@ object GradientBoostedTrees extends Logging { var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel. computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss) + if (validate) validatePredErrorCheckpointer.update(validatePredError) var bestValidateError = if (validate) validatePredError.values.mean() else 0.0 var bestM = 1 - // pseudo-residual for second iteration - data = predError.zip(input).map { case ((pred, _), point) => - LabeledPoint(-loss.gradient(pred, point.label), point.features) - } - var m = 1 - while (m < numIterations) { + var doneLearning = false + while (m < numIterations && !doneLearning) { + // Update data with pseudo-residuals + val data = predError.zip(input).map { case ((pred, _), point) => + LabeledPoint(-loss.gradient(pred, point.label), point.features) + } + timer.start(s"building tree $m") logDebug("###################################################") logDebug("Gradient boosting tree iteration " + m) logDebug("###################################################") val model = new DecisionTree(treeStrategy).run(data) timer.stop(s"building tree $m") - // Create partial model + // Update partial model baseLearners(m) = model // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. // Technically, the weight should be optimized for the particular loss. // However, the behavior should be reasonable, though not optimal. baseLearnerWeights(m) = learningRate - // Note: A model of type regression is used since we require raw prediction - val partialModel = new GradientBoostedTreesModel( - Regression, baseLearners.slice(0, m + 1), - baseLearnerWeights.slice(0, m + 1)) predError = GradientBoostedTreesModel.updatePredictionError( input, predError, baseLearnerWeights(m), baseLearners(m), loss) + predErrorCheckpointer.update(predError) logDebug("error of gbt = " + predError.values.mean()) if (validate) { @@ -246,21 +252,15 @@ object GradientBoostedTrees extends Logging { validatePredError = GradientBoostedTreesModel.updatePredictionError( validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss) + validatePredErrorCheckpointer.update(validatePredError) val currentValidateError = validatePredError.values.mean() if (bestValidateError - currentValidateError < validationTol) { - return new GradientBoostedTreesModel( - boostingStrategy.treeStrategy.algo, - baseLearners.slice(0, bestM), - baseLearnerWeights.slice(0, bestM)) + doneLearning = true } else if (currentValidateError < bestValidateError) { - bestValidateError = currentValidateError - bestM = m + 1 + bestValidateError = currentValidateError + bestM = m + 1 } } - // Update data with pseudo-residuals - data = predError.zip(input).map { case ((pred, _), point) => - LabeledPoint(-loss.gradient(pred, point.label), point.features) - } m += 1 } @@ -269,6 +269,8 @@ object GradientBoostedTrees extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") + predErrorCheckpointer.deleteAllCheckpoints() + validatePredErrorCheckpointer.deleteAllCheckpoints() if (persistedInput) input.unpersist() if (validate) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 2d6b01524ff3d..9fd30c9b56319 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -36,7 +36,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} * learning rate should be between in the interval (0, 1] * @param validationTol Useful when runWithValidation is used. If the error rate on the * validation input between two iterations is less than the validationTol - * then stop. Ignored when [[run]] is used. + * then stop. Ignored when + * [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used. */ @Experimental case class BoostingStrategy( diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 82c345491bb3c..a7bc77965fefd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.util.Utils /** @@ -76,6 +77,25 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + sc.setCheckpointDir(path) + + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) + val gbt = new GBTClassifier() + .setMaxDepth(2) + .setLossType("logistic") + .setMaxIter(5) + .setStepSize(0.1) + .setCheckpointInterval(2) + val model = gbt.fit(df) + + sc.checkpointDir = None + Utils.deleteRecursively(tempDir) + } + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 /* test("runWithValidation stops early and performs better on a validation dataset") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 9682edcd9ba84..dbdce0c9dea54 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -25,7 +25,8 @@ import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.DataFrame +import org.apache.spark.util.Utils /** @@ -88,6 +89,23 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { assert(predictions.min() < -1) } + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + sc.setCheckpointDir(path) + + val df = sqlContext.createDataFrame(data) + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setMaxIter(5) + .setStepSize(0.1) + .setCheckpointInterval(2) + val model = gbt.fit(df) + + sc.checkpointDir = None + Utils.deleteRecursively(tempDir) + } + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 /* test("runWithValidation stops early and performs better on a validation dataset") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 2521b3342181a..6fc9e8df621df 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -166,43 +166,58 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext val algos = Array(Regression, Regression, Classification) val losses = Array(SquaredError, AbsoluteError, LogLoss) - (algos zip losses) map { - case (algo, loss) => { - val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2, - categoricalFeaturesInfo = Map.empty) - val boostingStrategy = - new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) - val gbtValidate = new GradientBoostedTrees(boostingStrategy) - .runWithValidation(trainRdd, validateRdd) - val numTrees = gbtValidate.numTrees - assert(numTrees !== numIterations) - - // Test that it performs better on the validation dataset. - val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd) - val (errorWithoutValidation, errorWithValidation) = { - if (algo == Classification) { - val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) - (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd)) - } else { - (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd)) - } - } - assert(errorWithValidation <= errorWithoutValidation) - - // Test that results from evaluateEachIteration comply with runWithValidation. - // Note that convergenceTol is set to 0.0 - val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss) - assert(evaluationArray.length === numIterations) - assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1)) - var i = 1 - while (i < numTrees) { - assert(evaluationArray(i) <= evaluationArray(i - 1)) - i += 1 + algos.zip(losses).foreach { case (algo, loss) => + val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty) + val boostingStrategy = + new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) + val gbtValidate = new GradientBoostedTrees(boostingStrategy) + .runWithValidation(trainRdd, validateRdd) + val numTrees = gbtValidate.numTrees + assert(numTrees !== numIterations) + + // Test that it performs better on the validation dataset. + val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd) + val (errorWithoutValidation, errorWithValidation) = { + if (algo == Classification) { + val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) + (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd)) + } else { + (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd)) } } + assert(errorWithValidation <= errorWithoutValidation) + + // Test that results from evaluateEachIteration comply with runWithValidation. + // Note that convergenceTol is set to 0.0 + val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss) + assert(evaluationArray.length === numIterations) + assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1)) + var i = 1 + while (i < numTrees) { + assert(evaluationArray(i) <= evaluationArray(i - 1)) + i += 1 + } } } + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + sc.setCheckpointDir(path) + + val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) + + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty, checkpointInterval = 2) + val boostingStrategy = new BoostingStrategy(treeStrategy, SquaredError, 5, 0.1) + + val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) + + sc.checkpointDir = None + Utils.deleteRecursively(tempDir) + } + } private object GradientBoostedTreesSuite { From e7905a9395c1a002f50bab29e16a729e14d4ed6f Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 30 Jul 2015 16:15:43 -0700 Subject: [PATCH 010/340] [SPARK-9463] [ML] Expose model coefficients with names in SparkR RFormula Preview: ``` > summary(m) features coefficients 1 (Intercept) 1.6765001 2 Sepal_Length 0.3498801 3 Species.versicolor -0.9833885 4 Species.virginica -1.0075104 ``` Design doc from umbrella task: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit cc mengxr Author: Eric Liang Closes #7771 from ericl/summary and squashes the following commits: ccd54c3 [Eric Liang] second pass a5ca93b [Eric Liang] comments 2772111 [Eric Liang] clean up 70483ef [Eric Liang] fix test 7c247d4 [Eric Liang] Merge branch 'master' into summary 3c55024 [Eric Liang] working 8c539aa [Eric Liang] first pass --- R/pkg/NAMESPACE | 3 ++- R/pkg/R/mllib.R | 26 ++++++++++++++++++ R/pkg/inst/tests/test_mllib.R | 11 ++++++++ .../spark/ml/feature/OneHotEncoder.scala | 12 ++++----- .../apache/spark/ml/feature/RFormula.scala | 12 ++++++++- .../apache/spark/ml/r/SparkRWrappers.scala | 27 +++++++++++++++++-- .../ml/regression/LinearRegression.scala | 8 ++++-- .../spark/ml/feature/OneHotEncoderSuite.scala | 8 +++--- .../spark/ml/feature/RFormulaSuite.scala | 18 +++++++++++++ 9 files changed, 108 insertions(+), 17 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 7f7a8a2e4de24..a329e14f25aeb 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -12,7 +12,8 @@ export("print.jobj") # MLlib integration exportMethods("glm", - "predict") + "predict", + "summary") # Job group lifecycle management methods export("setJobGroup", diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 6a8bacaa552c6..efddcc1d8d71c 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -71,3 +71,29 @@ setMethod("predict", signature(object = "PipelineModel"), function(object, newData) { return(dataFrame(callJMethod(object@model, "transform", newData@sdf))) }) + +#' Get the summary of a model +#' +#' Returns the summary of a model produced by glm(), similarly to R's summary(). +#' +#' @param model A fitted MLlib model +#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See +#' summary.glm for more information. +#' @rdname glm +#' @export +#' @examples +#'\dontrun{ +#' model <- glm(y ~ x, trainingData) +#' summary(model) +#'} +setMethod("summary", signature(object = "PipelineModel"), + function(object) { + features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelFeatures", object@model) + weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelWeights", object@model) + coefficients <- as.matrix(unlist(weights)) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + return(list(coefficients = coefficients)) + }) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 3bef69324770a..f272de78ad4a6 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -48,3 +48,14 @@ test_that("dot minus and intercept vs native glm", { rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) }) + +test_that("summary coefficients match with native glm", { + training <- createDataFrame(sqlContext, iris) + stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) + coefs <- as.vector(stats$coefficients) + rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) + expect_true(all(abs(rCoefs - coefs) < 1e-6)) + expect_true(all( + as.character(stats$features) == + c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica"))) +}) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 3825942795645..9c60d4084ec46 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -66,7 +66,6 @@ class OneHotEncoder(override val uid: String) extends Transformer def setOutputCol(value: String): this.type = set(outputCol, value) override def transformSchema(schema: StructType): StructType = { - val is = "_is_" val inputColName = $(inputCol) val outputColName = $(outputCol) @@ -79,17 +78,17 @@ class OneHotEncoder(override val uid: String) extends Transformer val outputAttrNames: Option[Array[String]] = inputAttr match { case nominal: NominalAttribute => if (nominal.values.isDefined) { - nominal.values.map(_.map(v => inputColName + is + v)) + nominal.values } else if (nominal.numValues.isDefined) { - nominal.numValues.map(n => Array.tabulate(n)(i => inputColName + is + i)) + nominal.numValues.map(n => Array.tabulate(n)(_.toString)) } else { None } case binary: BinaryAttribute => if (binary.values.isDefined) { - binary.values.map(_.map(v => inputColName + is + v)) + binary.values } else { - Some(Array.tabulate(2)(i => inputColName + is + i)) + Some(Array.tabulate(2)(_.toString)) } case _: NumericAttribute => throw new RuntimeException( @@ -123,7 +122,6 @@ class OneHotEncoder(override val uid: String) extends Transformer override def transform(dataset: DataFrame): DataFrame = { // schema transformation - val is = "_is_" val inputColName = $(inputCol) val outputColName = $(outputCol) val shouldDropLast = $(dropLast) @@ -142,7 +140,7 @@ class OneHotEncoder(override val uid: String) extends Transformer math.max(m0, m1) } ).toInt + 1 - val outputAttrNames = Array.tabulate(numAttrs)(i => inputColName + is + i) + val outputAttrNames = Array.tabulate(numAttrs)(_.toString) val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames val outputAttrs: Array[Attribute] = filtered.map(name => BinaryAttribute.defaultAttr.withName(name)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 0b428d278d908..d1726917e4517 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.parsing.combinator.RegexParsers @@ -91,11 +92,20 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R // TODO(ekl) add support for feature interactions val encoderStages = ArrayBuffer[PipelineStage]() val tempColumns = ArrayBuffer[String]() + val takenNames = mutable.Set(dataset.columns: _*) val encodedTerms = resolvedFormula.terms.map { term => dataset.schema(term) match { case column if column.dataType == StringType => val indexCol = term + "_idx_" + uid - val encodedCol = term + "_onehot_" + uid + val encodedCol = { + var tmp = term + while (takenNames.contains(tmp)) { + tmp += "_" + } + tmp + } + takenNames.add(indexCol) + takenNames.add(encodedCol) encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol) encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol) tempColumns += indexCol diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 9f70592ccad7e..f5a022c31ed90 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -17,9 +17,10 @@ package org.apache.spark.ml.api.r +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.feature.RFormula -import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.sql.DataFrame @@ -44,4 +45,26 @@ private[r] object SparkRWrappers { val pipeline = new Pipeline().setStages(Array(formula, estimator)) pipeline.fit(df) } + + def getModelWeights(model: PipelineModel): Array[Double] = { + model.stages.last match { + case m: LinearRegressionModel => + Array(m.intercept) ++ m.weights.toArray + case _: LogisticRegressionModel => + throw new UnsupportedOperationException( + "No weights available for LogisticRegressionModel") // SPARK-9492 + } + } + + def getModelFeatures(model: PipelineModel): Array[String] = { + model.stages.last match { + case m: LinearRegressionModel => + val attrs = AttributeGroup.fromStructField( + m.summary.predictions.schema(m.summary.featuresCol)) + Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) + case _: LogisticRegressionModel => + throw new UnsupportedOperationException( + "No features names available for LogisticRegressionModel") // SPARK-9492 + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 89718e0f3e15a..3b85ba001b128 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -36,6 +36,7 @@ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.StructField import org.apache.spark.storage.StorageLevel import org.apache.spark.util.StatCounter @@ -146,9 +147,10 @@ class LinearRegression(override val uid: String) val model = new LinearRegressionModel(uid, weights, intercept) val trainingSummary = new LinearRegressionTrainingSummary( - model.transform(dataset).select($(predictionCol), $(labelCol)), + model.transform(dataset), $(predictionCol), $(labelCol), + $(featuresCol), Array(0D)) return copyValues(model.setSummary(trainingSummary)) } @@ -221,9 +223,10 @@ class LinearRegression(override val uid: String) val model = copyValues(new LinearRegressionModel(uid, weights, intercept)) val trainingSummary = new LinearRegressionTrainingSummary( - model.transform(dataset).select($(predictionCol), $(labelCol)), + model.transform(dataset), $(predictionCol), $(labelCol), + $(featuresCol), objectiveHistory) model.setSummary(trainingSummary) } @@ -300,6 +303,7 @@ class LinearRegressionTrainingSummary private[regression] ( predictions: DataFrame, predictionCol: String, labelCol: String, + val featuresCol: String, val objectiveHistory: Array[Double]) extends LinearRegressionSummary(predictions, predictionCol, labelCol) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 65846a846b7b4..321eeb843941c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -86,8 +86,8 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { val output = encoder.transform(df) val group = AttributeGroup.fromStructField(output.schema("encoded")) assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1)) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) } test("input column without ML attribute") { @@ -98,7 +98,7 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { val output = encoder.transform(df) val group = AttributeGroup.fromStructField(output.schema("encoded")) assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1)) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 8148c553e9051..6aed3243afce8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -105,4 +106,21 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { assert(result.schema.toString == resultSchema.toString) assert(result.collect() === expected.collect()) } + + test("attribute generation") { + val formula = new RFormula().setFormula("id ~ a + b") + val original = sqlContext.createDataFrame( + Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array( + new BinaryAttribute(Some("a__bar"), Some(1)), + new BinaryAttribute(Some("a__foo"), Some(2)), + new NumericAttribute(Some("b"), Some(3)))) + assert(attrs === expectedAttrs) + } } From 157840d1b14502a4f25cff53633c927998c6ada1 Mon Sep 17 00:00:00 2001 From: Hossein Date: Thu, 30 Jul 2015 16:16:17 -0700 Subject: [PATCH 011/340] [SPARK-8742] [SPARKR] Improve SparkR error messages for DataFrame API This patch improves SparkR error message reporting, especially with DataFrame API. When there is a user error (e.g., malformed SQL query), the message of the cause is sent back through the RPC and the R client reads it and returns it back to user. cc shivaram Author: Hossein Closes #7742 from falaki/SPARK-8742 and squashes the following commits: 4f643c9 [Hossein] Not logging exceptions in RBackendHandler 4a8005c [Hossein] Returning stack track of causing exception from RBackendHandler 5cf17f0 [Hossein] Adding unit test for error messages from SQLContext 2af75d5 [Hossein] Reading error message in case of failure and stoping with that message f479c99 [Hossein] Wrting exception cause message in JVM --- R/pkg/R/backend.R | 4 +++- R/pkg/inst/tests/test_sparkSQL.R | 5 +++++ .../scala/org/apache/spark/api/r/RBackendHandler.scala | 10 ++++++++-- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R index 2fb6fae55f28c..49162838b8d1a 100644 --- a/R/pkg/R/backend.R +++ b/R/pkg/R/backend.R @@ -110,6 +110,8 @@ invokeJava <- function(isStatic, objId, methodName, ...) { # TODO: check the status code to output error information returnStatus <- readInt(conn) - stopifnot(returnStatus == 0) + if (returnStatus != 0) { + stop(readString(conn)) + } readObject(conn) } diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index d5db97248c770..61c8a7ec7d837 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -1002,6 +1002,11 @@ test_that("crosstab() on a DataFrame", { expect_identical(expected, ordered) }) +test_that("SQL error message is returned from JVM", { + retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) + expect_equal(grepl("Table Not Found: blah", retError), TRUE) +}) + unlink(parquetPath) unlink(jsonPath) unlink(jsonPathNa) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index a5de10fe89c42..14dac4ed28ce3 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -69,8 +69,11 @@ private[r] class RBackendHandler(server: RBackend) case e: Exception => logError(s"Removing $objId failed", e) writeInt(dos, -1) + writeString(dos, s"Removing $objId failed: ${e.getMessage}") } - case _ => dos.writeInt(-1) + case _ => + dos.writeInt(-1) + writeString(dos, s"Error: unknown method $methodName") } } else { handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) @@ -146,8 +149,11 @@ private[r] class RBackendHandler(server: RBackend) } } catch { case e: Exception => - logError(s"$methodName on $objId failed", e) + logError(s"$methodName on $objId failed") writeInt(dos, -1) + // Writing the error message of the cause for the exception. This will be returned + // to user in the R process. + writeString(dos, Utils.exceptionString(e.getCause)) } } From 04c8409107710fc9a625ee513d68c149745539f3 Mon Sep 17 00:00:00 2001 From: Calvin Jia Date: Thu, 30 Jul 2015 16:32:40 -0700 Subject: [PATCH 012/340] [SPARK-9199] [CORE] Update Tachyon dependency from 0.6.4 -> 0.7.0 No new dependencies are added. The exclusion changes are due to the change in tachyon-client 0.7.0's project structure. There is no client side API change in Tachyon 0.7.0 so no code changes are required. Author: Calvin Jia Closes #7577 from calvinjia/SPARK-9199 and squashes the following commits: 4e81e40 [Calvin Jia] Update Tachyon dependency from 0.6.4 -> 0.7.0 --- core/pom.xml | 34 +++++----------------------------- make-distribution.sh | 2 +- 2 files changed, 6 insertions(+), 30 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 6fa87ec6a24af..202678779150b 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -286,7 +286,7 @@ org.tachyonproject tachyon-client - 0.6.4 + 0.7.0 org.apache.hadoop @@ -297,36 +297,12 @@ curator-recipes - org.eclipse.jetty - jetty-jsp + org.tachyonproject + tachyon-underfs-glusterfs - org.eclipse.jetty - jetty-webapp - - - org.eclipse.jetty - jetty-server - - - org.eclipse.jetty - jetty-servlet - - - junit - junit - - - org.powermock - powermock-module-junit4 - - - org.powermock - powermock-api-mockito - - - org.apache.curator - curator-test + org.tachyonproject + tachyon-underfs-s3 diff --git a/make-distribution.sh b/make-distribution.sh index cac7032bb2e87..4789b0e09cc8a 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -33,7 +33,7 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" SPARK_TACHYON=false -TACHYON_VERSION="0.6.4" +TACHYON_VERSION="0.7.0" TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}" From 1afdeb7b458f86e2641f062fb9ddc00e9c5c7531 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 30 Jul 2015 16:44:02 -0700 Subject: [PATCH 013/340] [STREAMING] [TEST] [HOTFIX] Fixed Kinesis test to not throw weird errors when Kinesis tests are enabled without AWS keys If Kinesis tests are enabled by env ENABLE_KINESIS_TESTS = 1 but no AWS credentials are found, the desired behavior is the fail the test using with ``` Exception encountered when attempting to run a suite with class name: org.apache.spark.streaming.kinesis.KinesisBackedBlockRDDSuite *** ABORTED *** (3 seconds, 5 milliseconds) [info] java.lang.Exception: Kinesis tests enabled, but could get not AWS credentials ``` Instead KinesisStreamSuite fails with ``` [info] - basic operation *** FAILED *** (3 seconds, 35 milliseconds) [info] java.lang.IllegalArgumentException: requirement failed: Stream not yet created, call createStream() to create one [info] at scala.Predef$.require(Predef.scala:233) [info] at org.apache.spark.streaming.kinesis.KinesisTestUtils.streamName(KinesisTestUtils.scala:77) [info] at org.apache.spark.streaming.kinesis.KinesisTestUtils$$anonfun$deleteStream$1.apply(KinesisTestUtils.scala:150) [info] at org.apache.spark.streaming.kinesis.KinesisTestUtils$$anonfun$deleteStream$1.apply(KinesisTestUtils.scala:150) [info] at org.apache.spark.Logging$class.logWarning(Logging.scala:71) [info] at org.apache.spark.streaming.kinesis.KinesisTestUtils.logWarning(KinesisTestUtils.scala:39) [info] at org.apache.spark.streaming.kinesis.KinesisTestUtils.deleteStream(KinesisTestUtils.scala:150) [info] at org.apache.spark.streaming.kinesis.KinesisStreamSuite$$anonfun$3.apply$mcV$sp(KinesisStreamSuite.scala:111) [info] at org.apache.spark.streaming.kinesis.KinesisStreamSuite$$anonfun$3.apply(KinesisStreamSuite.scala:86) [info] at org.apache.spark.streaming.kinesis.KinesisStreamSuite$$anonfun$3.apply(KinesisStreamSuite.scala:86) ``` This is because attempting to delete a non-existent Kinesis stream throws uncaught exception. This PR fixes it. Author: Tathagata Das Closes #7809 from tdas/kinesis-test-hotfix and squashes the following commits: 7c372e6 [Tathagata Das] Fixed test --- .../streaming/kinesis/KinesisTestUtils.scala | 27 ++++++++++--------- .../kinesis/KinesisStreamSuite.scala | 4 +-- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index 0ff1b7ed0fd90..ca39358b75cb6 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -53,6 +53,8 @@ private class KinesisTestUtils( @volatile private var streamCreated = false + + @volatile private var _streamName: String = _ private lazy val kinesisClient = { @@ -115,21 +117,9 @@ private class KinesisTestUtils( shardIdToSeqNumbers.toMap } - def describeStream(streamNameToDescribe: String = streamName): Option[StreamDescription] = { - try { - val describeStreamRequest = new DescribeStreamRequest().withStreamName(streamNameToDescribe) - val desc = kinesisClient.describeStream(describeStreamRequest).getStreamDescription() - Some(desc) - } catch { - case rnfe: ResourceNotFoundException => - None - } - } - def deleteStream(): Unit = { try { - if (describeStream().nonEmpty) { - val deleteStreamRequest = new DeleteStreamRequest() + if (streamCreated) { kinesisClient.deleteStream(streamName) } } catch { @@ -149,6 +139,17 @@ private class KinesisTestUtils( } } + private def describeStream(streamNameToDescribe: String): Option[StreamDescription] = { + try { + val describeStreamRequest = new DescribeStreamRequest().withStreamName(streamNameToDescribe) + val desc = kinesisClient.describeStream(describeStreamRequest).getStreamDescription() + Some(desc) + } catch { + case rnfe: ResourceNotFoundException => + None + } + } + private def findNonExistentStreamName(): String = { var testStreamName: String = null do { diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index f9c952b9468bb..b88c9c6478d56 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -88,11 +88,11 @@ class KinesisStreamSuite extends KinesisFunSuite try { kinesisTestUtils.createStream() ssc = new StreamingContext(sc, Seconds(1)) - val aWSCredentials = KinesisTestUtils.getAWSCredentials() + val awsCredentials = KinesisTestUtils.getAWSCredentials() val stream = KinesisUtils.createStream(ssc, kinesisAppName, kinesisTestUtils.streamName, kinesisTestUtils.endpointUrl, kinesisTestUtils.regionName, InitialPositionInStream.LATEST, Seconds(10), StorageLevel.MEMORY_ONLY, - aWSCredentials.getAWSAccessKeyId, aWSCredentials.getAWSSecretKey) + awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => From f06b5def7294de39c8a76d910cc900f5ee0c1864 Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 30 Jul 2015 16:57:13 -0700 Subject: [PATCH 014/340] Removing unnecessary check --- .../scala/org/apache/spark/util/collection/CompactBuffer.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala index 5bed400cf96ef..4d43d8d5cc8d8 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala @@ -141,7 +141,6 @@ private[spark] class CompactBuffer[T: ClassTag] extends Seq[T] with Serializable newArrayLen = Int.MaxValue - 2 } } - require(newArrayLen != null) val newArray = new Array[T](newArrayLen) if (otherElements != null) { System.arraycopy(otherElements, 0, newArray, 0, otherElements.length) From ca71cc8c8b2d64b7756ae697c06876cd18b536dc Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 30 Jul 2015 16:57:38 -0700 Subject: [PATCH 015/340] [SPARK-9408] [PYSPARK] [MLLIB] Refactor linalg.py to /linalg This is based on MechCoder 's PR https://github.com/apache/spark/pull/7731. Hopefully it could pass tests. MechCoder I tried to make minimal changes. If this passes Jenkins, we can merge this one first and then try to move `__init__.py` to `local.py` in a separate PR. Closes #7731 Author: Xiangrui Meng Closes #7746 from mengxr/SPARK-9408 and squashes the following commits: 0e05a3b [Xiangrui Meng] merge master 1135551 [Xiangrui Meng] add a comment for str(...) c48cae0 [Xiangrui Meng] update tests 173a805 [Xiangrui Meng] move linalg.py to linalg/__init__.py --- dev/sparktestsupport/modules.py | 2 +- python/pyspark/mllib/{linalg.py => linalg/__init__.py} | 0 python/pyspark/sql/types.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename python/pyspark/mllib/{linalg.py => linalg/__init__.py} (100%) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 030d982e99106..44600cb9523c1 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -323,7 +323,7 @@ def contains_file(self, filename): "pyspark.mllib.evaluation", "pyspark.mllib.feature", "pyspark.mllib.fpm", - "pyspark.mllib.linalg", + "pyspark.mllib.linalg.__init__", "pyspark.mllib.random", "pyspark.mllib.recommendation", "pyspark.mllib.regression", diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg/__init__.py similarity index 100% rename from python/pyspark/mllib/linalg.py rename to python/pyspark/mllib/linalg/__init__.py diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 0976aea72c034..6f74b7162f7cc 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -648,7 +648,7 @@ def jsonValue(self): @classmethod def fromJson(cls, json): - pyUDT = str(json["pyClass"]) + pyUDT = str(json["pyClass"]) # convert unicode to str split = pyUDT.rfind(".") pyModule = pyUDT[:split] pyClass = pyUDT[split+1:] From df32669514afc0223ecdeca30fbfbe0b40baef3a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 30 Jul 2015 17:16:03 -0700 Subject: [PATCH 016/340] [SPARK-7157][SQL] add sampleBy to DataFrame This was previously committed but then reverted due to test failures (see #6769). Author: Xiangrui Meng Closes #7755 from rxin/SPARK-7157 and squashes the following commits: fbf9044 [Xiangrui Meng] fix python test 542bd37 [Xiangrui Meng] update test 604fe6d [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7157 f051afd [Xiangrui Meng] use udf instead of building expression f4e9425 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7157 8fb990b [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7157 103beb3 [Xiangrui Meng] add Java-friendly sampleBy 991f26f [Xiangrui Meng] fix seed 4a14834 [Xiangrui Meng] move sampleBy to stat 832f7cc [Xiangrui Meng] add sampleBy to DataFrame --- python/pyspark/sql/dataframe.py | 41 ++++++++++++++++++ .../spark/sql/DataFrameStatFunctions.scala | 42 +++++++++++++++++++ .../apache/spark/sql/JavaDataFrameSuite.java | 9 ++++ .../apache/spark/sql/DataFrameStatSuite.scala | 12 +++++- 4 files changed, 102 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d76e051bd73a1..0f3480c239187 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -441,6 +441,42 @@ def sample(self, withReplacement, fraction, seed=None): rdd = self._jdf.sample(withReplacement, fraction, long(seed)) return DataFrame(rdd, self.sql_ctx) + @since(1.5) + def sampleBy(self, col, fractions, seed=None): + """ + Returns a stratified sample without replacement based on the + fraction given on each stratum. + + :param col: column that defines strata + :param fractions: + sampling fraction for each stratum. If a stratum is not + specified, we treat its fraction as zero. + :param seed: random seed + :return: a new DataFrame that represents the stratified sample + + >>> from pyspark.sql.functions import col + >>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key")) + >>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0) + >>> sampled.groupBy("key").count().orderBy("key").show() + +---+-----+ + |key|count| + +---+-----+ + | 0| 3| + | 1| 8| + +---+-----+ + + """ + if not isinstance(col, str): + raise ValueError("col must be a string, but got %r" % type(col)) + if not isinstance(fractions, dict): + raise ValueError("fractions must be a dict but got %r" % type(fractions)) + for k, v in fractions.items(): + if not isinstance(k, (float, int, long, basestring)): + raise ValueError("key must be float, int, long, or string, but got %r" % type(k)) + fractions[k] = float(v) + seed = seed if seed is not None else random.randint(0, sys.maxsize) + return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx) + @since(1.4) def randomSplit(self, weights, seed=None): """Randomly splits this :class:`DataFrame` with the provided weights. @@ -1314,6 +1350,11 @@ def freqItems(self, cols, support=None): freqItems.__doc__ = DataFrame.freqItems.__doc__ + def sampleBy(self, col, fractions, seed=None): + return self.df.sampleBy(col, fractions, seed) + + sampleBy.__doc__ = DataFrame.sampleBy.__doc__ + def _test(): import doctest diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 4ec58082e7aef..2e68e358f2f1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +import java.{util => ju, lang => jl} + +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.sql.execution.stat._ @@ -166,4 +170,42 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def freqItems(cols: Seq[String]): DataFrame = { FrequentItems.singlePassFreqItems(df, cols, 0.01) } + + /** + * Returns a stratified sample without replacement based on the fraction given on each stratum. + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @tparam T stratum type + * @return a new [[DataFrame]] that represents the stratified sample + * + * @since 1.5.0 + */ + def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = { + require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), + s"Fractions must be in [0, 1], but got $fractions.") + import org.apache.spark.sql.functions.{rand, udf} + val c = Column(col) + val r = rand(seed) + val f = udf { (stratum: Any, x: Double) => + x < fractions.getOrElse(stratum.asInstanceOf[T], 0.0) + } + df.filter(f(c, r)) + } + + /** + * Returns a stratified sample without replacement based on the fraction given on each stratum. + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @tparam T stratum type + * @return a new [[DataFrame]] that represents the stratified sample + * + * @since 1.5.0 + */ + def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { + sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 9e61d06f4036e..2c669bb59a0b5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -226,4 +226,13 @@ public void testCovariance() { Double result = df.stat().cov("a", "b"); Assert.assertTrue(Math.abs(result) < 1e-6); } + + @Test + public void testSampleBy() { + DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key")); + DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); + Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); + Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)}; + Assert.assertArrayEquals(expected, actual); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 7ba4ba73e0cc9..07a675e64f527 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -21,9 +21,9 @@ import java.util.Random import org.scalatest.Matchers._ -import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.functions.col -class DataFrameStatSuite extends SparkFunSuite { +class DataFrameStatSuite extends QueryTest { private val sqlCtx = org.apache.spark.sql.test.TestSQLContext import sqlCtx.implicits._ @@ -130,4 +130,12 @@ class DataFrameStatSuite extends SparkFunSuite { val items2 = singleColResults.collect().head items2.getSeq[Double](0) should contain (-1.0) } + + test("sampleBy") { + val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key")) + val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) + checkAnswer( + sampled.groupBy("key").count().orderBy("key"), + Seq(Row(0, 5), Row(1, 8))) + } } From e7a0976e991f75a7bda99509e2b040daab965ae6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 30 Jul 2015 17:17:27 -0700 Subject: [PATCH 017/340] [SPARK-9458][SPARK-9469][SQL] Code generate prefix computation in sorting & moves unsafe conversion out of TungstenSort. Author: Reynold Xin Closes #7803 from rxin/SPARK-9458 and squashes the following commits: 5b032dc [Reynold Xin] Fix string. b670dbb [Reynold Xin] [SPARK-9458][SPARK-9469][SQL] Code generate prefix computation in sorting & moves unsafe conversion out of TungstenSort. --- .../unsafe/sort/PrefixComparators.java | 49 ++++++++------ .../unsafe/sort/PrefixComparatorsSuite.scala | 22 ++----- .../execution/UnsafeExternalRowSorter.java | 27 ++++---- .../sql/catalyst/expressions/SortOrder.scala | 44 ++++++++++++- .../spark/sql/execution/SortPrefixUtils.scala | 64 +++---------------- .../spark/sql/execution/SparkStrategies.scala | 4 +- .../sql/execution/joins/HashedRelation.scala | 4 +- .../org/apache/spark/sql/execution/sort.scala | 64 ++++++++----------- .../execution/RowFormatConvertersSuite.scala | 11 ++-- ...ortSuite.scala => TungstenSortSuite.scala} | 10 +-- 10 files changed, 138 insertions(+), 161 deletions(-) rename sql/core/src/test/scala/org/apache/spark/sql/execution/{UnsafeExternalSortSuite.scala => TungstenSortSuite.scala} (87%) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index 600aff7d15d8a..4d7e5b3dfba6e 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -28,9 +28,11 @@ public class PrefixComparators { private PrefixComparators() {} public static final StringPrefixComparator STRING = new StringPrefixComparator(); - public static final IntegralPrefixComparator INTEGRAL = new IntegralPrefixComparator(); - public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator(); + public static final StringPrefixComparatorDesc STRING_DESC = new StringPrefixComparatorDesc(); + public static final LongPrefixComparator LONG = new LongPrefixComparator(); + public static final LongPrefixComparatorDesc LONG_DESC = new LongPrefixComparatorDesc(); public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator(); + public static final DoublePrefixComparatorDesc DOUBLE_DESC = new DoublePrefixComparatorDesc(); public static final class StringPrefixComparator extends PrefixComparator { @Override @@ -38,50 +40,55 @@ public int compare(long aPrefix, long bPrefix) { return UnsignedLongs.compare(aPrefix, bPrefix); } - public long computePrefix(UTF8String value) { + public static long computePrefix(UTF8String value) { return value == null ? 0L : value.getPrefix(); } } - /** - * Prefix comparator for all integral types (boolean, byte, short, int, long). - */ - public static final class IntegralPrefixComparator extends PrefixComparator { + public static final class StringPrefixComparatorDesc extends PrefixComparator { + @Override + public int compare(long bPrefix, long aPrefix) { + return UnsignedLongs.compare(aPrefix, bPrefix); + } + } + + public static final class LongPrefixComparator extends PrefixComparator { @Override public int compare(long a, long b) { return (a < b) ? -1 : (a > b) ? 1 : 0; } + } - public final long NULL_PREFIX = Long.MIN_VALUE; + public static final class LongPrefixComparatorDesc extends PrefixComparator { + @Override + public int compare(long b, long a) { + return (a < b) ? -1 : (a > b) ? 1 : 0; + } } - public static final class FloatPrefixComparator extends PrefixComparator { + public static final class DoublePrefixComparator extends PrefixComparator { @Override public int compare(long aPrefix, long bPrefix) { - float a = Float.intBitsToFloat((int) aPrefix); - float b = Float.intBitsToFloat((int) bPrefix); - return Utils.nanSafeCompareFloats(a, b); + double a = Double.longBitsToDouble(aPrefix); + double b = Double.longBitsToDouble(bPrefix); + return Utils.nanSafeCompareDoubles(a, b); } - public long computePrefix(float value) { - return Float.floatToIntBits(value) & 0xffffffffL; + public static long computePrefix(double value) { + return Double.doubleToLongBits(value); } - - public final long NULL_PREFIX = computePrefix(Float.NEGATIVE_INFINITY); } - public static final class DoublePrefixComparator extends PrefixComparator { + public static final class DoublePrefixComparatorDesc extends PrefixComparator { @Override - public int compare(long aPrefix, long bPrefix) { + public int compare(long bPrefix, long aPrefix) { double a = Double.longBitsToDouble(aPrefix); double b = Double.longBitsToDouble(bPrefix); return Utils.nanSafeCompareDoubles(a, b); } - public long computePrefix(double value) { + public static long computePrefix(double value) { return Double.doubleToLongBits(value); } - - public final long NULL_PREFIX = computePrefix(Double.NEGATIVE_INFINITY); } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index cf53a8ad21c60..26a2e96edaaa2 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -29,8 +29,8 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { def testPrefixComparison(s1: String, s2: String): Unit = { val utf8string1 = UTF8String.fromString(s1) val utf8string2 = UTF8String.fromString(s2) - val s1Prefix = PrefixComparators.STRING.computePrefix(utf8string1) - val s2Prefix = PrefixComparators.STRING.computePrefix(utf8string2) + val s1Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string1) + val s2Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string2) val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix) val cmp = UnsignedBytes.lexicographicalComparator().compare( @@ -55,27 +55,15 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) } } - test("float prefix comparator handles NaN properly") { - val nan1: Float = java.lang.Float.intBitsToFloat(0x7f800001) - val nan2: Float = java.lang.Float.intBitsToFloat(0x7fffffff) - assert(nan1.isNaN) - assert(nan2.isNaN) - val nan1Prefix = PrefixComparators.FLOAT.computePrefix(nan1) - val nan2Prefix = PrefixComparators.FLOAT.computePrefix(nan2) - assert(nan1Prefix === nan2Prefix) - val floatMaxPrefix = PrefixComparators.FLOAT.computePrefix(Float.MaxValue) - assert(PrefixComparators.FLOAT.compare(nan1Prefix, floatMaxPrefix) === 1) - } - test("double prefix comparator handles NaNs properly") { val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L) val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL) assert(nan1.isNaN) assert(nan2.isNaN) - val nan1Prefix = PrefixComparators.DOUBLE.computePrefix(nan1) - val nan2Prefix = PrefixComparators.DOUBLE.computePrefix(nan2) + val nan1Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan1) + val nan2Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan2) assert(nan1Prefix === nan2Prefix) - val doubleMaxPrefix = PrefixComparators.DOUBLE.computePrefix(Double.MaxValue) + val doubleMaxPrefix = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue) assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1) } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 4c3f2c6557140..68c49feae938e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -48,7 +48,6 @@ final class UnsafeExternalRowSorter { private long numRowsInserted = 0; private final StructType schema; - private final UnsafeProjection unsafeProjection; private final PrefixComputer prefixComputer; private final UnsafeExternalSorter sorter; @@ -62,7 +61,6 @@ public UnsafeExternalRowSorter( PrefixComparator prefixComparator, PrefixComputer prefixComputer) throws IOException { this.schema = schema; - this.unsafeProjection = UnsafeProjection.create(schema); this.prefixComputer = prefixComputer; final SparkEnv sparkEnv = SparkEnv.get(); final TaskContext taskContext = TaskContext.get(); @@ -88,13 +86,12 @@ void setTestSpillFrequency(int frequency) { } @VisibleForTesting - void insertRow(InternalRow row) throws IOException { - UnsafeRow unsafeRow = unsafeProjection.apply(row); + void insertRow(UnsafeRow row) throws IOException { final long prefix = prefixComputer.computePrefix(row); sorter.insertRecord( - unsafeRow.getBaseObject(), - unsafeRow.getBaseOffset(), - unsafeRow.getSizeInBytes(), + row.getBaseObject(), + row.getBaseOffset(), + row.getSizeInBytes(), prefix ); numRowsInserted++; @@ -113,7 +110,7 @@ private void cleanupResources() { } @VisibleForTesting - Iterator sort() throws IOException { + Iterator sort() throws IOException { try { final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); if (!sortedIterator.hasNext()) { @@ -121,7 +118,7 @@ Iterator sort() throws IOException { // here in order to prevent memory leaks. cleanupResources(); } - return new AbstractScalaRowIterator() { + return new AbstractScalaRowIterator() { private final int numFields = schema.length(); private UnsafeRow row = new UnsafeRow(); @@ -132,7 +129,7 @@ public boolean hasNext() { } @Override - public InternalRow next() { + public UnsafeRow next() { try { sortedIterator.loadNext(); row.pointTo( @@ -164,11 +161,11 @@ public InternalRow next() { } - public Iterator sort(Iterator inputIterator) throws IOException { - while (inputIterator.hasNext()) { - insertRow(inputIterator.next()); - } - return sort(); + public Iterator sort(Iterator inputIterator) throws IOException { + while (inputIterator.hasNext()) { + insertRow(inputIterator.next()); + } + return sort(); } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 3f436c0eb893c..9fe877f10fa08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.types._ +import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator abstract sealed class SortDirection case object Ascending extends SortDirection @@ -37,4 +40,43 @@ case class SortOrder(child: Expression, direction: SortDirection) override def nullable: Boolean = child.nullable override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" + + def isAscending: Boolean = direction == Ascending +} + +/** + * An expression to generate a 64-bit long prefix used in sorting. + */ +case class SortPrefix(child: SortOrder) extends UnaryExpression { + + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val childCode = child.child.gen(ctx) + val input = childCode.primitive + val DoublePrefixCmp = classOf[DoublePrefixComparator].getName + + val (nullValue: Long, prefixCode: String) = child.child.dataType match { + case BooleanType => + (Long.MinValue, s"$input ? 1L : 0L") + case _: IntegralType => + (Long.MinValue, s"(long) $input") + case FloatType | DoubleType => + (DoublePrefixComparator.computePrefix(Double.NegativeInfinity), + s"$DoublePrefixCmp.computePrefix((double)$input)") + case StringType => (0L, s"$input.getPrefix()") + case _ => (0L, "0L") + } + + childCode.code + + s""" + |long ${ev.primitive} = ${nullValue}L; + |boolean ${ev.isNull} = false; + |if (!${childCode.isNull}) { + | ${ev.primitive} = $prefixCode; + |} + """.stripMargin + } + + override def dataType: DataType = LongType } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 2dee3542d6101..a2145b185ce90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -18,10 +18,8 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator} @@ -37,61 +35,15 @@ object SortPrefixUtils { def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = { sortOrder.dataType match { - case StringType => PrefixComparators.STRING - case BooleanType | ByteType | ShortType | IntegerType | LongType => PrefixComparators.INTEGRAL - case FloatType => PrefixComparators.FLOAT - case DoubleType => PrefixComparators.DOUBLE + case StringType if sortOrder.isAscending => PrefixComparators.STRING + case StringType if !sortOrder.isAscending => PrefixComparators.STRING_DESC + case BooleanType | ByteType | ShortType | IntegerType | LongType if sortOrder.isAscending => + PrefixComparators.LONG + case BooleanType | ByteType | ShortType | IntegerType | LongType if !sortOrder.isAscending => + PrefixComparators.LONG_DESC + case FloatType | DoubleType if sortOrder.isAscending => PrefixComparators.DOUBLE + case FloatType | DoubleType if !sortOrder.isAscending => PrefixComparators.DOUBLE_DESC case _ => NoOpPrefixComparator } } - - def getPrefixComputer(sortOrder: SortOrder): InternalRow => Long = { - sortOrder.dataType match { - case StringType => (row: InternalRow) => { - PrefixComparators.STRING.computePrefix(sortOrder.child.eval(row).asInstanceOf[UTF8String]) - } - case BooleanType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else if (sortOrder.child.eval(row).asInstanceOf[Boolean]) 1 - else 0 - } - case ByteType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Byte] - } - case ShortType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Short] - } - case IntegerType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Int] - } - case LongType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Long] - } - case FloatType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.FLOAT.NULL_PREFIX - else PrefixComparators.FLOAT.computePrefix(sortOrder.child.eval(row).asInstanceOf[Float]) - } - case DoubleType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.DOUBLE.NULL_PREFIX - else PrefixComparators.DOUBLE.computePrefix(sortOrder.child.eval(row).asInstanceOf[Double]) - } - case _ => (row: InternalRow) => 0L - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 52a9b02d373c7..03d24a88d4ecd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -341,8 +341,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = { if (sqlContext.conf.unsafeEnabled && sqlContext.conf.codegenEnabled && - UnsafeExternalSort.supportsSchema(child.schema)) { - execution.UnsafeExternalSort(sortExprs, global, child) + TungstenSort.supportsSchema(child.schema)) { + execution.TungstenSort(sortExprs, global, child) } else if (sqlContext.conf.externalSortEnabled) { execution.ExternalSort(sortExprs, global, child) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 26dbc911e9521..f88a45f48aee9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -229,7 +229,7 @@ private[joins] final class UnsafeHashedRelation( // write all the values as single byte array var totalSize = 0L var i = 0 - while (i < values.size) { + while (i < values.length) { totalSize += values(i).getSizeInBytes + 4 + 4 i += 1 } @@ -240,7 +240,7 @@ private[joins] final class UnsafeHashedRelation( out.writeInt(totalSize.toInt) out.write(key.getBytes) i = 0 - while (i < values.size) { + while (i < values.length) { // [num of fields] [num of bytes] [row bytes] // write the integer in native order, so they can be read by UNSAFE.getInt() if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index f82208868c3e3..6d903ab23c57f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -17,16 +17,14 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions.{Descending, BindReferences, Attribute, SortOrder} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, OrderedDistribution, Distribution} import org.apache.spark.sql.types.StructType import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter -import org.apache.spark.util.collection.unsafe.sort.PrefixComparator //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines various sort operators. @@ -97,59 +95,53 @@ case class ExternalSort( * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will * spill every `frequency` records. */ -case class UnsafeExternalSort( +case class TungstenSort( sortOrder: Seq[SortOrder], global: Boolean, child: SparkPlan, testSpillFrequency: Int = 0) extends UnaryNode { - private[this] val schema: StructType = child.schema + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder override def requiredChildDistribution: Seq[Distribution] = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled") - def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = { - val ordering = newOrdering(sortOrder, child.output) - val boundSortExpression = BindReferences.bindReference(sortOrder.head, child.output) - // Hack until we generate separate comparator implementations for ascending vs. descending - // (or choose to codegen them): - val prefixComparator = { - val comp = SortPrefixUtils.getPrefixComparator(boundSortExpression) - if (sortOrder.head.direction == Descending) { - new PrefixComparator { - override def compare(p1: Long, p2: Long): Int = -1 * comp.compare(p1, p2) - } - } else { - comp - } - } - val prefixComputer = { - val prefixComputer = SortPrefixUtils.getPrefixComputer(boundSortExpression) - new UnsafeExternalRowSorter.PrefixComputer { - override def computePrefix(row: InternalRow): Long = prefixComputer(row) + protected override def doExecute(): RDD[InternalRow] = { + val schema = child.schema + val childOutput = child.output + child.execute().mapPartitions({ iter => + val ordering = newOrdering(sortOrder, childOutput) + + // The comparator for comparing prefix + val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput) + val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) + + // The generator for prefix + val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) + val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { + override def computePrefix(row: InternalRow): Long = { + prefixProjection.apply(row).getLong(0) } } + val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer) if (testSpillFrequency > 0) { sorter.setTestSpillFrequency(testSpillFrequency) } - sorter.sort(iterator) - } - child.execute().mapPartitions(doSort, preservesPartitioning = true) + sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + }, preservesPartitioning = true) } - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder - - override def outputsUnsafeRows: Boolean = true } -@DeveloperApi -object UnsafeExternalSort { +object TungstenSort { /** * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 7b75f755918c1..707cd9c6d939b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.IsNull +import org.apache.spark.sql.catalyst.expressions.{Literal, IsNull} import org.apache.spark.sql.test.TestSQLContext class RowFormatConvertersSuite extends SparkPlanTest { @@ -31,7 +30,7 @@ class RowFormatConvertersSuite extends SparkPlanTest { private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) assert(!outputsSafe.outputsUnsafeRows) - private val outputsUnsafe = UnsafeExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) + private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null)) assert(outputsUnsafe.outputsUnsafeRows) test("planner should insert unsafe->safe conversions when required") { @@ -41,14 +40,14 @@ class RowFormatConvertersSuite extends SparkPlanTest { } test("filter can process unsafe rows") { - val plan = Filter(IsNull(null), outputsUnsafe) + val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe) val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).isEmpty) + assert(getConverters(preparedPlan).size === 1) assert(preparedPlan.outputsUnsafeRows) } test("filter can process safe rows") { - val plan = Filter(IsNull(null), outputsSafe) + val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe) val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) assert(getConverters(preparedPlan).isEmpty) assert(!preparedPlan.outputsUnsafeRows) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala similarity index 87% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala index 138636b0c65b8..450963547c798 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ -class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { +class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { override def beforeAll(): Unit = { TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) @@ -39,7 +39,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { test("sort followed by limit") { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), + (child: SparkPlan) => Limit(10, TungstenSort('a.asc :: Nil, true, child)), (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), sortAnswers = false ) @@ -50,7 +50,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { val stringLength = 1024 * 1024 * 2 checkThatPlansAgree( Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), - UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), + TungstenSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), Sort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) @@ -70,11 +70,11 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), StructType(StructField("a", dataType, nullable = true) :: Nil) ) - assert(UnsafeExternalSort.supportsSchema(inputDf.schema)) + assert(TungstenSort.supportsSchema(inputDf.schema)) checkThatPlansAgree( inputDf, plan => ConvertToSafe( - UnsafeExternalSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), + TungstenSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), Sort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) From 0b1a464b6e061580a75b99a91b042069d76bbbfd Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 30 Jul 2015 17:18:32 -0700 Subject: [PATCH 018/340] [SPARK-9425] [SQL] support DecimalType in UnsafeRow This PR brings the support of DecimalType in UnsafeRow, for precision <= 18, it's settable, otherwise it's not settable. Author: Davies Liu Closes #7758 from davies/unsafe_decimal and squashes the following commits: 478b1ba [Davies Liu] address comments 536314c [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_decimal 7c2e77a [Davies Liu] fix JoinedRow 76d6fa4 [Davies Liu] fix tests 99d3151 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_decimal d49c6ae [Davies Liu] support DecimalType in UnsafeRow --- .../expressions/SpecializedGetters.java | 2 +- .../UnsafeFixedWidthAggregationMap.java | 22 ++-- .../sql/catalyst/expressions/UnsafeRow.java | 53 +++++--- .../expressions/UnsafeRowWriters.java | 42 +++++++ .../sql/catalyst/CatalystTypeConverters.scala | 9 +- .../spark/sql/catalyst/InternalRow.scala | 4 +- .../sql/catalyst/expressions/Projection.scala | 7 +- .../expressions/codegen/CodeGenerator.scala | 9 +- .../codegen/GenerateUnsafeProjection.scala | 115 ++++++++++-------- .../spark/sql/catalyst/expressions/rows.scala | 3 +- .../org/apache/spark/sql/types/Decimal.scala | 6 +- .../spark/sql/types/GenericArrayData.scala | 2 +- .../sql/catalyst/expressions/CastSuite.scala | 5 +- .../expressions/DateExpressionsSuite.scala | 2 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 8 +- .../expressions/UnsafeRowConverterSuite.scala | 17 +-- .../spark/sql/columnar/ColumnBuilder.scala | 2 +- .../spark/sql/columnar/ColumnStats.scala | 4 +- .../spark/sql/columnar/ColumnType.scala | 2 +- .../sql/execution/GeneratedAggregate.scala | 2 +- .../sql/execution/SparkSqlSerializer2.scala | 2 +- .../sql/parquet/ParquetTableSupport.scala | 4 +- .../spark/sql/columnar/ColumnStatsSuite.scala | 40 +++++- 23 files changed, 237 insertions(+), 125 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java index f7cea13688876..e3d3ba7a9ccc0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java @@ -41,7 +41,7 @@ public interface SpecializedGetters { double getDouble(int ordinal); - Decimal getDecimal(int ordinal); + Decimal getDecimal(int ordinal, int precision, int scale); UTF8String getUTF8String(int ordinal); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index 03f4c3ed8e6bb..f3b462778dc10 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -20,6 +20,8 @@ import java.util.Iterator; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.PlatformDependent; @@ -61,26 +63,18 @@ public final class UnsafeFixedWidthAggregationMap { private final boolean enablePerfMetrics; - /** - * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema, - * false otherwise. - */ - public static boolean supportsGroupKeySchema(StructType schema) { - for (StructField field: schema.fields()) { - if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) { - return false; - } - } - return true; - } - /** * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given * schema, false otherwise. */ public static boolean supportsAggregationBufferSchema(StructType schema) { for (StructField field: schema.fields()) { - if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) { + if (field.dataType() instanceof DecimalType) { + DecimalType dt = (DecimalType) field.dataType(); + if (dt.precision() > Decimal.MAX_LONG_DIGITS()) { + return false; + } + } else if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) { return false; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 6d684bac37573..e7088edced1a1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -19,6 +19,8 @@ import java.io.IOException; import java.io.OutputStream; +import java.math.BigDecimal; +import java.math.BigInteger; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; @@ -65,12 +67,7 @@ public static int calculateBitSetWidthInBytes(int numFields) { */ public static final Set settableFieldTypes; - /** - * Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException). - */ - public static final Set readableFieldTypes; - - // TODO: support DecimalType + // DecimalType(precision <= 18) is settable static { settableFieldTypes = Collections.unmodifiableSet( new HashSet<>( @@ -86,16 +83,6 @@ public static int calculateBitSetWidthInBytes(int numFields) { DateType, TimestampType }))); - - // We support get() on a superset of the types for which we support set(): - final Set _readableFieldTypes = new HashSet<>( - Arrays.asList(new DataType[]{ - StringType, - BinaryType, - CalendarIntervalType - })); - _readableFieldTypes.addAll(settableFieldTypes); - readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes); } ////////////////////////////////////////////////////////////////////////////// @@ -232,6 +219,21 @@ public void setFloat(int ordinal, float value) { PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); } + @Override + public void setDecimal(int ordinal, Decimal value, int precision) { + assertIndexIsValid(ordinal); + if (value == null) { + setNullAt(ordinal); + } else { + if (precision <= Decimal.MAX_LONG_DIGITS()) { + setLong(ordinal, value.toUnscaledLong()); + } else { + // TODO(davies): support update decimal (hold a bounded space even it's null) + throw new UnsupportedOperationException(); + } + } + } + @Override public Object get(int ordinal) { throw new UnsupportedOperationException(); @@ -256,7 +258,8 @@ public Object get(int ordinal, DataType dataType) { } else if (dataType instanceof DoubleType) { return getDouble(ordinal); } else if (dataType instanceof DecimalType) { - return getDecimal(ordinal); + DecimalType dt = (DecimalType) dataType; + return getDecimal(ordinal, dt.precision(), dt.scale()); } else if (dataType instanceof DateType) { return getInt(ordinal); } else if (dataType instanceof TimestampType) { @@ -322,6 +325,22 @@ public double getDouble(int ordinal) { return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal)); } + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + assertIndexIsValid(ordinal); + if (isNullAt(ordinal)) { + return null; + } + if (precision <= Decimal.MAX_LONG_DIGITS()) { + return Decimal.apply(getLong(ordinal), precision, scale); + } else { + byte[] bytes = getBinary(ordinal); + BigInteger bigInteger = new BigInteger(bytes); + BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(new scala.math.BigDecimal(javaDecimal), precision, scale); + } + } + @Override public UTF8String getUTF8String(int ordinal) { assertIndexIsValid(ordinal); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java index c3259e21c4a78..f43a285cd6cad 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.ByteArray; @@ -30,6 +31,47 @@ */ public class UnsafeRowWriters { + /** Writer for Decimal with precision under 18. */ + public static class CompactDecimalWriter { + + public static int getSize(Decimal input) { + return 0; + } + + public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) { + target.setLong(ordinal, input.toUnscaledLong()); + return 0; + } + } + + /** Writer for Decimal with precision larger than 18. */ + public static class DecimalWriter { + + public static int getSize(Decimal input) { + // bounded size + return 16; + } + + public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) { + final long offset = target.getBaseOffset() + cursor; + final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); + final int numBytes = bytes.length; + assert(numBytes <= 16); + + // zero-out the bytes + PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, 0L); + PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, 0L); + + // Write the bytes to the variable length portion. + PlatformDependent.copyMemory(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, + target.getBaseObject(), offset, numBytes); + + // Set the fixed length portion. + target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); + return 16; + } + } + /** Writer for UTF8String. */ public static class UTF8StringWriter { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 22452c0f201ef..7ca20fe97fbef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -68,7 +68,7 @@ object CatalystTypeConverters { case StringType => StringConverter case DateType => DateConverter case TimestampType => TimestampConverter - case dt: DecimalType => BigDecimalConverter + case dt: DecimalType => new DecimalConverter(dt) case BooleanType => BooleanConverter case ByteType => ByteConverter case ShortType => ShortConverter @@ -306,7 +306,8 @@ object CatalystTypeConverters { DateTimeUtils.toJavaTimestamp(row.getLong(column)) } - private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { + private class DecimalConverter(dataType: DecimalType) + extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match { case d: BigDecimal => Decimal(d) case d: JavaBigDecimal => Decimal(d) @@ -314,9 +315,11 @@ object CatalystTypeConverters { } override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal = - row.getDecimal(column).toJavaBigDecimal + row.getDecimal(column, dataType.precision, dataType.scale).toJavaBigDecimal } + private object BigDecimalConverter extends DecimalConverter(DecimalType.SYSTEM_DEFAULT) + private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] { final override def toScala(catalystValue: Any): Any = catalystValue final override def toCatalystImpl(scalaValue: T): Any = scalaValue diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 486ba036548c8..b19bf4386b0ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -58,8 +58,8 @@ abstract class InternalRow extends Serializable with SpecializedGetters { override def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal, BinaryType) - override def getDecimal(ordinal: Int): Decimal = - getAs[Decimal](ordinal, DecimalType.SYSTEM_DEFAULT) + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = + getAs[Decimal](ordinal, DecimalType(precision, scale)) override def getInterval(ordinal: Int): CalendarInterval = getAs[CalendarInterval](ordinal, CalendarIntervalType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index b3beb7e28f208..7c7664e4c1a91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateMutableProjection} -import org.apache.spark.sql.types.{StructType, DataType} +import org.apache.spark.sql.types.{Decimal, StructType, DataType} import org.apache.spark.unsafe.types.UTF8String /** @@ -225,6 +225,11 @@ class JoinedRow extends InternalRow { override def getFloat(i: Int): Float = if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) + override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = { + if (i < row1.numFields) row1.getDecimal(i, precision, scale) + else row2.getDecimal(i - row1.numFields, precision, scale) + } + override def getStruct(i: Int, numFields: Int): InternalRow = { if (i < row1.numFields) { row1.getStruct(i, numFields) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index c39e0df6fae2a..60e2863f7bbb0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -106,6 +106,7 @@ class CodeGenContext { val jt = javaType(dataType) dataType match { case _ if isPrimitiveType(jt) => s"$getter.get${primitiveTypeName(jt)}($ordinal)" + case t: DecimalType => s"$getter.getDecimal($ordinal, ${t.precision}, ${t.scale})" case StringType => s"$getter.getUTF8String($ordinal)" case BinaryType => s"$getter.getBinary($ordinal)" case CalendarIntervalType => s"$getter.getInterval($ordinal)" @@ -120,10 +121,10 @@ class CodeGenContext { */ def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = { val jt = javaType(dataType) - if (isPrimitiveType(jt)) { - s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" - } else { - s"$row.update($ordinal, $value)" + dataType match { + case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" + case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" + case _ => s"$row.update($ordinal, $value)" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index a662357fb6cf9..1d223986d9441 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -35,6 +35,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName private val IntervalWriter = classOf[UnsafeRowWriters.IntervalWriter].getName private val StructWriter = classOf[UnsafeRowWriters.StructWriter].getName + private val CompactDecimalWriter = classOf[UnsafeRowWriters.CompactDecimalWriter].getName + private val DecimalWriter = classOf[UnsafeRowWriters.DecimalWriter].getName /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = dataType match { @@ -42,9 +44,64 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _: CalendarIntervalType => true case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) case NullType => true + case t: DecimalType => true case _ => false } + def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match { + case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => + s" + (${ev.isNull} ? 0 : $DecimalWriter.getSize(${ev.primitive}))" + case StringType => + s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))" + case BinaryType => + s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))" + case CalendarIntervalType => + s" + (${ev.isNull} ? 0 : 16)" + case _: StructType => + s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))" + case _ => "" + } + + def genFieldWriter( + ctx: CodeGenContext, + fieldType: DataType, + ev: GeneratedExpressionCode, + primitive: String, + index: Int, + cursor: String): String = fieldType match { + case _ if ctx.isPrimitiveType(fieldType) => + s"${ctx.setColumn(primitive, fieldType, index, ev.primitive)}" + case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => + s""" + // make sure Decimal object has the same scale as DecimalType + if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) { + $CompactDecimalWriter.write($primitive, $index, $cursor, ${ev.primitive}); + } else { + $primitive.setNullAt($index); + } + """ + case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => + s""" + // make sure Decimal object has the same scale as DecimalType + if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) { + $cursor += $DecimalWriter.write($primitive, $index, $cursor, ${ev.primitive}); + } else { + $primitive.setNullAt($index); + } + """ + case StringType => + s"$cursor += $StringWriter.write($primitive, $index, $cursor, ${ev.primitive})" + case BinaryType => + s"$cursor += $BinaryWriter.write($primitive, $index, $cursor, ${ev.primitive})" + case CalendarIntervalType => + s"$cursor += $IntervalWriter.write($primitive, $index, $cursor, ${ev.primitive})" + case t: StructType => + s"$cursor += $StructWriter.write($primitive, $index, $cursor, ${ev.primitive})" + case NullType => "" + case _ => + throw new UnsupportedOperationException(s"Not supported DataType: $fieldType") + } + /** * Generates the code to create an [[UnsafeRow]] object based on the input expressions. * @param ctx context for code generation @@ -69,36 +126,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val allExprs = exprs.map(_.code).mkString("\n") val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) - val additionalSize = expressions.zipWithIndex.map { case (e, i) => - e.dataType match { - case StringType => - s" + (${exprs(i).isNull} ? 0 : $StringWriter.getSize(${exprs(i).primitive}))" - case BinaryType => - s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))" - case CalendarIntervalType => - s" + (${exprs(i).isNull} ? 0 : 16)" - case _: StructType => - s" + (${exprs(i).isNull} ? 0 : $StructWriter.getSize(${exprs(i).primitive}))" - case _ => "" - } + val additionalSize = expressions.zipWithIndex.map { + case (e, i) => genAdditionalSize(e.dataType, exprs(i)) }.mkString("") val writers = expressions.zipWithIndex.map { case (e, i) => - val update = e.dataType match { - case dt if ctx.isPrimitiveType(dt) => - s"${ctx.setColumn(ret, dt, i, exprs(i).primitive)}" - case StringType => - s"$cursor += $StringWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" - case BinaryType => - s"$cursor += $BinaryWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" - case CalendarIntervalType => - s"$cursor += $IntervalWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" - case t: StructType => - s"$cursor += $StructWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" - case NullType => "" - case _ => - throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}") - } + val update = genFieldWriter(ctx, e.dataType, exprs(i), ret, i, cursor) s"""if (${exprs(i).isNull}) { $ret.setNullAt($i); } else { @@ -168,35 +201,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) val additionalSize = schema.toSeq.map(_.dataType).zip(exprs).map { case (dt, ev) => - dt match { - case StringType => - s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))" - case BinaryType => - s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))" - case CalendarIntervalType => - s" + (${ev.isNull} ? 0 : 16)" - case _: StructType => - s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))" - case _ => "" - } + genAdditionalSize(dt, ev) }.mkString("") val writers = schema.toSeq.map(_.dataType).zip(exprs).zipWithIndex.map { case ((dt, ev), i) => - val update = dt match { - case _ if ctx.isPrimitiveType(dt) => - s"${ctx.setColumn(primitive, dt, i, exprs(i).primitive)}" - case StringType => - s"$cursor += $StringWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" - case BinaryType => - s"$cursor += $BinaryWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" - case CalendarIntervalType => - s"$cursor += $IntervalWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" - case t: StructType => - s"$cursor += $StructWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" - case NullType => "" - case _ => - throw new UnsupportedOperationException(s"Not supported DataType: $dt") - } + val update = genFieldWriter(ctx, dt, ev, primitive, i, cursor) s""" if (${exprs(i).isNull}) { $primitive.setNullAt($i); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index b7c4ece4a16fe..df6ea586c87ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types.{DataType, StructType, AtomicType} +import org.apache.spark.sql.types.{Decimal, DataType, StructType, AtomicType} import org.apache.spark.unsafe.types.UTF8String /** @@ -39,6 +39,7 @@ abstract class MutableRow extends InternalRow { def setShort(i: Int, value: Short): Unit = { update(i, value) } def setByte(i: Int, value: Byte): Unit = { update(i, value) } def setFloat(i: Int, value: Float): Unit = { update(i, value) } + def setDecimal(i: Int, value: Decimal, precision: Int) { update(i, value) } def setString(i: Int, value: String): Unit = { update(i, UTF8String.fromString(value)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index bc689810bc292..c0155eeb450a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -188,6 +188,10 @@ final class Decimal extends Ordered[Decimal] with Serializable { * @return true if successful, false if overflow would occur */ def changePrecision(precision: Int, scale: Int): Boolean = { + // fast path for UnsafeProjection + if (precision == this.precision && scale == this.scale) { + return true + } // First, update our longVal if we can, or transfer over to using a BigDecimal if (decimalVal.eq(null)) { if (scale < _scale) { @@ -224,7 +228,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { decimalVal = newVal } else { // We're still using Longs, but we should check whether we match the new precision - val p = POW_10(math.min(_precision, MAX_LONG_DIGITS)) + val p = POW_10(math.min(precision, MAX_LONG_DIGITS)) if (longVal <= -p || longVal >= p) { // Note that we shouldn't have been able to fix this by switching to BigDecimal return false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala index 7992ba947c069..35ace673fb3da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala @@ -43,7 +43,7 @@ class GenericArrayData(array: Array[Any]) extends ArrayData { override def getDouble(ordinal: Int): Double = getAs(ordinal) - override def getDecimal(ordinal: Int): Decimal = getAs(ordinal) + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 4f35b653d73c0..1ad70733eae03 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -242,10 +242,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123)) checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123)) - checkEvaluation(cast(123L, DecimalType(3, 1)), Decimal(123.0)) + checkEvaluation(cast(123L, DecimalType(3, 1)), null) - // TODO: Fix the following bug and re-enable it. - // checkEvaluation(cast(123L, DecimalType(2, 0)), null) + checkEvaluation(cast(123L, DecimalType(2, 0)), null) } test("cast from boolean") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index fd1d6c1d25497..887e43621a941 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.{Timestamp, Date} +import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat import java.util.Calendar diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala index 6a907290f2dbe..c6b4c729de2f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -55,13 +55,13 @@ class UnsafeFixedWidthAggregationMapSuite } test("supported schemas") { + assert(supportsAggregationBufferSchema( + StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil))) + assert(!supportsAggregationBufferSchema( + StructType(StructField("x", DecimalType.SYSTEM_DEFAULT) :: Nil))) assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) - assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil))) - assert( !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) - assert( - !supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) } test("empty map") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index b7bc17f89e82f..a0e1701339ea7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -46,7 +46,6 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getLong(1) === 1) assert(unsafeRow.getInt(2) === 2) - // We can copy UnsafeRows as long as they don't reference ObjectPools val unsafeRowCopy = unsafeRow.copy() assert(unsafeRowCopy.getLong(0) === 0) assert(unsafeRowCopy.getLong(1) === 1) @@ -122,8 +121,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { FloatType, DoubleType, StringType, - BinaryType - // DecimalType.Default, + BinaryType, + DecimalType.USER_DEFAULT // ArrayType(IntegerType) ) val converter = UnsafeProjection.create(fieldTypes) @@ -150,7 +149,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(createdFromNull.getDouble(7) === 0.0d) assert(createdFromNull.getUTF8String(8) === null) assert(createdFromNull.getBinary(9) === null) - // assert(createdFromNull.get(10) === null) + assert(createdFromNull.getDecimal(10, 10, 0) === null) // assert(createdFromNull.get(11) === null) // If we have an UnsafeRow with columns that are initially non-null and we null out those @@ -168,7 +167,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { r.setDouble(7, 700) r.update(8, UTF8String.fromString("hello")) r.update(9, "world".getBytes) - // r.update(10, Decimal(10)) + r.setDecimal(10, Decimal(10), 10) // r.update(11, Array(11)) r } @@ -184,7 +183,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) assert(setToNullAfterCreation.getBinary(9) === rowWithNoNullColumns.getBinary(9)) - // assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) + assert(setToNullAfterCreation.getDecimal(10, 10, 0) === + rowWithNoNullColumns.getDecimal(10, 10, 0)) // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) for (i <- fieldTypes.indices) { @@ -203,7 +203,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { setToNullAfterCreation.setDouble(7, 700) // setToNullAfterCreation.update(8, UTF8String.fromString("hello")) // setToNullAfterCreation.update(9, "world".getBytes) - // setToNullAfterCreation.update(10, Decimal(10)) + setToNullAfterCreation.setDecimal(10, Decimal(10), 10) // setToNullAfterCreation.update(11, Array(11)) assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) @@ -216,7 +216,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) // assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) // assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) - // assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) + assert(setToNullAfterCreation.getDecimal(10, 10, 0) === + rowWithNoNullColumns.getDecimal(10, 10, 0)) // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 454b7b91a63f5..1620fc401ba6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -114,7 +114,7 @@ private[sql] class FixedDecimalColumnBuilder( precision: Int, scale: Int) extends NativeColumnBuilder( - new FixedDecimalColumnStats, + new FixedDecimalColumnStats(precision, scale), FIXED_DECIMAL(precision, scale)) // TODO (lian) Add support for array, struct and map diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 32a84b2676e07..af1a8ecca9b57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -234,14 +234,14 @@ private[sql] class BinaryColumnStats extends ColumnStats { InternalRow(null, null, nullCount, count, sizeInBytes) } -private[sql] class FixedDecimalColumnStats extends ColumnStats { +private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { protected var upper: Decimal = null protected var lower: Decimal = null override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getDecimal(ordinal) + val value = row.getDecimal(ordinal, precision, scale) if (upper == null || value.compareTo(upper) > 0) upper = value if (lower == null || value.compareTo(lower) < 0) lower = value sizeInBytes += FIXED_DECIMAL.defaultSize diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 2863f6c230a9d..30f8fe320db3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -392,7 +392,7 @@ private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int) } override def getField(row: InternalRow, ordinal: Int): Decimal = { - row.getDecimal(ordinal) + row.getDecimal(ordinal, precision, scale) } override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index b85aada9d9d4c..d851eae3fcc71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -202,7 +202,7 @@ case class GeneratedAggregate( val schemaSupportsUnsafe: Boolean = { UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && - UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema) + UnsafeProjection.canSupport(groupKeySchema) } child.execute().mapPartitions { iter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index c808442a4849b..e5bbd0aaed0a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -298,7 +298,7 @@ private[sql] object SparkSqlSerializer2 { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val value = row.getDecimal(i) + val value = row.getDecimal(i, decimal.precision, decimal.scale) val javaBigDecimal = value.toJavaBigDecimal // First, write out the unscaled value. val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 79dd16b7b0c39..ec8da38a3d427 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -293,8 +293,8 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { writer.addBinary(Binary.fromByteArray(record.getUTF8String(index).getBytes)) case BinaryType => writer.addBinary(Binary.fromByteArray(record.getBinary(index))) - case DecimalType.Fixed(precision, _) => - writeDecimal(record.getDecimal(index), precision) + case DecimalType.Fixed(precision, scale) => + writeDecimal(record.getDecimal(index, precision, scale), precision) case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 4499a7207031d..66014ddca0596 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -34,8 +34,7 @@ class ColumnStatsSuite extends SparkFunSuite { testColumnStats(classOf[DoubleColumnStats], DOUBLE, InternalRow(Double.MaxValue, Double.MinValue, 0)) testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0)) - testColumnStats(classOf[FixedDecimalColumnStats], - FIXED_DECIMAL(15, 10), InternalRow(null, null, 0)) + testDecimalColumnStats(InternalRow(null, null, 0)) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], @@ -52,7 +51,7 @@ class ColumnStatsSuite extends SparkFunSuite { } test(s"$columnStatsName: non-empty") { - import ColumnarTestUtils._ + import org.apache.spark.sql.columnar.ColumnarTestUtils._ val columnStats = columnStatsClass.newInstance() val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) @@ -73,4 +72,39 @@ class ColumnStatsSuite extends SparkFunSuite { } } } + + def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](initialStatistics: InternalRow) { + + val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName + val columnType = FIXED_DECIMAL(15, 10) + + test(s"$columnStatsName: empty") { + val columnStats = new FixedDecimalColumnStats(15, 10) + columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { + case (actual, expected) => assert(actual === expected) + } + } + + test(s"$columnStatsName: non-empty") { + import org.apache.spark.sql.columnar.ColumnarTestUtils._ + + val columnStats = new FixedDecimalColumnStats(15, 10) + val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) + rows.foreach(columnStats.gatherStats(_, 0)) + + val values = rows.take(10).map(_.get(0, columnType.dataType).asInstanceOf[T#InternalType]) + val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] + val stats = columnStats.collectedStatistics + + assertResult(values.min(ordering), "Wrong lower bound")(stats.genericGet(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.genericGet(1)) + assertResult(10, "Wrong null count")(stats.genericGet(2)) + assertResult(20, "Wrong row count")(stats.genericGet(3)) + assertResult(stats.genericGet(4), "Wrong size in bytes") { + rows.map { row => + if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) + }.sum + } + } + } } From 351eda0e2fd47c183c4298469970032097ad07a0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 30 Jul 2015 17:22:51 -0700 Subject: [PATCH 019/340] [SPARK-6319][SQL] Throw AnalysisException when using BinaryType on Join and Aggregate JIRA: https://issues.apache.org/jira/browse/SPARK-6319 Spark SQL uses plain byte arrays to represent binary values. However, the arrays are compared by reference rather than by values. Thus, we should not use BinaryType on Join and Aggregate in current implementation. Author: Liang-Chi Hsieh Closes #7787 from viirya/agg_no_binary_type and squashes the following commits: 4f76cac [Liang-Chi Hsieh] Throw AnalysisException when using BinaryType on Join and Aggregate. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 20 +++++++++++++++++++ .../spark/sql/DataFrameAggregateSuite.scala | 11 +++++++++- .../org/apache/spark/sql/JoinSuite.scala | 9 +++++++++ 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index a373714832962..0ebc3d180a780 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -87,6 +87,18 @@ trait CheckAnalysis { s"join condition '${condition.prettyString}' " + s"of type ${condition.dataType.simpleString} is not a boolean.") + case j @ Join(_, _, _, Some(condition)) => + def checkValidJoinConditionExprs(expr: Expression): Unit = expr match { + case p: Predicate => + p.asInstanceOf[Expression].children.foreach(checkValidJoinConditionExprs) + case e if e.dataType.isInstanceOf[BinaryType] => + failAnalysis(s"expression ${e.prettyString} in join condition " + + s"'${condition.prettyString}' can't be binary type.") + case _ => // OK + } + + checkValidJoinConditionExprs(condition) + case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case _: AggregateExpression => // OK @@ -100,7 +112,15 @@ trait CheckAnalysis { case e => e.children.foreach(checkValidAggregateExpression) } + def checkValidGroupingExprs(expr: Expression): Unit = expr.dataType match { + case BinaryType => + failAnalysis(s"grouping expression '${expr.prettyString}' in aggregate can " + + s"not be binary type.") + case _ => // OK + } + aggregateExprs.foreach(checkValidAggregateExpression) + aggregateExprs.foreach(checkValidGroupingExprs) case Sort(orders, _, _) => orders.foreach { order => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index b26d3ab253a1d..228ece8065151 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.{BinaryType, DecimalType} class DataFrameAggregateSuite extends QueryTest { @@ -191,4 +191,13 @@ class DataFrameAggregateSuite extends QueryTest { Row(null)) } + test("aggregation can't work on binary type") { + val df = Seq(1, 1, 2, 2).map(i => Tuple1(i.toString)).toDF("c").select($"c" cast BinaryType) + intercept[AnalysisException] { + df.groupBy("c").agg(count("*")) + } + intercept[AnalysisException] { + df.distinct + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 666f26bf620e1..27c08f64649ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.types.BinaryType class JoinSuite extends QueryTest with BeforeAndAfterEach { @@ -489,4 +490,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(3, 2) :: Nil) } + + test("Join can't work on binary type") { + val left = Seq(1, 1, 2, 2).map(i => Tuple1(i.toString)).toDF("c").select($"c" cast BinaryType) + val right = Seq(1, 1, 2, 2).map(i => Tuple1(i.toString)).toDF("d").select($"d" cast BinaryType) + intercept[AnalysisException] { + left.join(right, ($"left.N" === $"right.N"), "full") + } + } } From 65fa4181c35135080870c1e4c1f904ada3a8cf59 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 30 Jul 2015 17:26:18 -0700 Subject: [PATCH 020/340] [SPARK-9077] [MLLIB] Improve error message for decision trees when numExamples < maxCategoriesPerFeature Improve error message when number of examples is less than arity of high-arity categorical feature CC jkbradley is this about what you had in mind? I know it's a starter, but was on my list to close out in the short term. Author: Sean Owen Closes #7800 from srowen/SPARK-9077 and squashes the following commits: b8f6cdb [Sean Owen] Improve error message when number of examples is less than arity of high-arity categorical feature --- .../spark/mllib/tree/impl/DecisionTreeMetadata.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 380291ac22bd3..9fe264656ede7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -128,9 +128,13 @@ private[spark] object DecisionTreeMetadata extends Logging { // based on the number of training examples. if (strategy.categoricalFeaturesInfo.nonEmpty) { val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max + val maxCategory = + strategy.categoricalFeaturesInfo.find(_._2 == maxCategoriesPerFeature).get._1 require(maxCategoriesPerFeature <= maxPossibleBins, - s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories " + - s"in categorical features (= $maxCategoriesPerFeature)") + s"DecisionTree requires maxBins (= $maxPossibleBins) to be at least as large as the " + + s"number of values in each categorical feature, but categorical feature $maxCategory " + + s"has $maxCategoriesPerFeature values. Considering remove this and other categorical " + + "features with a large number of values, or add more training examples.") } val unorderedFeatures = new mutable.HashSet[Int]() From 3c66ff727d4b47220e1ff363cea215189ed64f36 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 30 Jul 2015 17:38:48 -0700 Subject: [PATCH 021/340] [SPARK-9489] Remove unnecessary compatibility and requirements checks from Exchange While reviewing yhuai's patch for SPARK-2205 (#7773), I noticed that Exchange's `compatible` check may be incorrectly returning `false` in many cases. As far as I know, this is not actually a problem because the `compatible`, `meetsRequirements`, and `needsAnySort` checks are serving only as short-circuit performance optimizations that are not necessary for correctness. In order to reduce code complexity, I think that we should remove these checks and unconditionally rewrite the operator's children. This should be safe because we rewrite the tree in a single bottom-up pass. Author: Josh Rosen Closes #7807 from JoshRosen/SPARK-9489 and squashes the following commits: 9d76ce9 [Josh Rosen] [SPARK-9489] Remove compatibleWith, meetsRequirements, and needsAnySort checks from Exchange --- .../plans/physical/partitioning.scala | 35 --------- .../apache/spark/sql/execution/Exchange.scala | 76 +++++-------------- 2 files changed, 17 insertions(+), 94 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 2dcfa19fec383..f4d1dbaf28efe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -86,14 +86,6 @@ sealed trait Partitioning { */ def satisfies(required: Distribution): Boolean - /** - * Returns true iff all distribution guarantees made by this partitioning can also be made - * for the `other` specified partitioning. - * For example, two [[HashPartitioning HashPartitioning]]s are - * only compatible if the `numPartitions` of them is the same. - */ - def compatibleWith(other: Partitioning): Boolean - /** Returns the expressions that are used to key the partitioning. */ def keyExpressions: Seq[Expression] } @@ -104,11 +96,6 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning { case _ => false } - override def compatibleWith(other: Partitioning): Boolean = other match { - case UnknownPartitioning(_) => true - case _ => false - } - override def keyExpressions: Seq[Expression] = Nil } @@ -117,11 +104,6 @@ case object SinglePartition extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def compatibleWith(other: Partitioning): Boolean = other match { - case SinglePartition => true - case _ => false - } - override def keyExpressions: Seq[Expression] = Nil } @@ -130,11 +112,6 @@ case object BroadcastPartitioning extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def compatibleWith(other: Partitioning): Boolean = other match { - case SinglePartition => true - case _ => false - } - override def keyExpressions: Seq[Expression] = Nil } @@ -159,12 +136,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } - override def compatibleWith(other: Partitioning): Boolean = other match { - case BroadcastPartitioning => true - case h: HashPartitioning if h == this => true - case _ => false - } - override def keyExpressions: Seq[Expression] = expressions } @@ -199,11 +170,5 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case _ => false } - override def compatibleWith(other: Partitioning): Boolean = other match { - case BroadcastPartitioning => true - case r: RangePartitioning if r == this => true - case _ => false - } - override def keyExpressions: Seq[Expression] = ordering.map(_.child) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 70e5031fb63c0..6bd57f010a990 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -202,41 +202,6 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator: SparkPlan => - // True iff every child's outputPartitioning satisfies the corresponding - // required data distribution. - def meetsRequirements: Boolean = - operator.requiredChildDistribution.zip(operator.children).forall { - case (required, child) => - val valid = child.outputPartitioning.satisfies(required) - logDebug( - s"${if (valid) "Valid" else "Invalid"} distribution," + - s"required: $required current: ${child.outputPartitioning}") - valid - } - - // True iff any of the children are incorrectly sorted. - def needsAnySort: Boolean = - operator.requiredChildOrdering.zip(operator.children).exists { - case (required, child) => required.nonEmpty && required != child.outputOrdering - } - - // True iff outputPartitionings of children are compatible with each other. - // It is possible that every child satisfies its required data distribution - // but two children have incompatible outputPartitionings. For example, - // A dataset is range partitioned by "a.asc" (RangePartitioning) and another - // dataset is hash partitioned by "a" (HashPartitioning). Tuples in these two - // datasets are both clustered by "a", but these two outputPartitionings are not - // compatible. - // TODO: ASSUMES TRANSITIVITY? - def compatible: Boolean = - operator.children - .map(_.outputPartitioning) - .sliding(2) - .forall { - case Seq(a) => true - case Seq(a, b) => a.compatibleWith(b) - } - // Adds Exchange or Sort operators as required def addOperatorsIfNecessary( partitioning: Partitioning, @@ -269,33 +234,26 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ addSortIfNecessary(addShuffleIfNecessary(child)) } - if (meetsRequirements && compatible && !needsAnySort) { - operator - } else { - // At least one child does not satisfies its required data distribution or - // at least one child's outputPartitioning is not compatible with another child's - // outputPartitioning. In this case, we need to add Exchange operators. - val requirements = - (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children) + val requirements = + (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children) - val fixedChildren = requirements.zipped.map { - case (AllTuples, rowOrdering, child) => - addOperatorsIfNecessary(SinglePartition, rowOrdering, child) - case (ClusteredDistribution(clustering), rowOrdering, child) => - addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) - case (OrderedDistribution(ordering), rowOrdering, child) => - addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) + val fixedChildren = requirements.zipped.map { + case (AllTuples, rowOrdering, child) => + addOperatorsIfNecessary(SinglePartition, rowOrdering, child) + case (ClusteredDistribution(clustering), rowOrdering, child) => + addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) + case (OrderedDistribution(ordering), rowOrdering, child) => + addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) - case (UnspecifiedDistribution, Seq(), child) => - child - case (UnspecifiedDistribution, rowOrdering, child) => - sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) + case (UnspecifiedDistribution, Seq(), child) => + child + case (UnspecifiedDistribution, rowOrdering, child) => + sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) - case (dist, ordering, _) => - sys.error(s"Don't know how to ensure $dist with ordering $ordering") - } - - operator.withNewChildren(fixedChildren) + case (dist, ordering, _) => + sys.error(s"Don't know how to ensure $dist with ordering $ordering") } + + operator.withNewChildren(fixedChildren) } } From 9307f5653d19a6a2fda355a675ca9ea97e35611b Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Thu, 30 Jul 2015 17:44:20 -0700 Subject: [PATCH 022/340] [SPARK-9472] [STREAMING] consistent hadoop configuration, streaming only Author: cody koeninger Closes #7772 from koeninger/streaming-hadoop-config and squashes the following commits: 5267284 [cody koeninger] [SPARK-4229][Streaming] consistent hadoop configuration, streaming only --- .../main/scala/org/apache/spark/streaming/Checkpoint.scala | 3 ++- .../org/apache/spark/streaming/StreamingContext.scala | 7 ++++--- .../apache/spark/streaming/api/java/JavaPairDStream.scala | 2 +- .../spark/streaming/api/java/JavaStreamingContext.scala | 3 ++- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 65d4e933bf8e9..2780d5b6adbcf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.conf.Configuration import org.apache.spark.{SparkException, SparkConf, Logging} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec import org.apache.spark.util.{MetadataCleaner, Utils} import org.apache.spark.streaming.scheduler.JobGenerator @@ -100,7 +101,7 @@ object Checkpoint extends Logging { } val path = new Path(checkpointDir) - val fs = fsOption.getOrElse(path.getFileSystem(new Configuration())) + val fs = fsOption.getOrElse(path.getFileSystem(SparkHadoopUtil.get.conf)) if (fs.exists(path)) { val statuses = fs.listStatus(path) if (statuses != null) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 92438f1b1fbf7..177e710ace54b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -34,6 +34,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark._ import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.input.FixedLengthBinaryInputFormat import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.serializer.SerializationDebugger @@ -110,7 +111,7 @@ class StreamingContext private[streaming] ( * Recreate a StreamingContext from a checkpoint file. * @param path Path to the directory that was specified as the checkpoint directory */ - def this(path: String) = this(path, new Configuration) + def this(path: String) = this(path, SparkHadoopUtil.get.conf) /** * Recreate a StreamingContext from a checkpoint file using an existing SparkContext. @@ -803,7 +804,7 @@ object StreamingContext extends Logging { def getActiveOrCreate( checkpointPath: String, creatingFunc: () => StreamingContext, - hadoopConf: Configuration = new Configuration(), + hadoopConf: Configuration = SparkHadoopUtil.get.conf, createOnError: Boolean = false ): StreamingContext = { ACTIVATION_LOCK.synchronized { @@ -828,7 +829,7 @@ object StreamingContext extends Logging { def getOrCreate( checkpointPath: String, creatingFunc: () => StreamingContext, - hadoopConf: Configuration = new Configuration(), + hadoopConf: Configuration = SparkHadoopUtil.get.conf, createOnError: Boolean = false ): StreamingContext = { val checkpointOption = CheckpointReader.read( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index 959ac9c177f81..26383e420101e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -788,7 +788,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[F], - conf: Configuration = new Configuration) { + conf: Configuration = dstream.context.sparkContext.hadoopConfiguration) { dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 40deb6d7ea79a..35cc3ce5cf468 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -33,6 +33,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} import org.apache.spark.api.java.function.{Function0 => JFunction0} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ @@ -136,7 +137,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Recreate a JavaStreamingContext from a checkpoint file. * @param path Path to the directory that was specified as the checkpoint directory */ - def this(path: String) = this(new StreamingContext(path, new Configuration)) + def this(path: String) = this(new StreamingContext(path, SparkHadoopUtil.get.conf)) /** * Re-creates a JavaStreamingContext from a checkpoint file. From 0ec234c2a6ffb5d87fb073fa7d245fce49d466d2 Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 30 Jul 2015 17:31:43 -0700 Subject: [PATCH 023/340] Fixing some typos and unnecessary comments Removing original implementation that was commented out --- .../src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala | 3 --- .../scala/org/apache/spark/storage/DiskBlockObjectWriter.scala | 1 - 2 files changed, 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index ac02d861ed463..6927f2aecec34 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -482,9 +482,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } shuffledRdd.mapPartitions(groupOnPartition(_), preservesPartitioning = true) - //val bufs = combineByKey[ExternalList[V]]( - // createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine = false) - //bufs.asInstanceOf[RDD[(K, Iterable[V])]] } /** diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index 7ffa9230f3bbb..49d9154f95a5b 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -231,5 +231,4 @@ private[spark] class DiskBlockObjectWriter( objOut.flush() bs.flush() } - } From 83670fc9e6fc9c7a6ae68dfdd3f9335ea72f4ab0 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Thu, 30 Jul 2015 19:22:38 -0700 Subject: [PATCH 024/340] [SPARK-8176] [SPARK-8197] [SQL] function to_date/ trunc This PR is based on #6988 , thanks to adrian-wang . This brings two SQL functions: to_date() and trunc(). Closes #6988 Author: Daoyuan Wang Author: Davies Liu Closes #7805 from davies/to_date and squashes the following commits: 2c7beba [Davies Liu] Merge branch 'master' of github.com:apache/spark into to_date 310dd55 [Daoyuan Wang] remove dup test in rebase 980b092 [Daoyuan Wang] resolve rebase conflict a476c5a [Daoyuan Wang] address comments from davies d44ea5f [Daoyuan Wang] function to_date, trunc --- python/pyspark/sql/functions.py | 30 +++++++ .../catalyst/analysis/FunctionRegistry.scala | 2 + .../expressions/datetimeFunctions.scala | 88 ++++++++++++++++++- .../sql/catalyst/util/DateTimeUtils.scala | 34 +++++++ .../expressions/DateExpressionsSuite.scala | 29 +++++- .../expressions/NonFoldableLiteral.scala | 4 + .../org/apache/spark/sql/functions.scala | 16 ++++ .../apache/spark/sql/DateFunctionsSuite.scala | 44 ++++++++++ 8 files changed, 245 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a7295e25f0aa5..8024a8de07c98 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -888,6 +888,36 @@ def months_between(date1, date2): return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2))) +@since(1.5) +def to_date(col): + """ + Converts the column of StringType or TimestampType into DateType. + + >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t']) + >>> df.select(to_date(df.t).alias('date')).collect() + [Row(date=datetime.date(1997, 2, 28))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.to_date(_to_java_column(col))) + + +@since(1.5) +def trunc(date, format): + """ + Returns date truncated to the unit specified by the format. + + :param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' + + >>> df = sqlContext.createDataFrame([('1997-02-28',)], ['d']) + >>> df.select(trunc(df.d, 'year').alias('year')).collect() + [Row(year=datetime.date(1997, 1, 1))] + >>> df.select(trunc(df.d, 'mon').alias('month')).collect() + [Row(month=datetime.date(1997, 2, 1))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.trunc(_to_java_column(date), format)) + + @since(1.5) def size(col): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 6c7c481fab8db..1bf7204a2515c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -223,6 +223,8 @@ object FunctionRegistry { expression[NextDay]("next_day"), expression[Quarter]("quarter"), expression[Second]("second"), + expression[ToDate]("to_date"), + expression[TruncDate]("trunc"), expression[UnixTimestamp]("unix_timestamp"), expression[WeekOfYear]("weekofyear"), expression[Year]("year"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala index 9795673ee0664..6e7613340c032 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala @@ -507,7 +507,6 @@ case class FromUnixTime(sec: Expression, format: Expression) }) } } - } /** @@ -696,3 +695,90 @@ case class MonthsBetween(date1: Expression, date2: Expression) }) } } + +/** + * Returns the date part of a timestamp or string. + */ +case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + // Implicit casting of spark will accept string in both date and timestamp format, as + // well as TimestampType. + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = DateType + + override def eval(input: InternalRow): Any = child.eval(input) + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, d => d) + } +} + +/* + * Returns date truncated to the unit specified by the format. + */ +case class TruncDate(date: Expression, format: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + override def left: Expression = date + override def right: Expression = format + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) + override def dataType: DataType = DateType + override def prettyName: String = "trunc" + + lazy val minItemConst = DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) + + override def eval(input: InternalRow): Any = { + val minItem = if (format.foldable) { + minItemConst + } else { + DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) + } + if (minItem == -1) { + // unknown format + null + } else { + val d = date.eval(input) + if (d == null) { + null + } else { + DateTimeUtils.truncDate(d.asInstanceOf[Int], minItem) + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + + if (format.foldable) { + if (minItemConst == -1) { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } else { + val d = date.gen(ctx) + s""" + ${d.code} + boolean ${ev.isNull} = ${d.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = $dtu.truncDate(${d.primitive}, $minItemConst); + } + """ + } + } else { + nullSafeCodeGen(ctx, ev, (dateVal, fmt) => { + val form = ctx.freshName("form") + s""" + int $form = $dtu.parseTruncLevel($fmt); + if ($form == -1) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = $dtu.truncDate($dateVal, $form); + } + """ + }) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 53abdf6618eac..5a7c25b8d508d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -779,4 +779,38 @@ object DateTimeUtils { } date + (lastDayOfMonthInYear - dayInYear) } + + private val TRUNC_TO_YEAR = 1 + private val TRUNC_TO_MONTH = 2 + private val TRUNC_INVALID = -1 + + /** + * Returns the trunc date from original date and trunc level. + * Trunc level should be generated using `parseTruncLevel()`, should only be 1 or 2. + */ + def truncDate(d: Int, level: Int): Int = { + if (level == TRUNC_TO_YEAR) { + d - DateTimeUtils.getDayInYear(d) + 1 + } else if (level == TRUNC_TO_MONTH) { + d - DateTimeUtils.getDayOfMonth(d) + 1 + } else { + throw new Exception(s"Invalid trunc level: $level") + } + } + + /** + * Returns the truncate level, could be TRUNC_YEAR, TRUNC_MONTH, or TRUNC_INVALID, + * TRUNC_INVALID means unsupported truncate level. + */ + def parseTruncLevel(format: UTF8String): Int = { + if (format == null) { + TRUNC_INVALID + } else { + format.toString.toUpperCase match { + case "YEAR" | "YYYY" | "YY" => TRUNC_TO_YEAR + case "MON" | "MONTH" | "MM" => TRUNC_TO_MONTH + case _ => TRUNC_INVALID + } + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 887e43621a941..6c15c05da3094 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -351,6 +351,34 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { NextDay(Literal(Date.valueOf("2015-07-23")), Literal.create(null, StringType)), null) } + test("function to_date") { + checkEvaluation( + ToDate(Literal(Date.valueOf("2015-07-22"))), + DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-22"))) + checkEvaluation(ToDate(Literal.create(null, DateType)), null) + } + + test("function trunc") { + def testTrunc(input: Date, fmt: String, expected: Date): Unit = { + checkEvaluation(TruncDate(Literal.create(input, DateType), Literal.create(fmt, StringType)), + expected) + checkEvaluation( + TruncDate(Literal.create(input, DateType), NonFoldableLiteral.create(fmt, StringType)), + expected) + } + val date = Date.valueOf("2015-07-22") + Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach{ fmt => + testTrunc(date, fmt, Date.valueOf("2015-01-01")) + } + Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt => + testTrunc(date, fmt, Date.valueOf("2015-07-01")) + } + testTrunc(date, "DD", null) + testTrunc(date, null, null) + testTrunc(null, "MON", null) + testTrunc(null, null, null) + } + test("from_unixtime") { val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" @@ -405,5 +433,4 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format")), null) } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala index 0559fb80e7fce..31ecf4a9e810a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala @@ -47,4 +47,8 @@ object NonFoldableLiteral { val lit = Literal(value) NonFoldableLiteral(lit.value, lit.dataType) } + def create(value: Any, dataType: DataType): NonFoldableLiteral = { + val lit = Literal.create(value, dataType) + NonFoldableLiteral(lit.value, lit.dataType) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 168894d66117d..46dc4605a5ccb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2181,6 +2181,22 @@ object functions { */ def unix_timestamp(s: Column, p: String): Column = UnixTimestamp(s.expr, Literal(p)) + /* + * Converts the column into DateType. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def to_date(e: Column): Column = ToDate(e.expr) + + /** + * Returns date truncated to the unit specified by the format. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def trunc(date: Column, format: String): Column = TruncDate(date.expr, Literal(format)) + ////////////////////////////////////////////////////////////////////////////////////////////// // Collection functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index b7267c413165a..8c596fad74ee4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -345,6 +345,50 @@ class DateFunctionsSuite extends QueryTest { Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-07-30")))) } + test("function to_date") { + val d1 = Date.valueOf("2015-07-22") + val d2 = Date.valueOf("2015-07-01") + val t1 = Timestamp.valueOf("2015-07-22 10:00:00") + val t2 = Timestamp.valueOf("2014-12-31 23:59:59") + val s1 = "2015-07-22 10:00:00" + val s2 = "2014-12-31" + val df = Seq((d1, t1, s1), (d2, t2, s2)).toDF("d", "t", "s") + + checkAnswer( + df.select(to_date(col("t"))), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")))) + checkAnswer( + df.select(to_date(col("d"))), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2015-07-01")))) + checkAnswer( + df.select(to_date(col("s"))), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")))) + + checkAnswer( + df.selectExpr("to_date(t)"), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")))) + checkAnswer( + df.selectExpr("to_date(d)"), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2015-07-01")))) + checkAnswer( + df.selectExpr("to_date(s)"), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")))) + } + + test("function trunc") { + val df = Seq( + (1, Timestamp.valueOf("2015-07-22 10:00:00")), + (2, Timestamp.valueOf("2014-12-31 00:00:00"))).toDF("i", "t") + + checkAnswer( + df.select(trunc(col("t"), "YY")), + Seq(Row(Date.valueOf("2015-01-01")), Row(Date.valueOf("2014-01-01")))) + + checkAnswer( + df.selectExpr("trunc(t, 'Month')"), + Seq(Row(Date.valueOf("2015-07-01")), Row(Date.valueOf("2014-12-01")))) + } + test("from_unixtime") { val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" From 4e5919bfb47a58bcbda90ae01c1bed2128ded983 Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Thu, 30 Jul 2015 23:02:11 -0700 Subject: [PATCH 025/340] [SPARK-7690] [ML] Multiclass classification Evaluator Multiclass Classification Evaluator for ML Pipelines. F1 score, precision, recall, weighted precision and weighted recall are supported as available metrics. Author: Ram Sriharsha Closes #7475 from harsha2010/SPARK-7690 and squashes the following commits: 9bf4ec7 [Ram Sriharsha] fix indentation 3f09a85 [Ram Sriharsha] cleanup doc 16115ae [Ram Sriharsha] code review fixes 032d2a3 [Ram Sriharsha] fix test eec9865 [Ram Sriharsha] Fix Python Indentation 1dbeffd [Ram Sriharsha] Merge branch 'master' into SPARK-7690 68cea85 [Ram Sriharsha] Merge branch 'master' into SPARK-7690 54c03de [Ram Sriharsha] [SPARK-7690][ml][WIP] Multiclass Evaluator for ML Pipeline --- .../MulticlassClassificationEvaluator.scala | 85 +++++++++++++++++++ ...lticlassClassificationEvaluatorSuite.scala | 28 ++++++ python/pyspark/ml/evaluation.py | 66 ++++++++++++++ 3 files changed, 179 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala new file mode 100644 index 0000000000000..44f779c1908d7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param} +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.types.DoubleType + +/** + * :: Experimental :: + * Evaluator for multiclass classification, which expects two input columns: score and label. + */ +@Experimental +class MulticlassClassificationEvaluator (override val uid: String) + extends Evaluator with HasPredictionCol with HasLabelCol { + + def this() = this(Identifiable.randomUID("mcEval")) + + /** + * param for metric name in evaluation (supports `"f1"` (default), `"precision"`, `"recall"`, + * `"weightedPrecision"`, `"weightedRecall"`) + * @group param + */ + val metricName: Param[String] = { + val allowedParams = ParamValidators.inArray(Array("f1", "precision", + "recall", "weightedPrecision", "weightedRecall")) + new Param(this, "metricName", "metric name in evaluation " + + "(f1|precision|recall|weightedPrecision|weightedRecall)", allowedParams) + } + + /** @group getParam */ + def getMetricName: String = $(metricName) + + /** @group setParam */ + def setMetricName(value: String): this.type = set(metricName, value) + + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + setDefault(metricName -> "f1") + + override def evaluate(dataset: DataFrame): Double = { + val schema = dataset.schema + SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + + val predictionAndLabels = dataset.select($(predictionCol), $(labelCol)) + .map { case Row(prediction: Double, label: Double) => + (prediction, label) + } + val metrics = new MulticlassMetrics(predictionAndLabels) + val metric = $(metricName) match { + case "f1" => metrics.weightedFMeasure + case "precision" => metrics.precision + case "recall" => metrics.recall + case "weightedPrecision" => metrics.weightedPrecision + case "weightedRecall" => metrics.weightedRecall + } + metric + } + + override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala new file mode 100644 index 0000000000000..6d8412b0b3701 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite + +class MulticlassClassificationEvaluatorSuite extends SparkFunSuite { + + test("params") { + ParamsSuite.checkParams(new MulticlassClassificationEvaluator) + } +} diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 595593a7f2cde..06e809352225b 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -214,6 +214,72 @@ def setParams(self, predictionCol="prediction", labelCol="label", kwargs = self.setParams._input_kwargs return self._set(**kwargs) + +@inherit_doc +class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): + """ + Evaluator for Multiclass Classification, which expects two input + columns: prediction and label. + >>> scoreAndLabels = [(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), + ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)] + >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["prediction", "label"]) + ... + >>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction") + >>> evaluator.evaluate(dataset) + 0.66... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "precision"}) + 0.66... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "recall"}) + 0.66... + """ + # a placeholder to make it appear in the generated doc + metricName = Param(Params._dummy(), "metricName", + "metric name in evaluation " + "(f1|precision|recall|weightedPrecision|weightedRecall)") + + @keyword_only + def __init__(self, predictionCol="prediction", labelCol="label", + metricName="f1"): + """ + __init__(self, predictionCol="prediction", labelCol="label", \ + metricName="f1") + """ + super(MulticlassClassificationEvaluator, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid) + # param for metric name in evaluation (f1|precision|recall|weightedPrecision|weightedRecall) + self.metricName = Param(self, "metricName", + "metric name in evaluation" + " (f1|precision|recall|weightedPrecision|weightedRecall)") + self._setDefault(predictionCol="prediction", labelCol="label", + metricName="f1") + kwargs = self.__init__._input_kwargs + self._set(**kwargs) + + def setMetricName(self, value): + """ + Sets the value of :py:attr:`metricName`. + """ + self._paramMap[self.metricName] = value + return self + + def getMetricName(self): + """ + Gets the value of metricName or its default value. + """ + return self.getOrDefault(self.metricName) + + @keyword_only + def setParams(self, predictionCol="prediction", labelCol="label", + metricName="f1"): + """ + setParams(self, predictionCol="prediction", labelCol="label", \ + metricName="f1") + Sets params for multiclass classification evaluator. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + if __name__ == "__main__": import doctest from pyspark.context import SparkContext From 69b62f76fced18efa35a107c9be4bc22eba72878 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 30 Jul 2015 23:03:48 -0700 Subject: [PATCH 026/340] [SPARK-9214] [ML] [PySpark] support ml.NaiveBayes for Python support ml.NaiveBayes for Python Author: Yanbo Liang Closes #7568 from yanboliang/spark-9214 and squashes the following commits: 5ee3fd6 [Yanbo Liang] fix typos 3ecd046 [Yanbo Liang] fix typos f9c94d1 [Yanbo Liang] change lambda_ to smoothing and fix other issues 180452a [Yanbo Liang] fix typos 7dda1f4 [Yanbo Liang] support ml.NaiveBayes for Python --- .../spark/ml/classification/NaiveBayes.scala | 10 +- .../classification/JavaNaiveBayesSuite.java | 4 +- .../ml/classification/NaiveBayesSuite.scala | 6 +- python/pyspark/ml/classification.py | 116 +++++++++++++++++- 4 files changed, 125 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 1f547e4a98af7..5be35fe209291 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -38,11 +38,11 @@ private[ml] trait NaiveBayesParams extends PredictorParams { * (default = 1.0). * @group param */ - final val lambda: DoubleParam = new DoubleParam(this, "lambda", "The smoothing parameter.", + final val smoothing: DoubleParam = new DoubleParam(this, "smoothing", "The smoothing parameter.", ParamValidators.gtEq(0)) /** @group getParam */ - final def getLambda: Double = $(lambda) + final def getSmoothing: Double = $(smoothing) /** * The model type which is a string (case-sensitive). @@ -79,8 +79,8 @@ class NaiveBayes(override val uid: String) * Default is 1.0. * @group setParam */ - def setLambda(value: Double): this.type = set(lambda, value) - setDefault(lambda -> 1.0) + def setSmoothing(value: Double): this.type = set(smoothing, value) + setDefault(smoothing -> 1.0) /** * Set the model type using a string (case-sensitive). @@ -92,7 +92,7 @@ class NaiveBayes(override val uid: String) override protected def train(dataset: DataFrame): NaiveBayesModel = { val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) - val oldModel = OldNaiveBayes.train(oldDataset, $(lambda), $(modelType)) + val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType)) NaiveBayesModel.fromOld(oldModel, this) } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java index 09a9fba0c19cf..a700c9cddb206 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java @@ -68,7 +68,7 @@ public void naiveBayesDefaultParams() { assert(nb.getLabelCol() == "label"); assert(nb.getFeaturesCol() == "features"); assert(nb.getPredictionCol() == "prediction"); - assert(nb.getLambda() == 1.0); + assert(nb.getSmoothing() == 1.0); assert(nb.getModelType() == "multinomial"); } @@ -89,7 +89,7 @@ public void testNaiveBayes() { }); DataFrame dataset = jsql.createDataFrame(jrdd, schema); - NaiveBayes nb = new NaiveBayes().setLambda(0.5).setModelType("multinomial"); + NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial"); NaiveBayesModel model = nb.fit(dataset); DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label"); diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 76381a2741296..264bde3703c5f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -58,7 +58,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { assert(nb.getLabelCol === "label") assert(nb.getFeaturesCol === "features") assert(nb.getPredictionCol === "prediction") - assert(nb.getLambda === 1.0) + assert(nb.getSmoothing === 1.0) assert(nb.getModelType === "multinomial") } @@ -75,7 +75,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 42, "multinomial")) - val nb = new NaiveBayes().setLambda(1.0).setModelType("multinomial") + val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial") val model = nb.fit(testDataset) validateModelFit(pi, theta, model) @@ -101,7 +101,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 45, "bernoulli")) - val nb = new NaiveBayes().setLambda(1.0).setModelType("bernoulli") + val nb = new NaiveBayes().setSmoothing(1.0).setModelType("bernoulli") val model = nb.fit(testDataset) validateModelFit(pi, theta, model) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 5a82bc286d1e8..93ffcd40949b3 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -25,7 +25,8 @@ __all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassifier', 'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel', - 'RandomForestClassifier', 'RandomForestClassificationModel'] + 'RandomForestClassifier', 'RandomForestClassificationModel', 'NaiveBayes', + 'NaiveBayesModel'] @inherit_doc @@ -576,6 +577,119 @@ class GBTClassificationModel(TreeEnsembleModels): """ +@inherit_doc +class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol): + """ + Naive Bayes Classifiers. + + >>> from pyspark.sql import Row + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([ + ... Row(label=0.0, features=Vectors.dense([0.0, 0.0])), + ... Row(label=0.0, features=Vectors.dense([0.0, 1.0])), + ... Row(label=1.0, features=Vectors.dense([1.0, 0.0]))]) + >>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial") + >>> model = nb.fit(df) + >>> model.pi + DenseVector([-0.51..., -0.91...]) + >>> model.theta + DenseMatrix(2, 2, [-1.09..., -0.40..., -0.40..., -1.09...], 1) + >>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF() + >>> model.transform(test0).head().prediction + 1.0 + >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF() + >>> model.transform(test1).head().prediction + 1.0 + """ + + # a placeholder to make it appear in the generated doc + smoothing = Param(Params._dummy(), "smoothing", "The smoothing parameter, should be >= 0, " + + "default is 1.0") + modelType = Param(Params._dummy(), "modelType", "The model type which is a string " + + "(case-sensitive). Supported options: multinomial (default) and bernoulli.") + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + smoothing=1.0, modelType="multinomial"): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + smoothing=1.0, modelType="multinomial") + """ + super(NaiveBayes, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.classification.NaiveBayes", self.uid) + #: param for the smoothing parameter. + self.smoothing = Param(self, "smoothing", "The smoothing parameter, should be >= 0, " + + "default is 1.0") + #: param for the model type. + self.modelType = Param(self, "modelType", "The model type which is a string " + + "(case-sensitive). Supported options: multinomial (default) " + + "and bernoulli.") + self._setDefault(smoothing=1.0, modelType="multinomial") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + smoothing=1.0, modelType="multinomial"): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + smoothing=1.0, modelType="multinomial") + Sets params for Naive Bayes. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return NaiveBayesModel(java_model) + + def setSmoothing(self, value): + """ + Sets the value of :py:attr:`smoothing`. + """ + self._paramMap[self.smoothing] = value + return self + + def getSmoothing(self): + """ + Gets the value of smoothing or its default value. + """ + return self.getOrDefault(self.smoothing) + + def setModelType(self, value): + """ + Sets the value of :py:attr:`modelType`. + """ + self._paramMap[self.modelType] = value + return self + + def getModelType(self): + """ + Gets the value of modelType or its default value. + """ + return self.getOrDefault(self.modelType) + + +class NaiveBayesModel(JavaModel): + """ + Model fitted by NaiveBayes. + """ + + @property + def pi(self): + """ + log of class priors. + """ + return self._call_java("pi") + + @property + def theta(self): + """ + log of class conditional probabilities. + """ + return self._call_java("theta") + + if __name__ == "__main__": import doctest from pyspark.context import SparkContext From 0244170b66476abc4a39ed609a852f1a6fa455e7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 30 Jul 2015 23:05:58 -0700 Subject: [PATCH 027/340] [SPARK-9152][SQL] Implement code generation for Like and RLike JIRA: https://issues.apache.org/jira/browse/SPARK-9152 This PR implements code generation for `Like` and `RLike`. Author: Liang-Chi Hsieh Closes #7561 from viirya/like_rlike_codegen and squashes the following commits: fe5641b [Liang-Chi Hsieh] Add test for NonFoldableLiteral. ccd1b43 [Liang-Chi Hsieh] For comments. 0086723 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into like_rlike_codegen 50df9a8 [Liang-Chi Hsieh] Use nullSafeCodeGen. 8092a68 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into like_rlike_codegen 696d451 [Liang-Chi Hsieh] Check expression foldable. 48e5536 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into like_rlike_codegen aea58e0 [Liang-Chi Hsieh] For comments. 46d946f [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into like_rlike_codegen a0fb76e [Liang-Chi Hsieh] For comments. 6cffe3c [Liang-Chi Hsieh] For comments. 69f0fb6 [Liang-Chi Hsieh] Add code generation for Like and RLike. --- .../expressions/stringOperations.scala | 105 ++++++++++++++---- .../spark/sql/catalyst/util/StringUtils.scala | 47 ++++++++ .../expressions/StringExpressionsSuite.scala | 16 +++ .../sql/catalyst/util/StringUtilsSuite.scala | 34 ++++++ 4 files changed, 180 insertions(+), 22 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 79c0ca56a8e79..99a62343f138d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -21,8 +21,11 @@ import java.text.DecimalFormat import java.util.Locale import java.util.regex.{MatchResult, Pattern} +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -160,32 +163,51 @@ trait StringRegexExpression extends ImplicitCastInputTypes { case class Like(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression with CodegenFallback { - // replace the _ with .{1} exactly match 1 time of any character - // replace the % with .*, match 0 or more times with any character - override def escape(v: String): String = - if (!v.isEmpty) { - "(?s)" + (' ' +: v.init).zip(v).flatMap { - case (prev, '\\') => "" - case ('\\', c) => - c match { - case '_' => "_" - case '%' => "%" - case _ => Pattern.quote("\\" + c) - } - case (prev, c) => - c match { - case '_' => "." - case '%' => ".*" - case _ => Pattern.quote(Character.toString(c)) - } - }.mkString - } else { - v - } + override def escape(v: String): String = StringUtils.escapeLikeRegex(v) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() override def toString: String = s"$left LIKE $right" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val patternClass = classOf[Pattern].getName + val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" + val pattern = ctx.freshName("pattern") + + if (right.foldable) { + val rVal = right.eval() + if (rVal != null) { + val regexStr = + StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) + ctx.addMutableState(patternClass, pattern, + s"""$pattern = ${patternClass}.compile("$regexStr");""") + + // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. + val eval = left.gen(ctx) + s""" + ${eval.code} + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = $pattern.matcher(${eval.primitive}.toString()).matches(); + } + """ + } else { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } + } else { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + s""" + String rightStr = ${eval2}.toString(); + ${patternClass} $pattern = ${patternClass}.compile($escapeFunc(rightStr)); + ${ev.primitive} = $pattern.matcher(${eval1}.toString()).matches(); + """ + }) + } + } } @@ -195,6 +217,45 @@ case class RLike(left: Expression, right: Expression) override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) override def toString: String = s"$left RLIKE $right" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val patternClass = classOf[Pattern].getName + val pattern = ctx.freshName("pattern") + + if (right.foldable) { + val rVal = right.eval() + if (rVal != null) { + val regexStr = + StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) + ctx.addMutableState(patternClass, pattern, + s"""$pattern = ${patternClass}.compile("$regexStr");""") + + // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. + val eval = left.gen(ctx) + s""" + ${eval.code} + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = $pattern.matcher(${eval.primitive}.toString()).find(0); + } + """ + } else { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } + } else { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + s""" + String rightStr = ${eval2}.toString(); + ${patternClass} $pattern = ${patternClass}.compile(rightStr); + ${ev.primitive} = $pattern.matcher(${eval1}.toString()).find(0); + """ + }) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala new file mode 100644 index 0000000000000..9ddfb3a0d3759 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.util.regex.Pattern + +object StringUtils { + + // replace the _ with .{1} exactly match 1 time of any character + // replace the % with .*, match 0 or more times with any character + def escapeLikeRegex(v: String): String = { + if (!v.isEmpty) { + "(?s)" + (' ' +: v.init).zip(v).flatMap { + case (prev, '\\') => "" + case ('\\', c) => + c match { + case '_' => "_" + case '%' => "%" + case _ => Pattern.quote("\\" + c) + } + case (prev, c) => + c match { + case '_' => "." + case '%' => ".*" + case _ => Pattern.quote(Character.toString(c)) + } + }.mkString + } else { + v + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 07b952531ec2e..3ecd0d374c46b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -191,6 +191,15 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, StringType).like("a"), null) checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) checkEvaluation(Literal.create(null, StringType).like(Literal.create(null, StringType)), null) + checkEvaluation( + Literal.create("a", StringType).like(NonFoldableLiteral.create("a", StringType)), true) + checkEvaluation( + Literal.create("a", StringType).like(NonFoldableLiteral.create(null, StringType)), null) + checkEvaluation( + Literal.create(null, StringType).like(NonFoldableLiteral.create("a", StringType)), null) + checkEvaluation( + Literal.create(null, StringType).like(NonFoldableLiteral.create(null, StringType)), null) + checkEvaluation("abdef" like "abdef", true) checkEvaluation("a_%b" like "a\\__b", true) checkEvaluation("addb" like "a_%b", true) @@ -232,6 +241,13 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, StringType) rlike "abdef", null) checkEvaluation("abdef" rlike Literal.create(null, StringType), null) checkEvaluation(Literal.create(null, StringType) rlike Literal.create(null, StringType), null) + checkEvaluation("abdef" rlike NonFoldableLiteral.create("abdef", StringType), true) + checkEvaluation("abdef" rlike NonFoldableLiteral.create(null, StringType), null) + checkEvaluation( + Literal.create(null, StringType) rlike NonFoldableLiteral.create("abdef", StringType), null) + checkEvaluation( + Literal.create(null, StringType) rlike NonFoldableLiteral.create(null, StringType), null) + checkEvaluation("abdef" rlike "abdef", true) checkEvaluation("abbbbc" rlike "a.*c", true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala new file mode 100644 index 0000000000000..d6f273f9e568a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.StringUtils._ + +class StringUtilsSuite extends SparkFunSuite { + + test("escapeLikeRegex") { + assert(escapeLikeRegex("abdef") === "(?s)\\Qa\\E\\Qb\\E\\Qd\\E\\Qe\\E\\Qf\\E") + assert(escapeLikeRegex("a\\__b") === "(?s)\\Qa\\E_.\\Qb\\E") + assert(escapeLikeRegex("a_%b") === "(?s)\\Qa\\E..*\\Qb\\E") + assert(escapeLikeRegex("a%\\%b") === "(?s)\\Qa\\E.*%\\Qb\\E") + assert(escapeLikeRegex("a%") === "(?s)\\Qa\\E.*") + assert(escapeLikeRegex("**") === "(?s)\\Q*\\E\\Q*\\E") + assert(escapeLikeRegex("a_b") === "(?s)\\Qa\\E.\\Qb\\E") + } +} From a3a85d73da053c8e2830759fbc68b734081fa4f3 Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Thu, 30 Jul 2015 23:50:06 -0700 Subject: [PATCH 028/340] [SPARK-9496][SQL]do not print the password in config https://issues.apache.org/jira/browse/SPARK-9496 We better do not print the password in log. Author: WangTaoTheTonic Closes #7815 from WangTaoTheTonic/master and squashes the following commits: c7a5145 [WangTaoTheTonic] do not print the password in config --- .../org/apache/spark/sql/hive/client/ClientWrapper.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 8adda54754230..6e0912da5862d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -91,7 +91,11 @@ private[hive] class ClientWrapper( // this action explicit. initialConf.setClassLoader(initClassLoader) config.foreach { case (k, v) => - logDebug(s"Hive Config: $k=$v") + if (k.toLowerCase.contains("password")) { + logDebug(s"Hive Config: $k=xxx") + } else { + logDebug(s"Hive Config: $k=$v") + } initialConf.set(k, v) } val newState = new SessionState(initialConf) From 6bba7509a932aa4d39266df2d15b1370b7aabbec Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 31 Jul 2015 08:28:05 -0700 Subject: [PATCH 029/340] [SPARK-9500] add TernaryExpression to simplify ternary expressions There lots of duplicated code in ternary expressions, create a TernaryExpression for them to reduce duplicated code. cc chenghao-intel Author: Davies Liu Closes #7816 from davies/ternary and squashes the following commits: ed2bf76 [Davies Liu] add TernaryExpression --- .../sql/catalyst/expressions/Expression.scala | 85 +++++ .../expressions/codegen/CodeGenerator.scala | 2 +- .../spark/sql/catalyst/expressions/math.scala | 66 +--- .../expressions/stringOperations.scala | 356 +++++------------- 4 files changed, 183 insertions(+), 326 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 8fc182607ce68..2842b3ec5a0c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -432,3 +432,88 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { private[sql] object BinaryOperator { def unapply(e: BinaryOperator): Option[(Expression, Expression)] = Some((e.left, e.right)) } + +/** + * An expression with three inputs and one output. The output is by default evaluated to null + * if any input is evaluated to null. + */ +abstract class TernaryExpression extends Expression { + + override def foldable: Boolean = children.forall(_.foldable) + + override def nullable: Boolean = children.exists(_.nullable) + + /** + * Default behavior of evaluation according to the default nullability of BinaryExpression. + * If subclass of BinaryExpression override nullable, probably should also override this. + */ + override def eval(input: InternalRow): Any = { + val exprs = children + val value1 = exprs(0).eval(input) + if (value1 != null) { + val value2 = exprs(1).eval(input) + if (value2 != null) { + val value3 = exprs(2).eval(input) + if (value3 != null) { + return nullSafeEval(value1, value2, value3) + } + } + } + null + } + + /** + * Called by default [[eval]] implementation. If subclass of BinaryExpression keep the default + * nullability, they can override this method to save null-check code. If we need full control + * of evaluation process, we should override [[eval]]. + */ + protected def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = + sys.error(s"BinaryExpressions must override either eval or nullSafeEval") + + /** + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f accepts two variable names and returns Java code to compute the output. + */ + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String, String) => String): String = { + nullSafeCodeGen(ctx, ev, (eval1, eval2, eval3) => { + s"${ev.primitive} = ${f(eval1, eval2, eval3)};" + }) + } + + /** + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f function that accepts the 2 non-null evaluation result names of children + * and returns Java code to compute the output. + */ + protected def nullSafeCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String, String) => String): String = { + val evals = children.map(_.gen(ctx)) + val resultCode = f(evals(0).primitive, evals(1).primitive, evals(2).primitive) + s""" + ${evals(0).code} + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${evals(0).isNull}) { + ${evals(1).code} + if (!${evals(1).isNull}) { + ${evals(2).code} + if (!${evals(2).isNull}) { + ${ev.isNull} = false; // resultCode could change nullability + $resultCode + } + } + } + """ + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 60e2863f7bbb0..e50ec27fc2eb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -305,7 +305,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin evaluator.cook(code) } catch { case e: Exception => - val msg = "failed to compile:\n " + CodeFormatter.format(code) + val msg = s"failed to compile: $e\n" + CodeFormatter.format(code) logError(msg, e) throw new Exception(msg, e) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index e6d807f6d897b..15ceb9193a8c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -165,69 +165,29 @@ case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH" * @param toBaseExpr to which base */ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) - extends Expression with ImplicitCastInputTypes { - - override def foldable: Boolean = numExpr.foldable && fromBaseExpr.foldable && toBaseExpr.foldable - - override def nullable: Boolean = numExpr.nullable || fromBaseExpr.nullable || toBaseExpr.nullable + extends TernaryExpression with ImplicitCastInputTypes { override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr) - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType) - override def dataType: DataType = StringType - /** Returns the result of evaluating this expression on a given input Row */ - override def eval(input: InternalRow): Any = { - val num = numExpr.eval(input) - if (num != null) { - val fromBase = fromBaseExpr.eval(input) - if (fromBase != null) { - val toBase = toBaseExpr.eval(input) - if (toBase != null) { - NumberConverter.convert( - num.asInstanceOf[UTF8String].getBytes, - fromBase.asInstanceOf[Int], - toBase.asInstanceOf[Int]) - } else { - null - } - } else { - null - } - } else { - null - } + override def nullSafeEval(num: Any, fromBase: Any, toBase: Any): Any = { + NumberConverter.convert( + num.asInstanceOf[UTF8String].getBytes, + fromBase.asInstanceOf[Int], + toBase.asInstanceOf[Int]) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val numGen = numExpr.gen(ctx) - val from = fromBaseExpr.gen(ctx) - val to = toBaseExpr.gen(ctx) - val numconv = NumberConverter.getClass.getName.stripSuffix("$") - s""" - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - ${numGen.code} - boolean ${ev.isNull} = ${numGen.isNull}; - if (!${ev.isNull}) { - ${from.code} - if (!${from.isNull}) { - ${to.code} - if (!${to.isNull}) { - ${ev.primitive} = $numconv.convert(${numGen.primitive}.getBytes(), - ${from.primitive}, ${to.primitive}); - if (${ev.primitive} == null) { - ${ev.isNull} = true; - } - } else { - ${ev.isNull} = true; - } - } else { - ${ev.isNull} = true; - } + nullSafeCodeGen(ctx, ev, (num, from, to) => + s""" + ${ev.primitive} = $numconv.convert($num.getBytes(), $from, $to); + if (${ev.primitive} == null) { + ${ev.isNull} = true; } - """ + """ + ) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 99a62343f138d..684eac12bd6f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -426,15 +426,13 @@ case class StringInstr(str: Expression, substr: Expression) * in given string after position pos. */ case class StringLocate(substr: Expression, str: Expression, start: Expression) - extends Expression with ImplicitCastInputTypes with CodegenFallback { + extends TernaryExpression with ImplicitCastInputTypes with CodegenFallback { def this(substr: Expression, str: Expression) = { this(substr, str, Literal(0)) } override def children: Seq[Expression] = substr :: str :: start :: Nil - override def foldable: Boolean = children.forall(_.foldable) - override def nullable: Boolean = substr.nullable || str.nullable override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) @@ -467,60 +465,18 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) * Returns str, left-padded with pad to a length of len. */ case class StringLPad(str: Expression, len: Expression, pad: Expression) - extends Expression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes { override def children: Seq[Expression] = str :: len :: pad :: Nil - override def foldable: Boolean = children.forall(_.foldable) - override def nullable: Boolean = children.exists(_.nullable) override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) - override def eval(input: InternalRow): Any = { - val s = str.eval(input) - if (s == null) { - null - } else { - val l = len.eval(input) - if (l == null) { - null - } else { - val p = pad.eval(input) - if (p == null) { - null - } else { - val len = l.asInstanceOf[Int] - val str = s.asInstanceOf[UTF8String] - val pad = p.asInstanceOf[UTF8String] - - str.lpad(len, pad) - } - } - } + override def nullSafeEval(str: Any, len: Any, pad: Any): Any = { + str.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) } override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val lenGen = len.gen(ctx) - val strGen = str.gen(ctx) - val padGen = pad.gen(ctx) - - s""" - ${lenGen.code} - boolean ${ev.isNull} = ${lenGen.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${strGen.code} - if (!${strGen.isNull}) { - ${padGen.code} - if (!${padGen.isNull}) { - ${ev.primitive} = ${strGen.primitive}.lpad(${lenGen.primitive}, ${padGen.primitive}); - } else { - ${ev.isNull} = true; - } - } else { - ${ev.isNull} = true; - } - } - """ + defineCodeGen(ctx, ev, (str, len, pad) => s"$str.lpad($len, $pad)") } override def prettyName: String = "lpad" @@ -530,60 +486,18 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) * Returns str, right-padded with pad to a length of len. */ case class StringRPad(str: Expression, len: Expression, pad: Expression) - extends Expression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes { override def children: Seq[Expression] = str :: len :: pad :: Nil - override def foldable: Boolean = children.forall(_.foldable) - override def nullable: Boolean = children.exists(_.nullable) override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) - override def eval(input: InternalRow): Any = { - val s = str.eval(input) - if (s == null) { - null - } else { - val l = len.eval(input) - if (l == null) { - null - } else { - val p = pad.eval(input) - if (p == null) { - null - } else { - val len = l.asInstanceOf[Int] - val str = s.asInstanceOf[UTF8String] - val pad = p.asInstanceOf[UTF8String] - - str.rpad(len, pad) - } - } - } + override def nullSafeEval(str: Any, len: Any, pad: Any): Any = { + str.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) } override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val lenGen = len.gen(ctx) - val strGen = str.gen(ctx) - val padGen = pad.gen(ctx) - - s""" - ${lenGen.code} - boolean ${ev.isNull} = ${lenGen.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${strGen.code} - if (!${strGen.isNull}) { - ${padGen.code} - if (!${padGen.isNull}) { - ${ev.primitive} = ${strGen.primitive}.rpad(${lenGen.primitive}, ${padGen.primitive}); - } else { - ${ev.isNull} = true; - } - } else { - ${ev.isNull} = true; - } - } - """ + defineCodeGen(ctx, ev, (str, len, pad) => s"$str.rpad($len, $pad)") } override def prettyName: String = "rpad" @@ -745,68 +659,24 @@ case class StringSplit(str: Expression, pattern: Expression) * Defined for String and Binary types. */ case class Substring(str: Expression, pos: Expression, len: Expression) - extends Expression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes { def this(str: Expression, pos: Expression) = { this(str, pos, Literal(Integer.MAX_VALUE)) } - override def foldable: Boolean = str.foldable && pos.foldable && len.foldable - override def nullable: Boolean = str.nullable || pos.nullable || len.nullable - override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType) override def children: Seq[Expression] = str :: pos :: len :: Nil - override def eval(input: InternalRow): Any = { - val stringEval = str.eval(input) - if (stringEval != null) { - val posEval = pos.eval(input) - if (posEval != null) { - val lenEval = len.eval(input) - if (lenEval != null) { - stringEval.asInstanceOf[UTF8String] - .substringSQL(posEval.asInstanceOf[Int], lenEval.asInstanceOf[Int]) - } else { - null - } - } else { - null - } - } else { - null - } + override def nullSafeEval(string: Any, pos: Any, len: Any): Any = { + string.asInstanceOf[UTF8String].substringSQL(pos.asInstanceOf[Int], len.asInstanceOf[Int]) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val strGen = str.gen(ctx) - val posGen = pos.gen(ctx) - val lenGen = len.gen(ctx) - - val start = ctx.freshName("start") - val end = ctx.freshName("end") - - s""" - ${strGen.code} - boolean ${ev.isNull} = ${strGen.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${posGen.code} - if (!${posGen.isNull}) { - ${lenGen.code} - if (!${lenGen.isNull}) { - ${ev.primitive} = ${strGen.primitive} - .substringSQL(${posGen.primitive}, ${lenGen.primitive}); - } else { - ${ev.isNull} = true; - } - } else { - ${ev.isNull} = true; - } - } - """ + defineCodeGen(ctx, ev, (str, pos, len) => s"$str.substringSQL($pos, $len)") } } @@ -986,7 +856,7 @@ case class Encode(value: Expression, charset: Expression) * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. */ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression) - extends Expression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes { // last regex in string, we will update the pattern iff regexp value changed. @transient private var lastRegex: UTF8String = _ @@ -998,40 +868,26 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio // result buffer write by Matcher @transient private val result: StringBuffer = new StringBuffer - override def nullable: Boolean = subject.nullable || regexp.nullable || rep.nullable - override def foldable: Boolean = subject.foldable && regexp.foldable && rep.foldable - - override def eval(input: InternalRow): Any = { - val s = subject.eval(input) - if (null != s) { - val p = regexp.eval(input) - if (null != p) { - val r = rep.eval(input) - if (null != r) { - if (!p.equals(lastRegex)) { - // regex value changed - lastRegex = p.asInstanceOf[UTF8String] - pattern = Pattern.compile(lastRegex.toString) - } - if (!r.equals(lastReplacementInUTF8)) { - // replacement string changed - lastReplacementInUTF8 = r.asInstanceOf[UTF8String] - lastReplacement = lastReplacementInUTF8.toString - } - val m = pattern.matcher(s.toString()) - result.delete(0, result.length()) - - while (m.find) { - m.appendReplacement(result, lastReplacement) - } - m.appendTail(result) + override def nullSafeEval(s: Any, p: Any, r: Any): Any = { + if (!p.equals(lastRegex)) { + // regex value changed + lastRegex = p.asInstanceOf[UTF8String] + pattern = Pattern.compile(lastRegex.toString) + } + if (!r.equals(lastReplacementInUTF8)) { + // replacement string changed + lastReplacementInUTF8 = r.asInstanceOf[UTF8String] + lastReplacement = lastReplacementInUTF8.toString + } + val m = pattern.matcher(s.toString()) + result.delete(0, result.length()) - return UTF8String.fromString(result.toString) - } - } + while (m.find) { + m.appendReplacement(result, lastReplacement) } + m.appendTail(result) - null + UTF8String.fromString(result.toString) } override def dataType: DataType = StringType @@ -1048,59 +904,43 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio val termResult = ctx.freshName("result") - val classNameUTF8String = classOf[UTF8String].getCanonicalName val classNamePattern = classOf[Pattern].getCanonicalName - val classNameString = classOf[java.lang.String].getCanonicalName val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName - ctx.addMutableState(classNameUTF8String, + ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") - ctx.addMutableState(classNameString, + ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;") - ctx.addMutableState(classNameUTF8String, + ctx.addMutableState("UTF8String", termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;") ctx.addMutableState(classNameStringBuffer, termResult, s"${termResult} = new $classNameStringBuffer();") - val evalSubject = subject.gen(ctx) - val evalRegexp = regexp.gen(ctx) - val evalRep = rep.gen(ctx) - + nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => { s""" - ${evalSubject.code} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${evalSubject.isNull}) { - ${evalRegexp.code} - if (!${evalRegexp.isNull}) { - ${evalRep.code} - if (!${evalRep.isNull}) { - if (!${evalRegexp.primitive}.equals(${termLastRegex})) { - // regex value changed - ${termLastRegex} = ${evalRegexp.primitive}; - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); - } - if (!${evalRep.primitive}.equals(${termLastReplacementInUTF8})) { - // replacement string changed - ${termLastReplacementInUTF8} = ${evalRep.primitive}; - ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); - } - ${termResult}.delete(0, ${termResult}.length()); - ${classOf[java.util.regex.Matcher].getCanonicalName} m = - ${termPattern}.matcher(${evalSubject.primitive}.toString()); + if (!$regexp.equals(${termLastRegex})) { + // regex value changed + ${termLastRegex} = $regexp; + ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + } + if (!$rep.equals(${termLastReplacementInUTF8})) { + // replacement string changed + ${termLastReplacementInUTF8} = $rep; + ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); + } + ${termResult}.delete(0, ${termResult}.length()); + java.util.regex.Matcher m = ${termPattern}.matcher($subject.toString()); - while (m.find()) { - m.appendReplacement(${termResult}, ${termLastReplacement}); - } - m.appendTail(${termResult}); - ${ev.primitive} = ${classNameUTF8String}.fromString(${termResult}.toString()); - ${ev.isNull} = false; - } - } + while (m.find()) { + m.appendReplacement(${termResult}, ${termLastReplacement}); } + m.appendTail(${termResult}); + ${ev.primitive} = UTF8String.fromString(${termResult}.toString()); + ${ev.isNull} = false; """ + }) } } @@ -1110,7 +950,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. */ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) - extends Expression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes { def this(s: Expression, r: Expression) = this(s, r, Literal(1)) // last regex in string, we will update the pattern iff regexp value changed. @@ -1118,32 +958,19 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio // last regex pattern, we cache it for performance concern @transient private var pattern: Pattern = _ - override def nullable: Boolean = subject.nullable || regexp.nullable || idx.nullable - override def foldable: Boolean = subject.foldable && regexp.foldable && idx.foldable - - override def eval(input: InternalRow): Any = { - val s = subject.eval(input) - if (null != s) { - val p = regexp.eval(input) - if (null != p) { - val r = idx.eval(input) - if (null != r) { - if (!p.equals(lastRegex)) { - // regex value changed - lastRegex = p.asInstanceOf[UTF8String] - pattern = Pattern.compile(lastRegex.toString) - } - val m = pattern.matcher(s.toString()) - if (m.find) { - val mr: MatchResult = m.toMatchResult - return UTF8String.fromString(mr.group(r.asInstanceOf[Int])) - } - return UTF8String.EMPTY_UTF8 - } - } + override def nullSafeEval(s: Any, p: Any, r: Any): Any = { + if (!p.equals(lastRegex)) { + // regex value changed + lastRegex = p.asInstanceOf[UTF8String] + pattern = Pattern.compile(lastRegex.toString) + } + val m = pattern.matcher(s.toString()) + if (m.find) { + val mr: MatchResult = m.toMatchResult + UTF8String.fromString(mr.group(r.asInstanceOf[Int])) + } else { + UTF8String.EMPTY_UTF8 } - - null } override def dataType: DataType = StringType @@ -1154,44 +981,29 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val termLastRegex = ctx.freshName("lastRegex") val termPattern = ctx.freshName("pattern") - val classNameUTF8String = classOf[UTF8String].getCanonicalName val classNamePattern = classOf[Pattern].getCanonicalName - ctx.addMutableState(classNameUTF8String, termLastRegex, s"${termLastRegex} = null;") + ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") - val evalSubject = subject.gen(ctx) - val evalRegexp = regexp.gen(ctx) - val evalIdx = idx.gen(ctx) - - s""" - ${evalSubject.code} - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - boolean ${ev.isNull} = true; - if (!${evalSubject.isNull}) { - ${evalRegexp.code} - if (!${evalRegexp.isNull}) { - ${evalIdx.code} - if (!${evalIdx.isNull}) { - if (!${evalRegexp.primitive}.equals(${termLastRegex})) { - // regex value changed - ${termLastRegex} = ${evalRegexp.primitive}; - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); - } - ${classOf[java.util.regex.Matcher].getCanonicalName} m = - ${termPattern}.matcher(${evalSubject.primitive}.toString()); - if (m.find()) { - ${classOf[java.util.regex.MatchResult].getCanonicalName} mr = m.toMatchResult(); - ${ev.primitive} = ${classNameUTF8String}.fromString(mr.group(${evalIdx.primitive})); - ${ev.isNull} = false; - } else { - ${ev.primitive} = ${classNameUTF8String}.EMPTY_UTF8; - ${ev.isNull} = false; - } - } - } + nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { + s""" + if (!$regexp.equals(${termLastRegex})) { + // regex value changed + ${termLastRegex} = $regexp; + ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); } - """ + java.util.regex.Matcher m = + ${termPattern}.matcher($subject.toString()); + if (m.find()) { + java.util.regex.MatchResult mr = m.toMatchResult(); + ${ev.primitive} = UTF8String.fromString(mr.group($idx)); + ${ev.isNull} = false; + } else { + ${ev.primitive} = UTF8String.EMPTY_UTF8; + ${ev.isNull} = false; + }""" + }) } } From fc0e57e5aba82a3f227fef05a843283e2ec893fc Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Fri, 31 Jul 2015 09:33:38 -0700 Subject: [PATCH 030/340] [SPARK-9053] [SPARKR] Fix spaces around parens, infix operators etc. ### JIRA [[SPARK-9053] Fix spaces around parens, infix operators etc. - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-9053) ### The Result of `lint-r` [The result of lint-r at the rivision:a4c83cb1e4b066cd60264b6572fd3e51d160d26a](https://gist.github.com/yu-iskw/d253d7f8ef351f86443d) Author: Yu ISHIKAWA Closes #7584 from yu-iskw/SPARK-9053 and squashes the following commits: 613170f [Yu ISHIKAWA] Ignore a warning about a space before a left parentheses ede61e1 [Yu ISHIKAWA] Ignores two warnings about a space before a left parentheses. TODO: After updating `lintr`, we will remove the ignores de3e0db [Yu ISHIKAWA] Add '## nolint start' & '## nolint end' statement to ignore infix space warnings e233ea8 [Yu ISHIKAWA] [SPARK-9053][SparkR] Fix spaces around parens, infix operators etc. --- R/pkg/R/DataFrame.R | 4 ++++ R/pkg/R/RDD.R | 7 +++++-- R/pkg/R/column.R | 2 +- R/pkg/R/context.R | 2 +- R/pkg/R/pairRDD.R | 2 +- R/pkg/R/utils.R | 4 ++-- R/pkg/inst/tests/test_binary_function.R | 2 +- R/pkg/inst/tests/test_rdd.R | 6 +++--- R/pkg/inst/tests/test_sparkSQL.R | 4 +++- 9 files changed, 21 insertions(+), 12 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index f4c93d3c7dd67..b31ad3729e09b 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1322,9 +1322,11 @@ setMethod("write.df", "org.apache.spark.sql.parquet") } allModes <- c("append", "overwrite", "error", "ignore") + # nolint start if (!(mode %in% allModes)) { stop('mode should be one of "append", "overwrite", "error", "ignore"') } + # nolint end jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) options <- varargsToEnv(...) if (!is.null(path)) { @@ -1384,9 +1386,11 @@ setMethod("saveAsTable", "org.apache.spark.sql.parquet") } allModes <- c("append", "overwrite", "error", "ignore") + # nolint start if (!(mode %in% allModes)) { stop('mode should be one of "append", "overwrite", "error", "ignore"') } + # nolint end jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) options <- varargsToEnv(...) callJMethod(df@sdf, "saveAsTable", tableName, source, jmode, options) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index d2d096709245d..2a013b3dbb968 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -85,7 +85,9 @@ setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) isPipelinable <- function(rdd) { e <- rdd@env + # nolint start !(e$isCached || e$isCheckpointed) + # nolint end } if (!inherits(prev, "PipelinedRDD") || !isPipelinable(prev)) { @@ -97,7 +99,8 @@ setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) # prev_serializedMode is used during the delayed computation of JRDD in getJRDD } else { pipelinedFunc <- function(partIndex, part) { - func(partIndex, prev@func(partIndex, part)) + f <- prev@func + func(partIndex, f(partIndex, part)) } .Object@func <- cleanClosure(pipelinedFunc) .Object@prev_jrdd <- prev@prev_jrdd # maintain the pipeline @@ -841,7 +844,7 @@ setMethod("sampleRDD", if (withReplacement) { count <- rpois(1, fraction) if (count > 0) { - res[(len + 1):(len + count)] <- rep(list(elem), count) + res[ (len + 1) : (len + count) ] <- rep(list(elem), count) len <- len + count } } else { diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 2892e1416cc65..eeaf9f193b728 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -65,7 +65,7 @@ functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt", "acos", "asin", "atan", "cbrt", "ceiling", "cos", "cosh", "exp", "expm1", "floor", "log", "log10", "log1p", "rint", "sign", "sin", "sinh", "tan", "tanh", "toDegrees", "toRadians") -binary_mathfunctions<- c("atan2", "hypot") +binary_mathfunctions <- c("atan2", "hypot") createOperator <- function(op) { setMethod(op, diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 43be9c904fdf6..720990e1c6087 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -121,7 +121,7 @@ parallelize <- function(sc, coll, numSlices = 1) { numSlices <- length(coll) sliceLen <- ceiling(length(coll) / numSlices) - slices <- split(coll, rep(1:(numSlices + 1), each = sliceLen)[1:length(coll)]) + slices <- split(coll, rep(1: (numSlices + 1), each = sliceLen)[1:length(coll)]) # Serialize each slice: obtain a list of raws, or a list of lists (slices) of # 2-tuples of raws diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 83801d3209700..199c3fd6ab1b2 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -879,7 +879,7 @@ setMethod("sampleByKey", if (withReplacement) { count <- rpois(1, frac) if (count > 0) { - res[(len + 1):(len + count)] <- rep(list(elem), count) + res[ (len + 1) : (len + count) ] <- rep(list(elem), count) len <- len + count } } else { diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 3f45589a50443..4f9f4d9cad2a8 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -32,7 +32,7 @@ convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL, } results <- if (arrSize > 0) { - lapply(0:(arrSize - 1), + lapply(0 : (arrSize - 1), function(index) { obj <- callJMethod(jList, "get", as.integer(index)) @@ -572,7 +572,7 @@ mergePartitions <- function(rdd, zip) { keys <- list() } if (lengthOfValues > 1) { - values <- part[(lengthOfKeys + 1) : (len - 1)] + values <- part[ (lengthOfKeys + 1) : (len - 1) ] } else { values <- list() } diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R index dca0657c57e0d..f054ac9a87d61 100644 --- a/R/pkg/inst/tests/test_binary_function.R +++ b/R/pkg/inst/tests/test_binary_function.R @@ -40,7 +40,7 @@ test_that("union on two RDDs", { expect_equal(actual, c(as.list(nums), mockFile)) expect_equal(getSerializedMode(union.rdd), "byte") - rdd<- map(text.rdd, function(x) {x}) + rdd <- map(text.rdd, function(x) {x}) union.rdd <- unionRDD(rdd, text.rdd) actual <- collect(union.rdd) expect_equal(actual, as.list(c(mockFile, mockFile))) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index 6c3aaab8c711e..71aed2bb9d6a8 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -250,7 +250,7 @@ test_that("flatMapValues() on pairwise RDDs", { expect_equal(actual, list(list(1,1), list(1,2), list(2,3), list(2,4))) # Generate x to x+1 for every value - actual <- collect(flatMapValues(intRdd, function(x) { x:(x + 1) })) + actual <- collect(flatMapValues(intRdd, function(x) { x: (x + 1) })) expect_equal(actual, list(list(1L, -1), list(1L, 0), list(2L, 100), list(2L, 101), list(2L, 1), list(2L, 2), list(1L, 200), list(1L, 201))) @@ -293,7 +293,7 @@ test_that("sumRDD() on RDDs", { }) test_that("keyBy on RDDs", { - func <- function(x) { x*x } + func <- function(x) { x * x } keys <- keyBy(rdd, func) actual <- collect(keys) expect_equal(actual, lapply(nums, function(x) { list(func(x), x) })) @@ -311,7 +311,7 @@ test_that("repartition/coalesce on RDDs", { r2 <- repartition(rdd, 6) expect_equal(numPartitions(r2), 6L) count <- length(collectPartition(r2, 0L)) - expect_true(count >=0 && count <= 4) + expect_true(count >= 0 && count <= 4) # coalesce r3 <- coalesce(rdd, 1) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 61c8a7ec7d837..aca41aa6dcf24 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -666,10 +666,12 @@ test_that("column binary mathfunctions", { expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6)) expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7)) expect_equal(collect(select(df, atan2(df$a, df$b)))[4, "ATAN2(a, b)"], atan2(4, 8)) + ## nolint start expect_equal(collect(select(df, hypot(df$a, df$b)))[1, "HYPOT(a, b)"], sqrt(1^2 + 5^2)) expect_equal(collect(select(df, hypot(df$a, df$b)))[2, "HYPOT(a, b)"], sqrt(2^2 + 6^2)) expect_equal(collect(select(df, hypot(df$a, df$b)))[3, "HYPOT(a, b)"], sqrt(3^2 + 7^2)) expect_equal(collect(select(df, hypot(df$a, df$b)))[4, "HYPOT(a, b)"], sqrt(4^2 + 8^2)) + ## nolint end }) test_that("string operators", { @@ -876,7 +878,7 @@ test_that("parquetFile works with multiple input paths", { write.df(df, parquetPath2, "parquet", mode="overwrite") parquetDF <- parquetFile(sqlContext, parquetPath, parquetPath2) expect_is(parquetDF, "DataFrame") - expect_equal(count(parquetDF), count(df)*2) + expect_equal(count(parquetDF), count(df) * 2) }) test_that("describe() on a DataFrame", { From 04a49edfdb606c01fa4f8ae6e730ec4f9bd0cb6d Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 31 Jul 2015 09:34:10 -0700 Subject: [PATCH 031/340] [SPARK-9497] [SPARK-9509] [CORE] Use ask instead of askWithRetry `RpcEndpointRef.askWithRetry` throws `SparkException` rather than `TimeoutException`. Use ask to replace it because we don't need to retry here. Author: zsxwing Closes #7824 from zsxwing/SPARK-9497 and squashes the following commits: 7bfc2b4 [zsxwing] Use ask instead of askWithRetry --- .../scala/org/apache/spark/deploy/client/AppClient.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 79b251e7e62fe..a659abf70395d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -27,7 +27,7 @@ import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master import org.apache.spark.rpc._ -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils} /** * Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL, @@ -248,7 +248,8 @@ private[spark] class AppClient( def stop() { if (endpoint != null) { try { - endpoint.askWithRetry[Boolean](StopAppClient) + val timeout = RpcUtils.askRpcTimeout(conf) + timeout.awaitResult(endpoint.ask[Boolean](StopAppClient)) } catch { case e: TimeoutException => logInfo("Stop request to Master timed out; it may already be shut down.") From 27ae851ce16082775ffbcb5b8fc6bdbe65dc70fc Mon Sep 17 00:00:00 2001 From: tedyu Date: Fri, 31 Jul 2015 18:16:55 +0100 Subject: [PATCH 032/340] [SPARK-9446] Clear Active SparkContext in stop() method In thread 'stopped SparkContext remaining active' on mailing list, Andres observed the following in driver log: ``` 15/07/29 15:17:09 WARN YarnSchedulerBackend$YarnSchedulerEndpoint: ApplicationMaster has disassociated:
15/07/29 15:17:09 INFO YarnClientSchedulerBackend: Shutting down all executors Exception in thread "Yarn application state monitor" org.apache.spark.SparkException: Error asking standalone scheduler to shut down executors at org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.stopExecutors(CoarseGrainedSchedulerBackend.scala:261) at org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.stop(CoarseGrainedSchedulerBackend.scala:266) at org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend.stop(YarnClientSchedulerBackend.scala:158) at org.apache.spark.scheduler.TaskSchedulerImpl.stop(TaskSchedulerImpl.scala:416) at org.apache.spark.scheduler.DAGScheduler.stop(DAGScheduler.scala:1411) at org.apache.spark.SparkContext.stop(SparkContext.scala:1644) at org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend$$anon$1.run(YarnClientSchedulerBackend.scala:139) Caused by: java.lang.InterruptedException at java.util.concurrent.locks.AbstractQueuedSynchronizer.tryAcquireSharedNanos(AbstractQueuedSynchronizer.java:1325) at scala.concurrent.impl.Promise$DefaultPromise.tryAwait(Promise.scala:208) at scala.concurrent.impl.Promise$DefaultPromise.ready(Promise.scala:218) at scala.concurrent.impl.Promise$DefaultPromise.result(Promise.scala:223) at scala.concurrent.Await$$anonfun$result$1.apply(package.scala:190) at scala.concurrent.BlockContext$DefaultBlockContext$.blockOn(BlockContext.scala:53) at scala.concurrent.Await$.result(package.scala:190)15/07/29 15:17:09 INFO YarnClientSchedulerBackend: Asking each executor to shut down at org.apache.spark.rpc.RpcEndpointRef.askWithRetry(RpcEndpointRef.scala:102) at org.apache.spark.rpc.RpcEndpointRef.askWithRetry(RpcEndpointRef.scala:78) at org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.stopExecutors(CoarseGrainedSchedulerBackend.scala:257) ... 6 more ``` Effect of the above exception is that a stopped SparkContext is returned to user since SparkContext.clearActiveContext() is not called. Author: tedyu Closes #7756 from tedyu/master and squashes the following commits: 7339ff2 [tedyu] Move null assignment out of tryLogNonFatalError block 6e02cd9 [tedyu] Use Utils.tryLogNonFatalError to guard resource release f5fb519 [tedyu] Clear Active SparkContext in stop() method using finally --- .../scala/org/apache/spark/SparkContext.scala | 50 ++++++++++++++----- 1 file changed, 37 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ac6ac6c216767..2d8aa25d81daa 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1689,33 +1689,57 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli Utils.removeShutdownHook(_shutdownHookRef) } - postApplicationEnd() - _ui.foreach(_.stop()) + Utils.tryLogNonFatalError { + postApplicationEnd() + } + Utils.tryLogNonFatalError { + _ui.foreach(_.stop()) + } if (env != null) { - env.metricsSystem.report() + Utils.tryLogNonFatalError { + env.metricsSystem.report() + } } if (metadataCleaner != null) { - metadataCleaner.cancel() + Utils.tryLogNonFatalError { + metadataCleaner.cancel() + } + } + Utils.tryLogNonFatalError { + _cleaner.foreach(_.stop()) + } + Utils.tryLogNonFatalError { + _executorAllocationManager.foreach(_.stop()) } - _cleaner.foreach(_.stop()) - _executorAllocationManager.foreach(_.stop()) if (_dagScheduler != null) { - _dagScheduler.stop() + Utils.tryLogNonFatalError { + _dagScheduler.stop() + } _dagScheduler = null } if (_listenerBusStarted) { - listenerBus.stop() - _listenerBusStarted = false + Utils.tryLogNonFatalError { + listenerBus.stop() + _listenerBusStarted = false + } + } + Utils.tryLogNonFatalError { + _eventLogger.foreach(_.stop()) } - _eventLogger.foreach(_.stop()) if (env != null && _heartbeatReceiver != null) { - env.rpcEnv.stop(_heartbeatReceiver) + Utils.tryLogNonFatalError { + env.rpcEnv.stop(_heartbeatReceiver) + } + } + Utils.tryLogNonFatalError { + _progressBar.foreach(_.stop()) } - _progressBar.foreach(_.stop()) _taskScheduler = null // TODO: Cache.stop()? if (_env != null) { - _env.stop() + Utils.tryLogNonFatalError { + _env.stop() + } SparkEnv.set(null) } SparkContext.clearActiveContext() From 0024da9157ba12ec84883a78441fa6835c1d0042 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 31 Jul 2015 11:07:34 -0700 Subject: [PATCH 033/340] [SQL] address comments for to_date/trunc This PR address the comments in #7805 cc rxin Author: Davies Liu Closes #7817 from davies/trunc and squashes the following commits: f729d5f [Davies Liu] rollback cb7f7832 [Davies Liu] genCode() is protected 31e52ef [Davies Liu] fix style ed1edc7 [Davies Liu] address comments for #7805 --- .../catalyst/expressions/datetimeFunctions.scala | 15 ++++++++------- .../spark/sql/catalyst/util/DateTimeUtils.scala | 3 ++- .../expressions/ExpressionEvalHelper.scala | 4 +--- .../scala/org/apache/spark/sql/functions.scala | 3 +++ 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala index 6e7613340c032..07dea5b470b5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala @@ -726,15 +726,16 @@ case class TruncDate(date: Expression, format: Expression) override def dataType: DataType = DateType override def prettyName: String = "trunc" - lazy val minItemConst = DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) + private lazy val truncLevel: Int = + DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) override def eval(input: InternalRow): Any = { - val minItem = if (format.foldable) { - minItemConst + val level = if (format.foldable) { + truncLevel } else { DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) } - if (minItem == -1) { + if (level == -1) { // unknown format null } else { @@ -742,7 +743,7 @@ case class TruncDate(date: Expression, format: Expression) if (d == null) { null } else { - DateTimeUtils.truncDate(d.asInstanceOf[Int], minItem) + DateTimeUtils.truncDate(d.asInstanceOf[Int], level) } } } @@ -751,7 +752,7 @@ case class TruncDate(date: Expression, format: Expression) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (format.foldable) { - if (minItemConst == -1) { + if (truncLevel == -1) { s""" boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; @@ -763,7 +764,7 @@ case class TruncDate(date: Expression, format: Expression) boolean ${ev.isNull} = ${d.isNull}; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.primitive} = $dtu.truncDate(${d.primitive}, $minItemConst); + ${ev.primitive} = $dtu.truncDate(${d.primitive}, $truncLevel); } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 5a7c25b8d508d..032ed8a56a50e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -794,7 +794,8 @@ object DateTimeUtils { } else if (level == TRUNC_TO_MONTH) { d - DateTimeUtils.getDayOfMonth(d) + 1 } else { - throw new Exception(s"Invalid trunc level: $level") + // caller make sure that this should never be reached + sys.error(s"Invalid trunc level: $level") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 3c05e5c3b833c..a41185b4d8754 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -18,11 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import org.scalactic.TripleEqualsSupport.Spread -import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 46dc4605a5ccb..5d82a5eadd94d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2192,6 +2192,9 @@ object functions { /** * Returns date truncated to the unit specified by the format. * + * @param format: 'year', 'yyyy', 'yy' for truncate by year, + * or 'month', 'mon', 'mm' for truncate by month + * * @group datetime_funcs * @since 1.5.0 */ From 6add4eddb39e7748a87da3e921ea3c7881d30a82 Mon Sep 17 00:00:00 2001 From: Alexander Ulanov Date: Fri, 31 Jul 2015 11:22:40 -0700 Subject: [PATCH 034/340] [SPARK-9471] [ML] Multilayer Perceptron This pull request contains the following feature for ML: - Multilayer Perceptron classifier This implementation is based on our initial pull request with bgreeven: https://github.com/apache/spark/pull/1290 and inspired by very insightful suggestions from mengxr and witgo (I would like to thank all other people from the mentioned thread for useful discussions). The original code was extensively tested and benchmarked. Since then, I've addressed two main requirements that prevented the code from merging into the main branch: - Extensible interface, so it will be easy to implement new types of networks - Main building blocks are traits `Layer` and `LayerModel`. They are used for constructing layers of ANN. New layers can be added by extending the `Layer` and `LayerModel` traits. These traits are private in this release in order to save path to improve them based on community feedback - Back propagation is implemented in general form, so there is no need to change it (optimization algorithm) when new layers are implemented - Speed and scalability: this implementation has to be comparable in terms of speed to the state of the art single node implementations. - The developed benchmark for large ANN shows that the proposed code is on par with C++ CPU implementation and scales nicely with the number of workers. Details can be found here: https://github.com/avulanov/ann-benchmark - DBN and RBM by witgo https://github.com/witgo/spark/tree/ann-interface-gemm-dbn - Dropout https://github.com/avulanov/spark/tree/ann-interface-gemm mengxr and dbtsai kindly agreed to perform code review. Author: Alexander Ulanov Author: Bert Greevenbosch Closes #7621 from avulanov/SPARK-2352-ann and squashes the following commits: 4806b6f [Alexander Ulanov] Addressing reviewers comments. a7e7951 [Alexander Ulanov] Default blockSize: 100. Added documentation to blockSize parameter and DataStacker class f69bb3d [Alexander Ulanov] Addressing reviewers comments. 374bea6 [Alexander Ulanov] Moving ANN to ML package. GradientDescent constructor is now spark private. 43b0ae2 [Alexander Ulanov] Addressing reviewers comments. Adding multiclass test. 9d18469 [Alexander Ulanov] Addressing reviewers comments: unnecessary copy of data in predict 35125ab [Alexander Ulanov] Style fix in tests e191301 [Alexander Ulanov] Apache header a226133 [Alexander Ulanov] Multilayer Perceptron regressor and classifier --- .../org/apache/spark/ml/ann/BreezeUtil.scala | 63 ++ .../scala/org/apache/spark/ml/ann/Layer.scala | 882 ++++++++++++++++++ .../MultilayerPerceptronClassifier.scala | 193 ++++ .../org/apache/spark/ml/param/params.scala | 5 + .../mllib/optimization/GradientDescent.scala | 2 +- .../org/apache/spark/ml/ann/ANNSuite.scala | 91 ++ .../MultilayerPerceptronClassifierSuite.scala | 91 ++ 7 files changed, 1326 insertions(+), 1 deletion(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala new file mode 100644 index 0000000000000..7429f9d652ac5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.ann + +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV} +import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS} + +/** + * In-place DGEMM and DGEMV for Breeze + */ +private[ann] object BreezeUtil { + + // TODO: switch to MLlib BLAS interface + private def transposeString(a: BDM[Double]): String = if (a.isTranspose) "T" else "N" + + /** + * DGEMM: C := alpha * A * B + beta * C + * @param alpha alpha + * @param a A + * @param b B + * @param beta beta + * @param c C + */ + def dgemm(alpha: Double, a: BDM[Double], b: BDM[Double], beta: Double, c: BDM[Double]): Unit = { + // TODO: add code if matrices isTranspose!!! + require(a.cols == b.rows, "A & B Dimension mismatch!") + require(a.rows == c.rows, "A & C Dimension mismatch!") + require(b.cols == c.cols, "A & C Dimension mismatch!") + NativeBLAS.dgemm(transposeString(a), transposeString(b), c.rows, c.cols, a.cols, + alpha, a.data, a.offset, a.majorStride, b.data, b.offset, b.majorStride, + beta, c.data, c.offset, c.rows) + } + + /** + * DGEMV: y := alpha * A * x + beta * y + * @param alpha alpha + * @param a A + * @param x x + * @param beta beta + * @param y y + */ + def dgemv(alpha: Double, a: BDM[Double], x: BDV[Double], beta: Double, y: BDV[Double]): Unit = { + require(a.cols == x.length, "A & b Dimension mismatch!") + NativeBLAS.dgemv(transposeString(a), a.rows, a.cols, + alpha, a.data, a.offset, a.majorStride, x.data, x.offset, x.stride, + beta, y.data, y.offset, y.stride) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala new file mode 100644 index 0000000000000..b5258ff348477 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala @@ -0,0 +1,882 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.ann + +import breeze.linalg.{*, DenseMatrix => BDM, DenseVector => BDV, Vector => BV, axpy => Baxpy, + sum => Bsum} +import breeze.numerics.{log => Blog, sigmoid => Bsigmoid} + +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.optimization._ +import org.apache.spark.rdd.RDD +import org.apache.spark.util.random.XORShiftRandom + +/** + * Trait that holds Layer properties, that are needed to instantiate it. + * Implements Layer instantiation. + * + */ +private[ann] trait Layer extends Serializable { + /** + * Returns the instance of the layer based on weights provided + * @param weights vector with layer weights + * @param position position of weights in the vector + * @return the layer model + */ + def getInstance(weights: Vector, position: Int): LayerModel + + /** + * Returns the instance of the layer with random generated weights + * @param seed seed + * @return the layer model + */ + def getInstance(seed: Long): LayerModel +} + +/** + * Trait that holds Layer weights (or parameters). + * Implements functions needed for forward propagation, computing delta and gradient. + * Can return weights in Vector format. + */ +private[ann] trait LayerModel extends Serializable { + /** + * number of weights + */ + val size: Int + + /** + * Evaluates the data (process the data through the layer) + * @param data data + * @return processed data + */ + def eval(data: BDM[Double]): BDM[Double] + + /** + * Computes the delta for back propagation + * @param nextDelta delta of the next layer + * @param input input data + * @return delta + */ + def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] + + /** + * Computes the gradient + * @param delta delta for this layer + * @param input input data + * @return gradient + */ + def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] + + /** + * Returns weights for the layer in a single vector + * @return layer weights + */ + def weights(): Vector +} + +/** + * Layer properties of affine transformations, that is y=A*x+b + * @param numIn number of inputs + * @param numOut number of outputs + */ +private[ann] class AffineLayer(val numIn: Int, val numOut: Int) extends Layer { + + override def getInstance(weights: Vector, position: Int): LayerModel = { + AffineLayerModel(this, weights, position) + } + + override def getInstance(seed: Long = 11L): LayerModel = { + AffineLayerModel(this, seed) + } +} + +/** + * Model of Affine layer y=A*x+b + * @param w weights (matrix A) + * @param b bias (vector b) + */ +private[ann] class AffineLayerModel private(w: BDM[Double], b: BDV[Double]) extends LayerModel { + val size = w.size + b.length + val gwb = new Array[Double](size) + private lazy val gw: BDM[Double] = new BDM[Double](w.rows, w.cols, gwb) + private lazy val gb: BDV[Double] = new BDV[Double](gwb, w.size) + private var z: BDM[Double] = null + private var d: BDM[Double] = null + private var ones: BDV[Double] = null + + override def eval(data: BDM[Double]): BDM[Double] = { + if (z == null || z.cols != data.cols) z = new BDM[Double](w.rows, data.cols) + z(::, *) := b + BreezeUtil.dgemm(1.0, w, data, 1.0, z) + z + } + + override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = { + if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](w.cols, nextDelta.cols) + BreezeUtil.dgemm(1.0, w.t, nextDelta, 0.0, d) + d + } + + override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = { + BreezeUtil.dgemm(1.0 / input.cols, delta, input.t, 0.0, gw) + if (ones == null || ones.length != delta.cols) ones = BDV.ones[Double](delta.cols) + BreezeUtil.dgemv(1.0 / input.cols, delta, ones, 0.0, gb) + gwb + } + + override def weights(): Vector = AffineLayerModel.roll(w, b) +} + +/** + * Fabric for Affine layer models + */ +private[ann] object AffineLayerModel { + + /** + * Creates a model of Affine layer + * @param layer layer properties + * @param weights vector with weights + * @param position position of weights in the vector + * @return model of Affine layer + */ + def apply(layer: AffineLayer, weights: Vector, position: Int): AffineLayerModel = { + val (w, b) = unroll(weights, position, layer.numIn, layer.numOut) + new AffineLayerModel(w, b) + } + + /** + * Creates a model of Affine layer + * @param layer layer properties + * @param seed seed + * @return model of Affine layer + */ + def apply(layer: AffineLayer, seed: Long): AffineLayerModel = { + val (w, b) = randomWeights(layer.numIn, layer.numOut, seed) + new AffineLayerModel(w, b) + } + + /** + * Unrolls the weights from the vector + * @param weights vector with weights + * @param position position of weights for this layer + * @param numIn number of layer inputs + * @param numOut number of layer outputs + * @return matrix A and vector b + */ + def unroll( + weights: Vector, + position: Int, + numIn: Int, + numOut: Int): (BDM[Double], BDV[Double]) = { + val weightsCopy = weights.toArray + // TODO: the array is not copied to BDMs, make sure this is OK! + val a = new BDM[Double](numOut, numIn, weightsCopy, position) + val b = new BDV[Double](weightsCopy, position + (numOut * numIn), 1, numOut) + (a, b) + } + + /** + * Roll the layer weights into a vector + * @param a matrix A + * @param b vector b + * @return vector of weights + */ + def roll(a: BDM[Double], b: BDV[Double]): Vector = { + val result = new Array[Double](a.size + b.length) + // TODO: make sure that we need to copy! + System.arraycopy(a.toArray, 0, result, 0, a.size) + System.arraycopy(b.toArray, 0, result, a.size, b.length) + Vectors.dense(result) + } + + /** + * Generate random weights for the layer + * @param numIn number of inputs + * @param numOut number of outputs + * @param seed seed + * @return (matrix A, vector b) + */ + def randomWeights(numIn: Int, numOut: Int, seed: Long = 11L): (BDM[Double], BDV[Double]) = { + val rand: XORShiftRandom = new XORShiftRandom(seed) + val weights = BDM.fill[Double](numOut, numIn){ (rand.nextDouble * 4.8 - 2.4) / numIn } + val bias = BDV.fill[Double](numOut){ (rand.nextDouble * 4.8 - 2.4) / numIn } + (weights, bias) + } +} + +/** + * Trait for functions and their derivatives for functional layers + */ +private[ann] trait ActivationFunction extends Serializable { + + /** + * Implements a function + * @param x input data + * @param y output data + */ + def eval(x: BDM[Double], y: BDM[Double]): Unit + + /** + * Implements a derivative of a function (needed for the back propagation) + * @param x input data + * @param y output data + */ + def derivative(x: BDM[Double], y: BDM[Double]): Unit + + /** + * Implements a cross entropy error of a function. + * Needed if the functional layer that contains this function is the output layer + * of the network. + * @param target target output + * @param output computed output + * @param result intermediate result + * @return cross-entropy + */ + def crossEntropy(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double + + /** + * Implements a mean squared error of a function + * @param target target output + * @param output computed output + * @param result intermediate result + * @return mean squared error + */ + def squared(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double +} + +/** + * Implements in-place application of functions + */ +private[ann] object ActivationFunction { + + def apply(x: BDM[Double], y: BDM[Double], func: Double => Double): Unit = { + var i = 0 + while (i < x.rows) { + var j = 0 + while (j < x.cols) { + y(i, j) = func(x(i, j)) + j += 1 + } + i += 1 + } + } + + def apply( + x1: BDM[Double], + x2: BDM[Double], + y: BDM[Double], + func: (Double, Double) => Double): Unit = { + var i = 0 + while (i < x1.rows) { + var j = 0 + while (j < x1.cols) { + y(i, j) = func(x1(i, j), x2(i, j)) + j += 1 + } + i += 1 + } + } +} + +/** + * Implements SoftMax activation function + */ +private[ann] class SoftmaxFunction extends ActivationFunction { + override def eval(x: BDM[Double], y: BDM[Double]): Unit = { + var j = 0 + // find max value to make sure later that exponent is computable + while (j < x.cols) { + var i = 0 + var max = Double.MinValue + while (i < x.rows) { + if (x(i, j) > max) { + max = x(i, j) + } + i += 1 + } + var sum = 0.0 + i = 0 + while (i < x.rows) { + val res = Math.exp(x(i, j) - max) + y(i, j) = res + sum += res + i += 1 + } + i = 0 + while (i < x.rows) { + y(i, j) /= sum + i += 1 + } + j += 1 + } + } + + override def crossEntropy( + output: BDM[Double], + target: BDM[Double], + result: BDM[Double]): Double = { + def m(o: Double, t: Double): Double = o - t + ActivationFunction(output, target, result, m) + -Bsum( target :* Blog(output)) / output.cols + } + + override def derivative(x: BDM[Double], y: BDM[Double]): Unit = { + def sd(z: Double): Double = (1 - z) * z + ActivationFunction(x, y, sd) + } + + override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = { + throw new UnsupportedOperationException("Sorry, squared error is not defined for SoftMax.") + } +} + +/** + * Implements Sigmoid activation function + */ +private[ann] class SigmoidFunction extends ActivationFunction { + override def eval(x: BDM[Double], y: BDM[Double]): Unit = { + def s(z: Double): Double = Bsigmoid(z) + ActivationFunction(x, y, s) + } + + override def crossEntropy( + output: BDM[Double], + target: BDM[Double], + result: BDM[Double]): Double = { + def m(o: Double, t: Double): Double = o - t + ActivationFunction(output, target, result, m) + -Bsum(target :* Blog(output)) / output.cols + } + + override def derivative(x: BDM[Double], y: BDM[Double]): Unit = { + def sd(z: Double): Double = (1 - z) * z + ActivationFunction(x, y, sd) + } + + override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = { + // TODO: make it readable + def m(o: Double, t: Double): Double = (o - t) + ActivationFunction(output, target, result, m) + val e = Bsum(result :* result) / 2 / output.cols + def m2(x: Double, o: Double) = x * (o - o * o) + ActivationFunction(result, output, result, m2) + e + } +} + +/** + * Functional layer properties, y = f(x) + * @param activationFunction activation function + */ +private[ann] class FunctionalLayer (val activationFunction: ActivationFunction) extends Layer { + override def getInstance(weights: Vector, position: Int): LayerModel = getInstance(0L) + + override def getInstance(seed: Long): LayerModel = + FunctionalLayerModel(this) +} + +/** + * Functional layer model. Holds no weights. + * @param activationFunction activation function + */ +private[ann] class FunctionalLayerModel private (val activationFunction: ActivationFunction) + extends LayerModel { + val size = 0 + // matrices for in-place computations + // outputs + private var f: BDM[Double] = null + // delta + private var d: BDM[Double] = null + // matrix for error computation + private var e: BDM[Double] = null + // delta gradient + private lazy val dg = new Array[Double](0) + + override def eval(data: BDM[Double]): BDM[Double] = { + if (f == null || f.cols != data.cols) f = new BDM[Double](data.rows, data.cols) + activationFunction.eval(data, f) + f + } + + override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = { + if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](nextDelta.rows, nextDelta.cols) + activationFunction.derivative(input, d) + d :*= nextDelta + d + } + + override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = dg + + override def weights(): Vector = Vectors.dense(new Array[Double](0)) + + def crossEntropy(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = { + if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols) + val error = activationFunction.crossEntropy(output, target, e) + (e, error) + } + + def squared(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = { + if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols) + val error = activationFunction.squared(output, target, e) + (e, error) + } + + def error(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = { + // TODO: allow user pick error + activationFunction match { + case sigmoid: SigmoidFunction => squared(output, target) + case softmax: SoftmaxFunction => crossEntropy(output, target) + } + } +} + +/** + * Fabric of functional layer models + */ +private[ann] object FunctionalLayerModel { + def apply(layer: FunctionalLayer): FunctionalLayerModel = + new FunctionalLayerModel(layer.activationFunction) +} + +/** + * Trait for the artificial neural network (ANN) topology properties + */ +private[ann] trait Topology extends Serializable{ + def getInstance(weights: Vector): TopologyModel + def getInstance(seed: Long): TopologyModel +} + +/** + * Trait for ANN topology model + */ +private[ann] trait TopologyModel extends Serializable{ + /** + * Forward propagation + * @param data input data + * @return array of outputs for each of the layers + */ + def forward(data: BDM[Double]): Array[BDM[Double]] + + /** + * Prediction of the model + * @param data input data + * @return prediction + */ + def predict(data: Vector): Vector + + /** + * Computes gradient for the network + * @param data input data + * @param target target output + * @param cumGradient cumulative gradient + * @param blockSize block size + * @return error + */ + def computeGradient(data: BDM[Double], target: BDM[Double], cumGradient: Vector, + blockSize: Int): Double + + /** + * Returns the weights of the ANN + * @return weights + */ + def weights(): Vector +} + +/** + * Feed forward ANN + * @param layers + */ +private[ann] class FeedForwardTopology private(val layers: Array[Layer]) extends Topology { + override def getInstance(weights: Vector): TopologyModel = FeedForwardModel(this, weights) + + override def getInstance(seed: Long): TopologyModel = FeedForwardModel(this, seed) +} + +/** + * Factory for some of the frequently-used topologies + */ +private[ml] object FeedForwardTopology { + /** + * Creates a feed forward topology from the array of layers + * @param layers array of layers + * @return feed forward topology + */ + def apply(layers: Array[Layer]): FeedForwardTopology = { + new FeedForwardTopology(layers) + } + + /** + * Creates a multi-layer perceptron + * @param layerSizes sizes of layers including input and output size + * @param softmax wether to use SoftMax or Sigmoid function for an output layer. + * Softmax is default + * @return multilayer perceptron topology + */ + def multiLayerPerceptron(layerSizes: Array[Int], softmax: Boolean = true): FeedForwardTopology = { + val layers = new Array[Layer]((layerSizes.length - 1) * 2) + for(i <- 0 until layerSizes.length - 1){ + layers(i * 2) = new AffineLayer(layerSizes(i), layerSizes(i + 1)) + layers(i * 2 + 1) = + if (softmax && i == layerSizes.length - 2) { + new FunctionalLayer(new SoftmaxFunction()) + } else { + new FunctionalLayer(new SigmoidFunction()) + } + } + FeedForwardTopology(layers) + } +} + +/** + * Model of Feed Forward Neural Network. + * Implements forward, gradient computation and can return weights in vector format. + * @param layerModels models of layers + * @param topology topology of the network + */ +private[ml] class FeedForwardModel private( + val layerModels: Array[LayerModel], + val topology: FeedForwardTopology) extends TopologyModel { + override def forward(data: BDM[Double]): Array[BDM[Double]] = { + val outputs = new Array[BDM[Double]](layerModels.length) + outputs(0) = layerModels(0).eval(data) + for (i <- 1 until layerModels.length) { + outputs(i) = layerModels(i).eval(outputs(i-1)) + } + outputs + } + + override def computeGradient( + data: BDM[Double], + target: BDM[Double], + cumGradient: Vector, + realBatchSize: Int): Double = { + val outputs = forward(data) + val deltas = new Array[BDM[Double]](layerModels.length) + val L = layerModels.length - 1 + val (newE, newError) = layerModels.last match { + case flm: FunctionalLayerModel => flm.error(outputs.last, target) + case _ => + throw new UnsupportedOperationException("Non-functional layer not supported at the top") + } + deltas(L) = new BDM[Double](0, 0) + deltas(L - 1) = newE + for (i <- (L - 2) to (0, -1)) { + deltas(i) = layerModels(i + 1).prevDelta(deltas(i + 1), outputs(i + 1)) + } + val grads = new Array[Array[Double]](layerModels.length) + for (i <- 0 until layerModels.length) { + val input = if (i==0) data else outputs(i - 1) + grads(i) = layerModels(i).grad(deltas(i), input) + } + // update cumGradient + val cumGradientArray = cumGradient.toArray + var offset = 0 + // TODO: extract roll + for (i <- 0 until grads.length) { + val gradArray = grads(i) + var k = 0 + while (k < gradArray.length) { + cumGradientArray(offset + k) += gradArray(k) + k += 1 + } + offset += gradArray.length + } + newError + } + + // TODO: do we really need to copy the weights? they should be read-only + override def weights(): Vector = { + // TODO: extract roll + var size = 0 + for (i <- 0 until layerModels.length) { + size += layerModels(i).size + } + val array = new Array[Double](size) + var offset = 0 + for (i <- 0 until layerModels.length) { + val layerWeights = layerModels(i).weights().toArray + System.arraycopy(layerWeights, 0, array, offset, layerWeights.length) + offset += layerWeights.length + } + Vectors.dense(array) + } + + override def predict(data: Vector): Vector = { + val size = data.size + val result = forward(new BDM[Double](size, 1, data.toArray)) + Vectors.dense(result.last.toArray) + } +} + +/** + * Fabric for feed forward ANN models + */ +private[ann] object FeedForwardModel { + + /** + * Creates a model from a topology and weights + * @param topology topology + * @param weights weights + * @return model + */ + def apply(topology: FeedForwardTopology, weights: Vector): FeedForwardModel = { + val layers = topology.layers + val layerModels = new Array[LayerModel](layers.length) + var offset = 0 + for (i <- 0 until layers.length) { + layerModels(i) = layers(i).getInstance(weights, offset) + offset += layerModels(i).size + } + new FeedForwardModel(layerModels, topology) + } + + /** + * Creates a model given a topology and seed + * @param topology topology + * @param seed seed for generating the weights + * @return model + */ + def apply(topology: FeedForwardTopology, seed: Long = 11L): FeedForwardModel = { + val layers = topology.layers + val layerModels = new Array[LayerModel](layers.length) + var offset = 0 + for(i <- 0 until layers.length){ + layerModels(i) = layers(i).getInstance(seed) + offset += layerModels(i).size + } + new FeedForwardModel(layerModels, topology) + } +} + +/** + * Neural network gradient. Does nothing but calling Model's gradient + * @param topology topology + * @param dataStacker data stacker + */ +private[ann] class ANNGradient(topology: Topology, dataStacker: DataStacker) extends Gradient { + + override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { + val gradient = Vectors.zeros(weights.size) + val loss = compute(data, label, weights, gradient) + (gradient, loss) + } + + override def compute( + data: Vector, + label: Double, + weights: Vector, + cumGradient: Vector): Double = { + val (input, target, realBatchSize) = dataStacker.unstack(data) + val model = topology.getInstance(weights) + model.computeGradient(input, target, cumGradient, realBatchSize) + } +} + +/** + * Stacks pairs of training samples (input, output) in one vector allowing them to pass + * through Optimizer/Gradient interfaces. If stackSize is more than one, makes blocks + * or matrices of inputs and outputs and then stack them in one vector. + * This can be used for further batch computations after unstacking. + * @param stackSize stack size + * @param inputSize size of the input vectors + * @param outputSize size of the output vectors + */ +private[ann] class DataStacker(stackSize: Int, inputSize: Int, outputSize: Int) + extends Serializable { + + /** + * Stacks the data + * @param data RDD of vector pairs + * @return RDD of double (always zero) and vector that contains the stacked vectors + */ + def stack(data: RDD[(Vector, Vector)]): RDD[(Double, Vector)] = { + val stackedData = if (stackSize == 1) { + data.map { v => + (0.0, + Vectors.fromBreeze(BDV.vertcat( + v._1.toBreeze.toDenseVector, + v._2.toBreeze.toDenseVector)) + ) } + } else { + data.mapPartitions { it => + it.grouped(stackSize).map { seq => + val size = seq.size + val bigVector = new Array[Double](inputSize * size + outputSize * size) + var i = 0 + seq.foreach { case (in, out) => + System.arraycopy(in.toArray, 0, bigVector, i * inputSize, inputSize) + System.arraycopy(out.toArray, 0, bigVector, + inputSize * size + i * outputSize, outputSize) + i += 1 + } + (0.0, Vectors.dense(bigVector)) + } + } + } + stackedData + } + + /** + * Unstack the stacked vectors into matrices for batch operations + * @param data stacked vector + * @return pair of matrices holding input and output data and the real stack size + */ + def unstack(data: Vector): (BDM[Double], BDM[Double], Int) = { + val arrData = data.toArray + val realStackSize = arrData.length / (inputSize + outputSize) + val input = new BDM(inputSize, realStackSize, arrData) + val target = new BDM(outputSize, realStackSize, arrData, inputSize * realStackSize) + (input, target, realStackSize) + } +} + +/** + * Simple updater + */ +private[ann] class ANNUpdater extends Updater { + + override def compute( + weightsOld: Vector, + gradient: Vector, + stepSize: Double, + iter: Int, + regParam: Double): (Vector, Double) = { + val thisIterStepSize = stepSize + val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector + Baxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) + (Vectors.fromBreeze(brzWeights), 0) + } +} + +/** + * MLlib-style trainer class that trains a network given the data and topology + * @param topology topology of ANN + * @param inputSize input size + * @param outputSize output size + */ +private[ml] class FeedForwardTrainer( + topology: Topology, + val inputSize: Int, + val outputSize: Int) extends Serializable { + + // TODO: what if we need to pass random seed? + private var _weights = topology.getInstance(11L).weights() + private var _stackSize = 128 + private var dataStacker = new DataStacker(_stackSize, inputSize, outputSize) + private var _gradient: Gradient = new ANNGradient(topology, dataStacker) + private var _updater: Updater = new ANNUpdater() + private var optimizer: Optimizer = LBFGSOptimizer.setConvergenceTol(1e-4).setNumIterations(100) + + /** + * Returns weights + * @return weights + */ + def getWeights: Vector = _weights + + /** + * Sets weights + * @param value weights + * @return trainer + */ + def setWeights(value: Vector): FeedForwardTrainer = { + _weights = value + this + } + + /** + * Sets the stack size + * @param value stack size + * @return trainer + */ + def setStackSize(value: Int): FeedForwardTrainer = { + _stackSize = value + dataStacker = new DataStacker(value, inputSize, outputSize) + this + } + + /** + * Sets the SGD optimizer + * @return SGD optimizer + */ + def SGDOptimizer: GradientDescent = { + val sgd = new GradientDescent(_gradient, _updater) + optimizer = sgd + sgd + } + + /** + * Sets the LBFGS optimizer + * @return LBGS optimizer + */ + def LBFGSOptimizer: LBFGS = { + val lbfgs = new LBFGS(_gradient, _updater) + optimizer = lbfgs + lbfgs + } + + /** + * Sets the updater + * @param value updater + * @return trainer + */ + def setUpdater(value: Updater): FeedForwardTrainer = { + _updater = value + updateUpdater(value) + this + } + + /** + * Sets the gradient + * @param value gradient + * @return trainer + */ + def setGradient(value: Gradient): FeedForwardTrainer = { + _gradient = value + updateGradient(value) + this + } + + private[this] def updateGradient(gradient: Gradient): Unit = { + optimizer match { + case lbfgs: LBFGS => lbfgs.setGradient(gradient) + case sgd: GradientDescent => sgd.setGradient(gradient) + case other => throw new UnsupportedOperationException( + s"Only LBFGS and GradientDescent are supported but got ${other.getClass}.") + } + } + + private[this] def updateUpdater(updater: Updater): Unit = { + optimizer match { + case lbfgs: LBFGS => lbfgs.setUpdater(updater) + case sgd: GradientDescent => sgd.setUpdater(updater) + case other => throw new UnsupportedOperationException( + s"Only LBFGS and GradientDescent are supported but got ${other.getClass}.") + } + } + + /** + * Trains the ANN + * @param data RDD of input and output vector pairs + * @return model + */ + def train(data: RDD[(Vector, Vector)]): TopologyModel = { + val newWeights = optimizer.optimize(dataStacker.stack(data), getWeights) + topology.getInstance(newWeights) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala new file mode 100644 index 0000000000000..8cd2103d7d5e6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed} +import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} +import org.apache.spark.ml.param.{IntParam, ParamValidators, IntArrayParam, ParamMap} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.ann.{FeedForwardTrainer, FeedForwardTopology} +import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.sql.DataFrame + +/** Params for Multilayer Perceptron. */ +private[ml] trait MultilayerPerceptronParams extends PredictorParams + with HasSeed with HasMaxIter with HasTol { + /** + * Layer sizes including input size and output size. + * @group param + */ + final val layers: IntArrayParam = new IntArrayParam(this, "layers", + "Sizes of layers from input layer to output layer" + + " E.g., Array(780, 100, 10) means 780 inputs, " + + "one hidden layer with 100 neurons and output layer of 10 neurons.", + // TODO: how to check ALSO that all elements are greater than 0? + ParamValidators.arrayLengthGt(1) + ) + + /** @group setParam */ + def setLayers(value: Array[Int]): this.type = set(layers, value) + + /** @group getParam */ + final def getLayers: Array[Int] = $(layers) + + /** + * Block size for stacking input data in matrices to speed up the computation. + * Data is stacked within partitions. If block size is more than remaining data in + * a partition then it is adjusted to the size of this data. + * Recommended size is between 10 and 1000. + * @group expertParam + */ + final val blockSize: IntParam = new IntParam(this, "blockSize", + "Block size for stacking input data in matrices. Data is stacked within partitions." + + " If block size is more than remaining data in a partition then " + + "it is adjusted to the size of this data. Recommended size is between 10 and 1000", + ParamValidators.gt(0)) + + /** @group setParam */ + def setBlockSize(value: Int): this.type = set(blockSize, value) + + /** @group getParam */ + final def getBlockSize: Int = $(blockSize) + + /** + * Set the maximum number of iterations. + * Default is 100. + * @group setParam + */ + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** + * Set the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-4. + * @group setParam + */ + def setTol(value: Double): this.type = set(tol, value) + + /** + * Set the seed for weights initialization. + * @group setParam + */ + def setSeed(value: Long): this.type = set(seed, value) + + setDefault(maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 128) +} + +/** Label to vector converter. */ +private object LabelConverter { + // TODO: Use OneHotEncoder instead + /** + * Encodes a label as a vector. + * Returns a vector of given length with zeroes at all positions + * and value 1.0 at the position that corresponds to the label. + * + * @param labeledPoint labeled point + * @param labelCount total number of labels + * @return pair of features and vector encoding of a label + */ + def encodeLabeledPoint(labeledPoint: LabeledPoint, labelCount: Int): (Vector, Vector) = { + val output = Array.fill(labelCount)(0.0) + output(labeledPoint.label.toInt) = 1.0 + (labeledPoint.features, Vectors.dense(output)) + } + + /** + * Converts a vector to a label. + * Returns the position of the maximal element of a vector. + * + * @param output label encoded with a vector + * @return label + */ + def decodeLabel(output: Vector): Double = { + output.argmax.toDouble + } +} + +/** + * :: Experimental :: + * Classifier trainer based on the Multilayer Perceptron. + * Each layer has sigmoid activation function, output layer has softmax. + * Number of inputs has to be equal to the size of feature vectors. + * Number of outputs has to be equal to the total number of labels. + * + */ +@Experimental +class MultilayerPerceptronClassifier(override val uid: String) + extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassifierModel] + with MultilayerPerceptronParams { + + def this() = this(Identifiable.randomUID("mlpc")) + + override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra) + + /** + * Train a model using the given dataset and parameters. + * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation + * and copying parameters into the model. + * + * @param dataset Training dataset + * @return Fitted model + */ + override protected def train(dataset: DataFrame): MultilayerPerceptronClassifierModel = { + val myLayers = $(layers) + val labels = myLayers.last + val lpData = extractLabeledPoints(dataset) + val data = lpData.map(lp => LabelConverter.encodeLabeledPoint(lp, labels)) + val topology = FeedForwardTopology.multiLayerPerceptron(myLayers, true) + val FeedForwardTrainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last) + FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol($(tol)).setNumIterations($(maxIter)) + FeedForwardTrainer.setStackSize($(blockSize)) + val mlpModel = FeedForwardTrainer.train(data) + new MultilayerPerceptronClassifierModel(uid, myLayers, mlpModel.weights()) + } +} + +/** + * :: Experimental :: + * Classifier model based on the Multilayer Perceptron. + * Each layer has sigmoid activation function, output layer has softmax. + * @param uid uid + * @param layers array of layer sizes including input and output layers + * @param weights vector of initial weights for the model that consists of the weights of layers + * @return prediction model + */ +@Experimental +class MultilayerPerceptronClassifierModel private[ml] ( + override val uid: String, + layers: Array[Int], + weights: Vector) + extends PredictionModel[Vector, MultilayerPerceptronClassifierModel] + with Serializable { + + private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) + + /** + * Predict label for the given features. + * This internal method is used to implement [[transform()]] and output [[predictionCol]]. + */ + override protected def predict(features: Vector): Double = { + LabelConverter.decodeLabel(mlpModel.predict(features)) + } + + override def copy(extra: ParamMap): MultilayerPerceptronClassifierModel = { + copyValues(new MultilayerPerceptronClassifierModel(uid, layers, weights), extra) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 954aa17e26a02..d68f5ff0053c9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -166,6 +166,11 @@ object ParamValidators { def inArray[T](allowed: java.util.List[T]): T => Boolean = { (value: T) => allowed.contains(value) } + + /** Check that the array length is greater than lowerBound. */ + def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T]) => + value.length > lowerBound + } } // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index ab7611fd077ef..8f0d1e4aa010a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -32,7 +32,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector} * @param gradient Gradient function to be used. * @param updater Updater to be used to update weights after every iteration. */ -class GradientDescent private[mllib] (private var gradient: Gradient, private var updater: Updater) +class GradientDescent private[spark] (private var gradient: Gradient, private var updater: Updater) extends Optimizer with Logging { private var stepSize: Double = 1.0 diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala new file mode 100644 index 0000000000000..1292e57d7c01a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.ann + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + + +class ANNSuite extends SparkFunSuite with MLlibTestSparkContext { + + // TODO: test for weights comparison with Weka MLP + test("ANN with Sigmoid learns XOR function with LBFGS optimizer") { + val inputs = Array( + Array(0.0, 0.0), + Array(0.0, 1.0), + Array(1.0, 0.0), + Array(1.0, 1.0) + ) + val outputs = Array(0.0, 1.0, 1.0, 0.0) + val data = inputs.zip(outputs).map { case (features, label) => + (Vectors.dense(features), Vectors.dense(label)) + } + val rddData = sc.parallelize(data, 1) + val hiddenLayersTopology = Array(5) + val dataSample = rddData.first() + val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size + val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false) + val initialWeights = FeedForwardModel(topology, 23124).weights() + val trainer = new FeedForwardTrainer(topology, 2, 1) + trainer.setWeights(initialWeights) + trainer.LBFGSOptimizer.setNumIterations(20) + val model = trainer.train(rddData) + val predictionAndLabels = rddData.map { case (input, label) => + (model.predict(input)(0), label(0)) + }.collect() + predictionAndLabels.foreach { case (p, l) => + assert(math.round(p) === l) + } + } + + test("ANN with SoftMax learns XOR function with 2-bit output and batch GD optimizer") { + val inputs = Array( + Array(0.0, 0.0), + Array(0.0, 1.0), + Array(1.0, 0.0), + Array(1.0, 1.0) + ) + val outputs = Array( + Array(1.0, 0.0), + Array(0.0, 1.0), + Array(0.0, 1.0), + Array(1.0, 0.0) + ) + val data = inputs.zip(outputs).map { case (features, label) => + (Vectors.dense(features), Vectors.dense(label)) + } + val rddData = sc.parallelize(data, 1) + val hiddenLayersTopology = Array(5) + val dataSample = rddData.first() + val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size + val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false) + val initialWeights = FeedForwardModel(topology, 23124).weights() + val trainer = new FeedForwardTrainer(topology, 2, 2) + trainer.SGDOptimizer.setNumIterations(2000) + trainer.setWeights(initialWeights) + val model = trainer.train(rddData) + val predictionAndLabels = rddData.map { case (input, label) => + (model.predict(input), label) + }.collect() + predictionAndLabels.foreach { case (p, l) => + assert(p ~== l absTol 0.5) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala new file mode 100644 index 0000000000000..ddc948f65df45 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.classification.LogisticRegressionSuite._ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.Row + +class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("XOR function learning as binary classification problem with two outputs.") { + val dataFrame = sqlContext.createDataFrame(Seq( + (Vectors.dense(0.0, 0.0), 0.0), + (Vectors.dense(0.0, 1.0), 1.0), + (Vectors.dense(1.0, 0.0), 1.0), + (Vectors.dense(1.0, 1.0), 0.0)) + ).toDF("features", "label") + val layers = Array[Int](2, 5, 2) + val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(1) + .setSeed(11L) + .setMaxIter(100) + val model = trainer.fit(dataFrame) + val result = model.transform(dataFrame) + val predictionAndLabels = result.select("prediction", "label").collect() + predictionAndLabels.foreach { case Row(p: Double, l: Double) => + assert(p == l) + } + } + + // TODO: implement a more rigorous test + test("3 class classification with 2 hidden layers") { + val nPoints = 1000 + + // The following weights are taken from OneVsRestSuite.scala + // they represent 3-class iris dataset + val weights = Array( + -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, + -0.16624, -0.84355, -0.048509, -0.301789, 4.170682) + + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + val rdd = sc.parallelize(generateMultinomialLogisticInput( + weights, xMean, xVariance, true, nPoints, 42), 2) + val dataFrame = sqlContext.createDataFrame(rdd).toDF("label", "features") + val numClasses = 3 + val numIterations = 100 + val layers = Array[Int](4, 5, 4, numClasses) + val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(1) + .setSeed(11L) + .setMaxIter(numIterations) + val model = trainer.fit(dataFrame) + val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label") + .map { case Row(p: Double, l: Double) => (p, l) } + // train multinomial logistic regression + val lr = new LogisticRegressionWithLBFGS() + .setIntercept(true) + .setNumClasses(numClasses) + lr.optimizer.setRegParam(0.0) + .setNumIterations(numIterations) + val lrModel = lr.run(rdd) + val lrPredictionAndLabels = lrModel.predict(rdd.map(_.features)).zip(rdd.map(_.label)) + // MLP's predictions should not differ a lot from LR's. + val lrMetrics = new MulticlassMetrics(lrPredictionAndLabels) + val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels) + assert(mlpMetrics.confusionMatrix ~== lrMetrics.confusionMatrix absTol 100) + } +} From 4011a947154d97a9ffb5a71f077481a12534d36b Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Fri, 31 Jul 2015 11:50:15 -0700 Subject: [PATCH 035/340] [SPARK-9231] [MLLIB] DistributedLDAModel method for top topics per document jira: https://issues.apache.org/jira/browse/SPARK-9231 Helper method in DistributedLDAModel of this form: ``` /** * For each document, return the top k weighted topics for that document. * return RDD of (doc ID, topic indices, topic weights) */ def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] ``` Author: Yuhao Yang Closes #7785 from hhbyyh/topTopicsPerdoc and squashes the following commits: 30ad153 [Yuhao Yang] small fix fd24580 [Yuhao Yang] add topTopics per document to DistributedLDAModel --- .../spark/mllib/clustering/LDAModel.scala | 19 ++++++++++++++++++- .../spark/mllib/clustering/LDASuite.scala | 13 ++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 6cfad3fbbdb87..82281a0daf008 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argtopk, normalize, sum} import breeze.numerics.{exp, lgamma} import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats @@ -591,6 +591,23 @@ class DistributedLDAModel private[clustering] ( JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]]) } + /** + * For each document, return the top k weighted topics for that document and their weights. + * @return RDD of (doc ID, topic indices, topic weights) + */ + def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] = { + graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) => + val topIndices = argtopk(topicCounts, k) + val sumCounts = sum(topicCounts) + val weights = if (sumCounts != 0) { + topicCounts(topIndices) / sumCounts + } else { + topicCounts(topIndices) + } + (docID.toLong, topIndices.toArray, weights.toArray) + } + } + // TODO: // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index c43e1e575c09c..695ee3b82efc5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, max, argmax} +import breeze.linalg.{DenseMatrix => BDM, argtopk, max, argmax} import org.apache.spark.SparkFunSuite import org.apache.spark.graphx.Edge @@ -108,6 +108,17 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(topicDistribution.toArray.sum ~== 1.0 absTol 1e-5) } + val top2TopicsPerDoc = model.topTopicsPerDocument(2).map(t => (t._1, (t._2, t._3))) + model.topicDistributions.join(top2TopicsPerDoc).collect().foreach { + case (docId, (topicDistribution, (indices, weights))) => + assert(indices.length == 2) + assert(weights.length == 2) + val bdvTopicDist = topicDistribution.toBreeze + val top2Indices = argtopk(bdvTopicDist, 2) + assert(top2Indices.toArray === indices) + assert(bdvTopicDist(top2Indices).toArray === weights) + } + // Check: log probabilities assert(model.logLikelihood < 0.0) assert(model.logPrior < 0.0) From e8bdcdeabb2df139a656f86686cdb53c891b1f4b Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 31 Jul 2015 11:56:52 -0700 Subject: [PATCH 036/340] [SPARK-6885] [ML] decision tree support predict class probabilities Decision tree support predict class probabilities. Implement the prediction probabilities function referred the old DecisionTree API and the [sklean API](https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/tree.py#L593). I make the DecisionTreeClassificationModel inherit from ProbabilisticClassificationModel, make the predictRaw to return the raw counts vector and make raw2probabilityInPlace/predictProbability return the probabilities for each prediction. Author: Yanbo Liang Closes #7694 from yanboliang/spark-6885 and squashes the following commits: 08d5b7f [Yanbo Liang] fix ImpurityStats null parameters and raw2probabilityInPlace sum = 0 issue 2174278 [Yanbo Liang] solve merge conflicts 7e90ba8 [Yanbo Liang] fix typos 33ae183 [Yanbo Liang] fix annotation ff043d3 [Yanbo Liang] raw2probabilityInPlace should operate in-place c32d6ce [Yanbo Liang] optimize calculateImpurityStats function again 6167fb0 [Yanbo Liang] optimize calculateImpurityStats function fbbe2ec [Yanbo Liang] eliminate duplicated struct and code beb1634 [Yanbo Liang] try to eliminate impurityStats for each LearningNode 99e8943 [Yanbo Liang] code optimization 5ec3323 [Yanbo Liang] implement InformationGainAndImpurityStats 227c91b [Yanbo Liang] refactor LearningNode to store ImpurityCalculator d746ffc [Yanbo Liang] decision tree support predict class probabilities --- .../DecisionTreeClassifier.scala | 40 ++++-- .../ml/classification/GBTClassifier.scala | 2 +- .../RandomForestClassifier.scala | 2 +- .../ml/regression/DecisionTreeRegressor.scala | 2 +- .../spark/ml/regression/GBTRegressor.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 2 +- .../scala/org/apache/spark/ml/tree/Node.scala | 80 ++++++----- .../spark/ml/tree/impl/RandomForest.scala | 126 ++++++++---------- .../spark/mllib/tree/impurity/Entropy.scala | 2 +- .../spark/mllib/tree/impurity/Gini.scala | 2 +- .../spark/mllib/tree/impurity/Impurity.scala | 2 +- .../spark/mllib/tree/impurity/Variance.scala | 2 +- .../tree/model/InformationGainStats.scala | 61 ++++++++- .../DecisionTreeClassifierSuite.scala | 30 ++++- .../classification/GBTClassifierSuite.scala | 2 +- .../RandomForestClassifierSuite.scala | 2 +- 16 files changed, 229 insertions(+), 130 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 36fe1bd40469c..f27cfd0331419 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -18,12 +18,11 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util.{Identifiable, MetadataUtils} -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} @@ -39,7 +38,7 @@ import org.apache.spark.sql.DataFrame */ @Experimental final class DecisionTreeClassifier(override val uid: String) - extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] + extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] with DecisionTreeParams with TreeClassifierParams { def this() = this(Identifiable.randomUID("dtc")) @@ -106,8 +105,9 @@ object DecisionTreeClassifier { @Experimental final class DecisionTreeClassificationModel private[ml] ( override val uid: String, - override val rootNode: Node) - extends PredictionModel[Vector, DecisionTreeClassificationModel] + override val rootNode: Node, + override val numClasses: Int) + extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] with DecisionTreeModel with Serializable { require(rootNode != null, @@ -117,14 +117,36 @@ final class DecisionTreeClassificationModel private[ml] ( * Construct a decision tree classification model. * @param rootNode Root node of tree, with other nodes attached. */ - def this(rootNode: Node) = this(Identifiable.randomUID("dtc"), rootNode) + def this(rootNode: Node, numClasses: Int) = + this(Identifiable.randomUID("dtc"), rootNode, numClasses) override protected def predict(features: Vector): Double = { - rootNode.predict(features) + rootNode.predictImpl(features).prediction + } + + override protected def predictRaw(features: Vector): Vector = { + Vectors.dense(rootNode.predictImpl(features).impurityStats.stats.clone()) + } + + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction match { + case dv: DenseVector => + var i = 0 + val size = dv.size + val sum = dv.values.sum + while (i < size) { + dv.values(i) = if (sum != 0) dv.values(i) / sum else 0.0 + i += 1 + } + dv + case sv: SparseVector => + throw new RuntimeException("Unexpected error in DecisionTreeClassificationModel:" + + " raw2probabilityInPlace encountered SparseVector") + } } override def copy(extra: ParamMap): DecisionTreeClassificationModel = { - copyValues(new DecisionTreeClassificationModel(uid, rootNode), extra) + copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra) } override def toString: String = { @@ -149,6 +171,6 @@ private[ml] object DecisionTreeClassificationModel { s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc") - new DecisionTreeClassificationModel(uid, rootNode) + new DecisionTreeClassificationModel(uid, rootNode, -1) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index eb0b1a0a405fc..c3891a9599262 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -190,7 +190,7 @@ final class GBTClassificationModel( override protected def predict(features: Vector): Double = { // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions - val treePredictions = _trees.map(_.rootNode.predict(features)) + val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) if (prediction > 0.0) 1.0 else 0.0 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index bc19bd6df894f..0c7eb4a662fdb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -160,7 +160,7 @@ final class RandomForestClassificationModel private[ml] ( // Ignore the weights since all are 1.0 for now. val votes = new Array[Double](numClasses) _trees.view.foreach { tree => - val prediction = tree.rootNode.predict(features).toInt + val prediction = tree.rootNode.predictImpl(features).prediction.toInt votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight } Vectors.dense(votes) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 6f3340c2f02be..4d30e4b5548aa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -110,7 +110,7 @@ final class DecisionTreeRegressionModel private[ml] ( def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode) override protected def predict(features: Vector): Double = { - rootNode.predict(features) + rootNode.predictImpl(features).prediction } override def copy(extra: ParamMap): DecisionTreeRegressionModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index e38dc73ee0ba7..5633bc320273a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -180,7 +180,7 @@ final class GBTRegressionModel( override protected def predict(features: Vector): Double = { // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions - val treePredictions = _trees.map(_.rootNode.predict(features)) + val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 506a878c2553b..17fb1ad5e15d4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -143,7 +143,7 @@ final class RandomForestRegressionModel private[ml] ( // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128 // Predict average of tree predictions. // Ignore the weights since all are 1.0 for now. - _trees.map(_.rootNode.predict(features)).sum / numTrees + _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees } override def copy(extra: ParamMap): RandomForestRegressionModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index bbc2427ca7d3d..8879352a600a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -19,8 +19,9 @@ package org.apache.spark.ml.tree import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats, - Node => OldNode, Predict => OldPredict} + Node => OldNode, Predict => OldPredict, ImpurityStats} /** * :: DeveloperApi :: @@ -38,8 +39,15 @@ sealed abstract class Node extends Serializable { /** Impurity measure at this node (for training data) */ def impurity: Double + /** + * Statistics aggregated from training data at this node, used to compute prediction, impurity, + * and probabilities. + * For classification, the array of class counts must be normalized to a probability distribution. + */ + private[tree] def impurityStats: ImpurityCalculator + /** Recursive prediction helper method */ - private[ml] def predict(features: Vector): Double = prediction + private[ml] def predictImpl(features: Vector): LeafNode /** * Get the number of nodes in tree below this node, including leaf nodes. @@ -75,7 +83,8 @@ private[ml] object Node { if (oldNode.isLeaf) { // TODO: Once the implementation has been moved to this API, then include sufficient // statistics here. - new LeafNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity) + new LeafNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, impurityStats = null) } else { val gain = if (oldNode.stats.nonEmpty) { oldNode.stats.get.gain @@ -85,7 +94,7 @@ private[ml] object Node { new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity, gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures), rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures), - split = Split.fromOld(oldNode.split.get, categoricalFeatures)) + split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) } } } @@ -99,11 +108,13 @@ private[ml] object Node { @DeveloperApi final class LeafNode private[ml] ( override val prediction: Double, - override val impurity: Double) extends Node { + override val impurity: Double, + override val impurityStats: ImpurityCalculator) extends Node { - override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)" + override def toString: String = + s"LeafNode(prediction = $prediction, impurity = $impurity)" - override private[ml] def predict(features: Vector): Double = prediction + override private[ml] def predictImpl(features: Vector): LeafNode = this override private[tree] def numDescendants: Int = 0 @@ -115,9 +126,8 @@ final class LeafNode private[ml] ( override private[tree] def subtreeDepth: Int = 0 override private[ml] def toOld(id: Int): OldNode = { - // NOTE: We do NOT store 'prob' in the new API currently. - new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = true, - None, None, None, None) + new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), + impurity, isLeaf = true, None, None, None, None) } } @@ -139,17 +149,18 @@ final class InternalNode private[ml] ( val gain: Double, val leftChild: Node, val rightChild: Node, - val split: Split) extends Node { + val split: Split, + override val impurityStats: ImpurityCalculator) extends Node { override def toString: String = { s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" } - override private[ml] def predict(features: Vector): Double = { + override private[ml] def predictImpl(features: Vector): LeafNode = { if (split.shouldGoLeft(features)) { - leftChild.predict(features) + leftChild.predictImpl(features) } else { - rightChild.predict(features) + rightChild.predictImpl(features) } } @@ -172,9 +183,8 @@ final class InternalNode private[ml] ( override private[ml] def toOld(id: Int): OldNode = { assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API" + " since the old API does not support deep trees.") - // NOTE: We do NOT store 'prob' in the new API currently. - new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = false, - Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))), + new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), impurity, + isLeaf = false, Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))), Some(rightChild.toOld(OldNode.rightChildIndex(id))), Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity, new OldPredict(leftChild.prediction, prob = 0.0), @@ -223,36 +233,36 @@ private object InternalNode { * * @param id We currently use the same indexing as the old implementation in * [[org.apache.spark.mllib.tree.model.Node]], but this will change later. - * @param predictionStats Predicted label + class probability (for classification). - * We will later modify this to store aggregate statistics for labels - * to provide all class probabilities (for classification) and maybe a - * distribution (for regression). * @param isLeaf Indicates whether this node will definitely be a leaf in the learned tree, * so that we do not need to consider splitting it further. - * @param stats Old structure for storing stats about information gain, prediction, etc. - * This is legacy and will be modified in the future. + * @param stats Impurity statistics for this node. */ private[tree] class LearningNode( var id: Int, - var predictionStats: OldPredict, - var impurity: Double, var leftChild: Option[LearningNode], var rightChild: Option[LearningNode], var split: Option[Split], var isLeaf: Boolean, - var stats: Option[OldInformationGainStats]) extends Serializable { + var stats: ImpurityStats) extends Serializable { /** * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children. */ def toNode: Node = { if (leftChild.nonEmpty) { - assert(rightChild.nonEmpty && split.nonEmpty && stats.nonEmpty, + assert(rightChild.nonEmpty && split.nonEmpty && stats != null, "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") - new InternalNode(predictionStats.predict, impurity, stats.get.gain, - leftChild.get.toNode, rightChild.get.toNode, split.get) + new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, + leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator) } else { - new LeafNode(predictionStats.predict, impurity) + if (stats.valid) { + new LeafNode(stats.impurityCalculator.predict, stats.impurity, + stats.impurityCalculator) + } else { + // Here we want to keep same behavior with the old mllib.DecisionTreeModel + new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator) + } + } } @@ -263,16 +273,14 @@ private[tree] object LearningNode { /** Create a node with some of its fields set. */ def apply( id: Int, - predictionStats: OldPredict, - impurity: Double, - isLeaf: Boolean): LearningNode = { - new LearningNode(id, predictionStats, impurity, None, None, None, false, None) + isLeaf: Boolean, + stats: ImpurityStats): LearningNode = { + new LearningNode(id, None, None, None, false, stats) } /** Create an empty node with the given node index. Values must be set later on. */ def emptyNode(nodeIndex: Int): LearningNode = { - new LearningNode(nodeIndex, new OldPredict(Double.NaN, Double.NaN), Double.NaN, - None, None, None, false, None) + new LearningNode(nodeIndex, None, None, None, false, null) } // The below indexing methods were copied from spark.mllib.tree.model.Node diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 15b56bd844bad..a8b90d9d266a1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -31,7 +31,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => O import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata, TimeTracker} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator -import org.apache.spark.mllib.tree.model.{InformationGainStats, Predict} +import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} @@ -180,13 +180,17 @@ private[ml] object RandomForest extends Logging { parentUID match { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { - topNodes.map(rootNode => new DecisionTreeClassificationModel(uid, rootNode.toNode)) + topNodes.map { rootNode => + new DecisionTreeClassificationModel(uid, rootNode.toNode, strategy.getNumClasses) + } } else { topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode)) } case None => if (strategy.algo == OldAlgo.Classification) { - topNodes.map(rootNode => new DecisionTreeClassificationModel(rootNode.toNode)) + topNodes.map { rootNode => + new DecisionTreeClassificationModel(rootNode.toNode, strategy.getNumClasses) + } } else { topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode)) } @@ -549,9 +553,9 @@ private[ml] object RandomForest extends Logging { } // find best split for each node - val (split: Split, stats: InformationGainStats, predict: Predict) = + val (split: Split, stats: ImpurityStats) = binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) - (nodeIndex, (split, stats, predict)) + (nodeIndex, (split, stats)) }.collectAsMap() timer.stop("chooseSplits") @@ -568,17 +572,15 @@ private[ml] object RandomForest extends Logging { val nodeIndex = node.id val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) val aggNodeIndex = nodeInfo.nodeIndexInGroup - val (split: Split, stats: InformationGainStats, predict: Predict) = + val (split: Split, stats: ImpurityStats) = nodeToBestSplits(aggNodeIndex) logDebug("best split = " + split) // Extract info for this node. Create children if not leaf. val isLeaf = (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth) - node.predictionStats = predict node.isLeaf = isLeaf - node.stats = Some(stats) - node.impurity = stats.impurity + node.stats = stats logDebug("Node = " + node) if (!isLeaf) { @@ -587,9 +589,9 @@ private[ml] object RandomForest extends Logging { val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0) val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0) node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex), - stats.leftPredict, stats.leftImpurity, leftChildIsLeaf)) + leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator))) node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex), - stats.rightPredict, stats.rightImpurity, rightChildIsLeaf)) + rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator))) if (nodeIdCache.nonEmpty) { val nodeIndexUpdater = NodeIndexUpdater( @@ -621,28 +623,44 @@ private[ml] object RandomForest extends Logging { } /** - * Calculate the information gain for a given (feature, split) based upon left/right aggregates. + * Calculate the impurity statistics for a give (feature, split) based upon left/right aggregates. + * @param stats the recycle impurity statistics for this feature's all splits, + * only 'impurity' and 'impurityCalculator' are valid between each iteration * @param leftImpurityCalculator left node aggregates for this (feature, split) * @param rightImpurityCalculator right node aggregate for this (feature, split) - * @return information gain and statistics for split + * @param metadata learning and dataset metadata for DecisionTree + * @return Impurity statistics for this (feature, split) */ - private def calculateGainForSplit( + private def calculateImpurityStats( + stats: ImpurityStats, leftImpurityCalculator: ImpurityCalculator, rightImpurityCalculator: ImpurityCalculator, - metadata: DecisionTreeMetadata, - impurity: Double): InformationGainStats = { + metadata: DecisionTreeMetadata): ImpurityStats = { + + val parentImpurityCalculator: ImpurityCalculator = if (stats == null) { + leftImpurityCalculator.copy.add(rightImpurityCalculator) + } else { + stats.impurityCalculator + } + + val impurity: Double = if (stats == null) { + parentImpurityCalculator.calculate() + } else { + stats.impurity + } + val leftCount = leftImpurityCalculator.count val rightCount = rightImpurityCalculator.count + val totalCount = leftCount + rightCount + // If left child or right child doesn't satisfy minimum instances per node, // then this split is invalid, return invalid information gain stats. if ((leftCount < metadata.minInstancesPerNode) || (rightCount < metadata.minInstancesPerNode)) { - return InformationGainStats.invalidInformationGainStats + return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) } - val totalCount = leftCount + rightCount - val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 val rightImpurity = rightImpurityCalculator.calculate() @@ -654,39 +672,11 @@ private[ml] object RandomForest extends Logging { // if information gain doesn't satisfy minimum information gain, // then this split is invalid, return invalid information gain stats. if (gain < metadata.minInfoGain) { - return InformationGainStats.invalidInformationGainStats + return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) } - // calculate left and right predict - val leftPredict = calculatePredict(leftImpurityCalculator) - val rightPredict = calculatePredict(rightImpurityCalculator) - - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, - leftPredict, rightPredict) - } - - private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = { - val predict = impurityCalculator.predict - val prob = impurityCalculator.prob(predict) - new Predict(predict, prob) - } - - /** - * Calculate predict value for current node, given stats of any split. - * Note that this function is called only once for each node. - * @param leftImpurityCalculator left node aggregates for a split - * @param rightImpurityCalculator right node aggregates for a split - * @return predict value and impurity for current node - */ - private def calculatePredictImpurity( - leftImpurityCalculator: ImpurityCalculator, - rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { - val parentNodeAgg = leftImpurityCalculator.copy - parentNodeAgg.add(rightImpurityCalculator) - val predict = calculatePredict(parentNodeAgg) - val impurity = parentNodeAgg.calculate() - - (predict, impurity) + new ImpurityStats(gain, impurity, parentImpurityCalculator, + leftImpurityCalculator, rightImpurityCalculator) } /** @@ -698,14 +688,14 @@ private[ml] object RandomForest extends Logging { binAggregates: DTStatsAggregator, splits: Array[Array[Split]], featuresForNode: Option[Array[Int]], - node: LearningNode): (Split, InformationGainStats, Predict) = { + node: LearningNode): (Split, ImpurityStats) = { - // Calculate prediction and impurity if current node is top node + // Calculate InformationGain and ImpurityStats if current node is top node val level = LearningNode.indexToLevel(node.id) - var predictionAndImpurity: Option[(Predict, Double)] = if (level == 0) { - None + var gainAndImpurityStats: ImpurityStats = if (level ==0) { + null } else { - Some((node.predictionStats, node.impurity)) + node.stats } // For each (feature, split), calculate the gain, and select the best (feature, split). @@ -734,11 +724,9 @@ private[ml] object RandomForest extends Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) - predictionAndImpurity = Some(predictionAndImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2) - (splitIdx, gainStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIdx, gainAndImpurityStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (binAggregates.metadata.isUnordered(featureIndex)) { @@ -750,11 +738,9 @@ private[ml] object RandomForest extends Logging { val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) - predictionAndImpurity = Some(predictionAndImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2) - (splitIndex, gainStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { @@ -825,11 +811,9 @@ private[ml] object RandomForest extends Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) rightChildStats.subtract(leftChildStats) - predictionAndImpurity = Some(predictionAndImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2) - (splitIndex, gainStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) }.maxBy(_._2.gain) val categoriesForSplit = categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) @@ -839,7 +823,7 @@ private[ml] object RandomForest extends Logging { } }.maxBy(_._2.gain) - (bestSplit, bestSplitStats, predictionAndImpurity.get._1) + (bestSplit, bestSplitStats) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 5ac10f3fd32dd..0768204c33914 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -118,7 +118,7 @@ private[tree] class EntropyAggregator(numClasses: Int) * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { +private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { /** * Make a deep copy of this [[ImpurityCalculator]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 19d318203c344..d0077db6832e3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -114,7 +114,7 @@ private[tree] class GiniAggregator(numClasses: Int) * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { +private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { /** * Make a deep copy of this [[ImpurityCalculator]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 578749d85a4e6..86cee7e430b0a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -95,7 +95,7 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) { +private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) extends Serializable { /** * Make a deep copy of this [[ImpurityCalculator]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 7104a7fa4dd4c..04d0cd24e6632 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -98,7 +98,7 @@ private[tree] class VarianceAggregator() * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[tree] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { +private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { require(stats.size == 3, s"VarianceCalculator requires sufficient statistics array stats to be of length 3," + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index dc9e0f9f51ffb..508bf9c1bdb47 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator /** * :: DeveloperApi :: @@ -66,7 +67,6 @@ class InformationGainStats( } } - private[spark] object InformationGainStats { /** * An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to @@ -76,3 +76,62 @@ private[spark] object InformationGainStats { val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, new Predict(0.0, 0.0), new Predict(0.0, 0.0)) } + +/** + * :: DeveloperApi :: + * Impurity statistics for each split + * @param gain information gain value + * @param impurity current node impurity + * @param impurityCalculator impurity statistics for current node + * @param leftImpurityCalculator impurity statistics for left child node + * @param rightImpurityCalculator impurity statistics for right child node + * @param valid whether the current split satisfies minimum info gain or + * minimum number of instances per node + */ +@DeveloperApi +private[spark] class ImpurityStats( + val gain: Double, + val impurity: Double, + val impurityCalculator: ImpurityCalculator, + val leftImpurityCalculator: ImpurityCalculator, + val rightImpurityCalculator: ImpurityCalculator, + val valid: Boolean = true) extends Serializable { + + override def toString: String = { + s"gain = $gain, impurity = $impurity, left impurity = $leftImpurity, " + + s"right impurity = $rightImpurity" + } + + def leftImpurity: Double = if (leftImpurityCalculator != null) { + leftImpurityCalculator.calculate() + } else { + -1.0 + } + + def rightImpurity: Double = if (rightImpurityCalculator != null) { + rightImpurityCalculator.calculate() + } else { + -1.0 + } +} + +private[spark] object ImpurityStats { + + /** + * Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object to + * denote that current split doesn't satisfies minimum info gain or + * minimum number of instances per node. + */ + def getInvalidImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = { + new ImpurityStats(Double.MinValue, impurityCalculator.calculate(), + impurityCalculator, null, null, false) + } + + /** + * Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object + * that only 'impurity' and 'impurityCalculator' are defined. + */ + def getEmptyImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = { + new ImpurityStats(Double.NaN, impurityCalculator.calculate(), impurityCalculator, null, null) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 73b4805c4c597..c7bbf1ce07a23 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -21,12 +21,13 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -57,7 +58,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte test("params") { ParamsSuite.checkParams(new DecisionTreeClassifier) - val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0)) + val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2) ParamsSuite.checkParams(model) } @@ -231,6 +232,31 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) } + test("predictRaw and predictProbability") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3) + val numClasses = 3 + + val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val newTree = dt.fit(newData) + + val predictions = newTree.transform(newData) + .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol) + .collect() + + predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) => + assert(pred === rawPred.argmax, + s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") + val sum = rawPred.toArray.sum + assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred, + "probability prediction mismatch") + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index a7bc77965fefd..d4b5896c12c06 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -58,7 +58,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { test("params") { ParamsSuite.checkParams(new GBTClassifier) val model = new GBTClassificationModel("gbtc", - Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0))), + Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null))), Array(1.0)) ParamsSuite.checkParams(model) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index ab711c8e4b215..dbb2577c6204d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -66,7 +66,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte test("params") { ParamsSuite.checkParams(new RandomForestClassifier) val model = new RandomForestClassificationModel("rfc", - Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))), 2) + Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2) ParamsSuite.checkParams(model) } From 0a1d2ca42c8b31d6b0e70163795f0185d4622f87 Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Fri, 31 Jul 2015 12:04:03 -0700 Subject: [PATCH 037/340] [SPARK-8979] Add a PID based rate estimator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Based on #7600 /cc tdas Author: Iulian Dragos Author: François Garillot Closes #7648 from dragos/topic/streaming-bp/pid and squashes the following commits: aa5b097 [Iulian Dragos] Add more comments, made all PID constant parameters positive, a couple more tests. 93b74f8 [Iulian Dragos] Better explanation of historicalError. 7975b0c [Iulian Dragos] Add configuration for PID. 26cfd78 [Iulian Dragos] A couple of variable renames. d0bdf7c [Iulian Dragos] Update to latest version of the code, various style and name improvements. d58b845 [François Garillot] [SPARK-8979][Streaming] Implements a PIDRateEstimator --- .../dstream/ReceiverInputDStream.scala | 2 +- .../scheduler/rate/PIDRateEstimator.scala | 124 ++++++++++++++++ .../scheduler/rate/RateEstimator.scala | 18 ++- .../rate/PIDRateEstimatorSuite.scala | 137 ++++++++++++++++++ 4 files changed, 276 insertions(+), 5 deletions(-) create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 646a8c3530a62..670ef8d296a0b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -46,7 +46,7 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont */ override protected[streaming] val rateController: Option[RateController] = { if (RateController.isBackPressureEnabled(ssc.conf)) { - RateEstimator.create(ssc.conf).map { new ReceiverRateController(id, _) } + Some(new ReceiverRateController(id, RateEstimator.create(ssc.conf, ssc.graph.batchDuration))) } else { None } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala new file mode 100644 index 0000000000000..6ae56a68ad88c --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler.rate + +/** + * Implements a proportional-integral-derivative (PID) controller which acts on + * the speed of ingestion of elements into Spark Streaming. A PID controller works + * by calculating an '''error''' between a measured output and a desired value. In the + * case of Spark Streaming the error is the difference between the measured processing + * rate (number of elements/processing delay) and the previous rate. + * + * @see https://en.wikipedia.org/wiki/PID_controller + * + * @param batchDurationMillis the batch duration, in milliseconds + * @param proportional how much the correction should depend on the current + * error. This term usually provides the bulk of correction and should be positive or zero. + * A value too large would make the controller overshoot the setpoint, while a small value + * would make the controller too insensitive. The default value is 1. + * @param integral how much the correction should depend on the accumulation + * of past errors. This value should be positive or 0. This term accelerates the movement + * towards the desired value, but a large value may lead to overshooting. The default value + * is 0.2. + * @param derivative how much the correction should depend on a prediction + * of future errors, based on current rate of change. This value should be positive or 0. + * This term is not used very often, as it impacts stability of the system. The default + * value is 0. + */ +private[streaming] class PIDRateEstimator( + batchIntervalMillis: Long, + proportional: Double = 1D, + integral: Double = .2D, + derivative: Double = 0D) + extends RateEstimator { + + private var firstRun: Boolean = true + private var latestTime: Long = -1L + private var latestRate: Double = -1D + private var latestError: Double = -1L + + require( + batchIntervalMillis > 0, + s"Specified batch interval $batchIntervalMillis in PIDRateEstimator is invalid.") + require( + proportional >= 0, + s"Proportional term $proportional in PIDRateEstimator should be >= 0.") + require( + integral >= 0, + s"Integral term $integral in PIDRateEstimator should be >= 0.") + require( + derivative >= 0, + s"Derivative term $derivative in PIDRateEstimator should be >= 0.") + + + def compute(time: Long, // in milliseconds + numElements: Long, + processingDelay: Long, // in milliseconds + schedulingDelay: Long // in milliseconds + ): Option[Double] = { + + this.synchronized { + if (time > latestTime && processingDelay > 0 && batchIntervalMillis > 0) { + + // in seconds, should be close to batchDuration + val delaySinceUpdate = (time - latestTime).toDouble / 1000 + + // in elements/second + val processingRate = numElements.toDouble / processingDelay * 1000 + + // In our system `error` is the difference between the desired rate and the measured rate + // based on the latest batch information. We consider the desired rate to be latest rate, + // which is what this estimator calculated for the previous batch. + // in elements/second + val error = latestRate - processingRate + + // The error integral, based on schedulingDelay as an indicator for accumulated errors. + // A scheduling delay s corresponds to s * processingRate overflowing elements. Those + // are elements that couldn't be processed in previous batches, leading to this delay. + // In the following, we assume the processingRate didn't change too much. + // From the number of overflowing elements we can calculate the rate at which they would be + // processed by dividing it by the batch interval. This rate is our "historical" error, + // or integral part, since if we subtracted this rate from the previous "calculated rate", + // there wouldn't have been any overflowing elements, and the scheduling delay would have + // been zero. + // (in elements/second) + val historicalError = schedulingDelay.toDouble * processingRate / batchIntervalMillis + + // in elements/(second ^ 2) + val dError = (error - latestError) / delaySinceUpdate + + val newRate = (latestRate - proportional * error - + integral * historicalError - + derivative * dError).max(0.0) + latestTime = time + if (firstRun) { + latestRate = processingRate + latestError = 0D + firstRun = false + + None + } else { + latestRate = newRate + latestError = error + + Some(newRate) + } + } else None + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala index a08685119e5d5..17ccebc1ed41b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala @@ -19,6 +19,7 @@ package org.apache.spark.streaming.scheduler.rate import org.apache.spark.SparkConf import org.apache.spark.SparkException +import org.apache.spark.streaming.Duration /** * A component that estimates the rate at wich an InputDStream should ingest @@ -48,12 +49,21 @@ object RateEstimator { /** * Return a new RateEstimator based on the value of `spark.streaming.RateEstimator`. * - * @return None if there is no configured estimator, otherwise an instance of RateEstimator + * The only known estimator right now is `pid`. + * + * @return An instance of RateEstimator * @throws IllegalArgumentException if there is a configured RateEstimator that doesn't match any * known estimators. */ - def create(conf: SparkConf): Option[RateEstimator] = - conf.getOption("spark.streaming.backpressure.rateEstimator").map { estimator => - throw new IllegalArgumentException(s"Unkown rate estimator: $estimator") + def create(conf: SparkConf, batchInterval: Duration): RateEstimator = + conf.get("spark.streaming.backpressure.rateEstimator", "pid") match { + case "pid" => + val proportional = conf.getDouble("spark.streaming.backpressure.pid.proportional", 1.0) + val integral = conf.getDouble("spark.streaming.backpressure.pid.integral", 0.2) + val derived = conf.getDouble("spark.streaming.backpressure.pid.derived", 0.0) + new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived) + + case estimator => + throw new IllegalArgumentException(s"Unkown rate estimator: $estimator") } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala new file mode 100644 index 0000000000000..97c32d8f2d59e --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler.rate + +import scala.util.Random + +import org.scalatest.Inspectors.forAll +import org.scalatest.Matchers + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.streaming.Seconds + +class PIDRateEstimatorSuite extends SparkFunSuite with Matchers { + + test("the right estimator is created") { + val conf = new SparkConf + conf.set("spark.streaming.backpressure.rateEstimator", "pid") + val pid = RateEstimator.create(conf, Seconds(1)) + pid.getClass should equal(classOf[PIDRateEstimator]) + } + + test("estimator checks ranges") { + intercept[IllegalArgumentException] { + new PIDRateEstimator(0, 1, 2, 3) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, -1, 2, 3) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, 0, -1, 3) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, 0, 0, -1) + } + } + + private def createDefaultEstimator: PIDRateEstimator = { + new PIDRateEstimator(20, 1D, 0D, 0D) + } + + test("first bound is None") { + val p = createDefaultEstimator + p.compute(0, 10, 10, 0) should equal(None) + } + + test("second bound is rate") { + val p = createDefaultEstimator + p.compute(0, 10, 10, 0) + // 1000 elements / s + p.compute(10, 10, 10, 0) should equal(Some(1000)) + } + + test("works even with no time between updates") { + val p = createDefaultEstimator + p.compute(0, 10, 10, 0) + p.compute(10, 10, 10, 0) + p.compute(10, 10, 10, 0) should equal(None) + } + + test("bound is never negative") { + val p = new PIDRateEstimator(20, 1D, 1D, 0D) + // prepare a series of batch updates, one every 20ms, 0 processed elements, 2ms of processing + // this might point the estimator to try and decrease the bound, but we test it never + // goes below zero, which would be nonsensical. + val times = List.tabulate(50)(x => x * 20) // every 20ms + val elements = List.fill(50)(0) // no processing + val proc = List.fill(50)(20) // 20ms of processing + val sched = List.fill(50)(100) // strictly positive accumulation + val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) + res.head should equal(None) + res.tail should equal(List.fill(49)(Some(0D))) + } + + test("with no accumulated or positive error, |I| > 0, follow the processing speed") { + val p = new PIDRateEstimator(20, 1D, 1D, 0D) + // prepare a series of batch updates, one every 20ms with an increasing number of processed + // elements in each batch, but constant processing time, and no accumulated error. Even though + // the integral part is non-zero, the estimated rate should follow only the proportional term + val times = List.tabulate(50)(x => x * 20) // every 20ms + val elements = List.tabulate(50)(x => x * 20) // increasing + val proc = List.fill(50)(20) // 20ms of processing + val sched = List.fill(50)(0) + val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) + res.head should equal(None) + res.tail should equal(List.tabulate(50)(x => Some(x * 1000D)).tail) + } + + test("with no accumulated but some positive error, |I| > 0, follow the processing speed") { + val p = new PIDRateEstimator(20, 1D, 1D, 0D) + // prepare a series of batch updates, one every 20ms with an decreasing number of processed + // elements in each batch, but constant processing time, and no accumulated error. Even though + // the integral part is non-zero, the estimated rate should follow only the proportional term, + // asking for less and less elements + val times = List.tabulate(50)(x => x * 20) // every 20ms + val elements = List.tabulate(50)(x => (50 - x) * 20) // decreasing + val proc = List.fill(50)(20) // 20ms of processing + val sched = List.fill(50)(0) + val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) + res.head should equal(None) + res.tail should equal(List.tabulate(50)(x => Some((50 - x) * 1000D)).tail) + } + + test("with some accumulated and some positive error, |I| > 0, stay below the processing speed") { + val p = new PIDRateEstimator(20, 1D, .01D, 0D) + val times = List.tabulate(50)(x => x * 20) // every 20ms + val rng = new Random() + val elements = List.tabulate(50)(x => rng.nextInt(1000)) + val procDelayMs = 20 + val proc = List.fill(50)(procDelayMs) // 20ms of processing + val sched = List.tabulate(50)(x => rng.nextInt(19)) // random wait + val speeds = elements map ((x) => x.toDouble / procDelayMs * 1000) + + val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) + res.head should equal(None) + forAll(List.range(1, 50)) { (n) => + res(n) should not be None + if (res(n).get > 0 && sched(n) > 0) { + res(n).get should be < speeds(n) + } + } + } +} From 39ab199a3f735b7658ab3331d3e2fb03441aec13 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 31 Jul 2015 12:07:18 -0700 Subject: [PATCH 038/340] [SPARK-8640] [SQL] Enable Processing of Multiple Window Frames in a Single Window Operator This PR enables the processing of multiple window frames in a single window operator. This should improve the performance of processing multiple window expressions wich share partition by/order by clauses, because it will be more efficient with respect to memory use and group processing. Author: Herman van Hovell Closes #7515 from hvanhovell/SPARK-8640 and squashes the following commits: f0e1c21 [Herman van Hovell] Changed Window Logical/Physical plans to use partition by/order by specs directly instead of using WindowSpec. e1711c2 [Herman van Hovell] Enabled the processing of multiple window frames in a single Window operator. --- .../sql/catalyst/analysis/Analyzer.scala | 12 +++++++----- .../plans/logical/basicOperators.scala | 3 ++- .../spark/sql/execution/SparkStrategies.scala | 5 +++-- .../apache/spark/sql/execution/Window.scala | 19 ++++++++++--------- .../sql/hive/execution/HivePlanTest.scala | 18 ++++++++++++++++++ 5 files changed, 40 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 265f3d1e41765..51d910b258647 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -347,7 +347,7 @@ class Analyzer( val newOutput = oldVersion.generatorOutput.map(_.newInstance()) (oldVersion, oldVersion.copy(generatorOutput = newOutput)) - case oldVersion @ Window(_, windowExpressions, _, child) + case oldVersion @ Window(_, windowExpressions, _, _, child) if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) .nonEmpty => (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) @@ -825,7 +825,7 @@ class Analyzer( }.asInstanceOf[NamedExpression] } - // Second, we group extractedWindowExprBuffer based on their Window Spec. + // Second, we group extractedWindowExprBuffer based on their Partition and Order Specs. val groupedWindowExpressions = extractedWindowExprBuffer.groupBy { expr => val distinctWindowSpec = expr.collect { case window: WindowExpression => window.windowSpec @@ -841,7 +841,8 @@ class Analyzer( failAnalysis(s"$expr has multiple Window Specifications ($distinctWindowSpec)." + s"Please file a bug report with this error message, stack trace, and the query.") } else { - distinctWindowSpec.head + val spec = distinctWindowSpec.head + (spec.partitionSpec, spec.orderSpec) } }.toSeq @@ -850,9 +851,10 @@ class Analyzer( var currentChild = child var i = 0 while (i < groupedWindowExpressions.size) { - val (windowSpec, windowExpressions) = groupedWindowExpressions(i) + val ((partitionSpec, orderSpec), windowExpressions) = groupedWindowExpressions(i) // Set currentChild to the newly created Window operator. - currentChild = Window(currentChild.output, windowExpressions, windowSpec, currentChild) + currentChild = Window(currentChild.output, windowExpressions, + partitionSpec, orderSpec, currentChild) // Move to next Window Spec. i += 1 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index a67f8de6b733a..aacfc86ab0e49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -228,7 +228,8 @@ case class Aggregate( case class Window( projectList: Seq[Attribute], windowExpressions: Seq[NamedExpression], - windowSpec: WindowSpecDefinition, + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 03d24a88d4ecd..4aff52d992e6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -389,8 +389,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil } } - case logical.Window(projectList, windowExpressions, spec, child) => - execution.Window(projectList, windowExpressions, spec, planLater(child)) :: Nil + case logical.Window(projectList, windowExprs, partitionSpec, orderSpec, child) => + execution.Window( + projectList, windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child) => execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 91c8a02e2b5bc..fe9f2c7028171 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -80,23 +80,24 @@ import scala.collection.mutable case class Window( projectList: Seq[Attribute], windowExpression: Seq[NamedExpression], - windowSpec: WindowSpecDefinition, + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList ++ windowExpression.map(_.toAttribute) override def requiredChildDistribution: Seq[Distribution] = { - if (windowSpec.partitionSpec.isEmpty) { + if (partitionSpec.isEmpty) { // Only show warning when the number of bytes is larger than 100 MB? logWarning("No Partition Defined for Window operation! Moving all data to a single " + "partition, this can cause serious performance degradation.") AllTuples :: Nil - } else ClusteredDistribution(windowSpec.partitionSpec) :: Nil + } else ClusteredDistribution(partitionSpec) :: Nil } override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(windowSpec.partitionSpec.map(SortOrder(_, Ascending)) ++ windowSpec.orderSpec) + Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) override def outputOrdering: Seq[SortOrder] = child.outputOrdering @@ -115,12 +116,12 @@ case class Window( case RangeFrame => val (exprs, current, bound) = if (offset == 0) { // Use the entire order expression when the offset is 0. - val exprs = windowSpec.orderSpec.map(_.child) + val exprs = orderSpec.map(_.child) val projection = newMutableProjection(exprs, child.output) - (windowSpec.orderSpec, projection(), projection()) - } else if (windowSpec.orderSpec.size == 1) { + (orderSpec, projection(), projection()) + } else if (orderSpec.size == 1) { // Use only the first order expression when the offset is non-null. - val sortExpr = windowSpec.orderSpec.head + val sortExpr = orderSpec.head val expr = sortExpr.child // Create the projection which returns the current 'value'. val current = newMutableProjection(expr :: Nil, child.output)() @@ -250,7 +251,7 @@ case class Window( // Get all relevant projections. val result = createResultProjection(unboundExpressions) - val grouping = newProjection(windowSpec.partitionSpec, child.output) + val grouping = newProjection(partitionSpec, child.output) // Manage the stream and the grouping. var nextRow: InternalRow = EmptyRow diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index bdb53ddf59c19..ba56a8a6b689c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.functions._ import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.hive.test.TestHive class HivePlanTest extends QueryTest { @@ -31,4 +34,19 @@ class HivePlanTest extends QueryTest { comparePlans(optimized, correctAnswer) } + + test("window expressions sharing the same partition by and order by clause") { + val df = Seq.empty[(Int, String, Int, Int)].toDF("id", "grp", "seq", "val") + val window = Window. + partitionBy($"grp"). + orderBy($"val") + val query = df.select( + $"id", + sum($"val").over(window.rowsBetween(-1, 1)), + sum($"val").over(window.rangeBetween(-1, 1)) + ) + val plan = query.queryExecution.analyzed + assert(plan.collect{ case w: logical.Window => w }.size === 1, + "Should have only 1 Window operator.") + } } From 3afc1de89cb4de9f8ea74003dd1e6b5b006d06f0 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 31 Jul 2015 12:09:48 -0700 Subject: [PATCH 039/340] [SPARK-8564] [STREAMING] Add the Python API for Kinesis This PR adds the Python API for Kinesis, including a Python example and a simple unit test. Author: zsxwing Closes #6955 from zsxwing/kinesis-python and squashes the following commits: e42e471 [zsxwing] Merge branch 'master' into kinesis-python 455f7ea [zsxwing] Remove streaming_kinesis_asl_assembly module and simply add the source folder to streaming_kinesis_asl module 32e6451 [zsxwing] Merge remote-tracking branch 'origin/master' into kinesis-python 5082d28 [zsxwing] Fix the syntax error for Python 2.6 fca416b [zsxwing] Fix wrong comparison 96670ff [zsxwing] Fix the compilation error after merging master 756a128 [zsxwing] Merge branch 'master' into kinesis-python 6c37395 [zsxwing] Print stack trace for debug 7c5cfb0 [zsxwing] RUN_KINESIS_TESTS -> ENABLE_KINESIS_TESTS cc9d071 [zsxwing] Fix the python test errors 466b425 [zsxwing] Add python tests for Kinesis e33d505 [zsxwing] Merge remote-tracking branch 'origin/master' into kinesis-python 3da2601 [zsxwing] Fix the kinesis folder 687446b [zsxwing] Fix the error message and the maven output path add2beb [zsxwing] Merge branch 'master' into kinesis-python 4957c0b [zsxwing] Add the Python API for Kinesis --- dev/run-tests.py | 3 +- dev/sparktestsupport/modules.py | 9 +- docs/streaming-kinesis-integration.md | 19 +++ extras/kinesis-asl-assembly/pom.xml | 103 ++++++++++++++++ .../streaming/kinesis_wordcount_asl.py | 81 +++++++++++++ .../streaming/kinesis/KinesisTestUtils.scala | 19 ++- .../streaming/kinesis/KinesisUtils.scala | 78 +++++++++--- pom.xml | 1 + project/SparkBuild.scala | 6 +- python/pyspark/streaming/kinesis.py | 112 ++++++++++++++++++ python/pyspark/streaming/tests.py | 86 +++++++++++++- 11 files changed, 492 insertions(+), 25 deletions(-) create mode 100644 extras/kinesis-asl-assembly/pom.xml create mode 100644 extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py create mode 100644 python/pyspark/streaming/kinesis.py diff --git a/dev/run-tests.py b/dev/run-tests.py index 29420da9aa956..b6d181418f027 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -301,7 +301,8 @@ def build_spark_sbt(hadoop_version): sbt_goals = ["package", "assembly/assembly", "streaming-kafka-assembly/assembly", - "streaming-flume-assembly/assembly"] + "streaming-flume-assembly/assembly", + "streaming-kinesis-asl-assembly/assembly"] profiles_and_goals = build_profiles + sbt_goals print("[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: ", diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 44600cb9523c1..956dc81b62e93 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -138,6 +138,7 @@ def contains_file(self, filename): dependencies=[], source_file_regexes=[ "extras/kinesis-asl/", + "extras/kinesis-asl-assembly/", ], build_profile_flags=[ "-Pkinesis-asl", @@ -300,7 +301,13 @@ def contains_file(self, filename): pyspark_streaming = Module( name="pyspark-streaming", - dependencies=[pyspark_core, streaming, streaming_kafka, streaming_flume_assembly], + dependencies=[ + pyspark_core, + streaming, + streaming_kafka, + streaming_flume_assembly, + streaming_kinesis_asl + ], source_file_regexes=[ "python/pyspark/streaming" ], diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index aa9749afbc867..a7bcaec6fcd84 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -51,6 +51,17 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m See the [API docs](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html) and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java). Refer to the next subsection for instructions to run the example. + +
+ from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream + + kinesisStream = KinesisUtils.createStream( + streamingContext, [Kinesis app name], [Kinesis stream name], [endpoint URL], + [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2) + + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kinesis.KinesisUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py). Refer to the next subsection for instructions to run the example. +
@@ -135,6 +146,14 @@ To run the example, bin/run-example streaming.JavaKinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL] + +
+ + bin/spark-submit --jars extras/kinesis-asl/target/scala-*/\ + spark-streaming-kinesis-asl-assembly_*.jar \ + extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py \ + [Kinesis app name] [Kinesis stream name] [endpoint URL] [region name] +
diff --git a/extras/kinesis-asl-assembly/pom.xml b/extras/kinesis-asl-assembly/pom.xml new file mode 100644 index 0000000000000..70d2c9c58f54e --- /dev/null +++ b/extras/kinesis-asl-assembly/pom.xml @@ -0,0 +1,103 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.5.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-kinesis-asl-assembly_2.10 + jar + Spark Project Kinesis Assembly + http://spark.apache.org/ + + + streaming-kinesis-asl-assembly + + + + + org.apache.spark + spark-streaming-kinesis-asl_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-kinesis-asl-assembly-${project.version}.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + + diff --git a/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py b/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py new file mode 100644 index 0000000000000..f428f64da3c42 --- /dev/null +++ b/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Consumes messages from a Amazon Kinesis streams and does wordcount. + + This example spins up 1 Kinesis Receiver per shard for the given stream. + It then starts pulling from the last checkpointed sequence number of the given stream. + + Usage: kinesis_wordcount_asl.py + is the name of the consumer app, used to track the read data in DynamoDB + name of the Kinesis stream (ie. mySparkStream) + endpoint of the Kinesis service + (e.g. https://kinesis.us-east-1.amazonaws.com) + + + Example: + # export AWS keys if necessary + $ export AWS_ACCESS_KEY_ID= + $ export AWS_SECRET_KEY= + + # run the example + $ bin/spark-submit -jar extras/kinesis-asl/target/scala-*/\ + spark-streaming-kinesis-asl-assembly_*.jar \ + extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py \ + myAppName mySparkStream https://kinesis.us-east-1.amazonaws.com + + There is a companion helper class called KinesisWordProducerASL which puts dummy data + onto the Kinesis stream. + + This code uses the DefaultAWSCredentialsProviderChain to find credentials + in the following order: + Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + Java System Properties - aws.accessKeyId and aws.secretKey + Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs + Instance profile credentials - delivered through the Amazon EC2 metadata service + For more information, see + http://docs.aws.amazon.com/AWSSdkDocsJava/latest/DeveloperGuide/credentials.html + + See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more details on + the Kinesis Spark Streaming integration. +""" +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream + +if __name__ == "__main__": + if len(sys.argv) != 5: + print( + "Usage: kinesis_wordcount_asl.py ", + file=sys.stderr) + sys.exit(-1) + + sc = SparkContext(appName="PythonStreamingKinesisWordCountAsl") + ssc = StreamingContext(sc, 1) + appName, streamName, endpointUrl, regionName = sys.argv[1:] + lines = KinesisUtils.createStream( + ssc, appName, streamName, endpointUrl, regionName, InitialPositionInStream.LATEST, 2) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index ca39358b75cb6..255ac27f793ba 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -36,9 +36,15 @@ import org.apache.spark.Logging /** * Shared utility methods for performing Kinesis tests that actually transfer data */ -private class KinesisTestUtils( - val endpointUrl: String = "https://kinesis.us-west-2.amazonaws.com", - _regionName: String = "") extends Logging { +private class KinesisTestUtils(val endpointUrl: String, _regionName: String) extends Logging { + + def this() { + this("https://kinesis.us-west-2.amazonaws.com", "") + } + + def this(endpointUrl: String) { + this(endpointUrl, "") + } val regionName = if (_regionName.length == 0) { RegionUtils.getRegionByEndpoint(endpointUrl).getName() @@ -117,6 +123,13 @@ private class KinesisTestUtils( shardIdToSeqNumbers.toMap } + /** + * Expose a Python friendly API. + */ + def pushData(testData: java.util.List[Int]): Unit = { + pushData(scala.collection.JavaConversions.asScalaBuffer(testData)) + } + def deleteStream(): Unit = { try { if (streamCreated) { diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index e5acab50181e1..7dab17eba8483 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -86,19 +86,19 @@ object KinesisUtils { * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) * @param regionName Name of region used by the Kinesis Client Library (KCL) to update * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) - * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the * worker's initial starting position in the stream. * The values are either the beginning of the stream * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects. * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) */ def createStream( ssc: StreamingContext, @@ -130,7 +130,7 @@ object KinesisUtils { * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name in * [[org.apache.spark.SparkConf]]. * - * @param ssc Java StreamingContext object + * @param ssc StreamingContext object * @param streamName Kinesis stream name * @param endpointUrl Endpoint url of Kinesis service * (e.g., https://kinesis.us-east-1.amazonaws.com) @@ -175,15 +175,15 @@ object KinesisUtils { * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) * @param regionName Name of region used by the Kinesis Client Library (KCL) to update * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the * worker's initial starting position in the stream. * The values are either the beginning of the stream * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects. * StorageLevel.MEMORY_AND_DISK_2 is recommended. */ @@ -206,8 +206,8 @@ object KinesisUtils { * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * * Note: - * The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. + * The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. * * @param jssc Java StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library @@ -216,19 +216,19 @@ object KinesisUtils { * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) * @param regionName Name of region used by the Kinesis Client Library (KCL) to update * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) - * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the * worker's initial starting position in the stream. * The values are either the beginning of the stream * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects. * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) */ def createStream( jssc: JavaStreamingContext, @@ -297,3 +297,49 @@ object KinesisUtils { } } } + +/** + * This is a helper class that wraps the methods in KinesisUtils into more Python-friendly class and + * function so that it can be easily instantiated and called from Python's KinesisUtils. + */ +private class KinesisUtilsPythonHelper { + + def getInitialPositionInStream(initialPositionInStream: Int): InitialPositionInStream = { + initialPositionInStream match { + case 0 => InitialPositionInStream.LATEST + case 1 => InitialPositionInStream.TRIM_HORIZON + case _ => throw new IllegalArgumentException( + "Illegal InitialPositionInStream. Please use " + + "InitialPositionInStream.LATEST or InitialPositionInStream.TRIM_HORIZON") + } + } + + def createStream( + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: Int, + checkpointInterval: Duration, + storageLevel: StorageLevel, + awsAccessKeyId: String, + awsSecretKey: String + ): JavaReceiverInputDStream[Array[Byte]] = { + if (awsAccessKeyId == null && awsSecretKey != null) { + throw new IllegalArgumentException("awsSecretKey is set but awsAccessKeyId is null") + } + if (awsAccessKeyId != null && awsSecretKey == null) { + throw new IllegalArgumentException("awsAccessKeyId is set but awsSecretKey is null") + } + if (awsAccessKeyId == null && awsSecretKey == null) { + KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, + getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel) + } else { + KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, + getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel, + awsAccessKeyId, awsSecretKey) + } + } + +} diff --git a/pom.xml b/pom.xml index 35fc8c44bc1b0..e351c7c19df96 100644 --- a/pom.xml +++ b/pom.xml @@ -1642,6 +1642,7 @@ kinesis-asl extras/kinesis-asl + extras/kinesis-asl-assembly diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 61a05d375d99e..9a33baa7c6ce1 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -45,8 +45,8 @@ object BuildCommons { sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", "kinesis-asl").map(ProjectRef(buildLocation, _)) - val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly) = - Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly") + val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKinesisAslAssembly) = + Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-kinesis-asl-assembly") .map(ProjectRef(buildLocation, _)) val tools = ProjectRef(buildLocation, "tools") @@ -382,7 +382,7 @@ object Assembly { .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String]) }, jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => - if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly")) { + if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly") || mName.contains("streaming-kinesis-asl-assembly")) { // This must match the same name used in maven (see external/kafka-assembly/pom.xml) s"${mName}-${v}.jar" } else { diff --git a/python/pyspark/streaming/kinesis.py b/python/pyspark/streaming/kinesis.py new file mode 100644 index 0000000000000..bcfe2703fecf9 --- /dev/null +++ b/python/pyspark/streaming/kinesis.py @@ -0,0 +1,112 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from py4j.java_gateway import Py4JJavaError + +from pyspark.serializers import PairDeserializer, NoOpSerializer +from pyspark.storagelevel import StorageLevel +from pyspark.streaming import DStream + +__all__ = ['KinesisUtils', 'InitialPositionInStream', 'utf8_decoder'] + + +def utf8_decoder(s): + """ Decode the unicode as UTF-8 """ + return s and s.decode('utf-8') + + +class KinesisUtils(object): + + @staticmethod + def createStream(ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, + storageLevel=StorageLevel.MEMORY_AND_DISK_2, + awsAccessKeyId=None, awsSecretKey=None, decoder=utf8_decoder): + """ + Create an input stream that pulls messages from a Kinesis stream. This uses the + Kinesis Client Library (KCL) to pull messages from Kinesis. + + Note: The given AWS credentials will get saved in DStream checkpoints if checkpointing is + enabled. Make sure that your checkpoint directory is secure. + + :param ssc: StreamingContext object + :param kinesisAppName: Kinesis application name used by the Kinesis Client Library (KCL) to + update DynamoDB + :param streamName: Kinesis stream name + :param endpointUrl: Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + :param regionName: Name of region used by the Kinesis Client Library (KCL) to update + DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + :param initialPositionInStream: In the absence of Kinesis checkpoint info, this is the + worker's initial starting position in the stream. The + values are either the beginning of the stream per Kinesis' + limit of 24 hours (InitialPositionInStream.TRIM_HORIZON) or + the tip of the stream (InitialPositionInStream.LATEST). + :param checkpointInterval: Checkpoint interval for Kinesis checkpointing. See the Kinesis + Spark Streaming documentation for more details on the different + types of checkpoints. + :param storageLevel: Storage level to use for storing the received objects (default is + StorageLevel.MEMORY_AND_DISK_2) + :param awsAccessKeyId: AWS AccessKeyId (default is None. If None, will use + DefaultAWSCredentialsProviderChain) + :param awsSecretKey: AWS SecretKey (default is None. If None, will use + DefaultAWSCredentialsProviderChain) + :param decoder: A function used to decode value (default is utf8_decoder) + :return: A DStream object + """ + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + jduration = ssc._jduration(checkpointInterval) + + try: + # Use KinesisUtilsPythonHelper to access Scala's KinesisUtils + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ + .loadClass("org.apache.spark.streaming.kinesis.KinesisUtilsPythonHelper") + helper = helperClass.newInstance() + jstream = helper.createStream(ssc._jssc, kinesisAppName, streamName, endpointUrl, + regionName, initialPositionInStream, jduration, jlevel, + awsAccessKeyId, awsSecretKey) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + KinesisUtils._printErrorMsg(ssc.sparkContext) + raise e + stream = DStream(jstream, ssc, NoOpSerializer()) + return stream.map(lambda v: decoder(v)) + + @staticmethod + def _printErrorMsg(sc): + print(""" +________________________________________________________________________________________________ + + Spark Streaming's Kinesis libraries not found in class path. Try one of the following. + + 1. Include the Kinesis library and its dependencies with in the + spark-submit command as + + $ bin/spark-submit --packages org.apache.spark:spark-streaming-kinesis-asl:%s ... + + 2. Download the JAR of the artifact from Maven Central http://search.maven.org/, + Group Id = org.apache.spark, Artifact Id = spark-streaming-kinesis-asl-assembly, Version = %s. + Then, include the jar in the spark-submit command as + + $ bin/spark-submit --jars ... + +________________________________________________________________________________________________ + +""" % (sc.version, sc.version)) + + +class InitialPositionInStream(object): + LATEST, TRIM_HORIZON = (0, 1) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 4ecae1e4bf282..5cd544b2144ef 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -36,9 +36,11 @@ import unittest from pyspark.context import SparkConf, SparkContext, RDD +from pyspark.storagelevel import StorageLevel from pyspark.streaming.context import StreamingContext from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition from pyspark.streaming.flume import FlumeUtils +from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream class PySparkStreamingTestCase(unittest.TestCase): @@ -891,6 +893,67 @@ def test_flume_polling_multiple_hosts(self): self._testMultipleTimes(self._testFlumePollingMultipleHosts) +class KinesisStreamTests(PySparkStreamingTestCase): + + def test_kinesis_stream_api(self): + # Don't start the StreamingContext because we cannot test it in Jenkins + kinesisStream1 = KinesisUtils.createStream( + self.ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2) + kinesisStream2 = KinesisUtils.createStream( + self.ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2, + "awsAccessKey", "awsSecretKey") + + def test_kinesis_stream(self): + if os.environ.get('ENABLE_KINESIS_TESTS') != '1': + print("Skip test_kinesis_stream") + return + + import random + kinesisAppName = ("KinesisStreamTests-%d" % abs(random.randint(0, 10000000))) + kinesisTestUtilsClz = \ + self.sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.kinesis.KinesisTestUtils") + kinesisTestUtils = kinesisTestUtilsClz.newInstance() + try: + kinesisTestUtils.createStream() + aWSCredentials = kinesisTestUtils.getAWSCredentials() + stream = KinesisUtils.createStream( + self.ssc, kinesisAppName, kinesisTestUtils.streamName(), + kinesisTestUtils.endpointUrl(), kinesisTestUtils.regionName(), + InitialPositionInStream.LATEST, 10, StorageLevel.MEMORY_ONLY, + aWSCredentials.getAWSAccessKeyId(), aWSCredentials.getAWSSecretKey()) + + outputBuffer = [] + + def get_output(_, rdd): + for e in rdd.collect(): + outputBuffer.append(e) + + stream.foreachRDD(get_output) + self.ssc.start() + + testData = [i for i in range(1, 11)] + expectedOutput = set([str(i) for i in testData]) + start_time = time.time() + while time.time() - start_time < 120: + kinesisTestUtils.pushData(testData) + if expectedOutput == set(outputBuffer): + break + time.sleep(10) + self.assertEqual(expectedOutput, set(outputBuffer)) + except: + import traceback + traceback.print_exc() + raise + finally: + kinesisTestUtils.deleteStream() + kinesisTestUtils.deleteDynamoDBTable(kinesisAppName) + + def search_kafka_assembly_jar(): SPARK_HOME = os.environ["SPARK_HOME"] kafka_assembly_dir = os.path.join(SPARK_HOME, "external/kafka-assembly") @@ -926,10 +989,31 @@ def search_flume_assembly_jar(): else: return jars[0] + +def search_kinesis_asl_assembly_jar(): + SPARK_HOME = os.environ["SPARK_HOME"] + kinesis_asl_assembly_dir = os.path.join(SPARK_HOME, "extras/kinesis-asl-assembly") + jars = glob.glob( + os.path.join(kinesis_asl_assembly_dir, + "target/scala-*/spark-streaming-kinesis-asl-assembly-*.jar")) + if not jars: + raise Exception( + ("Failed to find Spark Streaming Kinesis ASL assembly jar in %s. " % + kinesis_asl_assembly_dir) + "You need to build Spark with " + "'build/sbt -Pkinesis-asl assembly/assembly streaming-kinesis-asl-assembly/assembly' " + "or 'build/mvn -Pkinesis-asl package' before running this test") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming Kinesis ASL assembly JARs in %s; please " + "remove all but one") % kinesis_asl_assembly_dir) + else: + return jars[0] + + if __name__ == "__main__": kafka_assembly_jar = search_kafka_assembly_jar() flume_assembly_jar = search_flume_assembly_jar() - jars = "%s,%s" % (kafka_assembly_jar, flume_assembly_jar) + kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar() + jars = "%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar) os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars unittest.main() From d04634701413410938a133358fe1d9fbc077645e Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 31 Jul 2015 12:10:55 -0700 Subject: [PATCH 040/340] [SPARK-9504] [STREAMING] [TESTS] Use eventually to fix the flaky test The previous code uses `ssc.awaitTerminationOrTimeout(500)`. Since nobody will stop it during `awaitTerminationOrTimeout`, it's just like `sleep(500)`. In a super overloaded Jenkins worker, the receiver may be not able to start in 500 milliseconds. Verified this in the log of https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/39149/ There is no log about starting the receiver before this failure. That's why `assert(runningCount > 0)` failed. This PR replaces `awaitTerminationOrTimeout` with `eventually` which should be more reliable. Author: zsxwing Closes #7823 from zsxwing/SPARK-9504 and squashes the following commits: 7af66a6 [zsxwing] Remove wrong assertion 5ba2c99 [zsxwing] Use eventually to fix the flaky test --- .../apache/spark/streaming/StreamingContextSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 84a5fbb3d95eb..b7db280f63588 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -261,7 +261,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo for (i <- 1 to 4) { logInfo("==================================\n\n\n") ssc = new StreamingContext(sc, Milliseconds(100)) - var runningCount = 0 + @volatile var runningCount = 0 TestReceiver.counter.set(1) val input = ssc.receiverStream(new TestReceiver) input.count().foreachRDD { rdd => @@ -270,14 +270,14 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo logInfo("Count = " + count + ", Running count = " + runningCount) } ssc.start() - ssc.awaitTerminationOrTimeout(500) + eventually(timeout(10.seconds), interval(10.millis)) { + assert(runningCount > 0) + } ssc.stop(stopSparkContext = false, stopGracefully = true) logInfo("Running count = " + runningCount) logInfo("TestReceiver.counter = " + TestReceiver.counter.get()) - assert(runningCount > 0) assert( - (TestReceiver.counter.get() == runningCount + 1) || - (TestReceiver.counter.get() == runningCount + 2), + TestReceiver.counter.get() == runningCount + 1, "Received records = " + TestReceiver.counter.get() + ", " + "processed records = " + runningCount ) From a8340fa7df17e3f0a3658f8b8045ab840845a72a Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Fri, 31 Jul 2015 12:12:22 -0700 Subject: [PATCH 041/340] [SPARK-9481] Add logLikelihood to LocalLDAModel jkbradley Exposes `bound` (variational log likelihood bound) through public API as `logLikelihood`. Also adds unit tests, some DRYing of `LDASuite`, and includes unit tests mentioned in #7760 Author: Feynman Liang Closes #7801 from feynmanliang/SPARK-9481-logLikelihood and squashes the following commits: 6d1b2c9 [Feynman Liang] Negate perplexity definition 5f62b20 [Feynman Liang] Add logLikelihood --- .../spark/mllib/clustering/LDAModel.scala | 20 ++- .../spark/mllib/clustering/LDASuite.scala | 129 +++++++++--------- 2 files changed, 78 insertions(+), 71 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 82281a0daf008..ff7035d2246c2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -217,22 +217,28 @@ class LocalLDAModel private[clustering] ( LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration, gammaShape) } - // TODO - // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? + + // TODO: declare in LDAModel and override once implemented in DistributedLDAModel + /** + * Calculates a lower bound on the log likelihood of the entire corpus. + * @param documents test corpus to use for calculating log likelihood + * @return variational lower bound on the log likelihood of the entire corpus + */ + def logLikelihood(documents: RDD[(Long, Vector)]): Double = bound(documents, + docConcentration, topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k, + vocabSize) /** - * Calculate the log variational bound on perplexity. See Equation (16) in original Online + * Calculate an upper bound bound on perplexity. See Equation (16) in original Online * LDA paper. * @param documents test corpus to use for calculating perplexity - * @return the log perplexity per word + * @return variational upper bound on log perplexity per word */ def logPerplexity(documents: RDD[(Long, Vector)]): Double = { val corpusWords = documents .map { case (_, termCounts) => termCounts.toArray.sum } .sum() - val batchVariationalBound = bound(documents, docConcentration, - topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k, vocabSize) - val perWordBound = batchVariationalBound / corpusWords + val perWordBound = -logLikelihood(documents) / corpusWords perWordBound } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 695ee3b82efc5..79d2a1cafd1fa 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -210,16 +210,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } test("OnlineLDAOptimizer with toy data") { - def toydata: Array[(Long, Vector)] = Array( - Vectors.sparse(6, Array(0, 1), Array(1, 1)), - Vectors.sparse(6, Array(1, 2), Array(1, 1)), - Vectors.sparse(6, Array(0, 2), Array(1, 1)), - Vectors.sparse(6, Array(3, 4), Array(1, 1)), - Vectors.sparse(6, Array(3, 5), Array(1, 1)), - Vectors.sparse(6, Array(4, 5), Array(1, 1)) - ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } - - val docs = sc.parallelize(toydata) + val docs = sc.parallelize(toyData) val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51) .setGammaShape(1e10) val lda = new LDA().setK(2) @@ -242,30 +233,45 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } } - test("LocalLDAModel logPerplexity") { - val k = 2 - val vocabSize = 6 - val alpha = 0.01 - val eta = 0.01 - val gammaShape = 100 - // obtained from LDA model trained in gensim, see below - val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array( - 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597, - 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124)) + test("LocalLDAModel logLikelihood") { + val ldaModel: LocalLDAModel = toyModel - def toydata: Array[(Long, Vector)] = Array( - Vectors.sparse(6, Array(0, 1), Array(1, 1)), - Vectors.sparse(6, Array(1, 2), Array(1, 1)), - Vectors.sparse(6, Array(0, 2), Array(1, 1)), - Vectors.sparse(6, Array(3, 4), Array(1, 1)), - Vectors.sparse(6, Array(3, 5), Array(1, 1)), - Vectors.sparse(6, Array(4, 5), Array(1, 1)) - ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } - val docs = sc.parallelize(toydata) + val docsSingleWord = sc.parallelize(Array(Vectors.sparse(6, Array(0), Array(1))) + .zipWithIndex + .map { case (wordCounts, docId) => (docId.toLong, wordCounts) }) + val docsRepeatedWord = sc.parallelize(Array(Vectors.sparse(6, Array(0), Array(5))) + .zipWithIndex + .map { case (wordCounts, docId) => (docId.toLong, wordCounts) }) + /* Verify results using gensim: + import numpy as np + from gensim import models + corpus = [ + [(0, 1.0), (1, 1.0)], + [(1, 1.0), (2, 1.0)], + [(0, 1.0), (2, 1.0)], + [(3, 1.0), (4, 1.0)], + [(3, 1.0), (5, 1.0)], + [(4, 1.0), (5, 1.0)]] + np.random.seed(2345) + lda = models.ldamodel.LdaModel( + corpus=corpus, alpha=0.01, eta=0.01, num_topics=2, update_every=0, passes=100, + decay=0.51, offset=1024) + docsSingleWord = [[(0, 1.0)]] + docsRepeatedWord = [[(0, 5.0)]] + print(lda.bound(docsSingleWord)) + > -25.9706969833 + print(lda.bound(docsRepeatedWord)) + > -31.4413908227 + */ - val ldaModel: LocalLDAModel = new LocalLDAModel( - topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape) + assert(ldaModel.logLikelihood(docsSingleWord) ~== -25.971 relTol 1E-3D) + assert(ldaModel.logLikelihood(docsRepeatedWord) ~== -31.441 relTol 1E-3D) + } + + test("LocalLDAModel logPerplexity") { + val docs = sc.parallelize(toyData) + val ldaModel: LocalLDAModel = toyModel /* Verify results using gensim: import numpy as np @@ -285,32 +291,13 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { > -3.69051285096 */ - assert(ldaModel.logPerplexity(docs) ~== -3.690D relTol 1E-3D) + // Gensim's definition of perplexity is negative our (and Stanford NLP's) definition + assert(ldaModel.logPerplexity(docs) ~== 3.690D relTol 1E-3D) } test("LocalLDAModel predict") { - val k = 2 - val vocabSize = 6 - val alpha = 0.01 - val eta = 0.01 - val gammaShape = 100 - // obtained from LDA model trained in gensim, see below - val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array( - 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597, - 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124)) - - def toydata: Array[(Long, Vector)] = Array( - Vectors.sparse(6, Array(0, 1), Array(1, 1)), - Vectors.sparse(6, Array(1, 2), Array(1, 1)), - Vectors.sparse(6, Array(0, 2), Array(1, 1)), - Vectors.sparse(6, Array(3, 4), Array(1, 1)), - Vectors.sparse(6, Array(3, 5), Array(1, 1)), - Vectors.sparse(6, Array(4, 5), Array(1, 1)) - ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } - val docs = sc.parallelize(toydata) - - val ldaModel: LocalLDAModel = new LocalLDAModel( - topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape) + val docs = sc.parallelize(toyData) + val ldaModel: LocalLDAModel = toyModel /* Verify results using gensim: import numpy as np @@ -351,16 +338,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } test("OnlineLDAOptimizer with asymmetric prior") { - def toydata: Array[(Long, Vector)] = Array( - Vectors.sparse(6, Array(0, 1), Array(1, 1)), - Vectors.sparse(6, Array(1, 2), Array(1, 1)), - Vectors.sparse(6, Array(0, 2), Array(1, 1)), - Vectors.sparse(6, Array(3, 4), Array(1, 1)), - Vectors.sparse(6, Array(3, 5), Array(1, 1)), - Vectors.sparse(6, Array(4, 5), Array(1, 1)) - ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } - - val docs = sc.parallelize(toydata) + val docs = sc.parallelize(toyData) val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51) .setGammaShape(1e10) val lda = new LDA().setK(2) @@ -531,4 +509,27 @@ private[clustering] object LDASuite { def getNonEmptyDoc(corpus: Array[(Long, Vector)]): Array[(Long, Vector)] = corpus.filter { case (_, wc: Vector) => Vectors.norm(wc, p = 1.0) != 0.0 } + + def toyData: Array[(Long, Vector)] = Array( + Vectors.sparse(6, Array(0, 1), Array(1, 1)), + Vectors.sparse(6, Array(1, 2), Array(1, 1)), + Vectors.sparse(6, Array(0, 2), Array(1, 1)), + Vectors.sparse(6, Array(3, 4), Array(1, 1)), + Vectors.sparse(6, Array(3, 5), Array(1, 1)), + Vectors.sparse(6, Array(4, 5), Array(1, 1)) + ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + + def toyModel: LocalLDAModel = { + val k = 2 + val vocabSize = 6 + val alpha = 0.01 + val eta = 0.01 + val gammaShape = 100 + val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array( + 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597, + 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124)) + val ldaModel: LocalLDAModel = new LocalLDAModel( + topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape) + ldaModel + } } From c0686668ae6a92b6bb4801a55c3b78aedbee816a Mon Sep 17 00:00:00 2001 From: CodingCat Date: Fri, 31 Jul 2015 20:27:00 +0100 Subject: [PATCH 042/340] [SPARK-9202] capping maximum number of executor&driver information kept in Worker https://issues.apache.org/jira/browse/SPARK-9202 Author: CodingCat Closes #7714 from CodingCat/SPARK-9202 and squashes the following commits: 23977fb [CodingCat] add comments about why we don't synchronize finishedExecutors & finishedDrivers dc9772d [CodingCat] addressing the comments e125241 [CodingCat] stylistic fix 80bfe52 [CodingCat] fix JsonProtocolSuite d7d9485 [CodingCat] styistic fix and respect insert ordering 031755f [CodingCat] add license info & stylistic fix c3b5361 [CodingCat] test cases and docs c557b3a [CodingCat] applications are fine 9cac751 [CodingCat] application is fine... ad87ed7 [CodingCat] trimFinishedExecutorsAndDrivers --- .../apache/spark/deploy/worker/Worker.scala | 124 ++++++++++------ .../spark/deploy/worker/ui/WorkerWebUI.scala | 4 +- .../apache/spark/deploy/DeployTestUtils.scala | 89 ++++++++++++ .../spark/deploy/JsonProtocolSuite.scala | 59 ++------ .../spark/deploy/worker/WorkerSuite.scala | 133 +++++++++++++++++- docs/configuration.md | 14 ++ 6 files changed, 329 insertions(+), 94 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 82e9578bbcba5..0276c24f85368 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -25,7 +25,7 @@ import java.util.concurrent._ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} import scala.collection.JavaConversions._ -import scala.collection.mutable.{HashMap, HashSet} +import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap} import scala.concurrent.ExecutionContext import scala.util.Random import scala.util.control.NonFatal @@ -115,13 +115,18 @@ private[worker] class Worker( } var workDir: File = null - val finishedExecutors = new HashMap[String, ExecutorRunner] + val finishedExecutors = new LinkedHashMap[String, ExecutorRunner] val drivers = new HashMap[String, DriverRunner] val executors = new HashMap[String, ExecutorRunner] - val finishedDrivers = new HashMap[String, DriverRunner] + val finishedDrivers = new LinkedHashMap[String, DriverRunner] val appDirectories = new HashMap[String, Seq[String]] val finishedApps = new HashSet[String] + val retainedExecutors = conf.getInt("spark.worker.ui.retainedExecutors", + WorkerWebUI.DEFAULT_RETAINED_EXECUTORS) + val retainedDrivers = conf.getInt("spark.worker.ui.retainedDrivers", + WorkerWebUI.DEFAULT_RETAINED_DRIVERS) + // The shuffle service is not actually started unless configured. private val shuffleService = new ExternalShuffleService(conf, securityMgr) @@ -461,25 +466,7 @@ private[worker] class Worker( } case executorStateChanged @ ExecutorStateChanged(appId, execId, state, message, exitStatus) => - sendToMaster(executorStateChanged) - val fullId = appId + "/" + execId - if (ExecutorState.isFinished(state)) { - executors.get(fullId) match { - case Some(executor) => - logInfo("Executor " + fullId + " finished with state " + state + - message.map(" message " + _).getOrElse("") + - exitStatus.map(" exitStatus " + _).getOrElse("")) - executors -= fullId - finishedExecutors(fullId) = executor - coresUsed -= executor.cores - memoryUsed -= executor.memory - case None => - logInfo("Unknown Executor " + fullId + " finished with state " + state + - message.map(" message " + _).getOrElse("") + - exitStatus.map(" exitStatus " + _).getOrElse("")) - } - maybeCleanupApplication(appId) - } + handleExecutorStateChanged(executorStateChanged) case KillExecutor(masterUrl, appId, execId) => if (masterUrl != activeMasterUrl) { @@ -523,24 +510,8 @@ private[worker] class Worker( } } - case driverStageChanged @ DriverStateChanged(driverId, state, exception) => { - state match { - case DriverState.ERROR => - logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}") - case DriverState.FAILED => - logWarning(s"Driver $driverId exited with failure") - case DriverState.FINISHED => - logInfo(s"Driver $driverId exited successfully") - case DriverState.KILLED => - logInfo(s"Driver $driverId was killed by user") - case _ => - logDebug(s"Driver $driverId changed state to $state") - } - sendToMaster(driverStageChanged) - val driver = drivers.remove(driverId).get - finishedDrivers(driverId) = driver - memoryUsed -= driver.driverDesc.mem - coresUsed -= driver.driverDesc.cores + case driverStateChanged @ DriverStateChanged(driverId, state, exception) => { + handleDriverStateChanged(driverStateChanged) } case ReregisterWithMaster => @@ -614,6 +585,78 @@ private[worker] class Worker( webUi.stop() metricsSystem.stop() } + + private def trimFinishedExecutorsIfNecessary(): Unit = { + // do not need to protect with locks since both WorkerPage and Restful server get data through + // thread-safe RpcEndPoint + if (finishedExecutors.size > retainedExecutors) { + finishedExecutors.take(math.max(finishedExecutors.size / 10, 1)).foreach { + case (executorId, _) => finishedExecutors.remove(executorId) + } + } + } + + private def trimFinishedDriversIfNecessary(): Unit = { + // do not need to protect with locks since both WorkerPage and Restful server get data through + // thread-safe RpcEndPoint + if (finishedDrivers.size > retainedDrivers) { + finishedDrivers.take(math.max(finishedDrivers.size / 10, 1)).foreach { + case (driverId, _) => finishedDrivers.remove(driverId) + } + } + } + + private[worker] def handleDriverStateChanged(driverStateChanged: DriverStateChanged): Unit = { + val driverId = driverStateChanged.driverId + val exception = driverStateChanged.exception + val state = driverStateChanged.state + state match { + case DriverState.ERROR => + logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}") + case DriverState.FAILED => + logWarning(s"Driver $driverId exited with failure") + case DriverState.FINISHED => + logInfo(s"Driver $driverId exited successfully") + case DriverState.KILLED => + logInfo(s"Driver $driverId was killed by user") + case _ => + logDebug(s"Driver $driverId changed state to $state") + } + sendToMaster(driverStateChanged) + val driver = drivers.remove(driverId).get + finishedDrivers(driverId) = driver + trimFinishedDriversIfNecessary() + memoryUsed -= driver.driverDesc.mem + coresUsed -= driver.driverDesc.cores + } + + private[worker] def handleExecutorStateChanged(executorStateChanged: ExecutorStateChanged): + Unit = { + sendToMaster(executorStateChanged) + val state = executorStateChanged.state + if (ExecutorState.isFinished(state)) { + val appId = executorStateChanged.appId + val fullId = appId + "/" + executorStateChanged.execId + val message = executorStateChanged.message + val exitStatus = executorStateChanged.exitStatus + executors.get(fullId) match { + case Some(executor) => + logInfo("Executor " + fullId + " finished with state " + state + + message.map(" message " + _).getOrElse("") + + exitStatus.map(" exitStatus " + _).getOrElse("")) + executors -= fullId + finishedExecutors(fullId) = executor + trimFinishedExecutorsIfNecessary() + coresUsed -= executor.cores + memoryUsed -= executor.memory + case None => + logInfo("Unknown Executor " + fullId + " finished with state " + state + + message.map(" message " + _).getOrElse("") + + exitStatus.map(" exitStatus " + _).getOrElse("")) + } + maybeCleanupApplication(appId) + } + } } private[deploy] object Worker extends Logging { @@ -669,5 +712,4 @@ private[deploy] object Worker extends Logging { cmd } } - } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 334a5b10142aa..709a27233598c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -53,6 +53,8 @@ class WorkerWebUI( } } -private[ui] object WorkerWebUI { +private[worker] object WorkerWebUI { val STATIC_RESOURCE_BASE = SparkUI.STATIC_RESOURCE_DIR + val DEFAULT_RETAINED_DRIVERS = 1000 + val DEFAULT_RETAINED_EXECUTORS = 1000 } diff --git a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala new file mode 100644 index 0000000000000..967aa0976f0ce --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io.File +import java.util.Date + +import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} +import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} +import org.apache.spark.{SecurityManager, SparkConf} + +private[deploy] object DeployTestUtils { + def createAppDesc(): ApplicationDescription = { + val cmd = new Command("mainClass", List("arg1", "arg2"), Map(), Seq(), Seq(), Seq()) + new ApplicationDescription("name", Some(4), 1234, cmd, "appUiUrl") + } + + def createAppInfo() : ApplicationInfo = { + val appInfo = new ApplicationInfo(JsonConstants.appInfoStartTime, + "id", createAppDesc(), JsonConstants.submitDate, null, Int.MaxValue) + appInfo.endTime = JsonConstants.currTimeInMillis + appInfo + } + + def createDriverCommand(): Command = new Command( + "org.apache.spark.FakeClass", Seq("some arg --and-some options -g foo"), + Map(("K1", "V1"), ("K2", "V2")), Seq("cp1", "cp2"), Seq("lp1", "lp2"), Seq("-Dfoo") + ) + + def createDriverDesc(): DriverDescription = + new DriverDescription("hdfs://some-dir/some.jar", 100, 3, false, createDriverCommand()) + + def createDriverInfo(): DriverInfo = new DriverInfo(3, "driver-3", + createDriverDesc(), new Date()) + + def createWorkerInfo(): WorkerInfo = { + val workerInfo = new WorkerInfo("id", "host", 8080, 4, 1234, null, 80, "publicAddress") + workerInfo.lastHeartbeat = JsonConstants.currTimeInMillis + workerInfo + } + + def createExecutorRunner(execId: Int): ExecutorRunner = { + new ExecutorRunner( + "appId", + execId, + createAppDesc(), + 4, + 1234, + null, + "workerId", + "host", + 123, + "publicAddress", + new File("sparkHome"), + new File("workDir"), + "akka://worker", + new SparkConf, + Seq("localDir"), + ExecutorState.RUNNING) + } + + def createDriverRunner(driverId: String): DriverRunner = { + val conf = new SparkConf() + new DriverRunner( + conf, + driverId, + new File("workDir"), + new File("sparkHome"), + createDriverDesc(), + null, + "akka://worker", + new SecurityManager(conf)) + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index 08529e0ef2806..0a9f128a3a6b6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.deploy -import java.io.File import java.util.Date import com.fasterxml.jackson.core.JsonParseException @@ -25,12 +24,14 @@ import org.json4s._ import org.json4s.jackson.JsonMethods import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} -import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, RecoveryState, WorkerInfo} -import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} -import org.apache.spark.{JsonTestUtils, SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.deploy.master.{ApplicationInfo, RecoveryState} +import org.apache.spark.deploy.worker.ExecutorRunner +import org.apache.spark.{JsonTestUtils, SparkFunSuite} class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { + import org.apache.spark.deploy.DeployTestUtils._ + test("writeApplicationInfo") { val output = JsonProtocol.writeApplicationInfo(createAppInfo()) assertValidJson(output) @@ -50,7 +51,7 @@ class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { } test("writeExecutorRunner") { - val output = JsonProtocol.writeExecutorRunner(createExecutorRunner()) + val output = JsonProtocol.writeExecutorRunner(createExecutorRunner(123)) assertValidJson(output) assertValidDataInJson(output, JsonMethods.parse(JsonConstants.executorRunnerJsonStr)) } @@ -77,9 +78,10 @@ class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { test("writeWorkerState") { val executors = List[ExecutorRunner]() - val finishedExecutors = List[ExecutorRunner](createExecutorRunner(), createExecutorRunner()) - val drivers = List(createDriverRunner()) - val finishedDrivers = List(createDriverRunner(), createDriverRunner()) + val finishedExecutors = List[ExecutorRunner](createExecutorRunner(123), + createExecutorRunner(123)) + val drivers = List(createDriverRunner("driverId")) + val finishedDrivers = List(createDriverRunner("driverId"), createDriverRunner("driverId")) val stateResponse = new WorkerStateResponse("host", 8080, "workerId", executors, finishedExecutors, drivers, finishedDrivers, "masterUrl", 4, 1234, 4, 1234, "masterWebUiUrl") val output = JsonProtocol.writeWorkerState(stateResponse) @@ -87,47 +89,6 @@ class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { assertValidDataInJson(output, JsonMethods.parse(JsonConstants.workerStateJsonStr)) } - def createAppDesc(): ApplicationDescription = { - val cmd = new Command("mainClass", List("arg1", "arg2"), Map(), Seq(), Seq(), Seq()) - new ApplicationDescription("name", Some(4), 1234, cmd, "appUiUrl") - } - - def createAppInfo() : ApplicationInfo = { - val appInfo = new ApplicationInfo(JsonConstants.appInfoStartTime, - "id", createAppDesc(), JsonConstants.submitDate, null, Int.MaxValue) - appInfo.endTime = JsonConstants.currTimeInMillis - appInfo - } - - def createDriverCommand(): Command = new Command( - "org.apache.spark.FakeClass", Seq("some arg --and-some options -g foo"), - Map(("K1", "V1"), ("K2", "V2")), Seq("cp1", "cp2"), Seq("lp1", "lp2"), Seq("-Dfoo") - ) - - def createDriverDesc(): DriverDescription = - new DriverDescription("hdfs://some-dir/some.jar", 100, 3, false, createDriverCommand()) - - def createDriverInfo(): DriverInfo = new DriverInfo(3, "driver-3", - createDriverDesc(), new Date()) - - def createWorkerInfo(): WorkerInfo = { - val workerInfo = new WorkerInfo("id", "host", 8080, 4, 1234, null, 80, "publicAddress") - workerInfo.lastHeartbeat = JsonConstants.currTimeInMillis - workerInfo - } - - def createExecutorRunner(): ExecutorRunner = { - new ExecutorRunner("appId", 123, createAppDesc(), 4, 1234, null, "workerId", "host", 123, - "publicAddress", new File("sparkHome"), new File("workDir"), "akka://worker", - new SparkConf, Seq("localDir"), ExecutorState.RUNNING) - } - - def createDriverRunner(): DriverRunner = { - val conf = new SparkConf() - new DriverRunner(conf, "driverId", new File("workDir"), new File("sparkHome"), - createDriverDesc(), null, "akka://worker", new SecurityManager(conf)) - } - def assertValidJson(json: JValue) { try { JsonMethods.parse(JsonMethods.compact(json)) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index 0f4d3b28d09df..faed4bdc68447 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -17,13 +17,18 @@ package org.apache.spark.deploy.worker -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.Command - import org.scalatest.Matchers +import org.apache.spark.deploy.DeployMessages.{DriverStateChanged, ExecutorStateChanged} +import org.apache.spark.deploy.master.DriverState +import org.apache.spark.deploy.{Command, ExecutorState} +import org.apache.spark.rpc.{RpcAddress, RpcEnv} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} + class WorkerSuite extends SparkFunSuite with Matchers { + import org.apache.spark.deploy.DeployTestUtils._ + def cmd(javaOpts: String*): Command = { Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts : _*)) } @@ -56,4 +61,126 @@ class WorkerSuite extends SparkFunSuite with Matchers { "-Dspark.ssl.useNodeLocalConf=true", "-Dspark.ssl.opt1=y", "-Dspark.ssl.opt2=z") } + + test("test clearing of finishedExecutors (small number of executors)") { + val conf = new SparkConf() + conf.set("spark.worker.ui.retainedExecutors", 2.toString) + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), + "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + // initialize workers + for (i <- 0 until 5) { + worker.executors += s"app1/$i" -> createExecutorRunner(i) + } + // initialize ExecutorStateChanged Message + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", 0, ExecutorState.EXITED, None, None)) + assert(worker.finishedExecutors.size === 1) + assert(worker.executors.size === 4) + for (i <- 1 until 5) { + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", i, ExecutorState.EXITED, None, None)) + assert(worker.finishedExecutors.size === 2) + if (i > 1) { + assert(!worker.finishedExecutors.contains(s"app1/${i - 2}")) + } + assert(worker.executors.size === 4 - i) + } + } + + test("test clearing of finishedExecutors (more executors)") { + val conf = new SparkConf() + conf.set("spark.worker.ui.retainedExecutors", 30.toString) + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), + "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + // initialize workers + for (i <- 0 until 50) { + worker.executors += s"app1/$i" -> createExecutorRunner(i) + } + // initialize ExecutorStateChanged Message + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", 0, ExecutorState.EXITED, None, None)) + assert(worker.finishedExecutors.size === 1) + assert(worker.executors.size === 49) + for (i <- 1 until 50) { + val expectedValue = { + if (worker.finishedExecutors.size < 30) { + worker.finishedExecutors.size + 1 + } else { + 28 + } + } + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", i, ExecutorState.EXITED, None, None)) + if (expectedValue == 28) { + for (j <- i - 30 until i - 27) { + assert(!worker.finishedExecutors.contains(s"app1/$j")) + } + } + assert(worker.executors.size === 49 - i) + assert(worker.finishedExecutors.size === expectedValue) + } + } + + test("test clearing of finishedDrivers (small number of drivers)") { + val conf = new SparkConf() + conf.set("spark.worker.ui.retainedDrivers", 2.toString) + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), + "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + // initialize workers + for (i <- 0 until 5) { + val driverId = s"driverId-$i" + worker.drivers += driverId -> createDriverRunner(driverId) + } + // initialize DriverStateChanged Message + worker.handleDriverStateChanged(DriverStateChanged("driverId-0", DriverState.FINISHED, None)) + assert(worker.drivers.size === 4) + assert(worker.finishedDrivers.size === 1) + for (i <- 1 until 5) { + val driverId = s"driverId-$i" + worker.handleDriverStateChanged(DriverStateChanged(driverId, DriverState.FINISHED, None)) + if (i > 1) { + assert(!worker.finishedDrivers.contains(s"driverId-${i - 2}")) + } + assert(worker.drivers.size === 4 - i) + assert(worker.finishedDrivers.size === 2) + } + } + + test("test clearing of finishedDrivers (more drivers)") { + val conf = new SparkConf() + conf.set("spark.worker.ui.retainedDrivers", 30.toString) + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), + "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + // initialize workers + for (i <- 0 until 50) { + val driverId = s"driverId-$i" + worker.drivers += driverId -> createDriverRunner(driverId) + } + // initialize DriverStateChanged Message + worker.handleDriverStateChanged(DriverStateChanged("driverId-0", DriverState.FINISHED, None)) + assert(worker.finishedDrivers.size === 1) + assert(worker.drivers.size === 49) + for (i <- 1 until 50) { + val expectedValue = { + if (worker.finishedDrivers.size < 30) { + worker.finishedDrivers.size + 1 + } else { + 28 + } + } + val driverId = s"driverId-$i" + worker.handleDriverStateChanged(DriverStateChanged(driverId, DriverState.FINISHED, None)) + if (expectedValue == 28) { + for (j <- i - 30 until i - 27) { + assert(!worker.finishedDrivers.contains(s"driverId-$j")) + } + } + assert(worker.drivers.size === 49 - i) + assert(worker.finishedDrivers.size === expectedValue) + } + } } diff --git a/docs/configuration.md b/docs/configuration.md index fd236137cb96e..24b606356a149 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -557,6 +557,20 @@ Apart from these, the following properties are also available, and may be useful collecting. + + spark.worker.ui.retainedExecutors + 1000 + + How many finished executors the Spark UI and status APIs remember before garbage collecting. + + + + spark.worker.ui.retainedDrivers + 1000 + + How many finished drivers the Spark UI and status APIs remember before garbage collecting. + + #### Compression and Serialization From e91df52bc811e1ca63f9a9b4beff773ec2c83566 Mon Sep 17 00:00:00 2001 From: mcheah Date: Fri, 31 Jul 2015 12:53:36 -0700 Subject: [PATCH 043/340] Clean up ExternalLists more eagerly. Uses a Cleaner thread similar to ContextCleaner that can also be running on the executors. Uses WeakReference to determine if a list can be cleaned up or not. ExternalList objects register themselves for cleanup upon construction or deserialization. --- .../org/apache/spark/ContextCleaner.scala | 114 ++++-------------- .../org/apache/spark/ExecutorCleaner.scala | 49 ++++++++ .../scala/org/apache/spark/SparkEnv.scala | 5 + .../apache/spark/WeakReferenceCleaner.scala | 91 ++++++++++++++ .../spark/serializer/KryoSerializer.scala | 2 +- .../spark/util/cleanup/CleanupTasks.scala | 42 +++++++ .../spark/util/collection/ExternalList.scala | 80 ++++++++++-- .../util/collection/SpillableCollection.scala | 3 +- .../util/collection/ExternalListSuite.scala | 95 ++++++++++++--- 9 files changed, 355 insertions(+), 126 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/ExecutorCleaner.scala create mode 100644 core/src/main/scala/org/apache/spark/WeakReferenceCleaner.scala create mode 100644 core/src/main/scala/org/apache/spark/util/cleanup/CleanupTasks.scala diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 37198d887b07b..d72d2c17592c3 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -17,35 +17,13 @@ package org.apache.spark -import java.lang.ref.{ReferenceQueue, WeakReference} +import org.apache.spark.util.Utils import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{RDDCheckpointData, RDD} -import org.apache.spark.util.Utils - -/** - * Classes that represent cleaning tasks. - */ -private sealed trait CleanupTask -private case class CleanRDD(rddId: Int) extends CleanupTask -private case class CleanShuffle(shuffleId: Int) extends CleanupTask -private case class CleanBroadcast(broadcastId: Long) extends CleanupTask -private case class CleanAccum(accId: Long) extends CleanupTask -private case class CleanCheckpoint(rddId: Int) extends CleanupTask - -/** - * A WeakReference associated with a CleanupTask. - * - * When the referent object becomes only weakly reachable, the corresponding - * CleanupTaskWeakReference is automatically added to the given reference queue. - */ -private class CleanupTaskWeakReference( - val task: CleanupTask, - referent: AnyRef, - referenceQueue: ReferenceQueue[AnyRef]) - extends WeakReference(referent, referenceQueue) +import org.apache.spark.util.cleanup._ /** * An asynchronous cleaner for RDD, shuffle, and broadcast state. @@ -54,18 +32,11 @@ private class CleanupTaskWeakReference( * to be processed when the associated object goes out of scope of the application. Actual * cleanup is performed in a separate daemon thread. */ -private[spark] class ContextCleaner(sc: SparkContext) extends Logging { - - private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference] - with SynchronizedBuffer[CleanupTaskWeakReference] - - private val referenceQueue = new ReferenceQueue[AnyRef] +private[spark] class ContextCleaner(sc: SparkContext) extends WeakReferenceCleaner { private val listeners = new ArrayBuffer[CleanerListener] with SynchronizedBuffer[CleanerListener] - private val cleaningThread = new Thread() { override def run() { keepCleaning() }} - /** * Whether the cleaning thread will block on cleanup tasks (other than shuffle, which * is controlled by the `spark.cleaner.referenceTracking.blocking.shuffle` parameter). @@ -92,35 +63,11 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private val blockOnShuffleCleanupTasks = sc.conf.getBoolean( "spark.cleaner.referenceTracking.blocking.shuffle", false) - @volatile private var stopped = false - /** Attach a listener object to get information of when objects are cleaned. */ def attachListener(listener: CleanerListener): Unit = { listeners += listener } - /** Start the cleaner. */ - def start(): Unit = { - cleaningThread.setDaemon(true) - cleaningThread.setName("Spark Context Cleaner") - cleaningThread.start() - } - - /** - * Stop the cleaning thread and wait until the thread has finished running its current task. - */ - def stop(): Unit = { - stopped = true - // Interrupt the cleaning thread, but wait until the current task has finished before - // doing so. This guards against the race condition where a cleaning thread may - // potentially clean similarly named variables created by a different SparkContext, - // resulting in otherwise inexplicable block-not-found exceptions (SPARK-6132). - synchronized { - cleaningThread.interrupt() - } - cleaningThread.join() - } - /** Register a RDD for cleanup when it is garbage collected. */ def registerRDDForCleanup(rdd: RDD[_]): Unit = { registerForCleanup(rdd, CleanRDD(rdd.id)) @@ -145,43 +92,30 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { registerForCleanup(rdd, CleanCheckpoint(parentId)) } - /** Register an object for cleanup. */ - private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = { - referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue) + /** Keep cleaning RDD, shuffle, and broadcast state. */ + override protected def keepCleaning(): Unit = Utils.tryOrStopSparkContext(sc) { + super.keepCleaning() } - /** Keep cleaning RDD, shuffle, and broadcast state. */ - private def keepCleaning(): Unit = Utils.tryOrStopSparkContext(sc) { - while (!stopped) { - try { - val reference = Option(referenceQueue.remove(ContextCleaner.REF_QUEUE_POLL_TIMEOUT)) - .map(_.asInstanceOf[CleanupTaskWeakReference]) - // Synchronize here to avoid being interrupted on stop() - synchronized { - reference.map(_.task).foreach { task => - logDebug("Got cleaning task " + task) - referenceBuffer -= reference.get - task match { - case CleanRDD(rddId) => - doCleanupRDD(rddId, blocking = blockOnCleanupTasks) - case CleanShuffle(shuffleId) => - doCleanupShuffle(shuffleId, blocking = blockOnShuffleCleanupTasks) - case CleanBroadcast(broadcastId) => - doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks) - case CleanAccum(accId) => - doCleanupAccum(accId, blocking = blockOnCleanupTasks) - case CleanCheckpoint(rddId) => - doCleanCheckpoint(rddId) - } - } - } - } catch { - case ie: InterruptedException if stopped => // ignore - case e: Exception => logError("Error in cleaning thread", e) - } + protected def handleCleanupForSpecificTask(task: CleanupTask): Unit = { + task match { + case CleanRDD(rddId) => + doCleanupRDD(rddId, blocking = blockOnCleanupTasks) + case CleanShuffle(shuffleId) => + doCleanupShuffle(shuffleId, blocking = blockOnShuffleCleanupTasks) + case CleanBroadcast(broadcastId) => + doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks) + case CleanAccum(accId) => + doCleanupAccum(accId, blocking = blockOnCleanupTasks) + case CleanCheckpoint(rddId) => + doCleanCheckpoint(rddId) + case unknown => + logWarning(s"Got a cleanup task $unknown that cannot be handled by ContextCleaner,") } } + protected def cleanupThreadName(): String = "Context Cleaner" + /** Perform RDD cleanup. */ def doCleanupRDD(rddId: Int, blocking: Boolean): Unit = { try { @@ -249,10 +183,6 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] } -private object ContextCleaner { - private val REF_QUEUE_POLL_TIMEOUT = 100 -} - /** * Listener class used for testing when any item has been cleaned by the Cleaner class. */ diff --git a/core/src/main/scala/org/apache/spark/ExecutorCleaner.scala b/core/src/main/scala/org/apache/spark/ExecutorCleaner.scala new file mode 100644 index 0000000000000..29641572b31dc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ExecutorCleaner.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark + +import java.io.File + +import org.apache.spark.util.cleanup.{CleanupTask, CleanExternalList} +import org.apache.spark.util.collection.ExternalList + +/** + * Asynchronous cleaner for objects created on the Executor. So far + * only supports cleaning up ExternalList objects. Equivalent to ContextCleaner + * but for objects on the Executor heap. + */ +private[spark] class ExecutorCleaner extends WeakReferenceCleaner { + + def registerExternalListForCleanup(list: ExternalList[_]): Unit = { + registerForCleanup(list, CleanExternalList(list.getBackingFileLocations())) + } + + def doCleanExternalList(paths: Iterable[String]): Unit = { + paths.map(path => new File(path)).foreach(f => { + if (f.exists()) f.delete() + }) + } + + override protected def handleCleanupForSpecificTask(task: CleanupTask): Unit = { + task match { + case CleanExternalList(paths) => doCleanExternalList(paths) + case unknown => logWarning(s"Got cleanup task that cannot be handled by ExecutorCleaner: $unknown") + } + } + + override protected def cleanupThreadName(): String = "Executor Cleaner" +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index adfece4d6e7c0..9fbc8743c2df8 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -73,6 +73,7 @@ class SparkEnv ( val shuffleMemoryManager: ShuffleMemoryManager, val executorMemoryManager: ExecutorMemoryManager, val outputCommitCoordinator: OutputCommitCoordinator, + val executorCleaner: ExecutorCleaner, val conf: SparkConf) extends Logging { // TODO Remove actorSystem @@ -101,6 +102,7 @@ class SparkEnv ( blockManager.master.stop() metricsSystem.stop() outputCommitCoordinator.stop() + executorCleaner.stop() rpcEnv.shutdown() // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut @@ -398,6 +400,8 @@ object SparkEnv extends Logging { } new ExecutorMemoryManager(allocator) } + val executorCleaner = new ExecutorCleaner + executorCleaner.start() val envInstance = new SparkEnv( executorId, @@ -417,6 +421,7 @@ object SparkEnv extends Logging { shuffleMemoryManager, executorMemoryManager, outputCommitCoordinator, + executorCleaner, conf) // Add a reference to tmp dir created by driver, we will delete this tmp dir when stop() is diff --git a/core/src/main/scala/org/apache/spark/WeakReferenceCleaner.scala b/core/src/main/scala/org/apache/spark/WeakReferenceCleaner.scala new file mode 100644 index 0000000000000..0dd6d4773dcb6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/WeakReferenceCleaner.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark + +import java.lang.ref.ReferenceQueue + +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} + +import org.apache.spark.util.cleanup.{CleanupTask, CleanupTaskWeakReference} + +/** + * Utility trait that keeps a long running thread for cleaning up weak references + * after they are GCed. Currently implemented by ContextCleaner and ExecutorCleaner + * only. + */ +private[spark] trait WeakReferenceCleaner extends Logging { + + private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference] + with SynchronizedBuffer[CleanupTaskWeakReference] + + private val referenceQueue = new ReferenceQueue[AnyRef] + + private val cleaningThread = new Thread() { override def run() { keepCleaning() }} + + private var stopped = false + + /** Start the cleaner. */ + def start(): Unit = { + cleaningThread.setDaemon(true) + cleaningThread.setName(cleanupThreadName()) + cleaningThread.start() + } + + def stop(): Unit = { + stopped = true + synchronized { + // Interrupt the cleaning thread, but wait until the current task has finished before + // doing so. This guards against the race condition where a cleaning thread may + // potentially clean similarly named variables created by a different SparkContext, + // resulting in otherwise inexplicable block-not-found exceptions (SPARK-6132). + cleaningThread.interrupt() + } + cleaningThread.join() + } + + protected def keepCleaning(): Unit = { + while (!stopped) { + try { + val reference = Option(referenceQueue.remove(WeakReferenceCleaner.REF_QUEUE_POLL_TIMEOUT)) + .map(_.asInstanceOf[CleanupTaskWeakReference]) + // Synchronize here to avoid being interrupted on stop() + synchronized { + reference.map(_.task).foreach { task => + logDebug("Got cleaning task " + task) + referenceBuffer -= reference.get + handleCleanupForSpecificTask(task) + } + } + } catch { + case ie: InterruptedException if stopped => // ignore + case e: Exception => logError("Error in cleaning thread", e) + } + } + } + + /** Register an object for cleanup. */ + protected def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = { + referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue) + } + + protected def handleCleanupForSpecificTask(task: CleanupTask) + protected def cleanupThreadName(): String +} + +private object WeakReferenceCleaner { + private val REF_QUEUE_POLL_TIMEOUT = 100 +} diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 7fd107fde60e9..1acd994cd9d86 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -103,7 +103,7 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer()) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) - kryo.register(classOf[ExternalList[Any]], new ExternalListSerializer[Any]()) + kryo.register(classOf[ExternalList[Any]], new ExternalList.ExternalListSerializer[Any]()) kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas)) kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas)) diff --git a/core/src/main/scala/org/apache/spark/util/cleanup/CleanupTasks.scala b/core/src/main/scala/org/apache/spark/util/cleanup/CleanupTasks.scala new file mode 100644 index 0000000000000..e0fb9e131de33 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/cleanup/CleanupTasks.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util.cleanup + +import java.lang.ref.{ReferenceQueue, WeakReference} + +/** + * Classes that represent cleaning tasks. + */ +private[spark] sealed trait CleanupTask +private[spark] case class CleanRDD(rddId: Int) extends CleanupTask +private[spark] case class CleanShuffle(shuffleId: Int) extends CleanupTask +private[spark] case class CleanBroadcast(broadcastId: Long) extends CleanupTask +private[spark] case class CleanAccum(accId: Long) extends CleanupTask +private[spark] case class CleanCheckpoint(rddId: Int) extends CleanupTask +private[spark] case class CleanExternalList(pathsToClean: Iterable[String]) extends CleanupTask + +/** + * A WeakReference associated with a CleanupTask. + * + * When the referent object becomes only weakly reachable, the corresponding + * CleanupTaskWeakReference is automatically added to the given reference queue. + */ +private[spark] class CleanupTaskWeakReference( + val task: CleanupTask, + referent: AnyRef, + referenceQueue: ReferenceQueue[AnyRef]) + extends WeakReference(referent, referenceQueue) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala index 110b304cc0f4d..28467b3d87d62 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala @@ -18,6 +18,9 @@ package org.apache.spark.util.collection import java.io._ +import org.apache.spark.util.TaskCompletionListener +import org.apache.spark.{TaskContext, ExecutorCleaner, SparkEnv} + import scala.reflect.ClassTag import scala.collection.generic.Growable import scala.collection.mutable.ArrayBuffer @@ -25,6 +28,7 @@ import scala.collection.mutable.ArrayBuffer import com.esotericsoftware.kryo.io.{Output, Input} import com.esotericsoftware.kryo.{Kryo, Serializer => KSerializer} +import org.apache.spark.util.collection.ExternalList._ import org.apache.spark.serializer.DeserializationStream import org.apache.spark.storage.{DiskBlockObjectWriter, BlockId} @@ -34,37 +38,58 @@ import org.apache.spark.storage.{DiskBlockObjectWriter, BlockId} * Implementation is based heavily on `org.apache.spark.util.collection.ExternalAppendOnlyMap}` */ @SerialVersionUID(1L) -private[spark] class ExternalList[T](implicit private var tag: ClassTag[T]) +private[spark] class ExternalList[T](implicit var tag: ClassTag[T]) extends Growable[T] with Iterable[T] with SpillableCollection[T, SizeTrackingCompactBuffer[T]] with Serializable { - // Lazy vals so that this isn't created multiple times but still can be re-instantiated properly - // after serialization - private lazy val spilledLists = new ArrayBuffer[Iterable[T]] - - private var list = new SizeTrackingCompactBuffer[T]() + // Var to allow rebuilding it during Java serialization + private var spilledLists = new ArrayBuffer[DiskListIterable] + private var currentInMemoryList = new SizeTrackingCompactBuffer[T]() private var numItems = 0 + // We don't know up front what files will need to be cleaned up from this list. + // So check after the task is completed, after which this ExternalList will be + // completely built. + private var context = TaskContext.get + if (context != null) { + context.addTaskCompletionListener(new ScheduleCleanExternalList(this)) + } + override def size() = numItems override def +=(value: T) = { - list += value - if (maybeSpill(list, list.estimateSize())) { - list = new SizeTrackingCompactBuffer + currentInMemoryList += value + if (maybeSpill(currentInMemoryList, currentInMemoryList.estimateSize())) { + currentInMemoryList = new SizeTrackingCompactBuffer } numItems += 1 this } override def clear(): Unit = { + spilledLists.foreach(_.deleteBackingFile()) spilledLists.clear() - list = new SizeTrackingCompactBuffer[T]() + currentInMemoryList = new SizeTrackingCompactBuffer[T]() + } + + def getBackingFileLocations(): Iterable[String] = { + val locations = new ArrayBuffer[String] + for (diskList <- spilledLists) { + locations.append(diskList.backingFilePath()) + } + return locations + } + + def registerForCleanup(): Unit = { + if (spilledLists.size > 0) { + executorCleaner.registerExternalListForCleanup(this) + } } override def iterator: Iterator[T] = { - val myIt = list.iterator + val myIt = currentInMemoryList.iterator val allIts = spilledLists.map(_.iterator) ++ Seq(myIt) allIts.foldLeft(Iterator[T]())(_ ++ _) } @@ -74,6 +99,12 @@ private[spark] class ExternalList[T](implicit private var tag: ClassTag[T]) override def iterator: Iterator[T] = { new DiskListIterator(file, blockId, batchSizes) } + def deleteBackingFile(): Unit = { + if (file.exists()) { + file.delete() + } + } + def backingFilePath(): String = file.getAbsolutePath() } private class DiskListIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]) @@ -101,14 +132,23 @@ private[spark] class ExternalList[T](implicit private var tag: ClassTag[T]) private def readObject(stream: ObjectInputStream): Unit = { tag = stream.readObject().asInstanceOf[ClassTag[T]] val listSize = stream.readInt() - list = new SizeTrackingCompactBuffer[T] + spilledLists = new ArrayBuffer[DiskListIterable] + currentInMemoryList = new SizeTrackingCompactBuffer[T] for(i <- 0L until listSize) { val newItem = stream.readObject().asInstanceOf[T] this.+=(newItem) } + // Upon serialization, the context might have changed. So we can't just hold a single context, + // but we must retrieving the current context every time. + // Notice that in Kryo serialization this object is constructed from scratch + // and thus will look for the current TaskContext that way. + context = TaskContext.get() + if (context != null) { + context.addTaskCompletionListener(new ScheduleCleanExternalList(this)) + } } - override protected def getIteratorForCurrentSpillable(): Iterator[T] = list.iterator + override protected def getIteratorForCurrentSpillable(): Iterator[T] = currentInMemoryList.iterator override protected def recordNextSpilledPart(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]): Unit = { spilledLists += new DiskListIterable(file, blockId, batchSizes) } @@ -122,6 +162,18 @@ private[spark] class ExternalList[T](implicit private var tag: ClassTag[T]) * Java-serializing */ private[spark] object ExternalList { + + private class ScheduleCleanExternalList(private var list: ExternalList[_]) + extends TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = { + if (list != null) { + executorCleaner.registerExternalListForCleanup(list) + // Release reference to allow GC to clean it up + list = null + } + } + } + def apply[T: ClassTag](): ExternalList[T] = new ExternalList[T] def apply[T: ClassTag](value: T): ExternalList[T] = { @@ -129,6 +181,8 @@ private[spark] object ExternalList { buf += value buf } + + private val executorCleaner: ExecutorCleaner = SparkEnv.get.executorCleaner } private[spark] class ExternalListSerializer[T: ClassTag] extends KSerializer[ExternalList[T]] { diff --git a/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala index 75adef126e7cb..8a9c5caf7387c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala @@ -229,8 +229,7 @@ private[spark] trait SpillableCollection[C, T <: Iterable[C]] extends Spillable[ } } -// Visible and modifiable only for testing -private[collection] object SpillableCollection { +private object SpillableCollection { private def sparkConf(): SparkConf = SparkEnv.get.conf private def blockManager(): BlockManager = SparkEnv.get.blockManager private def diskBlockManager(): DiskBlockManager = blockManager.diskBlockManager diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala index 025f1bbd9c45a..d0334a30945d9 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala @@ -16,11 +16,19 @@ */ package org.apache.spark.util.collection +import java.io.File +import java.lang.ref.WeakReference + +import scala.language.existentials import scala.reflect.ClassTag -import org.apache.spark.{SparkContext, SparkConf, SparkFunSuite} +import org.apache.spark.{SparkEnv, SparkContext, SparkConf, SparkFunSuite} +import org.apache.spark.util.collection.ExternalListSuite._ import org.apache.spark.serializer.{KryoSerializer, JavaSerializer, SerializerInstance} -import org.junit.Assert.assertEquals + +import org.junit.Assert.{assertEquals, assertTrue, assertFalse} +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ class ExternalListSuite extends SparkFunSuite with Serializable { @@ -49,40 +57,91 @@ class ExternalListSuite extends SparkFunSuite with Serializable { list = new ExternalList[Int] // Test smaller list for Java serialization since serializing with Java is // really slow, and we already test serialization causing spilling in the Kryo case - for (i <- 0 to 1000) { + for (i <- 0 to 1000000) { list += i } testSerialization(serializer, list) } - test("Group by key with spilling list") { - val totalRddSize = 7200000 - val numBuckets = 5 - val rawLargeRdd = sparkContext.parallelize(1 to totalRddSize) + val totalRddSize = 7200000 + val numBuckets = 5 + val rawLargeRdd = sparkContext.parallelize(1 to totalRddSize) + test("Lists that are cached should be accessible twice, but when unpersisted are cleaned up.") { val groupedRdd = rawLargeRdd.map(x => (x % numBuckets, x)).groupByKey - def validateList(kv: (Int, Iterable[Int])): Unit = { - var numItems = 0 - for (valsInBucket <- kv._2) { - numItems += 1 - // Can't use scala assertions because including assert statements makes closures not serializable. - assertEquals(s"Value $valsInBucket should not be in bucket ${kv._1}", kv._1, valsInBucket % numBuckets) - } - assertEquals(s"Number of items in bucket ${kv._1} is incorrect.", totalRddSize / numBuckets, numItems) + val cachedRdd = groupedRdd.cache() + cachedRdd.foreach(validateList(totalRddSize, numBuckets, _)) + runGC() + // GC on the Cached RDD shouldn't trigger the cleanup + cachedRdd.foreach(validateList(totalRddSize, numBuckets, _)) + val filePaths = cachedRdd.map(_._2.asInstanceOf[ExternalList[Int]].getBackingFileLocations()).collect + filePaths.foreach(paths => { + paths.foreach(f => assertTrue(new File(f).exists())) + }) + cachedRdd.unpersist(true) + runGC() + checkFilesEventuallyRemoved(filePaths) + cachedRdd.foreach(validateList(totalRddSize, numBuckets, _)) + } + + test("List that is created in a task and released immediately should eventually clean up") { + val filePaths = rawLargeRdd + .map(x => (x % numBuckets, x)) + .groupByKey + .map(x => x._2.asInstanceOf[ExternalList[Int]].getBackingFileLocations()).collect + runGC() + checkFilesEventuallyRemoved(filePaths) + } + + private def checkFilesEventuallyRemoved(filePaths: Array[Iterable[String]]) { + eventually(timeout(15000 millis), interval(100 millis)) { + filePaths.foreach(paths => { + paths.foreach(f => assertFalse(new File(f).exists())) + }) + } + } + + /** Run GC and make sure it actually has run */ + private def runGC() { + val weakRef = new WeakReference(new Object()) + val startTime = System.currentTimeMillis + System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. + // Wait until a weak reference object has been GCed + while (System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { + System.gc() + Thread.sleep(200) } - groupedRdd.foreach(validateList(_)) } private def testSerialization[T: ClassTag]( serializer: SerializerInstance, list: ExternalList[T]): Unit = { val bytes = serializer.serialize(list) - val readList = serializer.deserialize(bytes).asInstanceOf[ExternalList[Int]] + var readList = serializer.deserialize(bytes).asInstanceOf[ExternalList[Int]] val originalIt = list.iterator - val readIt = readList.iterator + var readIt = readList.iterator while (originalIt.hasNext) { assert (originalIt.next == readIt.next) } assert (!readIt.hasNext) + val filePaths = readList.getBackingFileLocations() + readList = null + readIt = null + runGC() + eventually(timeout(15000 millis), interval(100 millis)) { + filePaths.foreach(path => assertFalse(new File(path).exists())) + } } } + +object ExternalListSuite { + def validateList(totalRddSize: Int, numBuckets: Int, kv: (Int, Iterable[Int])): Unit = { + var numItems = 0 + for (valsInBucket <- kv._2) { + numItems += 1 + // Can't use scala assertions because including assert statements makes closures not serializable. + assertEquals(s"Value $valsInBucket should not be in bucket ${kv._1}", kv._1, valsInBucket % numBuckets) + } + assertEquals(s"Number of items in bucket ${kv._1} is incorrect.", totalRddSize / numBuckets, numItems) + } +} From 3c0d2e55210735e0df2f8febb5f63c224af230e3 Mon Sep 17 00:00:00 2001 From: Meihua Wu Date: Fri, 31 Jul 2015 13:01:10 -0700 Subject: [PATCH 044/340] [SPARK-9246] [MLLIB] DistributedLDAModel predict top docs per topic Add topDocumentsPerTopic to DistributedLDAModel. Add ScalaDoc and unit tests. Author: Meihua Wu Closes #7769 from rotationsymmetry/SPARK-9246 and squashes the following commits: 1029e79c [Meihua Wu] clean up code comments a023b82 [Meihua Wu] Update tests to use Long for doc index. 91e5998 [Meihua Wu] Use Long for doc index. b9f70cf [Meihua Wu] Revise topDocumentsPerTopic 26ff3f6 [Meihua Wu] Add topDocumentsPerTopic, scala doc and unit tests --- .../spark/mllib/clustering/LDAModel.scala | 37 +++++++++++++++++++ .../spark/mllib/clustering/LDASuite.scala | 22 +++++++++++ 2 files changed, 59 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index ff7035d2246c2..0cdac84eeb591 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -516,6 +516,43 @@ class DistributedLDAModel private[clustering] ( } } + /** + * Return the top documents for each topic + * + * This is approximate; it may not return exactly the top-weighted documents for each topic. + * To get a more precise set of top documents, increase maxDocumentsPerTopic. + * + * @param maxDocumentsPerTopic Maximum number of documents to collect for each topic. + * @return Array over topics. Each element represent as a pair of matching arrays: + * (IDs for the documents, weights of the topic in these documents). + * For each topic, documents are sorted in order of decreasing topic weights. + */ + def topDocumentsPerTopic(maxDocumentsPerTopic: Int): Array[(Array[Long], Array[Double])] = { + val numTopics = k + val topicsInQueues: Array[BoundedPriorityQueue[(Double, Long)]] = + topicDistributions.mapPartitions { docVertices => + // For this partition, collect the most common docs for each topic in queues: + // queues(topic) = queue of (doc topic, doc ID). + val queues = + Array.fill(numTopics)(new BoundedPriorityQueue[(Double, Long)](maxDocumentsPerTopic)) + for ((docId, docTopics) <- docVertices) { + var topic = 0 + while (topic < numTopics) { + queues(topic) += (docTopics(topic) -> docId) + topic += 1 + } + } + Iterator(queues) + }.treeReduce { (q1, q2) => + q1.zip(q2).foreach { case (a, b) => a ++= b } + q1 + } + topicsInQueues.map { q => + val (docTopics, docs) = q.toArray.sortBy(-_._1).unzip + (docs.toArray, docTopics.toArray) + } + } + // TODO // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 79d2a1cafd1fa..f2b94707fd0ff 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -122,6 +122,28 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { // Check: log probabilities assert(model.logLikelihood < 0.0) assert(model.logPrior < 0.0) + + // Check: topDocumentsPerTopic + // Compare it with top documents per topic derived from topicDistributions + val topDocsByTopicDistributions = { n: Int => + Range(0, k).map { topic => + val (doc, docWeights) = topicDistributions.sortBy(-_._2(topic)).take(n).unzip + (doc.toArray, docWeights.map(_(topic)).toArray) + }.toArray + } + + // Top 3 documents per topic + model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach {case (t1, t2) => + assert(t1._1 === t2._1) + assert(t1._2 === t2._2) + } + + // All documents per topic + val q = tinyCorpus.length + model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach {case (t1, t2) => + assert(t1._1 === t2._1) + assert(t1._2 === t2._2) + } } test("vertex indexing") { From 060c79aab58efd4ce7353a1b00534de0d9e1de0b Mon Sep 17 00:00:00 2001 From: Sameer Abhyankar Date: Fri, 31 Jul 2015 13:08:55 -0700 Subject: [PATCH 045/340] [SPARK-9056] [STREAMING] Rename configuration `spark.streaming.minRememberDuration` to `spark.streaming.fileStream.minRememberDuration` Rename configuration `spark.streaming.minRememberDuration` to `spark.streaming.fileStream.minRememberDuration` Author: Sameer Abhyankar Author: Sameer Abhyankar Closes #7740 from sabhyankar/spark_branch_9056 and squashes the following commits: d5b2f1f [Sameer Abhyankar] Correct deprecated version to 1.5 1268133 [Sameer Abhyankar] Add {} and indentation ddf9844 [Sameer Abhyankar] Change 4 space indentation to 2 space indentation 1819b5f [Sameer Abhyankar] Use spark.streaming.fileStream.minRememberDuration property in lieu of spark.streaming.minRememberDuration --- core/src/main/scala/org/apache/spark/SparkConf.scala | 4 +++- .../apache/spark/streaming/dstream/FileInputDStream.scala | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 4161792976c7b..08bab4bf2739f 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -548,7 +548,9 @@ private[spark] object SparkConf extends Logging { "spark.rpc.askTimeout" -> Seq( AlternateConfig("spark.akka.askTimeout", "1.4")), "spark.rpc.lookupTimeout" -> Seq( - AlternateConfig("spark.akka.lookupTimeout", "1.4")) + AlternateConfig("spark.akka.lookupTimeout", "1.4")), + "spark.streaming.fileStream.minRememberDuration" -> Seq( + AlternateConfig("spark.streaming.minRememberDuration", "1.5")) ) /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index dd4da9d9ca6a2..c358f5b5bd70b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -86,8 +86,10 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( * Files with mod times older than this "window" of remembering will be ignored. So if new * files are visible within this window, then the file will get selected in the next batch. */ - private val minRememberDurationS = - Seconds(ssc.conf.getTimeAsSeconds("spark.streaming.minRememberDuration", "60s")) + private val minRememberDurationS = { + Seconds(ssc.conf.getTimeAsSeconds("spark.streaming.fileStream.minRememberDuration", + ssc.conf.get("spark.streaming.minRememberDuration", "60s"))) + } // This is a def so that it works during checkpoint recovery: private def clock = ssc.scheduler.clock From fbef566a107b47e5fddde0ea65b8587d5039062d Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 31 Jul 2015 13:11:42 -0700 Subject: [PATCH 046/340] [SPARK-9308] [ML] ml.NaiveBayesModel support predicting class probabilities Make NaiveBayesModel support predicting class probabilities, inherit from ProbabilisticClassificationModel. Author: Yanbo Liang Closes #7672 from yanboliang/spark-9308 and squashes the following commits: 25e224c [Yanbo Liang] raw2probabilityInPlace should operate in-place 3ee56d6 [Yanbo Liang] change predictRaw and raw2probabilityInPlace c07e7a2 [Yanbo Liang] ml.NaiveBayesModel support predicting class probabilities --- .../spark/ml/classification/NaiveBayes.scala | 65 ++++++++++++++----- .../ml/classification/NaiveBayesSuite.scala | 54 ++++++++++++++- 2 files changed, 101 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 5be35fe209291..b46b676204e0e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -69,7 +69,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams { * The input feature values must be nonnegative. */ class NaiveBayes(override val uid: String) - extends Predictor[Vector, NaiveBayes, NaiveBayesModel] + extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] with NaiveBayesParams { def this() = this(Identifiable.randomUID("nb")) @@ -106,7 +106,7 @@ class NaiveBayesModel private[ml] ( override val uid: String, val pi: Vector, val theta: Matrix) - extends PredictionModel[Vector, NaiveBayesModel] with NaiveBayesParams { + extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams { import OldNaiveBayes.{Bernoulli, Multinomial} @@ -129,29 +129,62 @@ class NaiveBayesModel private[ml] ( throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") } - override protected def predict(features: Vector): Double = { + override val numClasses: Int = pi.size + + private def multinomialCalculation(features: Vector) = { + val prob = theta.multiply(features) + BLAS.axpy(1.0, pi, prob) + prob + } + + private def bernoulliCalculation(features: Vector) = { + features.foreachActive((_, value) => + if (value != 0.0 && value != 1.0) { + throw new SparkException( + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features.") + } + ) + val prob = thetaMinusNegTheta.get.multiply(features) + BLAS.axpy(1.0, pi, prob) + BLAS.axpy(1.0, negThetaSum.get, prob) + prob + } + + override protected def predictRaw(features: Vector): Vector = { $(modelType) match { case Multinomial => - val prob = theta.multiply(features) - BLAS.axpy(1.0, pi, prob) - prob.argmax + multinomialCalculation(features) case Bernoulli => - features.foreachActive{ (index, value) => - if (value != 0.0 && value != 1.0) { - throw new SparkException( - s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features") - } - } - val prob = thetaMinusNegTheta.get.multiply(features) - BLAS.axpy(1.0, pi, prob) - BLAS.axpy(1.0, negThetaSum.get, prob) - prob.argmax + bernoulliCalculation(features) case _ => // This should never happen. throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") } } + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction match { + case dv: DenseVector => + var i = 0 + val size = dv.size + val maxLog = dv.values.max + while (i < size) { + dv.values(i) = math.exp(dv.values(i) - maxLog) + i += 1 + } + val probSum = dv.values.sum + i = 0 + while (i < size) { + dv.values(i) = dv.values(i) / probSum + i += 1 + } + dv + case sv: SparseVector => + throw new RuntimeException("Unexpected error in NaiveBayesModel:" + + " raw2probabilityInPlace encountered SparseVector") + } + } + override def copy(extra: ParamMap): NaiveBayesModel = { copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 264bde3703c5f..aea3d9b694490 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -17,8 +17,11 @@ package org.apache.spark.ml.classification +import breeze.linalg.{Vector => BV} + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.classification.NaiveBayes import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -28,6 +31,8 @@ import org.apache.spark.sql.Row class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { + import NaiveBayes.{Multinomial, Bernoulli} + def validatePrediction(predictionAndLabels: DataFrame): Unit = { val numOfErrorPredictions = predictionAndLabels.collect().count { case Row(prediction: Double, label: Double) => @@ -46,6 +51,43 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.theta.map(math.exp) ~== thetaData.map(math.exp) absTol 0.05, "theta mismatch") } + def expectedMultinomialProbabilities(model: NaiveBayesModel, feature: Vector): Vector = { + val logClassProbs: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze + val classProbs = logClassProbs.toArray.map(math.exp) + val classProbsSum = classProbs.sum + Vectors.dense(classProbs.map(_ / classProbsSum)) + } + + def expectedBernoulliProbabilities(model: NaiveBayesModel, feature: Vector): Vector = { + val negThetaMatrix = model.theta.map(v => math.log(1.0 - math.exp(v))) + val negFeature = Vectors.dense(feature.toArray.map(v => 1.0 - v)) + val piTheta: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze + val logClassProbs: BV[Double] = piTheta + negThetaMatrix.multiply(negFeature).toBreeze + val classProbs = logClassProbs.toArray.map(math.exp) + val classProbsSum = classProbs.sum + Vectors.dense(classProbs.map(_ / classProbsSum)) + } + + def validateProbabilities( + featureAndProbabilities: DataFrame, + model: NaiveBayesModel, + modelType: String): Unit = { + featureAndProbabilities.collect().foreach { + case Row(features: Vector, probability: Vector) => { + assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10) + val expected = modelType match { + case Multinomial => + expectedMultinomialProbabilities(model, features) + case Bernoulli => + expectedBernoulliProbabilities(model, features) + case _ => + throw new UnknownError(s"Invalid modelType: $modelType.") + } + assert(probability ~== expected relTol 1.0e-10) + } + } + } + test("params") { ParamsSuite.checkParams(new NaiveBayes) val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)), @@ -83,9 +125,13 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 17, "multinomial")) - val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") + val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") validatePrediction(predictionAndLabels) + + val featureAndProbabilities = model.transform(validationDataset) + .select("features", "probability") + validateProbabilities(featureAndProbabilities, model, "multinomial") } test("Naive Bayes Bernoulli") { @@ -109,8 +155,12 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 20, "bernoulli")) - val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") + val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") validatePrediction(predictionAndLabels) + + val featureAndProbabilities = model.transform(validationDataset) + .select("features", "probability") + validateProbabilities(featureAndProbabilities, model, "bernoulli") } } From 815c8245f47e61226a04e2e02f508457b5e9e536 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 31 Jul 2015 13:45:12 -0700 Subject: [PATCH 047/340] [SPARK-9466] [SQL] Increate two timeouts in CliSuite. Hopefully this can resolve the flakiness of this suite. JIRA: https://issues.apache.org/jira/browse/SPARK-9466 Author: Yin Huai Closes #7777 from yhuai/SPARK-9466 and squashes the following commits: e0e3a86 [Yin Huai] Increate the timeout. --- .../org/apache/spark/sql/hive/thriftserver/CliSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 13b0c5951dddc..df80d04b40801 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -137,7 +137,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { } test("Single command with --database") { - runCliWithin(1.minute)( + runCliWithin(2.minute)( "CREATE DATABASE hive_test_db;" -> "OK", "USE hive_test_db;" @@ -148,7 +148,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { -> "Time taken: " ) - runCliWithin(1.minute, Seq("--database", "hive_test_db", "-e", "SHOW TABLES;"))( + runCliWithin(2.minute, Seq("--database", "hive_test_db", "-e", "SHOW TABLES;"))( "" -> "OK", "" From 873ab0f9692d8ea6220abdb8d9200041068372a8 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 31 Jul 2015 13:45:28 -0700 Subject: [PATCH 048/340] [SPARK-9490] [DOCS] [MLLIB] MLlib evaluation metrics guide example python code uses deprecated print statement Use print(x) not print x for Python 3 in eval examples CC sethah mengxr -- just wanted to close this out before 1.5 Author: Sean Owen Closes #7822 from srowen/SPARK-9490 and squashes the following commits: 01abeba [Sean Owen] Change "print x" to "print(x)" in the rest of the docs too bd7f7fb [Sean Owen] Use print(x) not print x for Python 3 in eval examples --- docs/ml-guide.md | 2 +- docs/mllib-evaluation-metrics.md | 66 ++++++++++++++--------------- docs/mllib-feature-extraction.md | 2 +- docs/mllib-statistics.md | 20 ++++----- docs/quick-start.md | 2 +- docs/sql-programming-guide.md | 6 +-- docs/streaming-programming-guide.md | 2 +- 7 files changed, 50 insertions(+), 50 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 8c46adf256a9a..b6ca50e98db02 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -561,7 +561,7 @@ test = sc.parallelize([(4L, "spark i j k"), prediction = model.transform(test) selected = prediction.select("id", "text", "prediction") for row in selected.collect(): - print row + print(row) sc.stop() {% endhighlight %} diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index 4ca0bb06b26a6..7066d5c97418c 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -302,10 +302,10 @@ predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp metrics = BinaryClassificationMetrics(predictionAndLabels) # Area under precision-recall curve -print "Area under PR = %s" % metrics.areaUnderPR +print("Area under PR = %s" % metrics.areaUnderPR) # Area under ROC curve -print "Area under ROC = %s" % metrics.areaUnderROC +print("Area under ROC = %s" % metrics.areaUnderROC) {% endhighlight %} @@ -606,24 +606,24 @@ metrics = MulticlassMetrics(predictionAndLabels) precision = metrics.precision() recall = metrics.recall() f1Score = metrics.fMeasure() -print "Summary Stats" -print "Precision = %s" % precision -print "Recall = %s" % recall -print "F1 Score = %s" % f1Score +print("Summary Stats") +print("Precision = %s" % precision) +print("Recall = %s" % recall) +print("F1 Score = %s" % f1Score) # Statistics by class labels = data.map(lambda lp: lp.label).distinct().collect() for label in sorted(labels): - print "Class %s precision = %s" % (label, metrics.precision(label)) - print "Class %s recall = %s" % (label, metrics.recall(label)) - print "Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0)) + print("Class %s precision = %s" % (label, metrics.precision(label))) + print("Class %s recall = %s" % (label, metrics.recall(label))) + print("Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0))) # Weighted stats -print "Weighted recall = %s" % metrics.weightedRecall -print "Weighted precision = %s" % metrics.weightedPrecision -print "Weighted F(1) Score = %s" % metrics.weightedFMeasure() -print "Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5) -print "Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate +print("Weighted recall = %s" % metrics.weightedRecall) +print("Weighted precision = %s" % metrics.weightedPrecision) +print("Weighted F(1) Score = %s" % metrics.weightedFMeasure()) +print("Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5)) +print("Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate) {% endhighlight %} @@ -881,28 +881,28 @@ scoreAndLabels = sc.parallelize([ metrics = MultilabelMetrics(scoreAndLabels) # Summary stats -print "Recall = %s" % metrics.recall() -print "Precision = %s" % metrics.precision() -print "F1 measure = %s" % metrics.f1Measure() -print "Accuracy = %s" % metrics.accuracy +print("Recall = %s" % metrics.recall()) +print("Precision = %s" % metrics.precision()) +print("F1 measure = %s" % metrics.f1Measure()) +print("Accuracy = %s" % metrics.accuracy) # Individual label stats labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect() for label in labels: - print "Class %s precision = %s" % (label, metrics.precision(label)) - print "Class %s recall = %s" % (label, metrics.recall(label)) - print "Class %s F1 Measure = %s" % (label, metrics.f1Measure(label)) + print("Class %s precision = %s" % (label, metrics.precision(label))) + print("Class %s recall = %s" % (label, metrics.recall(label))) + print("Class %s F1 Measure = %s" % (label, metrics.f1Measure(label))) # Micro stats -print "Micro precision = %s" % metrics.microPrecision -print "Micro recall = %s" % metrics.microRecall -print "Micro F1 measure = %s" % metrics.microF1Measure +print("Micro precision = %s" % metrics.microPrecision) +print("Micro recall = %s" % metrics.microRecall) +print("Micro F1 measure = %s" % metrics.microF1Measure) # Hamming loss -print "Hamming loss = %s" % metrics.hammingLoss +print("Hamming loss = %s" % metrics.hammingLoss) # Subset accuracy -print "Subset accuracy = %s" % metrics.subsetAccuracy +print("Subset accuracy = %s" % metrics.subsetAccuracy) {% endhighlight %} @@ -1283,10 +1283,10 @@ scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1]) metrics = RegressionMetrics(scoreAndLabels) # Root mean sqaured error -print "RMSE = %s" % metrics.rootMeanSquaredError +print("RMSE = %s" % metrics.rootMeanSquaredError) # R-squared -print "R-squared = %s" % metrics.r2 +print("R-squared = %s" % metrics.r2) {% endhighlight %} @@ -1479,17 +1479,17 @@ valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.l metrics = RegressionMetrics(valuesAndPreds) # Squared Error -print "MSE = %s" % metrics.meanSquaredError -print "RMSE = %s" % metrics.rootMeanSquaredError +print("MSE = %s" % metrics.meanSquaredError) +print("RMSE = %s" % metrics.rootMeanSquaredError) # R-squared -print "R-squared = %s" % metrics.r2 +print("R-squared = %s" % metrics.r2) # Mean absolute error -print "MAE = %s" % metrics.meanAbsoluteError +print("MAE = %s" % metrics.meanAbsoluteError) # Explained variance -print "Explained variance = %s" % metrics.explainedVariance +print("Explained variance = %s" % metrics.explainedVariance) {% endhighlight %} diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index a69e41e2a1936..de86aba2ae627 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -221,7 +221,7 @@ model = word2vec.fit(inp) synonyms = model.findSynonyms('china', 40) for word, cosine_distance in synonyms: - print "{}: {}".format(word, cosine_distance) + print("{}: {}".format(word, cosine_distance)) {% endhighlight %} diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index de5d6485f9b5f..be04d0b4b53a8 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -95,9 +95,9 @@ mat = ... # an RDD of Vectors # Compute column summary statistics. summary = Statistics.colStats(mat) -print summary.mean() -print summary.variance() -print summary.numNonzeros() +print(summary.mean()) +print(summary.variance()) +print(summary.numNonzeros()) {% endhighlight %} @@ -183,12 +183,12 @@ seriesY = ... # must have the same number of partitions and cardinality as serie # Compute the correlation using Pearson's method. Enter "spearman" for Spearman's method. If a # method is not specified, Pearson's method will be used by default. -print Statistics.corr(seriesX, seriesY, method="pearson") +print(Statistics.corr(seriesX, seriesY, method="pearson")) data = ... # an RDD of Vectors # calculate the correlation matrix using Pearson's method. Use "spearman" for Spearman's method. # If a method is not specified, Pearson's method will be used by default. -print Statistics.corr(data, method="pearson") +print(Statistics.corr(data, method="pearson")) {% endhighlight %} @@ -398,14 +398,14 @@ vec = Vectors.dense(...) # a vector composed of the frequencies of events # compute the goodness of fit. If a second vector to test against is not supplied as a parameter, # the test runs against a uniform distribution. goodnessOfFitTestResult = Statistics.chiSqTest(vec) -print goodnessOfFitTestResult # summary of the test including the p-value, degrees of freedom, - # test statistic, the method used, and the null hypothesis. +print(goodnessOfFitTestResult) # summary of the test including the p-value, degrees of freedom, + # test statistic, the method used, and the null hypothesis. mat = Matrices.dense(...) # a contingency matrix # conduct Pearson's independence test on the input contingency matrix independenceTestResult = Statistics.chiSqTest(mat) -print independenceTestResult # summary of the test including the p-value, degrees of freedom... +print(independenceTestResult) # summary of the test including the p-value, degrees of freedom... obs = sc.parallelize(...) # LabeledPoint(feature, label) . @@ -415,8 +415,8 @@ obs = sc.parallelize(...) # LabeledPoint(feature, label) . featureTestResults = Statistics.chiSqTest(obs) for i, result in enumerate(featureTestResults): - print "Column $d:" % (i + 1) - print result + print("Column $d:" % (i + 1)) + print(result) {% endhighlight %} diff --git a/docs/quick-start.md b/docs/quick-start.md index bb39e4111f244..ce2cc9d2169cd 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -406,7 +406,7 @@ logData = sc.textFile(logFile).cache() numAs = logData.filter(lambda s: 'a' in s).count() numBs = logData.filter(lambda s: 'b' in s).count() -print "Lines with a: %i, lines with b: %i" % (numAs, numBs) +print("Lines with a: %i, lines with b: %i" % (numAs, numBs)) {% endhighlight %} diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 95945eb7fc8a0..d31baa080cbce 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -570,7 +570,7 @@ teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 1 # The results of SQL queries are RDDs and support all the normal RDD operations. teenNames = teenagers.map(lambda p: "Name: " + p.name) for teenName in teenNames.collect(): - print teenName + print(teenName) {% endhighlight %} @@ -752,7 +752,7 @@ results = sqlContext.sql("SELECT name FROM people") # The results of SQL queries are RDDs and support all the normal RDD operations. names = results.map(lambda p: "Name: " + p.name) for name in names.collect(): - print name + print(name) {% endhighlight %} @@ -1006,7 +1006,7 @@ parquetFile.registerTempTable("parquetFile"); teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") teenNames = teenagers.map(lambda p: "Name: " + p.name) for teenName in teenNames.collect(): - print teenName + print(teenName) {% endhighlight %} diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 2f3013b533eb0..4663b3f14c527 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1525,7 +1525,7 @@ def getSqlContextInstance(sparkContext): words = ... # DStream of strings def process(time, rdd): - print "========= %s =========" % str(time) + print("========= %s =========" % str(time)) try: # Get the singleton instance of SQLContext sqlContext = getSqlContextInstance(rdd.context) From 6e5fd613ea4b9aa0ab485ba681277a51a4367168 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 31 Jul 2015 21:51:55 +0100 Subject: [PATCH 049/340] [SPARK-9507] [BUILD] Remove dependency reduced POM hack now that shade plugin is updated Update to shade plugin 2.4.1, which removes the need for the dependency-reduced-POM workaround and the 'release' profile. Fix management of shade plugin version so children inherit it; bump assembly plugin version while here See https://issues.apache.org/jira/browse/SPARK-8819 I verified that `mvn clean package -DskipTests` works with Maven 3.3.3. pwendell are you up for trying this for the 1.5.0 release? Author: Sean Owen Closes #7826 from srowen/SPARK-9507 and squashes the following commits: e0b0fd2 [Sean Owen] Update to shade plugin 2.4.1, which removes the need for the dependency-reduced-POM workaround and the 'release' profile. Fix management of shade plugin version so children inherit it; bump assembly plugin version while here --- dev/create-release/create-release.sh | 4 ++-- pom.xml | 33 +++++----------------------- 2 files changed, 8 insertions(+), 29 deletions(-) diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 86a7a4068c40e..4311c8c9e4ca6 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -118,13 +118,13 @@ if [[ ! "$@" =~ --skip-publish ]]; then rm -rf $SPARK_REPO - build/mvn -DskipTests -Pyarn -Phive -Prelease\ + build/mvn -DskipTests -Pyarn -Phive \ -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install ./dev/change-scala-version.sh 2.11 - build/mvn -DskipTests -Pyarn -Phive -Prelease\ + build/mvn -DskipTests -Pyarn -Phive \ -Dscala-2.11 -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install diff --git a/pom.xml b/pom.xml index e351c7c19df96..1371a1b6bd9f1 100644 --- a/pom.xml +++ b/pom.xml @@ -160,9 +160,6 @@ 2.4.4 1.1.1.7 1.1.2 - - false - ${java.home} - ${create.dependency.reduced.pom} @@ -1836,26 +1835,6 @@ - - - release - - - true - - - [Review on Reviewable](https://reviewable.io/reviews/apache/spark/7773) Author: Yin Huai Author: Josh Rosen Closes #7773 from JoshRosen/multi-way-join-planning-improvements and squashes the following commits: 5c45924 [Josh Rosen] Merge remote-tracking branch 'origin/master' into multi-way-join-planning-improvements cd8269b [Josh Rosen] Refactor test to use SQLTestUtils 2963857 [Yin Huai] Revert unnecessary SqlConf change. 73913f7 [Yin Huai] Add comments and test. Also, revert the change in ShuffledHashOuterJoin for now. 4a99204 [Josh Rosen] Delete unrelated expression change 884ab95 [Josh Rosen] Carve out only SPARK-2205 changes. 247e5fa [Josh Rosen] Merge remote-tracking branch 'origin/master' into multi-way-join-planning-improvements c57a954 [Yin Huai] Bug fix. d3d2e64 [Yin Huai] First round of cleanup. f9516b0 [Yin Huai] Style c6667e7 [Yin Huai] Add PartitioningCollection. e616d3b [Yin Huai] wip 7c2d2d8 [Yin Huai] Bug fix and refactoring. 69bb072 [Yin Huai] Introduce NullSafeHashPartitioning and NullUnsafePartitioning. d5b84c3 [Yin Huai] Do not add unnessary filters. 2201129 [Yin Huai] Filter out rows that will not be joined in equal joins early. --- .../plans/physical/partitioning.scala | 87 ++++++++++++++++--- .../sql/catalyst/DistributionSuite.scala | 2 +- .../apache/spark/sql/execution/Exchange.scala | 2 +- .../joins/BroadcastHashOuterJoin.scala | 4 +- .../sql/execution/joins/HashOuterJoin.scala | 9 -- .../execution/joins/LeftSemiJoinHash.scala | 6 +- .../execution/joins/ShuffledHashJoin.scala | 7 +- .../joins/ShuffledHashOuterJoin.scala | 10 ++- .../sql/execution/joins/SortMergeJoin.scala | 3 +- .../spark/sql/execution/PlannerSuite.scala | 49 ++++++++++- 10 files changed, 148 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index f4d1dbaf28efe..ec659ce789c27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -60,8 +60,9 @@ case class ClusteredDistribution(clustering: Seq[Expression]) extends Distributi /** * Represents data where tuples have been ordered according to the `ordering` * [[Expression Expressions]]. This is a strictly stronger guarantee than - * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the same value for - * the ordering expressions are contiguous and will never be split across partitions. + * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the + * same value for the ordering expressions are contiguous and will never be split across + * partitions. */ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { require( @@ -86,8 +87,12 @@ sealed trait Partitioning { */ def satisfies(required: Distribution): Boolean - /** Returns the expressions that are used to key the partitioning. */ - def keyExpressions: Seq[Expression] + /** + * Returns true iff we can say that the partitioning scheme of this [[Partitioning]] + * guarantees the same partitioning scheme described by `other`. + */ + // TODO: Add an example once we have the `nullSafe` concept. + def guarantees(other: Partitioning): Boolean } case class UnknownPartitioning(numPartitions: Int) extends Partitioning { @@ -96,7 +101,7 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning { case _ => false } - override def keyExpressions: Seq[Expression] = Nil + override def guarantees(other: Partitioning): Boolean = false } case object SinglePartition extends Partitioning { @@ -104,7 +109,10 @@ case object SinglePartition extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def keyExpressions: Seq[Expression] = Nil + override def guarantees(other: Partitioning): Boolean = other match { + case SinglePartition => true + case _ => false + } } case object BroadcastPartitioning extends Partitioning { @@ -112,7 +120,10 @@ case object BroadcastPartitioning extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def keyExpressions: Seq[Expression] = Nil + override def guarantees(other: Partitioning): Boolean = other match { + case BroadcastPartitioning => true + case _ => false + } } /** @@ -127,7 +138,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - private[this] lazy val clusteringSet = expressions.toSet + lazy val clusteringSet = expressions.toSet override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true @@ -136,7 +147,11 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } - override def keyExpressions: Seq[Expression] = expressions + override def guarantees(other: Partitioning): Boolean = other match { + case o: HashPartitioning => + this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions + case _ => false + } } /** @@ -170,5 +185,57 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case _ => false } - override def keyExpressions: Seq[Expression] = ordering.map(_.child) + override def guarantees(other: Partitioning): Boolean = other match { + case o: RangePartitioning => this == o + case _ => false + } +} + +/** + * A collection of [[Partitioning]]s that can be used to describe the partitioning + * scheme of the output of a physical operator. It is usually used for an operator + * that has multiple children. In this case, a [[Partitioning]] in this collection + * describes how this operator's output is partitioned based on expressions from + * a child. For example, for a Join operator on two tables `A` and `B` + * with a join condition `A.key1 = B.key2`, assuming we use HashPartitioning schema, + * there are two [[Partitioning]]s can be used to describe how the output of + * this Join operator is partitioned, which are `HashPartitioning(A.key1)` and + * `HashPartitioning(B.key2)`. It is also worth noting that `partitionings` + * in this collection do not need to be equivalent, which is useful for + * Outer Join operators. + */ +case class PartitioningCollection(partitionings: Seq[Partitioning]) + extends Expression with Partitioning with Unevaluable { + + require( + partitionings.map(_.numPartitions).distinct.length == 1, + s"PartitioningCollection requires all of its partitionings have the same numPartitions.") + + override def children: Seq[Expression] = partitionings.collect { + case expr: Expression => expr + } + + override def nullable: Boolean = false + + override def dataType: DataType = IntegerType + + override val numPartitions = partitionings.map(_.numPartitions).distinct.head + + /** + * Returns true if any `partitioning` of this collection satisfies the given + * [[Distribution]]. + */ + override def satisfies(required: Distribution): Boolean = + partitionings.exists(_.satisfies(required)) + + /** + * Returns true if any `partitioning` of this collection guarantees + * the given [[Partitioning]]. + */ + override def guarantees(other: Partitioning): Boolean = + partitionings.exists(_.guarantees(other)) + + override def toString: String = { + partitionings.map(_.toString).mkString("(", " or ", ")") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index c046dbf4dc2c9..827f7ce692712 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -42,7 +42,7 @@ class DistributionSuite extends SparkFunSuite { } } - test("HashPartitioning is the output partitioning") { + test("HashPartitioning (with nullSafe = true) is the output partitioning") { // Cases which do not need an exchange between two data properties. checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 6bd57f010a990..05b009d1935bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -209,7 +209,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ child: SparkPlan): SparkPlan = { def addShuffleIfNecessary(child: SparkPlan): SparkPlan = { - if (child.outputPartitioning != partitioning) { + if (!child.outputPartitioning.guarantees(partitioning)) { Exchange(partitioning, child) } else { child diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index 77e7fe71009b7..309716a0efcc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.util.ThreadUtils @@ -57,6 +57,8 @@ case class BroadcastHashOuterJoin( override def requiredChildDistribution: Seq[Distribution] = UnspecifiedDistribution :: UnspecifiedDistribution :: Nil + override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning + @transient private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 7e671e7914f1a..a323aea4ea2c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -22,7 +22,6 @@ import java.util.{HashMap => JavaHashMap} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.util.collection.CompactBuffer @@ -38,14 +37,6 @@ trait HashOuterJoin { val left: SparkPlan val right: SparkPlan - override def outputPartitioning: Partitioning = joinType match { - case LeftOuter => left.outputPartitioning - case RightOuter => right.outputPartitioning - case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) - case x => - throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") - } - override def output: Seq[Attribute] = { joinType match { case LeftOuter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 26a664104d6fb..68ccd34d8ed9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, ClusteredDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} /** @@ -37,7 +37,9 @@ case class LeftSemiJoinHash( right: SparkPlan, condition: Option[Expression]) extends BinaryNode with HashSemiJoin { - override def requiredChildDistribution: Seq[ClusteredDistribution] = + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 5439e10a60b2a..fc6efe87bceb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} /** @@ -38,9 +38,10 @@ case class ShuffledHashJoin( right: SparkPlan) extends BinaryNode with HashJoin { - override def outputPartitioning: Partitioning = left.outputPartitioning + override def outputPartitioning: Partitioning = + PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) - override def requiredChildDistribution: Seq[ClusteredDistribution] = + override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index d29b593207c4d..eee8ad800f98e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, ClusteredDistribution} +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -44,6 +44,14 @@ case class ShuffledHashOuterJoin( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + override def outputPartitioning: Partitioning = joinType match { + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case x => + throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") + } + protected override def doExecute(): RDD[InternalRow] = { val joinedRow = new JoinedRow() left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index bb18b5403f8e8..41be78afd37e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -40,7 +40,8 @@ case class SortMergeJoin( override def output: Seq[Attribute] = left.output ++ right.output - override def outputPartitioning: Partitioning = left.outputPartitioning + override def outputPartitioning: Partitioning = + PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 845ce669f0b33..18b0e54dc7c53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -23,14 +23,18 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.test.TestSQLContext.planner._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Row, SQLConf, execution} +import org.apache.spark.sql.{SQLContext, Row, SQLConf, execution} -class PlannerSuite extends SparkFunSuite { +class PlannerSuite extends SparkFunSuite with SQLTestUtils { + + override def sqlContext: SQLContext = TestSQLContext + private def testPartialAggregationPlan(query: LogicalPlan): Unit = { val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption) val planned = @@ -157,4 +161,45 @@ class PlannerSuite extends SparkFunSuite { val planned = planner.TakeOrderedAndProject(query) assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) } + + test("PartitioningCollection") { + withTempTable("normal", "small", "tiny") { + testData.registerTempTable("normal") + testData.limit(10).registerTempTable("small") + testData.limit(3).registerTempTable("tiny") + + // Disable broadcast join + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + { + val numExchanges = sql( + """ + |SELECT * + |FROM + | normal JOIN small ON (normal.key = small.key) + | JOIN tiny ON (small.key = tiny.key) + """.stripMargin + ).queryExecution.executedPlan.collect { + case exchange: Exchange => exchange + }.length + assert(numExchanges === 3) + } + + { + // This second query joins on different keys: + val numExchanges = sql( + """ + |SELECT * + |FROM + | normal JOIN small ON (normal.key = small.key) + | JOIN tiny ON (normal.key = tiny.key) + """.stripMargin + ).queryExecution.executedPlan.collect { + case exchange: Exchange => exchange + }.length + assert(numExchanges === 3) + } + + } + } + } } From 4cdd8ecd66769316e8593da7790b84cd867968cd Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sun, 2 Aug 2015 22:19:27 -0700 Subject: [PATCH 090/340] [SPARK-9536] [SPARK-9537] [SPARK-9538] [ML] [PYSPARK] ml.classification support raw and probability prediction for PySpark Make the following ml.classification class support raw and probability prediction for PySpark: ```scala NaiveBayesModel DecisionTreeClassifierModel LogisticRegressionModel ``` Author: Yanbo Liang Closes #7866 from yanboliang/spark-9536-9537 and squashes the following commits: 2934dab [Yanbo Liang] ml.NaiveBayes, ml.DecisionTreeClassifier and ml.LogisticRegression support probability prediction --- python/pyspark/ml/classification.py | 61 ++++++++++++++++++++--------- 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 93ffcd40949b3..b5814f76de000 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -31,7 +31,7 @@ @inherit_doc class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - HasRegParam, HasTol, HasProbabilityCol): + HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol): """ Logistic regression. @@ -42,13 +42,18 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti ... Row(label=0.0, features=Vectors.sparse(1, [], []))]).toDF() >>> lr = LogisticRegression(maxIter=5, regParam=0.01) >>> model = lr.fit(df) - >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() - >>> model.transform(test0).head().prediction - 0.0 >>> model.weights DenseVector([5.5...]) >>> model.intercept -2.68... + >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() + >>> result = model.transform(test0).head() + >>> result.prediction + 0.0 + >>> result.probability + DenseVector([0.99..., 0.00...]) + >>> result.rawPrediction + DenseVector([8.22..., -8.22...]) >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() >>> model.transform(test1).head().prediction 1.0 @@ -70,11 +75,11 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=0.5, probabilityCol="probability"): + threshold=0.5, probabilityCol="probability", rawPredictionCol="rawPrediction"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=0.5, probabilityCol="probability") + threshold=0.5, probabilityCol="probability", rawPredictionCol="rawPrediction") """ super(LogisticRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -98,11 +103,11 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=0.5, probabilityCol="probability"): + threshold=0.5, probabilityCol="probability", rawPredictionCol="rawPrediction"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=0.5, probabilityCol="probability") + threshold=0.5, probabilityCol="probability", rawPredictionCol="rawPrediction") Sets params for logistic regression. """ kwargs = self.setParams._input_kwargs @@ -187,7 +192,8 @@ class GBTParams(object): @inherit_doc class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - DecisionTreeParams, HasCheckpointInterval): + HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams, + HasCheckpointInterval): """ `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` learning algorithm for classification. @@ -209,8 +215,13 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> model.depth 1 >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) - >>> model.transform(test0).head().prediction + >>> result = model.transform(test0).head() + >>> result.prediction 0.0 + >>> result.probability + DenseVector([1.0, 0.0]) + >>> result.rawPrediction + DenseVector([1.0, 0.0]) >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 @@ -223,10 +234,12 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + probabilityCol="probability", rawPredictionCol="rawPrediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + probabilityCol="probability", rawPredictionCol="rawPrediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini") """ @@ -246,11 +259,13 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + probabilityCol="probability", rawPredictionCol="rawPrediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + probabilityCol="probability", rawPredictionCol="rawPrediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini") Sets params for the DecisionTreeClassifier. @@ -578,7 +593,8 @@ class GBTClassificationModel(TreeEnsembleModels): @inherit_doc -class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol): +class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, + HasRawPredictionCol): """ Naive Bayes Classifiers. @@ -595,8 +611,13 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol): >>> model.theta DenseMatrix(2, 2, [-1.09..., -0.40..., -0.40..., -1.09...], 1) >>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF() - >>> model.transform(test0).head().prediction + >>> result = model.transform(test0).head() + >>> result.prediction 1.0 + >>> result.probability + DenseVector([0.42..., 0.57...]) + >>> result.rawPrediction + DenseVector([-1.60..., -1.32...]) >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF() >>> model.transform(test1).head().prediction 1.0 @@ -610,10 +631,12 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol): @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", - smoothing=1.0, modelType="multinomial"): + probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, + modelType="multinomial"): """ - __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", - smoothing=1.0, modelType="multinomial") + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \ + modelType="multinomial") """ super(NaiveBayes, self).__init__() self._java_obj = self._new_java_obj( @@ -631,10 +654,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", - smoothing=1.0, modelType="multinomial"): + probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, + modelType="multinomial"): """ - setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", - smoothing=1.0, modelType="multinomial") + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \ + modelType="multinomial") Sets params for Naive Bayes. """ kwargs = self.setParams._input_kwargs From 687c8c37150f4c93f8e57d86bb56321a4891286b Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sun, 2 Aug 2015 23:32:09 -0700 Subject: [PATCH 091/340] [SPARK-9372] [SQL] Filter nulls in join keys This PR adds an optimization rule, `FilterNullsInJoinKey`, to add `Filter` before join operators to filter out rows having null values for join keys. This optimization is guarded by a new SQL conf, `spark.sql.advancedOptimization`. The code in this PR was authored by yhuai; I'm opening this PR to factor out this change from #7685, a larger pull request which contains two other optimizations. Author: Yin Huai Author: Josh Rosen Closes #7768 from JoshRosen/filter-nulls-in-join-key and squashes the following commits: c02fc3f [Yin Huai] Address Josh's comments. 0a8e096 [Yin Huai] Update comments. ea7d5a6 [Yin Huai] Make sure we do not keep adding filters. be88760 [Yin Huai] Make it clear that FilterNullsInJoinKeySuite.scala is used to test FilterNullsInJoinKey. 8bb39ad [Yin Huai] Fix non-deterministic tests. 303236b [Josh Rosen] Revert changes that are unrelated to null join key filtering 40eeece [Josh Rosen] Merge remote-tracking branch 'origin/master' into filter-nulls-in-join-key c57a954 [Yin Huai] Bug fix. d3d2e64 [Yin Huai] First round of cleanup. f9516b0 [Yin Huai] Style c6667e7 [Yin Huai] Add PartitioningCollection. e616d3b [Yin Huai] wip 7c2d2d8 [Yin Huai] Bug fix and refactoring. 69bb072 [Yin Huai] Introduce NullSafeHashPartitioning and NullUnsafePartitioning. d5b84c3 [Yin Huai] Do not add unnessary filters. 2201129 [Yin Huai] Filter out rows that will not be joined in equal joins early. --- .../catalyst/expressions/nullFunctions.scala | 48 +++- .../sql/catalyst/optimizer/Optimizer.scala | 64 +++-- .../plans/logical/basicOperators.scala | 32 ++- .../expressions/ExpressionEvalHelper.scala | 4 +- .../expressions/MathFunctionsSuite.scala | 3 +- .../expressions/NullFunctionsSuite.scala | 49 +++- .../spark/sql/DataFrameNaFunctions.scala | 2 +- .../scala/org/apache/spark/sql/SQLConf.scala | 6 + .../org/apache/spark/sql/SQLContext.scala | 5 +- .../extendedOperatorOptimizations.scala | 160 ++++++++++++ .../optimizer/FilterNullsInJoinKeySuite.scala | 236 ++++++++++++++++++ 11 files changed, 572 insertions(+), 37 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 287718fab7f0d..d58c4756938c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -210,14 +210,58 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { } } +/** + * A predicate that is evaluated to be true if there are at least `n` null values. + */ +case class AtLeastNNulls(n: Int, children: Seq[Expression]) extends Predicate { + override def nullable: Boolean = false + override def foldable: Boolean = children.forall(_.foldable) + override def toString: String = s"AtLeastNNulls($n, ${children.mkString(",")})" + + private[this] val childrenArray = children.toArray + + override def eval(input: InternalRow): Boolean = { + var numNulls = 0 + var i = 0 + while (i < childrenArray.length && numNulls < n) { + val evalC = childrenArray(i).eval(input) + if (evalC == null) { + numNulls += 1 + } + i += 1 + } + numNulls >= n + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val numNulls = ctx.freshName("numNulls") + val code = children.map { e => + val eval = e.gen(ctx) + s""" + if ($numNulls < $n) { + ${eval.code} + if (${eval.isNull}) { + $numNulls += 1; + } + } + """ + }.mkString("\n") + s""" + int $numNulls = 0; + $code + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = $numNulls >= $n; + """ + } +} /** * A predicate that is evaluated to be true if there are at least `n` non-null and non-NaN values. */ -case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate { +case class AtLeastNNonNullNans(n: Int, children: Seq[Expression]) extends Predicate { override def nullable: Boolean = false override def foldable: Boolean = children.forall(_.foldable) - override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})" + override def toString: String = s"AtLeastNNonNullNans($n, ${children.mkString(",")})" private[this] val childrenArray = children.toArray diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 29d706dcb39a7..e4b6294dc7b8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -31,8 +31,14 @@ import org.apache.spark.sql.types._ abstract class Optimizer extends RuleExecutor[LogicalPlan] -object DefaultOptimizer extends Optimizer { - val batches = +class DefaultOptimizer extends Optimizer { + + /** + * Override to provide additional rules for the "Operator Optimizations" batch. + */ + val extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil + + lazy val batches = // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: @@ -41,26 +47,27 @@ object DefaultOptimizer extends Optimizer { RemoveLiteralFromGroupExpressions) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down - SetOperationPushDown, - SamplePushDown, - PushPredicateThroughJoin, - PushPredicateThroughProject, - PushPredicateThroughGenerate, - ColumnPruning, + SetOperationPushDown :: + SamplePushDown :: + PushPredicateThroughJoin :: + PushPredicateThroughProject :: + PushPredicateThroughGenerate :: + ColumnPruning :: // Operator combine - ProjectCollapsing, - CombineFilters, - CombineLimits, + ProjectCollapsing :: + CombineFilters :: + CombineLimits :: // Constant folding - NullPropagation, - OptimizeIn, - ConstantFolding, - LikeSimplification, - BooleanSimplification, - RemovePositive, - SimplifyFilters, - SimplifyCasts, - SimplifyCaseConversionExpressions) :: + NullPropagation :: + OptimizeIn :: + ConstantFolding :: + LikeSimplification :: + BooleanSimplification :: + RemovePositive :: + SimplifyFilters :: + SimplifyCasts :: + SimplifyCaseConversionExpressions :: + extendedOperatorOptimizationRules.toList : _*) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), @@ -222,12 +229,18 @@ object ColumnPruning extends Rule[LogicalPlan] { } /** Applies a projection only when the child is producing unnecessary attributes */ - private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = + private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = { if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) { - Project(allReferences.filter(c.outputSet.contains).toSeq, c) + // We need to preserve the nullability of c's output. + // So, we first create a outputMap and if a reference is from the output of + // c, we use that output attribute from c. + val outputMap = AttributeMap(c.output.map(attr => (attr, attr))) + val projectList = allReferences.filter(outputMap.contains).map(outputMap).toSeq + Project(projectList, c) } else { c } + } } /** @@ -517,6 +530,13 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { */ object CombineFilters extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Filter(Not(AtLeastNNulls(1, e1)), Filter(Not(AtLeastNNulls(1, e2)), grandChild)) => + // If we are combining two expressions Not(AtLeastNNulls(1, e1)) and + // Not(AtLeastNNulls(1, e2)) + // (this is used to make sure there is no null in the result of e1 and e2 and + // they are added by FilterNullsInJoinKey optimziation rule), we can + // just create a Not(AtLeastNNulls(1, (e1 ++ e2).distinct)). + Filter(Not(AtLeastNNulls(1, (e1 ++ e2).distinct)), grandChild) case ff @ Filter(fc, nf @ Filter(nc, grandChild)) => Filter(And(nc, fc), grandChild) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index aacfc86ab0e49..54b5f49772664 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -86,7 +86,37 @@ case class Generate( } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output + /** + * Indicates if `atLeastNNulls` is used to check if atLeastNNulls.children + * have at least one null value and atLeastNNulls.children are all attributes. + */ + private def isAtLeastOneNullOutputAttributes(atLeastNNulls: AtLeastNNulls): Boolean = { + val expressions = atLeastNNulls.children + val n = atLeastNNulls.n + if (n != 1) { + // AtLeastNNulls is not used to check if atLeastNNulls.children have + // at least one null value. + false + } else { + // AtLeastNNulls is used to check if atLeastNNulls.children have + // at least one null value. We need to make sure all atLeastNNulls.children + // are attributes. + expressions.forall(_.isInstanceOf[Attribute]) + } + } + + override def output: Seq[Attribute] = condition match { + case Not(a: AtLeastNNulls) if isAtLeastOneNullOutputAttributes(a) => + // The condition is used to make sure that there is no null value in + // a.children. + val nonNullableAttributes = AttributeSet(a.children.asInstanceOf[Seq[Attribute]]) + child.output.map { + case attr if nonNullableAttributes.contains(attr) => + attr.withNullability(false) + case attr => attr + } + case _ => child.output + } } case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a41185b4d8754..3e55151298741 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -31,6 +31,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} trait ExpressionEvalHelper { self: SparkFunSuite => + protected val defaultOptimizer = new DefaultOptimizer + protected def create_row(values: Any*): InternalRow = { InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst)) } @@ -186,7 +188,7 @@ trait ExpressionEvalHelper { expected: Any, inputRow: InternalRow = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = DefaultOptimizer.execute(plan) + val optimizedPlan = defaultOptimizer.execute(plan) checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 9fcb548af6bbb..649a5b44dc036 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -23,7 +23,6 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types._ @@ -149,7 +148,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = DefaultOptimizer.execute(plan) + val optimizedPlan = defaultOptimizer.execute(plan) checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala index ace6c15dc8418..bf197124d8dbc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala @@ -77,7 +77,7 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } - test("AtLeastNNonNulls") { + test("AtLeastNNonNullNans") { val mix = Seq(Literal("x"), Literal.create(null, StringType), Literal.create(null, DoubleType), @@ -96,11 +96,46 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal(Float.MaxValue), Literal(false)) - checkEvaluation(AtLeastNNonNulls(2, mix), true, EmptyRow) - checkEvaluation(AtLeastNNonNulls(3, mix), false, EmptyRow) - checkEvaluation(AtLeastNNonNulls(3, nanOnly), true, EmptyRow) - checkEvaluation(AtLeastNNonNulls(4, nanOnly), false, EmptyRow) - checkEvaluation(AtLeastNNonNulls(3, nullOnly), true, EmptyRow) - checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(0, mix), true, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(2, mix), true, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(3, mix), false, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(0, nanOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(3, nanOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(4, nanOnly), false, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(0, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(3, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(4, nullOnly), false, EmptyRow) + } + + test("AtLeastNNull") { + val mix = Seq(Literal("x"), + Literal.create(null, StringType), + Literal.create(null, DoubleType), + Literal(Double.NaN), + Literal(5f)) + + val nanOnly = Seq(Literal("x"), + Literal(10.0), + Literal(Float.NaN), + Literal(math.log(-2)), + Literal(Double.MaxValue)) + + val nullOnly = Seq(Literal("x"), + Literal.create(null, DoubleType), + Literal.create(null, DecimalType.USER_DEFAULT), + Literal(Float.MaxValue), + Literal(false)) + + checkEvaluation(AtLeastNNulls(0, mix), true, EmptyRow) + checkEvaluation(AtLeastNNulls(1, mix), true, EmptyRow) + checkEvaluation(AtLeastNNulls(2, mix), true, EmptyRow) + checkEvaluation(AtLeastNNulls(3, mix), false, EmptyRow) + checkEvaluation(AtLeastNNulls(0, nanOnly), true, EmptyRow) + checkEvaluation(AtLeastNNulls(1, nanOnly), false, EmptyRow) + checkEvaluation(AtLeastNNulls(2, nanOnly), false, EmptyRow) + checkEvaluation(AtLeastNNulls(0, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNulls(1, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNulls(2, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNulls(3, nullOnly), false, EmptyRow) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index a4fd4cf3b330b..ea85f0657a726 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -122,7 +122,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = { // Filtering condition: // only keep the row if it has at least `minNonNulls` non-null and non-NaN values. - val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name))) + val predicate = AtLeastNNonNullNans(minNonNulls, cols.map(name => df.resolve(name))) df.filter(Column(predicate)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 6644e85d4a037..387960c4b482b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -413,6 +413,10 @@ private[spark] object SQLConf { "spark.sql.useSerializer2", defaultValue = Some(true), isPublic = false) + val ADVANCED_SQL_OPTIMIZATION = booleanConf( + "spark.sql.advancedOptimization", + defaultValue = Some(true), isPublic = false) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -484,6 +488,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2) + private[spark] def advancedSqlOptimizations: Boolean = getConf(ADVANCED_SQL_OPTIMIZATION) + private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) private[spark] def defaultSizeInBytes: Long = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index dbb2a09846548..31e2b508d485e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.optimizer.FilterNullsInJoinKey import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -156,7 +157,9 @@ class SQLContext(@transient val sparkContext: SparkContext) } @transient - protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer + protected[sql] lazy val optimizer: Optimizer = new DefaultOptimizer { + override val extendedOperatorOptimizationRules = FilterNullsInJoinKey(self) :: Nil + } @transient protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala new file mode 100644 index 0000000000000..5a4dde5756964 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.optimizer + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter, LeftSemi} +import org.apache.spark.sql.catalyst.plans.logical.{Project, Filter, Join, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * An optimization rule used to insert Filters to filter out rows whose equal join keys + * have at least one null values. For this kind of rows, they will not contribute to + * the join results of equal joins because a null does not equal another null. We can + * filter them out before shuffling join input rows. For example, we have two tables + * + * table1(key String, value Int) + * "str1"|1 + * null |2 + * + * table2(key String, value Int) + * "str1"|3 + * null |4 + * + * For a inner equal join, the result will be + * "str1"|1|"str1"|3 + * + * those two rows having null as the value of key will not contribute to the result. + * So, we can filter them out early. + * + * This optimization rule can be disabled by setting spark.sql.advancedOptimization to false. + * + */ +case class FilterNullsInJoinKey( + sqlContext: SQLContext) + extends Rule[LogicalPlan] { + + /** + * Checks if we need to add a Filter operator. We will add a Filter when + * there is any attribute in `keys` whose corresponding attribute of `keys` + * in `plan.output` is still nullable (`nullable` field is `true`). + */ + private def needsFilter(keys: Seq[Expression], plan: LogicalPlan): Boolean = { + val keyAttributeSet = AttributeSet(keys.filter(_.isInstanceOf[Attribute])) + plan.output.filter(keyAttributeSet.contains).exists(_.nullable) + } + + /** + * Adds a Filter operator to make sure that every attribute in `keys` is non-nullable. + */ + private def addFilterIfNecessary( + keys: Seq[Expression], + child: LogicalPlan): LogicalPlan = { + // We get all attributes from keys. + val attributes = keys.filter(_.isInstanceOf[Attribute]) + + // Then, we create a Filter to make sure these attributes are non-nullable. + val filter = + if (attributes.nonEmpty) { + Filter(Not(AtLeastNNulls(1, attributes)), child) + } else { + child + } + + filter + } + + /** + * We reconstruct the join condition. + */ + private def reconstructJoinCondition( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + otherPredicate: Option[Expression]): Expression = { + // First, we rewrite the equal condition part. When we extract those keys, + // we use splitConjunctivePredicates. So, it is safe to use .reduce(And). + val rewrittenEqualJoinCondition = leftKeys.zip(rightKeys).map { + case (l, r) => EqualTo(l, r) + }.reduce(And) + + // Then, we add otherPredicate. When we extract those equal condition part, + // we use splitConjunctivePredicates. So, it is safe to use + // And(rewrittenEqualJoinCondition, c). + val rewrittenJoinCondition = otherPredicate + .map(c => And(rewrittenEqualJoinCondition, c)) + .getOrElse(rewrittenEqualJoinCondition) + + rewrittenJoinCondition + } + + def apply(plan: LogicalPlan): LogicalPlan = { + if (!sqlContext.conf.advancedSqlOptimizations) { + plan + } else { + plan transform { + case join: Join => join match { + // For a inner join having equal join condition part, we can add filters + // to both sides of the join operator. + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) + if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) => + val withLeftFilter = addFilterIfNecessary(leftKeys, left) + val withRightFilter = addFilterIfNecessary(rightKeys, right) + val rewrittenJoinCondition = + reconstructJoinCondition(leftKeys, rightKeys, condition) + + Join(withLeftFilter, withRightFilter, Inner, Some(rewrittenJoinCondition)) + + // For a left outer join having equal join condition part, we can add a filter + // to the right side of the join operator. + case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) + if needsFilter(rightKeys, right) => + val withRightFilter = addFilterIfNecessary(rightKeys, right) + val rewrittenJoinCondition = + reconstructJoinCondition(leftKeys, rightKeys, condition) + + Join(left, withRightFilter, LeftOuter, Some(rewrittenJoinCondition)) + + // For a right outer join having equal join condition part, we can add a filter + // to the left side of the join operator. + case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) + if needsFilter(leftKeys, left) => + val withLeftFilter = addFilterIfNecessary(leftKeys, left) + val rewrittenJoinCondition = + reconstructJoinCondition(leftKeys, rightKeys, condition) + + Join(withLeftFilter, right, RightOuter, Some(rewrittenJoinCondition)) + + // For a left semi join having equal join condition part, we can add filters + // to both sides of the join operator. + case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) + if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) => + val withLeftFilter = addFilterIfNecessary(leftKeys, left) + val withRightFilter = addFilterIfNecessary(rightKeys, right) + val rewrittenJoinCondition = + reconstructJoinCondition(leftKeys, rightKeys, condition) + + Join(withLeftFilter, withRightFilter, LeftSemi, Some(rewrittenJoinCondition)) + + case other => other + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala new file mode 100644 index 0000000000000..f98e4acafbf2c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Not, AtLeastNNulls} +import org.apache.spark.sql.catalyst.optimizer._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.test.TestSQLContext + +/** This is the test suite for FilterNullsInJoinKey optimization rule. */ +class FilterNullsInJoinKeySuite extends PlanTest { + + // We add predicate pushdown rules at here to make sure we do not + // create redundant Filter operators. Also, because the attribute ordering of + // the Project operator added by ColumnPruning may be not deterministic + // (the ordering may depend on the testing environment), + // we first construct the plan with expected Filter operators and then + // run the optimizer to add the the Project for column pruning. + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubQueries) :: + Batch("Operator Optimizations", FixedPoint(100), + FilterNullsInJoinKey(TestSQLContext), // This is the rule we test in this suite. + CombineFilters, + PushPredicateThroughProject, + BooleanSimplification, + PushPredicateThroughJoin, + PushPredicateThroughGenerate, + ColumnPruning, + ProjectCollapsing) :: Nil + } + + val leftRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.int) + + val rightRelation = LocalRelation('e.int, 'f.int, 'g.int, 'h.int) + + test("inner join") { + val joinCondition = + ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) + + val joinedPlan = + leftRelation + .join(rightRelation, Inner, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + val optimized = Optimize.execute(joinedPlan.analyze) + + // For an inner join, FilterNullsInJoinKey add filter to both side. + val correctLeft = + leftRelation + .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) + + val correctRight = + rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) + + val correctAnswer = + correctLeft + .join(correctRight, Inner, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) + } + + test("make sure we do not keep adding filters") { + val thirdRelation = LocalRelation('i.int, 'j.int, 'k.int, 'l.int) + val joinedPlan = + leftRelation + .join(rightRelation, Inner, Some('a === 'e)) + .join(thirdRelation, Inner, Some('b === 'i && 'a === 'j)) + + val optimized = Optimize.execute(joinedPlan.analyze) + val conditions = optimized.collect { + case Filter(condition @ Not(AtLeastNNulls(1, exprs)), _) => exprs + } + + // Make sure that we have three Not(AtLeastNNulls(1, exprs)) for those three tables. + assert(conditions.length === 3) + + // Make sure attribtues are indeed a, b, e, i, and j. + assert( + conditions.flatMap(exprs => exprs).toSet === + joinedPlan.select('a, 'b, 'e, 'i, 'j).analyze.output.toSet) + } + + test("inner join (partially optimized)") { + val joinCondition = + ('a + 2 === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) + + val joinedPlan = + leftRelation + .join(rightRelation, Inner, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + val optimized = Optimize.execute(joinedPlan.analyze) + + // We cannot extract attribute from the left join key. + val correctRight = + rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) + + val correctAnswer = + leftRelation + .join(correctRight, Inner, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) + } + + test("inner join (not optimized)") { + val nonOptimizedJoinConditions = + Some('c - 100 + 'd === 'g + 1 - 'h) :: + Some('d > 'h || 'c === 'g) :: + Some('d + 'g + 'c > 'd - 'h) :: Nil + + nonOptimizedJoinConditions.foreach { joinCondition => + val joinedPlan = + leftRelation + .join(rightRelation.select('f, 'g, 'h), Inner, joinCondition) + .select('a, 'c, 'f, 'd, 'h, 'g) + + val optimized = Optimize.execute(joinedPlan.analyze) + + comparePlans(optimized, Optimize.execute(joinedPlan.analyze)) + } + } + + test("left outer join") { + val joinCondition = + ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) + + val joinedPlan = + leftRelation + .join(rightRelation, LeftOuter, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + val optimized = Optimize.execute(joinedPlan.analyze) + + // For a left outer join, FilterNullsInJoinKey add filter to the right side. + val correctRight = + rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) + + val correctAnswer = + leftRelation + .join(correctRight, LeftOuter, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) + } + + test("right outer join") { + val joinCondition = + ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) + + val joinedPlan = + leftRelation + .join(rightRelation, RightOuter, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + val optimized = Optimize.execute(joinedPlan.analyze) + + // For a right outer join, FilterNullsInJoinKey add filter to the left side. + val correctLeft = + leftRelation + .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) + + val correctAnswer = + correctLeft + .join(rightRelation, RightOuter, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + + comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) + } + + test("full outer join") { + val joinCondition = + ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) + + val joinedPlan = + leftRelation + .join(rightRelation, FullOuter, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + // FilterNullsInJoinKey does not fire for a full outer join. + val optimized = Optimize.execute(joinedPlan.analyze) + + comparePlans(optimized, Optimize.execute(joinedPlan.analyze)) + } + + test("left semi join") { + val joinCondition = + ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) + + val joinedPlan = + leftRelation + .join(rightRelation, LeftSemi, Some(joinCondition)) + .select('a, 'd) + + val optimized = Optimize.execute(joinedPlan.analyze) + + // For a left semi join, FilterNullsInJoinKey add filter to both side. + val correctLeft = + leftRelation + .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) + + val correctRight = + rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) + + val correctAnswer = + correctLeft + .join(correctRight, LeftSemi, Some(joinCondition)) + .select('a, 'd) + + comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) + } +} From 608353c8e8e50461fafff91a2c885dca8af3aaa8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 2 Aug 2015 23:41:16 -0700 Subject: [PATCH 092/340] [SPARK-9404][SPARK-9542][SQL] unsafe array data and map data This PR adds a UnsafeArrayData, current we encode it in this way: first 4 bytes is the # elements then each 4 byte is the start offset of the element, unless it is negative, in which case the element is null. followed by the elements themselves an example: [10, 11, 12, 13, null, 14] will be encoded as: 5, 28, 32, 36, 40, -44, 44, 10, 11, 12, 13, 14 Note that, when we read a UnsafeArrayData from bytes, we can read the first 4 bytes as numElements and take the rest(first 4 bytes skipped) as value region. unsafe map data just use 2 unsafe array data, first 4 bytes is # of elements, second 4 bytes is numBytes of key array, the follows key array data and value array data. Author: Wenchen Fan Closes #7752 from cloud-fan/unsafe-array and squashes the following commits: 3269bd7 [Wenchen Fan] fix a bug 6445289 [Wenchen Fan] add unit tests 49adf26 [Wenchen Fan] add unsafe map 20d1039 [Wenchen Fan] add comments and unsafe converter 821b8db [Wenchen Fan] add unsafe array --- .../catalyst/expressions/UnsafeArrayData.java | 333 ++++++++++++++++++ .../catalyst/expressions/UnsafeMapData.java | 66 ++++ .../catalyst/expressions/UnsafeReaders.java | 48 +++ .../sql/catalyst/expressions/UnsafeRow.java | 34 +- .../expressions/UnsafeRowWriters.java | 71 ++++ .../catalyst/expressions/UnsafeWriters.java | 208 +++++++++++ .../sql/catalyst/expressions/FromUnsafe.scala | 67 ++++ .../sql/catalyst/expressions/Projection.scala | 10 +- .../expressions/codegen/CodeGenerator.scala | 4 +- .../codegen/GenerateUnsafeProjection.scala | 327 ++++++++++++++++- .../spark/sql/types/ArrayBasedMapData.scala | 15 +- .../apache/spark/sql/types/ArrayData.scala | 14 +- .../spark/sql/types/GenericArrayData.scala | 10 +- .../org/apache/spark/sql/types/MapData.scala | 2 + .../expressions/UnsafeRowConverterSuite.scala | 114 +++++- .../apache/spark/unsafe/types/UTF8String.java | 3 + 16 files changed, 1295 insertions(+), 31 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java new file mode 100644 index 0000000000000..0374846d71674 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -0,0 +1,333 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import java.math.BigDecimal; +import java.math.BigInteger; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * An Unsafe implementation of Array which is backed by raw memory instead of Java objects. + * + * Each tuple has two parts: [offsets] [values] + * + * In the `offsets` region, we store 4 bytes per element, represents the start address of this + * element in `values` region. We can get the length of this element by subtracting next offset. + * Note that offset can by negative which means this element is null. + * + * In the `values` region, we store the content of elements. As we can get length info, so elements + * can be variable-length. + * + * Note that when we write out this array, we should write out the `numElements` at first 4 bytes, + * then follows content. When we read in an array, we should read first 4 bytes as `numElements` + * and take the rest as content. + * + * Instances of `UnsafeArrayData` act as pointers to row data stored in this format. + */ +// todo: there is a lof of duplicated code between UnsafeRow and UnsafeArrayData. +public class UnsafeArrayData extends ArrayData { + + private Object baseObject; + private long baseOffset; + + // The number of elements in this array + private int numElements; + + // The size of this array's backing data, in bytes + private int sizeInBytes; + + private int getElementOffset(int ordinal) { + return PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + ordinal * 4L); + } + + private int getElementSize(int offset, int ordinal) { + if (ordinal == numElements - 1) { + return sizeInBytes - offset; + } else { + return Math.abs(getElementOffset(ordinal + 1)) - offset; + } + } + + private void assertIndexIsValid(int ordinal) { + assert ordinal >= 0 : "ordinal (" + ordinal + ") should >= 0"; + assert ordinal < numElements : "ordinal (" + ordinal + ") should < " + numElements; + } + + /** + * Construct a new UnsafeArrayData. The resulting UnsafeArrayData won't be usable until + * `pointTo()` has been called, since the value returned by this constructor is equivalent + * to a null pointer. + */ + public UnsafeArrayData() { } + + public Object getBaseObject() { return baseObject; } + public long getBaseOffset() { return baseOffset; } + public int getSizeInBytes() { return sizeInBytes; } + + @Override + public int numElements() { return numElements; } + + /** + * Update this UnsafeArrayData to point to different backing data. + * + * @param baseObject the base object + * @param baseOffset the offset within the base object + * @param sizeInBytes the size of this row's backing data, in bytes + */ + public void pointTo(Object baseObject, long baseOffset, int numElements, int sizeInBytes) { + assert numElements >= 0 : "numElements (" + numElements + ") should >= 0"; + this.numElements = numElements; + this.baseObject = baseObject; + this.baseOffset = baseOffset; + this.sizeInBytes = sizeInBytes; + } + + @Override + public boolean isNullAt(int ordinal) { + assertIndexIsValid(ordinal); + return getElementOffset(ordinal) < 0; + } + + @Override + public Object get(int ordinal, DataType dataType) { + if (isNullAt(ordinal) || dataType instanceof NullType) { + return null; + } else if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof DecimalType) { + DecimalType dt = (DecimalType) dataType; + return getDecimal(ordinal, dt.precision(), dt.scale()); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } else if (dataType instanceof BinaryType) { + return getBinary(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8String(ordinal); + } else if (dataType instanceof CalendarIntervalType) { + return getInterval(ordinal); + } else if (dataType instanceof StructType) { + return getStruct(ordinal, ((StructType) dataType).size()); + } else if (dataType instanceof ArrayType) { + return getArray(ordinal); + } else if (dataType instanceof MapType) { + return getMap(ordinal); + } else { + throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString()); + } + } + + @Override + public boolean getBoolean(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return false; + return PlatformDependent.UNSAFE.getBoolean(baseObject, baseOffset + offset); + } + + @Override + public byte getByte(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return 0; + return PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + offset); + } + + @Override + public short getShort(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return 0; + return PlatformDependent.UNSAFE.getShort(baseObject, baseOffset + offset); + } + + @Override + public int getInt(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return 0; + return PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset); + } + + @Override + public long getLong(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return 0; + return PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); + } + + @Override + public float getFloat(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return 0; + return PlatformDependent.UNSAFE.getFloat(baseObject, baseOffset + offset); + } + + @Override + public double getDouble(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return 0; + return PlatformDependent.UNSAFE.getDouble(baseObject, baseOffset + offset); + } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return null; + + if (precision <= Decimal.MAX_LONG_DIGITS()) { + final long value = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); + return Decimal.apply(value, precision, scale); + } else { + final byte[] bytes = getBinary(ordinal); + final BigInteger bigInteger = new BigInteger(bytes); + final BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(new scala.math.BigDecimal(javaDecimal), precision, scale); + } + } + + @Override + public UTF8String getUTF8String(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return null; + final int size = getElementSize(offset, ordinal); + return UTF8String.fromAddress(baseObject, baseOffset + offset, size); + } + + @Override + public byte[] getBinary(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return null; + final int size = getElementSize(offset, ordinal); + final byte[] bytes = new byte[size]; + PlatformDependent.copyMemory( + baseObject, + baseOffset + offset, + bytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + size); + return bytes; + } + + @Override + public CalendarInterval getInterval(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return null; + final int months = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); + final long microseconds = + PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset + 8); + return new CalendarInterval(months, microseconds); + } + + @Override + public InternalRow getStruct(int ordinal, int numFields) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return null; + final int size = getElementSize(offset, ordinal); + final UnsafeRow row = new UnsafeRow(); + row.pointTo(baseObject, baseOffset + offset, numFields, size); + return row; + } + + @Override + public ArrayData getArray(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return null; + final int size = getElementSize(offset, ordinal); + return UnsafeReaders.readArray(baseObject, baseOffset + offset, size); + } + + @Override + public MapData getMap(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return null; + final int size = getElementSize(offset, ordinal); + return UnsafeReaders.readMap(baseObject, baseOffset + offset, size); + } + + @Override + public int hashCode() { + return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42); + } + + @Override + public boolean equals(Object other) { + if (other instanceof UnsafeArrayData) { + UnsafeArrayData o = (UnsafeArrayData) other; + return (sizeInBytes == o.sizeInBytes) && + ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset, + sizeInBytes); + } + return false; + } + + public void writeToMemory(Object target, long targetOffset) { + PlatformDependent.copyMemory( + baseObject, + baseOffset, + target, + targetOffset, + sizeInBytes + ); + } + + @Override + public UnsafeArrayData copy() { + UnsafeArrayData arrayCopy = new UnsafeArrayData(); + final byte[] arrayDataCopy = new byte[sizeInBytes]; + PlatformDependent.copyMemory( + baseObject, + baseOffset, + arrayDataCopy, + PlatformDependent.BYTE_ARRAY_OFFSET, + sizeInBytes + ); + arrayCopy.pointTo(arrayDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numElements, sizeInBytes); + return arrayCopy; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java new file mode 100644 index 0000000000000..46216054ab38b --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.sql.types.ArrayData; +import org.apache.spark.sql.types.MapData; + +/** + * An Unsafe implementation of Map which is backed by raw memory instead of Java objects. + * + * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData. + */ +public class UnsafeMapData extends MapData { + + public final UnsafeArrayData keys; + public final UnsafeArrayData values; + // The number of elements in this array + private int numElements; + // The size of this array's backing data, in bytes + private int sizeInBytes; + + public int getSizeInBytes() { return sizeInBytes; } + + public UnsafeMapData(UnsafeArrayData keys, UnsafeArrayData values) { + assert keys.numElements() == values.numElements(); + this.sizeInBytes = keys.getSizeInBytes() + values.getSizeInBytes(); + this.numElements = keys.numElements(); + this.keys = keys; + this.values = values; + } + + @Override + public int numElements() { + return numElements; + } + + @Override + public ArrayData keyArray() { + return keys; + } + + @Override + public ArrayData valueArray() { + return values; + } + + @Override + public UnsafeMapData copy() { + return new UnsafeMapData(keys.copy(), values.copy()); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java new file mode 100644 index 0000000000000..b521b703389d3 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.unsafe.PlatformDependent; + +public class UnsafeReaders { + + public static UnsafeArrayData readArray(Object baseObject, long baseOffset, int numBytes) { + // Read the number of elements from first 4 bytes. + final int numElements = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset); + final UnsafeArrayData array = new UnsafeArrayData(); + // Skip the first 4 bytes. + array.pointTo(baseObject, baseOffset + 4, numElements, numBytes - 4); + return array; + } + + public static UnsafeMapData readMap(Object baseObject, long baseOffset, int numBytes) { + // Read the number of elements from first 4 bytes. + final int numElements = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset); + // Read the numBytes of key array in second 4 bytes. + final int keyArraySize = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + 4); + final int valueArraySize = numBytes - 8 - keyArraySize; + + final UnsafeArrayData keyArray = new UnsafeArrayData(); + keyArray.pointTo(baseObject, baseOffset + 8, numElements, keyArraySize); + + final UnsafeArrayData valueArray = new UnsafeArrayData(); + valueArray.pointTo(baseObject, baseOffset + 8 + keyArraySize, numElements, valueArraySize); + + return new UnsafeMapData(keyArray, valueArray); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index b4fc0b7b705ec..c5d42d73a43a4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -291,6 +291,10 @@ public Object get(int ordinal, DataType dataType) { return getInterval(ordinal); } else if (dataType instanceof StructType) { return getStruct(ordinal, ((StructType) dataType).size()); + } else if (dataType instanceof ArrayType) { + return getArray(ordinal); + } else if (dataType instanceof MapType) { + return getMap(ordinal); } else { throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString()); } @@ -346,7 +350,6 @@ public double getDouble(int ordinal) { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - assertIndexIsValid(ordinal); if (isNullAt(ordinal)) { return null; } @@ -362,7 +365,6 @@ public Decimal getDecimal(int ordinal, int precision, int scale) { @Override public UTF8String getUTF8String(int ordinal) { - assertIndexIsValid(ordinal); if (isNullAt(ordinal)) return null; final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); @@ -372,7 +374,6 @@ public UTF8String getUTF8String(int ordinal) { @Override public byte[] getBinary(int ordinal) { - assertIndexIsValid(ordinal); if (isNullAt(ordinal)) { return null; } else { @@ -410,7 +411,6 @@ public UnsafeRow getStruct(int ordinal, int numFields) { if (isNullAt(ordinal)) { return null; } else { - assertIndexIsValid(ordinal); final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) (offsetAndSize & ((1L << 32) - 1)); @@ -420,11 +420,33 @@ public UnsafeRow getStruct(int ordinal, int numFields) { } } + @Override + public ArrayData getArray(int ordinal) { + if (isNullAt(ordinal)) { + return null; + } else { + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + return UnsafeReaders.readArray(baseObject, baseOffset + offset, size); + } + } + + @Override + public MapData getMap(int ordinal) { + if (isNullAt(ordinal)) { + return null; + } else { + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + return UnsafeReaders.readMap(baseObject, baseOffset + offset, size); + } + } + /** * Copies this row, returning a self-contained UnsafeRow that stores its data in an internal * byte array rather than referencing data stored in a data page. - *

- * This method is only supported on UnsafeRows that do not use ObjectPools. */ @Override public UnsafeRow copy() { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java index f43a285cd6cad..31928731545da 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -19,6 +19,7 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.MapData; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.ByteArray; @@ -185,4 +186,74 @@ public static int write(UnsafeRow target, int ordinal, int cursor, CalendarInter return 16; } } + + public static class ArrayWriter { + + public static int getSize(UnsafeArrayData input) { + // we need extra 4 bytes the store the number of elements in this array. + return ByteArrayMethods.roundNumberOfBytesToNearestWord(input.getSizeInBytes() + 4); + } + + public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeArrayData input) { + final int numBytes = input.getSizeInBytes() + 4; + final long offset = target.getBaseOffset() + cursor; + + // write the number of elements into first 4 bytes. + PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset, input.numElements()); + + // zero-out the padding bytes + if ((numBytes & 0x07) > 0) { + PlatformDependent.UNSAFE.putLong( + target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + } + + // Write the bytes to the variable length portion. + input.writeToMemory(target.getBaseObject(), offset + 4); + + // Set the fixed length portion. + target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); + + return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + } + } + + public static class MapWriter { + + public static int getSize(UnsafeMapData input) { + // we need extra 8 bytes to store number of elements and numBytes of key array. + final int sizeInBytes = 4 + 4 + input.getSizeInBytes(); + return ByteArrayMethods.roundNumberOfBytesToNearestWord(sizeInBytes); + } + + public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeMapData input) { + final long offset = target.getBaseOffset() + cursor; + final UnsafeArrayData keyArray = input.keys; + final UnsafeArrayData valueArray = input.values; + final int keysNumBytes = keyArray.getSizeInBytes(); + final int valuesNumBytes = valueArray.getSizeInBytes(); + final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes; + + // write the number of elements into first 4 bytes. + PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset, input.numElements()); + // write the numBytes of key array into second 4 bytes. + PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset + 4, keysNumBytes); + + // zero-out the padding bytes + if ((numBytes & 0x07) > 0) { + PlatformDependent.UNSAFE.putLong( + target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + } + + // Write the bytes of key array to the variable length portion. + keyArray.writeToMemory(target.getBaseObject(), offset + 8); + + // Write the bytes of value array to the variable length portion. + valueArray.writeToMemory(target.getBaseObject(), offset + 8 + keysNumBytes); + + // Set the fixed length portion. + target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); + + return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + } + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java new file mode 100644 index 0000000000000..0e8e405d055de --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A set of helper methods to write data into the variable length portion. + */ +public class UnsafeWriters { + public static void writeToMemory( + Object inputObject, + long inputOffset, + Object targetObject, + long targetOffset, + int numBytes) { + + // zero-out the padding bytes +// if ((numBytes & 0x07) > 0) { +// PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + ((numBytes >> 3) << 3), 0L); +// } + + // Write the UnsafeData to the target memory. + PlatformDependent.copyMemory( + inputObject, + inputOffset, + targetObject, + targetOffset, + numBytes + ); + } + + public static int getRoundedSize(int size) { + //return ByteArrayMethods.roundNumberOfBytesToNearestWord(size); + // todo: do word alignment + return size; + } + + /** Writer for Decimal with precision larger than 18. */ + public static class DecimalWriter { + + public static int getSize(Decimal input) { + return 16; + } + + public static int write(Object targetObject, long targetOffset, Decimal input) { + final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); + final int numBytes = bytes.length; + assert(numBytes <= 16); + + // zero-out the bytes + PlatformDependent.UNSAFE.putLong(targetObject, targetOffset, 0L); + PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + 8, 0L); + + // Write the bytes to the variable length portion. + PlatformDependent.copyMemory(bytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + targetObject, + targetOffset, + numBytes); + + return 16; + } + } + + /** Writer for UTF8String. */ + public static class UTF8StringWriter { + + public static int getSize(UTF8String input) { + return getRoundedSize(input.numBytes()); + } + + public static int write(Object targetObject, long targetOffset, UTF8String input) { + final int numBytes = input.numBytes(); + + // Write the bytes to the variable length portion. + writeToMemory(input.getBaseObject(), input.getBaseOffset(), + targetObject, targetOffset, numBytes); + + return getRoundedSize(numBytes); + } + } + + /** Writer for binary (byte array) type. */ + public static class BinaryWriter { + + public static int getSize(byte[] input) { + return getRoundedSize(input.length); + } + + public static int write(Object targetObject, long targetOffset, byte[] input) { + final int numBytes = input.length; + + // Write the bytes to the variable length portion. + writeToMemory(input, PlatformDependent.BYTE_ARRAY_OFFSET, + targetObject, targetOffset, numBytes); + + return getRoundedSize(numBytes); + } + } + + /** Writer for UnsafeRow. */ + public static class StructWriter { + + public static int getSize(UnsafeRow input) { + return getRoundedSize(input.getSizeInBytes()); + } + + public static int write(Object targetObject, long targetOffset, UnsafeRow input) { + final int numBytes = input.getSizeInBytes(); + + // Write the bytes to the variable length portion. + writeToMemory(input.getBaseObject(), input.getBaseOffset(), + targetObject, targetOffset, numBytes); + + return getRoundedSize(numBytes); + } + } + + /** Writer for interval type. */ + public static class IntervalWriter { + + public static int getSize(UnsafeRow input) { + return 16; + } + + public static int write(Object targetObject, long targetOffset, CalendarInterval input) { + + // Write the months and microseconds fields of Interval to the variable length portion. + PlatformDependent.UNSAFE.putLong(targetObject, targetOffset, input.months); + PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + 8, input.microseconds); + + return 16; + } + } + + /** Writer for UnsafeArrayData. */ + public static class ArrayWriter { + + public static int getSize(UnsafeArrayData input) { + // we need extra 4 bytes the store the number of elements in this array. + return getRoundedSize(input.getSizeInBytes() + 4); + } + + public static int write(Object targetObject, long targetOffset, UnsafeArrayData input) { + final int numBytes = input.getSizeInBytes(); + + // write the number of elements into first 4 bytes. + PlatformDependent.UNSAFE.putInt(targetObject, targetOffset, input.numElements()); + + // Write the bytes to the variable length portion. + writeToMemory(input.getBaseObject(), input.getBaseOffset(), + targetObject, targetOffset + 4, numBytes); + + return getRoundedSize(numBytes + 4); + } + } + + public static class MapWriter { + + public static int getSize(UnsafeMapData input) { + // we need extra 8 bytes to store number of elements and numBytes of key array. + return getRoundedSize(4 + 4 + input.getSizeInBytes()); + } + + public static int write(Object targetObject, long targetOffset, UnsafeMapData input) { + final UnsafeArrayData keyArray = input.keys; + final UnsafeArrayData valueArray = input.values; + final int keysNumBytes = keyArray.getSizeInBytes(); + final int valuesNumBytes = valueArray.getSizeInBytes(); + final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes; + + // write the number of elements into first 4 bytes. + PlatformDependent.UNSAFE.putInt(targetObject, targetOffset, input.numElements()); + // write the numBytes of key array into second 4 bytes. + PlatformDependent.UNSAFE.putInt(targetObject, targetOffset + 4, keysNumBytes); + + // Write the bytes of key array to the variable length portion. + writeToMemory(keyArray.getBaseObject(), keyArray.getBaseOffset(), + targetObject, targetOffset + 8, keysNumBytes); + + // Write the bytes of value array to the variable length portion. + writeToMemory(valueArray.getBaseObject(), valueArray.getBaseOffset(), + targetObject, targetOffset + 8 + keysNumBytes, valuesNumBytes); + + return getRoundedSize(numBytes); + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala new file mode 100644 index 0000000000000..3caf0fb3410c4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.types._ + +case class FromUnsafe(child: Expression) extends UnaryExpression + with ExpectsInputTypes with CodegenFallback { + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(ArrayType, StructType, MapType)) + + override def dataType: DataType = child.dataType + + private def convert(value: Any, dt: DataType): Any = dt match { + case StructType(fields) => + val row = value.asInstanceOf[UnsafeRow] + val result = new Array[Any](fields.length) + fields.map(_.dataType).zipWithIndex.foreach { case (dt, i) => + if (!row.isNullAt(i)) { + result(i) = convert(row.get(i, dt), dt) + } + } + new GenericInternalRow(result) + + case ArrayType(elementType, _) => + val array = value.asInstanceOf[UnsafeArrayData] + val length = array.numElements() + val result = new Array[Any](length) + var i = 0 + while (i < length) { + if (!array.isNullAt(i)) { + result(i) = convert(array.get(i, elementType), elementType) + } + i += 1 + } + new GenericArrayData(result) + + case MapType(kt, vt, _) => + val map = value.asInstanceOf[UnsafeMapData] + val safeKeyArray = convert(map.keys, ArrayType(kt)).asInstanceOf[GenericArrayData] + val safeValueArray = convert(map.values, ArrayType(vt)).asInstanceOf[GenericArrayData] + new ArrayBasedMapData(safeKeyArray, safeValueArray) + + case _ => value + } + + override def nullSafeEval(input: Any): Any = { + convert(input, dataType) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 83129dc12dff6..79649741025a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -151,7 +151,15 @@ object FromUnsafeProjection { * Returns an UnsafeProjection for given Array of DataTypes. */ def apply(fields: Seq[DataType]): Projection = { - create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))) + create(fields.zipWithIndex.map(x => { + val b = new BoundReference(x._2, x._1, true) + // todo: this is quite slow, maybe remove this whole projection after remove generic getter of + // InternalRow? + b.dataType match { + case _: StructType | _: ArrayType | _: MapType => FromUnsafe(b) + case _ => b + } + })) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 03ec4b4b4ec55..7b41c9a3f3b8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -336,7 +336,9 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin classOf[Decimal].getName, classOf[CalendarInterval].getName, classOf[ArrayData].getName, - classOf[MapData].getName + classOf[UnsafeArrayData].getName, + classOf[MapData].getName, + classOf[UnsafeMapData].getName )) evaluator.setExtendedClass(classOf[GeneratedClass]) try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 934ec3f75c63f..fc3ecf5451426 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.PlatformDependent /** * Generates a [[Projection]] that returns an [[UnsafeRow]]. @@ -37,14 +38,19 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private val StructWriter = classOf[UnsafeRowWriters.StructWriter].getName private val CompactDecimalWriter = classOf[UnsafeRowWriters.CompactDecimalWriter].getName private val DecimalWriter = classOf[UnsafeRowWriters.DecimalWriter].getName + private val ArrayWriter = classOf[UnsafeRowWriters.ArrayWriter].getName + private val MapWriter = classOf[UnsafeRowWriters.MapWriter].getName + + private val PlatformDependent = classOf[PlatformDependent].getName /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = dataType match { - case t: AtomicType if !t.isInstanceOf[DecimalType] => true + case t: AtomicType => true case _: CalendarIntervalType => true case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) case NullType => true - case t: DecimalType => true + case t: ArrayType if canSupport(t.elementType) => true + case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true case _ => false } @@ -59,6 +65,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s" + (${ev.isNull} ? 0 : 16)" case _: StructType => s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))" + case _: ArrayType => + s" + (${ev.isNull} ? 0 : $ArrayWriter.getSize(${ev.primitive}))" + case _: MapType => + s" + (${ev.isNull} ? 0 : $MapWriter.getSize(${ev.primitive}))" case _ => "" } @@ -95,8 +105,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$cursor += $BinaryWriter.write($primitive, $index, $cursor, ${ev.primitive})" case CalendarIntervalType => s"$cursor += $IntervalWriter.write($primitive, $index, $cursor, ${ev.primitive})" - case t: StructType => + case _: StructType => s"$cursor += $StructWriter.write($primitive, $index, $cursor, ${ev.primitive})" + case _: ArrayType => + s"$cursor += $ArrayWriter.write($primitive, $index, $cursor, ${ev.primitive})" + case _: MapType => + s"$cursor += $MapWriter.write($primitive, $index, $cursor, ${ev.primitive})" case NullType => "" case _ => throw new UnsupportedOperationException(s"Not supported DataType: $fieldType") @@ -148,7 +162,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $ret.pointTo( $buffer, - org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, + $PlatformDependent.BYTE_ARRAY_OFFSET, ${expressions.size}, $numBytes); int $cursor = $fixedSize; @@ -237,7 +251,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | | $primitive.pointTo( | $buffer, - | org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, + | $PlatformDependent.BYTE_ARRAY_OFFSET, | ${exprs.size}, | $numBytes); | int $cursor = $fixedSize; @@ -250,6 +264,303 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro GeneratedExpressionCode(code, isNull, primitive) } + /** + * Generates the Java code to convert a struct (backed by InternalRow) to UnsafeRow. + * + * @param ctx code generation context + * @param inputs could be the codes for expressions or input struct fields. + * @param inputTypes types of the inputs + */ + private def createCodeForStruct2( + ctx: CodeGenContext, + inputs: Seq[GeneratedExpressionCode], + inputTypes: Seq[DataType]): GeneratedExpressionCode = { + + val output = ctx.freshName("convertedStruct") + ctx.addMutableState("UnsafeRow", output, s"$output = new UnsafeRow();") + val buffer = ctx.freshName("buffer") + ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") + val numBytes = ctx.freshName("numBytes") + val cursor = ctx.freshName("cursor") + + val convertedFields = inputTypes.zip(inputs).map { case (dt, input) => + createConvertCode(ctx, input, dt) + } + + val fixedSize = 8 * inputTypes.length + UnsafeRow.calculateBitSetWidthInBytes(inputTypes.length) + val additionalSize = inputTypes.zip(convertedFields).map { case (dt, ev) => + genAdditionalSize(dt, ev) + }.mkString("") + + val fieldWriters = inputTypes.zip(convertedFields).zipWithIndex.map { case ((dt, ev), i) => + val update = genFieldWriter(ctx, dt, ev, output, i, cursor) + s""" + if (${ev.isNull}) { + $output.setNullAt($i); + } else { + $update; + } + """ + }.mkString("\n") + + val code = s""" + ${convertedFields.map(_.code).mkString("\n")} + + final int $numBytes = $fixedSize $additionalSize; + if ($numBytes > $buffer.length) { + $buffer = new byte[$numBytes]; + } + + $output.pointTo( + $buffer, + $PlatformDependent.BYTE_ARRAY_OFFSET, + ${inputTypes.length}, + $numBytes); + + int $cursor = $fixedSize; + + $fieldWriters + """ + GeneratedExpressionCode(code, "false", output) + } + + private def getWriter(dt: DataType) = dt match { + case StringType => classOf[UnsafeWriters.UTF8StringWriter].getName + case BinaryType => classOf[UnsafeWriters.BinaryWriter].getName + case CalendarIntervalType => classOf[UnsafeWriters.IntervalWriter].getName + case _: StructType => classOf[UnsafeWriters.StructWriter].getName + case _: ArrayType => classOf[UnsafeWriters.ArrayWriter].getName + case _: MapType => classOf[UnsafeWriters.MapWriter].getName + case _: DecimalType => classOf[UnsafeWriters.DecimalWriter].getName + } + + private def createCodeForArray( + ctx: CodeGenContext, + input: GeneratedExpressionCode, + elementType: DataType): GeneratedExpressionCode = { + val output = ctx.freshName("convertedArray") + ctx.addMutableState("UnsafeArrayData", output, s"$output = new UnsafeArrayData();") + val buffer = ctx.freshName("buffer") + ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") + val outputIsNull = ctx.freshName("isNull") + val tmp = ctx.freshName("tmp") + val numElements = ctx.freshName("numElements") + val fixedSize = ctx.freshName("fixedSize") + val numBytes = ctx.freshName("numBytes") + val elements = ctx.freshName("elements") + val cursor = ctx.freshName("cursor") + val index = ctx.freshName("index") + + val element = GeneratedExpressionCode( + code = "", + isNull = s"$tmp.isNullAt($index)", + primitive = s"${ctx.getValue(tmp, elementType, index)}" + ) + val convertedElement: GeneratedExpressionCode = createConvertCode(ctx, element, elementType) + + // go through the input array to calculate how many bytes we need. + val calculateNumBytes = elementType match { + case _ if (ctx.isPrimitiveType(elementType)) => + // Should we do word align? + val elementSize = elementType.defaultSize + s""" + $numBytes += $elementSize * $numElements; + """ + case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => + s""" + $numBytes += 8 * $numElements; + """ + case _ => + val writer = getWriter(elementType) + val elementSize = s"$writer.getSize($elements[$index])" + val unsafeType = elementType match { + case _: StructType => "UnsafeRow" + case _: ArrayType => "UnsafeArrayData" + case _: MapType => "UnsafeMapData" + case _ => ctx.javaType(elementType) + } + val copy = elementType match { + // We reuse the buffer during conversion, need copy it before process next element. + case _: StructType | _: ArrayType | _: MapType => ".copy()" + case _ => "" + } + + s""" + final $unsafeType[] $elements = new $unsafeType[$numElements]; + for (int $index = 0; $index < $numElements; $index++) { + ${convertedElement.code} + if (!${convertedElement.isNull}) { + $elements[$index] = ${convertedElement.primitive}$copy; + $numBytes += $elementSize; + } + } + """ + } + + val writeElement = elementType match { + case _ if (ctx.isPrimitiveType(elementType)) => + // Should we do word align? + val elementSize = elementType.defaultSize + s""" + $PlatformDependent.UNSAFE.put${ctx.primitiveTypeName(elementType)}( + $buffer, + $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + ${convertedElement.primitive}); + $cursor += $elementSize; + """ + case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => + s""" + $PlatformDependent.UNSAFE.putLong( + $buffer, + $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + ${convertedElement.primitive}.toUnscaledLong()); + $cursor += 8; + """ + case _ => + val writer = getWriter(elementType) + s""" + $cursor += $writer.write( + $buffer, + $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + $elements[$index]); + """ + } + + val checkNull = elementType match { + case _ if ctx.isPrimitiveType(elementType) => s"${convertedElement.isNull}" + case t: DecimalType => s"$elements[$index] == null" + + s" || !$elements[$index].changePrecision(${t.precision}, ${t.scale})" + case _ => s"$elements[$index] == null" + } + + val code = s""" + ${input.code} + final boolean $outputIsNull = ${input.isNull}; + if (!$outputIsNull) { + final ArrayData $tmp = ${input.primitive}; + if ($tmp instanceof UnsafeArrayData) { + $output = (UnsafeArrayData) $tmp; + } else { + final int $numElements = $tmp.numElements(); + final int $fixedSize = 4 * $numElements; + int $numBytes = $fixedSize; + + $calculateNumBytes + + if ($numBytes > $buffer.length) { + $buffer = new byte[$numBytes]; + } + + int $cursor = $fixedSize; + for (int $index = 0; $index < $numElements; $index++) { + if ($checkNull) { + // If element is null, write the negative value address into offset region. + $PlatformDependent.UNSAFE.putInt( + $buffer, + $PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, + -$cursor); + } else { + $PlatformDependent.UNSAFE.putInt( + $buffer, + $PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, + $cursor); + + $writeElement + } + } + + $output.pointTo( + $buffer, + $PlatformDependent.BYTE_ARRAY_OFFSET, + $numElements, + $numBytes); + } + } + """ + GeneratedExpressionCode(code, outputIsNull, output) + } + + private def createCodeForMap( + ctx: CodeGenContext, + input: GeneratedExpressionCode, + keyType: DataType, + valueType: DataType): GeneratedExpressionCode = { + val output = ctx.freshName("convertedMap") + val outputIsNull = ctx.freshName("isNull") + val tmp = ctx.freshName("tmp") + + val keyArray = GeneratedExpressionCode( + code = "", + isNull = "false", + primitive = s"$tmp.keyArray()" + ) + val valueArray = GeneratedExpressionCode( + code = "", + isNull = "false", + primitive = s"$tmp.valueArray()" + ) + val convertedKeys: GeneratedExpressionCode = createCodeForArray(ctx, keyArray, keyType) + val convertedValues: GeneratedExpressionCode = createCodeForArray(ctx, valueArray, valueType) + + val code = s""" + ${input.code} + final boolean $outputIsNull = ${input.isNull}; + UnsafeMapData $output = null; + if (!$outputIsNull) { + final MapData $tmp = ${input.primitive}; + if ($tmp instanceof UnsafeMapData) { + $output = (UnsafeMapData) $tmp; + } else { + ${convertedKeys.code} + ${convertedValues.code} + $output = new UnsafeMapData(${convertedKeys.primitive}, ${convertedValues.primitive}); + } + } + """ + GeneratedExpressionCode(code, outputIsNull, output) + } + + /** + * Generates the java code to convert a data to its unsafe version. + */ + private def createConvertCode( + ctx: CodeGenContext, + input: GeneratedExpressionCode, + dataType: DataType): GeneratedExpressionCode = dataType match { + case t: StructType => + val output = ctx.freshName("convertedStruct") + val outputIsNull = ctx.freshName("isNull") + val tmp = ctx.freshName("tmp") + val fieldTypes = t.fields.map(_.dataType) + val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => + val getFieldCode = ctx.getValue(tmp, dt, i.toString) + val fieldIsNull = s"$tmp.isNullAt($i)" + GeneratedExpressionCode("", fieldIsNull, getFieldCode) + } + val converter = createCodeForStruct2(ctx, fieldEvals, fieldTypes) + val code = s""" + ${input.code} + UnsafeRow $output = null; + final boolean $outputIsNull = ${input.isNull}; + if (!$outputIsNull) { + final InternalRow $tmp = ${input.primitive}; + if ($tmp instanceof UnsafeRow) { + $output = (UnsafeRow) $tmp; + } else { + ${converter.code} + $output = ${converter.primitive}; + } + } + """ + GeneratedExpressionCode(code, outputIsNull, output) + + case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) + + case MapType(kt, vt, _) => createCodeForMap(ctx, input, kt, vt) + + case _ => input + } + protected def canonicalize(in: Seq[Expression]): Seq[Expression] = in.map(ExpressionCanonicalizer.execute) @@ -259,10 +570,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro protected def create(expressions: Seq[Expression]): UnsafeProjection = { val ctx = newCodeGenContext() - val isNull = ctx.freshName("retIsNull") - val primitive = ctx.freshName("retValue") - val eval = GeneratedExpressionCode("", isNull, primitive) - eval.code = createCode(ctx, eval, expressions) + val exprEvals = expressions.map(e => e.gen(ctx)) + val eval = createCodeForStruct2(ctx, exprEvals, expressions.map(_.dataType)) val code = s""" public Object generate($exprType[] exprs) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala index db4876355daec..f6fa021adee95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala @@ -22,6 +22,9 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte override def numElements(): Int = keyArray.numElements() + override def copy(): MapData = new ArrayBasedMapData(keyArray.copy(), valueArray.copy()) + + // We need to check equality of map type in tests. override def equals(o: Any): Boolean = { if (!o.isInstanceOf[ArrayBasedMapData]) { return false @@ -32,15 +35,15 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte return false } - this.keyArray == other.keyArray && this.valueArray == other.valueArray + ArrayBasedMapData.toScalaMap(this) == ArrayBasedMapData.toScalaMap(other) } override def hashCode: Int = { - keyArray.hashCode() * 37 + valueArray.hashCode() + ArrayBasedMapData.toScalaMap(this).hashCode() } override def toString(): String = { - s"keys: $keyArray\nvalues: $valueArray" + s"keys: $keyArray, values: $valueArray" } } @@ -48,4 +51,10 @@ object ArrayBasedMapData { def apply(keys: Array[Any], values: Array[Any]): ArrayBasedMapData = { new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values)) } + + def toScalaMap(map: ArrayBasedMapData): Map[Any, Any] = { + val keys = map.keyArray.asInstanceOf[GenericArrayData].array + val values = map.valueArray.asInstanceOf[GenericArrayData].array + keys.zip(values).toMap + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala index c99fc233255e5..642c56f12ded1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql.types +import scala.reflect.ClassTag + import org.apache.spark.sql.catalyst.expressions.SpecializedGetters abstract class ArrayData extends SpecializedGetters with Serializable { def numElements(): Int + def copy(): ArrayData + def toBooleanArray(): Array[Boolean] = { val size = numElements() val values = new Array[Boolean](size) @@ -99,19 +103,19 @@ abstract class ArrayData extends SpecializedGetters with Serializable { values } - def toArray[T](elementType: DataType): Array[T] = { + def toArray[T: ClassTag](elementType: DataType): Array[T] = { val size = numElements() - val values = new Array[Any](size) + val values = new Array[T](size) var i = 0 while (i < size) { if (isNullAt(i)) { - values(i) = null + values(i) = null.asInstanceOf[T] } else { - values(i) = get(i, elementType) + values(i) = get(i, elementType).asInstanceOf[T] } i += 1 } - values.asInstanceOf[Array[T]] + values } // todo: specialize this. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala index b3e75f8bad502..b314acdfe3644 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala @@ -17,13 +17,19 @@ package org.apache.spark.sql.types +import scala.reflect.ClassTag + import org.apache.spark.sql.catalyst.expressions.GenericSpecializedGetters -class GenericArrayData(array: Array[Any]) extends ArrayData with GenericSpecializedGetters { +class GenericArrayData(private[sql] val array: Array[Any]) + extends ArrayData with GenericSpecializedGetters { override def genericGet(ordinal: Int): Any = array(ordinal) - override def toArray[T](elementType: DataType): Array[T] = array.asInstanceOf[Array[T]] + override def copy(): ArrayData = new GenericArrayData(array.clone()) + + // todo: Array is invariant in scala, maybe use toSeq instead? + override def toArray[T: ClassTag](elementType: DataType): Array[T] = array.map(_.asInstanceOf[T]) override def numElements(): Int = array.length diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala index 5514c3cd8546a..f50969f0f0b79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala @@ -25,6 +25,8 @@ abstract class MapData extends Serializable { def valueArray(): ArrayData + def copy(): MapData + def foreach(keyType: DataType, valueType: DataType, f: (Any, Any) => Unit): Unit = { val length = numElements() val keys = keyArray() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 44f845620a109..59491c5ba160e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -31,6 +31,8 @@ import org.apache.spark.unsafe.types.UTF8String class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { + private def roundedSize(size: Int) = ByteArrayMethods.roundNumberOfBytesToNearestWord(size) + test("basic conversion with only primitive types") { val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) val converter = UnsafeProjection.create(fieldTypes) @@ -73,8 +75,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val unsafeRow: UnsafeRow = converter.apply(row) assert(unsafeRow.getSizeInBytes === 8 + (8 * 3) + - ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) + - ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length)) + roundedSize("Hello".getBytes.length) + + roundedSize("World".getBytes.length)) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") @@ -92,8 +94,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { row.update(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-05-08 08:10:25"))) val unsafeRow: UnsafeRow = converter.apply(row) - assert(unsafeRow.getSizeInBytes === 8 + (8 * 4) + - ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length)) + assert(unsafeRow.getSizeInBytes === 8 + (8 * 4) + roundedSize("Hello".getBytes.length)) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") @@ -172,6 +173,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { r } + // todo: we reuse the UnsafeRow in projection, so these tests are meaningless. val setToNullAfterCreation = converter.apply(rowWithNoNullColumns) assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) @@ -235,4 +237,108 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val converter = UnsafeProjection.create(fieldTypes) assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes) } + + test("basic conversion with array type") { + val fieldTypes: Array[DataType] = Array( + ArrayType(LongType), + ArrayType(ArrayType(LongType)) + ) + val converter = UnsafeProjection.create(fieldTypes) + + val array1 = new GenericArrayData(Array[Any](1L, 2L)) + val array2 = new GenericArrayData(Array[Any](new GenericArrayData(Array[Any](3L, 4L)))) + val row = new GenericMutableRow(fieldTypes.length) + row.update(0, array1) + row.update(1, array2) + + val unsafeRow: UnsafeRow = converter.apply(row) + assert(unsafeRow.numFields() == 2) + + val unsafeArray1 = unsafeRow.getArray(0).asInstanceOf[UnsafeArrayData] + assert(unsafeArray1.getSizeInBytes == 4 * 2 + 8 * 2) + assert(unsafeArray1.numElements() == 2) + assert(unsafeArray1.getLong(0) == 1L) + assert(unsafeArray1.getLong(1) == 2L) + + val unsafeArray2 = unsafeRow.getArray(1).asInstanceOf[UnsafeArrayData] + assert(unsafeArray2.numElements() == 1) + + val nestedArray = unsafeArray2.getArray(0).asInstanceOf[UnsafeArrayData] + assert(nestedArray.getSizeInBytes == 4 * 2 + 8 * 2) + assert(nestedArray.numElements() == 2) + assert(nestedArray.getLong(0) == 3L) + assert(nestedArray.getLong(1) == 4L) + + assert(unsafeArray2.getSizeInBytes == 4 + 4 + nestedArray.getSizeInBytes) + + val array1Size = roundedSize(4 + unsafeArray1.getSizeInBytes) + val array2Size = roundedSize(4 + unsafeArray2.getSizeInBytes) + assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size) + } + + test("basic conversion with map type") { + def createArray(values: Any*): ArrayData = new GenericArrayData(values.toArray) + + def testIntLongMap(map: UnsafeMapData, keys: Array[Int], values: Array[Long]): Unit = { + val numElements = keys.length + assert(map.numElements() == numElements) + + val keyArray = map.keys + assert(keyArray.getSizeInBytes == 4 * numElements + 4 * numElements) + assert(keyArray.numElements() == numElements) + keys.zipWithIndex.foreach { case (key, i) => + assert(keyArray.getInt(i) == key) + } + + val valueArray = map.values + assert(valueArray.getSizeInBytes == 4 * numElements + 8 * numElements) + assert(valueArray.numElements() == numElements) + values.zipWithIndex.foreach { case (value, i) => + assert(valueArray.getLong(i) == value) + } + + assert(map.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes) + } + + val fieldTypes: Array[DataType] = Array( + MapType(IntegerType, LongType), + MapType(IntegerType, MapType(IntegerType, LongType)) + ) + val converter = UnsafeProjection.create(fieldTypes) + + val map1 = new ArrayBasedMapData(createArray(1, 2), createArray(3L, 4L)) + + val innerMap = new ArrayBasedMapData(createArray(5, 6), createArray(7L, 8L)) + val map2 = new ArrayBasedMapData(createArray(9), createArray(innerMap)) + + val row = new GenericMutableRow(fieldTypes.length) + row.update(0, map1) + row.update(1, map2) + + val unsafeRow: UnsafeRow = converter.apply(row) + assert(unsafeRow.numFields() == 2) + + val unsafeMap1 = unsafeRow.getMap(0).asInstanceOf[UnsafeMapData] + testIntLongMap(unsafeMap1, Array(1, 2), Array(3L, 4L)) + + val unsafeMap2 = unsafeRow.getMap(1).asInstanceOf[UnsafeMapData] + assert(unsafeMap2.numElements() == 1) + + val keyArray = unsafeMap2.keys + assert(keyArray.getSizeInBytes == 4 + 4) + assert(keyArray.numElements() == 1) + assert(keyArray.getInt(0) == 9) + + val valueArray = unsafeMap2.values + assert(valueArray.numElements() == 1) + val nestedMap = valueArray.getMap(0).asInstanceOf[UnsafeMapData] + testIntLongMap(nestedMap, Array(5, 6), Array(7L, 8L)) + assert(valueArray.getSizeInBytes == 4 + 8 + nestedMap.getSizeInBytes) + + assert(unsafeMap2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes) + + val map1Size = roundedSize(8 + unsafeMap1.getSizeInBytes) + val map2Size = roundedSize(8 + unsafeMap2.getSizeInBytes) + assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size) + } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 916825d007cc8..f6c9b87778f8f 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -43,6 +43,9 @@ public final class UTF8String implements Comparable, Serializable { private final long offset; private final int numBytes; + public Object getBaseObject() { return base; } + public long getBaseOffset() { return offset; } + private static int[] bytesOfCodePointInUTF8 = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, From 98d6d9c7a996f5456eb2653bb96985a1a05f4ce1 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Mon, 3 Aug 2015 00:15:24 -0700 Subject: [PATCH 093/340] [SPARK-9549][SQL] fix bugs in expressions JIRA: https://issues.apache.org/jira/browse/SPARK-9549 This PR fix the following bugs: 1. `UnaryMinus`'s codegen version would fail to compile when the input is `Long.MinValue` 2. `BinaryComparison` would fail to compile in codegen mode when comparing Boolean types. 3. `AddMonth` would fail if passed a huge negative month, which would lead accessing negative index of `monthDays` array. 4. `Nanvl` with different type operands. Author: Yijie Shen Closes #7882 from yjshen/minor_bug_fix and squashes the following commits: 41bbd2c [Yijie Shen] fix bug in Nanvl type coercion 3dee204 [Yijie Shen] address comments 4fa5de0 [Yijie Shen] fix bugs in expressions --- .../catalyst/analysis/HiveTypeCoercion.scala | 5 ++ .../sql/catalyst/expressions/arithmetic.scala | 9 ++- .../sql/catalyst/expressions/predicates.scala | 1 + .../sql/catalyst/util/DateTimeUtils.scala | 7 ++- .../analysis/HiveTypeCoercionSuite.scala | 12 ++++ .../ArithmeticExpressionSuite.scala | 6 +- .../expressions/DateExpressionsSuite.scala | 2 + .../catalyst/expressions/PredicateSuite.scala | 62 +++++++++---------- .../spark/sql/ColumnExpressionSuite.scala | 18 +++--- 9 files changed, 79 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 603afc4032a37..422d423747026 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -562,6 +562,11 @@ object HiveTypeCoercion { case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => c } + + case NaNvl(l, r) if l.dataType == DoubleType && r.dataType == FloatType => + NaNvl(l, Cast(r, DoubleType)) + case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType => + NaNvl(Cast(l, DoubleType), r) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 6f8f4dd230f12..0891b55494710 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -36,7 +36,14 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") - case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))") + case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => { + val originValue = ctx.freshName("origin") + // codegen would fail to compile if we just write (-($c)) + // for example, we could not write --9223372036854775808L in code + s""" + ${ctx.javaType(dt)} $originValue = (${ctx.javaType(dt)})($eval); + ${ev.primitive} = (${ctx.javaType(dt)})(-($originValue)); + """}) case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index ab7d3afce8f2e..b69bbabee7e81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -227,6 +227,7 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { if (ctx.isPrimitiveType(left.dataType) + && left.dataType != BooleanType // java boolean doesn't support > or < operator && left.dataType != FloatType && left.dataType != DoubleType) { // faster version diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 6a98f4d9c54bc..f645eb5f7bb01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -614,8 +614,9 @@ object DateTimeUtils { */ def dateAddMonths(days: Int, months: Int): Int = { val absoluteMonth = (getYear(days) - YearZero) * 12 + getMonth(days) - 1 + months - val currentMonthInYear = absoluteMonth % 12 - val currentYear = absoluteMonth / 12 + val nonNegativeMonth = if (absoluteMonth >= 0) absoluteMonth else 0 + val currentMonthInYear = nonNegativeMonth % 12 + val currentYear = nonNegativeMonth / 12 val leapDay = if (currentMonthInYear == 1 && isLeapYear(currentYear + YearZero)) 1 else 0 val lastDayOfMonth = monthDays(currentMonthInYear) + leapDay @@ -626,7 +627,7 @@ object DateTimeUtils { } else { dayOfMonth } - firstDayOfMonth(absoluteMonth) + currentDayInMonth - 1 + firstDayOfMonth(nonNegativeMonth) + currentDayInMonth - 1 } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 70608771dd110..cbdf453f600ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -251,6 +251,18 @@ class HiveTypeCoercionSuite extends PlanTest { :: Nil)) } + test("nanvl casts") { + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)), + NaNvl(Cast(Literal.create(1.0, FloatType), DoubleType), Literal.create(1.0, DoubleType))) + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, FloatType)), + NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0, FloatType), DoubleType))) + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)), + NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType))) + } + test("type coercion for If") { val rule = HiveTypeCoercion.IfCoercion ruleTest(rule, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index d03b0fbbfb2b2..0bae8fe2fd8aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.types._ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -56,6 +56,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(UnaryMinus(input), convert(-1)) checkEvaluation(UnaryMinus(Literal.create(null, dataType)), null) } + checkEvaluation(UnaryMinus(Literal(Long.MinValue)), Long.MinValue) + checkEvaluation(UnaryMinus(Literal(Int.MinValue)), Int.MinValue) + checkEvaluation(UnaryMinus(Literal(Short.MinValue)), Short.MinValue) + checkEvaluation(UnaryMinus(Literal(Byte.MinValue)), Byte.MinValue) } test("- (Minus)") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 3bff8e012a763..e6e8790e90926 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -280,6 +280,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(AddMonths(Literal.create(null, DateType), Literal(1)), null) checkEvaluation(AddMonths(Literal.create(null, DateType), Literal.create(null, IntegerType)), null) + checkEvaluation( + AddMonths(Literal(Date.valueOf("2015-01-30")), Literal(Int.MinValue)), -7293498) } test("months_between") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 0bc2812a5dc83..d7eb13c50b134 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -136,60 +136,60 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true) } - private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d).map(Literal(_)) + private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false).map(Literal(_)) private val largeValues = - Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN).map(Literal(_)) + Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN, true).map(Literal(_)) private val equalValues1 = - Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_)) + Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_)) private val equalValues2 = - Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_)) + Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_)) - test("BinaryComparison: <") { + test("BinaryComparison: lessThan") { for (i <- 0 until smallValues.length) { - checkEvaluation(smallValues(i) < largeValues(i), true) - checkEvaluation(equalValues1(i) < equalValues2(i), false) - checkEvaluation(largeValues(i) < smallValues(i), false) + checkEvaluation(LessThan(smallValues(i), largeValues(i)), true) + checkEvaluation(LessThan(equalValues1(i), equalValues2(i)), false) + checkEvaluation(LessThan(largeValues(i), smallValues(i)), false) } } - test("BinaryComparison: <=") { + test("BinaryComparison: LessThanOrEqual") { for (i <- 0 until smallValues.length) { - checkEvaluation(smallValues(i) <= largeValues(i), true) - checkEvaluation(equalValues1(i) <= equalValues2(i), true) - checkEvaluation(largeValues(i) <= smallValues(i), false) + checkEvaluation(LessThanOrEqual(smallValues(i), largeValues(i)), true) + checkEvaluation(LessThanOrEqual(equalValues1(i), equalValues2(i)), true) + checkEvaluation(LessThanOrEqual(largeValues(i), smallValues(i)), false) } } - test("BinaryComparison: >") { + test("BinaryComparison: GreaterThan") { for (i <- 0 until smallValues.length) { - checkEvaluation(smallValues(i) > largeValues(i), false) - checkEvaluation(equalValues1(i) > equalValues2(i), false) - checkEvaluation(largeValues(i) > smallValues(i), true) + checkEvaluation(GreaterThan(smallValues(i), largeValues(i)), false) + checkEvaluation(GreaterThan(equalValues1(i), equalValues2(i)), false) + checkEvaluation(GreaterThan(largeValues(i), smallValues(i)), true) } } - test("BinaryComparison: >=") { + test("BinaryComparison: GreaterThanOrEqual") { for (i <- 0 until smallValues.length) { - checkEvaluation(smallValues(i) >= largeValues(i), false) - checkEvaluation(equalValues1(i) >= equalValues2(i), true) - checkEvaluation(largeValues(i) >= smallValues(i), true) + checkEvaluation(GreaterThanOrEqual(smallValues(i), largeValues(i)), false) + checkEvaluation(GreaterThanOrEqual(equalValues1(i), equalValues2(i)), true) + checkEvaluation(GreaterThanOrEqual(largeValues(i), smallValues(i)), true) } } - test("BinaryComparison: ===") { + test("BinaryComparison: EqualTo") { for (i <- 0 until smallValues.length) { - checkEvaluation(smallValues(i) === largeValues(i), false) - checkEvaluation(equalValues1(i) === equalValues2(i), true) - checkEvaluation(largeValues(i) === smallValues(i), false) + checkEvaluation(EqualTo(smallValues(i), largeValues(i)), false) + checkEvaluation(EqualTo(equalValues1(i), equalValues2(i)), true) + checkEvaluation(EqualTo(largeValues(i), smallValues(i)), false) } } - test("BinaryComparison: <=>") { + test("BinaryComparison: EqualNullSafe") { for (i <- 0 until smallValues.length) { - checkEvaluation(smallValues(i) <=> largeValues(i), false) - checkEvaluation(equalValues1(i) <=> equalValues2(i), true) - checkEvaluation(largeValues(i) <=> smallValues(i), false) + checkEvaluation(EqualNullSafe(smallValues(i), largeValues(i)), false) + checkEvaluation(EqualNullSafe(equalValues1(i), equalValues2(i)), true) + checkEvaluation(EqualNullSafe(largeValues(i), smallValues(i)), false) } } @@ -209,8 +209,8 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { nullTest(GreaterThanOrEqual) nullTest(EqualTo) - checkEvaluation(normalInt <=> nullInt, false) - checkEvaluation(nullInt <=> normalInt, false) - checkEvaluation(nullInt <=> nullInt, true) + checkEvaluation(EqualNullSafe(normalInt, nullInt), false) + checkEvaluation(EqualNullSafe(nullInt, normalInt), false) + checkEvaluation(EqualNullSafe(nullInt, nullInt), true) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index eb64684ae0fd9..35ca0b4c7cc21 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -227,20 +227,24 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { test("nanvl") { val testData = ctx.createDataFrame(ctx.sparkContext.parallelize( - Row(null, 3.0, Double.NaN, Double.PositiveInfinity) :: Nil), + Row(null, 3.0, Double.NaN, Double.PositiveInfinity, 1.0f, 4) :: Nil), StructType(Seq(StructField("a", DoubleType), StructField("b", DoubleType), - StructField("c", DoubleType), StructField("d", DoubleType)))) + StructField("c", DoubleType), StructField("d", DoubleType), + StructField("e", FloatType), StructField("f", IntegerType)))) checkAnswer( testData.select( - nanvl($"a", lit(5)), nanvl($"b", lit(10)), - nanvl($"c", lit(null).cast(DoubleType)), nanvl($"d", lit(10))), - Row(null, 3.0, null, Double.PositiveInfinity) + nanvl($"a", lit(5)), nanvl($"b", lit(10)), nanvl(lit(10), $"b"), + nanvl($"c", lit(null).cast(DoubleType)), nanvl($"d", lit(10)), + nanvl($"b", $"e"), nanvl($"e", $"f")), + Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0) ) testData.registerTempTable("t") checkAnswer( - ctx.sql("select nanvl(a, 5), nanvl(b, 10), nanvl(c, null), nanvl(d, 10) from t"), - Row(null, 3.0, null, Double.PositiveInfinity) + ctx.sql( + "select nanvl(a, 5), nanvl(b, 10), nanvl(10, b), nanvl(c, null), nanvl(d, 10), " + + " nanvl(b, e), nanvl(e, f) from t"), + Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0) ) } From 1ebd41b141a95ec264bd2dd50f0fe24cd459035d Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 3 Aug 2015 00:23:08 -0700 Subject: [PATCH 094/340] [SPARK-9240] [SQL] Hybrid aggregate operator using unsafe row This PR adds a base aggregation iterator `AggregationIterator`, which is used to create `SortBasedAggregationIterator` (for sort-based aggregation) and `UnsafeHybridAggregationIterator` (first it tries hash-based aggregation and falls back to the sort-based aggregation (using external sorter) if we cannot allocate memory for the map). With these two iterators, we will not need existing iterators and I am removing those. Also, we can use a single physical `Aggregate` operator and it internally determines what iterators to used. https://issues.apache.org/jira/browse/SPARK-9240 Author: Yin Huai Closes #7813 from yhuai/AggregateOperator and squashes the following commits: e317e2b [Yin Huai] Remove unnecessary change. 74d93c5 [Yin Huai] Merge remote-tracking branch 'upstream/master' into AggregateOperator ba6afbc [Yin Huai] Add a little bit more comments. c9cf3b6 [Yin Huai] update 0f1b06f [Yin Huai] Remove unnecessary code. 21fd15f [Yin Huai] Remove unnecessary change. 964f88b [Yin Huai] Implement fallback strategy. b1ea5cf [Yin Huai] wip 7fcbd87 [Yin Huai] Add a flag to control what iterator to use. 533d5b2 [Yin Huai] Prepare for fallback! 33b7022 [Yin Huai] wip bd9282b [Yin Huai] UDAFs now supports UnsafeRow. f52ee53 [Yin Huai] wip 3171f44 [Yin Huai] wip d2c45a0 [Yin Huai] wip f60cc83 [Yin Huai] Also check input schema. af32210 [Yin Huai] Check iter.hasNext before we create an iterator because the constructor of the iterato will read at least one row from a non-empty input iter. 299008c [Yin Huai] First round cleanup. 3915bac [Yin Huai] Create a base iterator class for aggregation iterators and add the initial version of the hybrid iterator. --- .../expressions/aggregate/interfaces.scala | 19 +- .../sql/execution/aggregate/Aggregate.scala | 182 +++++ .../aggregate/AggregationIterator.scala | 490 +++++++++++++ .../SortBasedAggregationIterator.scala | 236 +++++++ .../UnsafeHybridAggregationIterator.scala | 398 +++++++++++ .../aggregate/aggregateOperators.scala | 175 ----- .../aggregate/sortBasedIterators.scala | 664 ------------------ .../spark/sql/execution/aggregate/udaf.scala | 269 ++++++- .../spark/sql/execution/aggregate/utils.scala | 99 +-- .../spark/sql/execution/basicOperators.scala | 1 - .../org/apache/spark/sql/SQLQuerySuite.scala | 10 +- .../execution/SparkSqlSerializer2Suite.scala | 9 +- .../execution/AggregationQuerySuite.scala | 118 ++-- 13 files changed, 1697 insertions(+), 973 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index d08f553cefe8c..4abfdfe87d5e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -110,7 +110,11 @@ abstract class AggregateFunction2 * buffer value of `avg(x)` will be 0 and the position of the first buffer value of `avg(y)` * will be 2. */ - var mutableBufferOffset: Int = 0 + protected var mutableBufferOffset: Int = 0 + + def withNewMutableBufferOffset(newMutableBufferOffset: Int): Unit = { + mutableBufferOffset = newMutableBufferOffset + } /** * The offset of this function's start buffer value in the @@ -126,7 +130,11 @@ abstract class AggregateFunction2 * buffer value of `avg(x)` will be 1 and the position of the first buffer value of `avg(y)` * will be 3 (position 0 is used for the value of key`). */ - var inputBufferOffset: Int = 0 + protected var inputBufferOffset: Int = 0 + + def withNewInputBufferOffset(newInputBufferOffset: Int): Unit = { + inputBufferOffset = newInputBufferOffset + } /** The schema of the aggregation buffer. */ def bufferSchema: StructType @@ -195,11 +203,8 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable w override def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes) override def initialize(buffer: MutableRow): Unit = { - var i = 0 - while (i < bufferAttributes.size) { - buffer(i + mutableBufferOffset) = initialValues(i).eval() - i += 1 - } + throw new UnsupportedOperationException( + "AlgebraicAggregate's initialize should not be called directly") } override final def update(buffer: MutableRow, input: InternalRow): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala new file mode 100644 index 0000000000000..cf568dc048674 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} +import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode} +import org.apache.spark.sql.types.StructType + +/** + * An Aggregate Operator used to evaluate [[AggregateFunction2]]. Based on the data types + * of the grouping expressions and aggregate functions, it determines if it uses + * sort-based aggregation and hybrid (hash-based with sort-based as the fallback) to + * process input rows. + */ +case class Aggregate( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + private[this] val allAggregateExpressions = + nonCompleteAggregateExpressions ++ completeAggregateExpressions + + private[this] val hasNonAlgebricAggregateFunctions = + !allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) + + // Use the hybrid iterator if (1) unsafe is enabled, (2) the schemata of + // grouping key and aggregation buffer is supported; and (3) all + // aggregate functions are algebraic. + private[this] val supportsHybridIterator: Boolean = { + val aggregationBufferSchema: StructType = + StructType.fromAttributes( + allAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)) + val groupKeySchema: StructType = + StructType.fromAttributes(groupingExpressions.map(_.toAttribute)) + + val schemaSupportsUnsafe: Boolean = + UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && + UnsafeProjection.canSupport(groupKeySchema) + + // TODO: Use the hybrid iterator for non-algebric aggregate functions. + sqlContext.conf.unsafeEnabled && schemaSupportsUnsafe && !hasNonAlgebricAggregateFunctions + } + + // We need to use sorted input if we have grouping expressions, and + // we cannot use the hybrid iterator or the hybrid is disabled. + private[this] val requiresSortedInput: Boolean = { + groupingExpressions.nonEmpty && !supportsHybridIterator + } + + override def canProcessUnsafeRows: Boolean = !hasNonAlgebricAggregateFunctions + + // If result expressions' data types are all fixed length, we generate unsafe rows + // (We have this requirement instead of check the result of UnsafeProjection.canSupport + // is because we use a mutable projection to generate the result). + override def outputsUnsafeRows: Boolean = { + // resultExpressions.map(_.dataType).forall(UnsafeRow.isFixedLength) + // TODO: Supports generating UnsafeRows. We can just re-enable the line above and fix + // any issue we get. + false + } + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.length == 0 => AllTuples :: Nil + case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + if (requiresSortedInput) { + // TODO: We should not sort the input rows if they are just in reversed order. + groupingExpressions.map(SortOrder(_, Ascending)) :: Nil + } else { + Seq.fill(children.size)(Nil) + } + } + + override def outputOrdering: Seq[SortOrder] = { + if (requiresSortedInput) { + // It is possible that the child.outputOrdering starts with the required + // ordering expressions (e.g. we require [a] as the sort expression and the + // child's outputOrdering is [a, b]). We can only guarantee the output rows + // are sorted by values of groupingExpressions. + groupingExpressions.map(SortOrder(_, Ascending)) + } else { + Nil + } + } + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + // Because the constructor of an aggregation iterator will read at least the first row, + // we need to get the value of iter.hasNext first. + val hasInput = iter.hasNext + val useHybridIterator = + hasInput && + supportsHybridIterator && + groupingExpressions.nonEmpty + if (useHybridIterator) { + UnsafeHybridAggregationIterator.createFromInputIterator( + groupingExpressions, + nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection _, + child.output, + iter, + outputsUnsafeRows) + } else { + if (!hasInput && groupingExpressions.nonEmpty) { + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. + Iterator[InternalRow]() + } else { + val outputIter = SortBasedAggregationIterator.createFromInputIterator( + groupingExpressions, + nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection _ , + newProjection _, + child.output, + iter, + outputsUnsafeRows) + if (!hasInput && groupingExpressions.isEmpty) { + // There is no input and there is no grouping expressions. + // We need to output a single row as the output. + Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) + } else { + outputIter + } + } + } + } + } + + override def simpleString: String = { + val iterator = if (supportsHybridIterator && groupingExpressions.nonEmpty) { + classOf[UnsafeHybridAggregationIterator].getSimpleName + } else { + classOf[SortBasedAggregationIterator].getSimpleName + } + + s"""NewAggregate with $iterator ${groupingExpressions} ${allAggregateExpressions}""" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala new file mode 100644 index 0000000000000..abca373b0c4f9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -0,0 +1,490 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.unsafe.KVIterator + +import scala.collection.mutable.ArrayBuffer + +/** + * The base class of [[SortBasedAggregationIterator]] and [[UnsafeHybridAggregationIterator]]. + * It mainly contains two parts: + * 1. It initializes aggregate functions. + * 2. It creates two functions, `processRow` and `generateOutput` based on [[AggregateMode]] of + * its aggregate functions. `processRow` is the function to handle an input. `generateOutput` + * is used to generate result. + */ +abstract class AggregationIterator( + groupingKeyAttributes: Seq[Attribute], + valueAttributes: Seq[Attribute], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + outputsUnsafeRows: Boolean) + extends Iterator[InternalRow] with Logging { + + /////////////////////////////////////////////////////////////////////////// + // Initializing functions. + /////////////////////////////////////////////////////////////////////////// + + // An Seq of all AggregateExpressions. + // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final + // are at the beginning of the allAggregateExpressions. + protected val allAggregateExpressions = + nonCompleteAggregateExpressions ++ completeAggregateExpressions + + require( + allAggregateExpressions.map(_.mode).distinct.length <= 2, + s"$allAggregateExpressions are not supported becuase they have more than 2 distinct modes.") + + /** + * The distinct modes of AggregateExpressions. Right now, we can handle the following mode: + * - Partial-only: all AggregateExpressions have the mode of Partial; + * - PartialMerge-only: all AggregateExpressions have the mode of PartialMerge); + * - Final-only: all AggregateExpressions have the mode of Final; + * - Final-Complete: some AggregateExpressions have the mode of Final and + * others have the mode of Complete; + * - Complete-only: nonCompleteAggregateExpressions is empty and we have AggregateExpressions + * with mode Complete in completeAggregateExpressions; and + * - Grouping-only: there is no AggregateExpression. + */ + protected val aggregationMode: (Option[AggregateMode], Option[AggregateMode]) = + nonCompleteAggregateExpressions.map(_.mode).distinct.headOption -> + completeAggregateExpressions.map(_.mode).distinct.headOption + + // Initialize all AggregateFunctions by binding references if necessary, + // and set inputBufferOffset and mutableBufferOffset. + protected val allAggregateFunctions: Array[AggregateFunction2] = { + var mutableBufferOffset = 0 + var inputBufferOffset: Int = initialInputBufferOffset + val functions = new Array[AggregateFunction2](allAggregateExpressions.length) + var i = 0 + while (i < allAggregateExpressions.length) { + val func = allAggregateExpressions(i).aggregateFunction + val funcWithBoundReferences = allAggregateExpressions(i).mode match { + case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] => + // We need to create BoundReferences if the function is not an + // AlgebraicAggregate (it does not support code-gen) and the mode of + // this function is Partial or Complete because we will call eval of this + // function's children in the update method of this aggregate function. + // Those eval calls require BoundReferences to work. + BindReferences.bindReference(func, valueAttributes) + case _ => + // We only need to set inputBufferOffset for aggregate functions with mode + // PartialMerge and Final. + func.withNewInputBufferOffset(inputBufferOffset) + inputBufferOffset += func.bufferSchema.length + func + } + // Set mutableBufferOffset for this function. It is important that setting + // mutableBufferOffset happens after all potential bindReference operations + // because bindReference will create a new instance of the function. + funcWithBoundReferences.withNewMutableBufferOffset(mutableBufferOffset) + mutableBufferOffset += funcWithBoundReferences.bufferSchema.length + functions(i) = funcWithBoundReferences + i += 1 + } + functions + } + + // Positions of those non-algebraic aggregate functions in allAggregateFunctions. + // For example, we have func1, func2, func3, func4 in aggregateFunctions, and + // func2 and func3 are non-algebraic aggregate functions. + // nonAlgebraicAggregateFunctionPositions will be [1, 2]. + private[this] val allNonAlgebraicAggregateFunctionPositions: Array[Int] = { + val positions = new ArrayBuffer[Int]() + var i = 0 + while (i < allAggregateFunctions.length) { + allAggregateFunctions(i) match { + case agg: AlgebraicAggregate => + case _ => positions += i + } + i += 1 + } + positions.toArray + } + + // All AggregateFunctions functions with mode Partial, PartialMerge, or Final. + private[this] val nonCompleteAggregateFunctions: Array[AggregateFunction2] = + allAggregateFunctions.take(nonCompleteAggregateExpressions.length) + + // All non-algebraic aggregate functions with mode Partial, PartialMerge, or Final. + private[this] val nonCompleteNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = + nonCompleteAggregateFunctions.collect { + case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func + } + + // The projection used to initialize buffer values for all AlgebraicAggregates. + private[this] val algebraicInitialProjection = { + val initExpressions = allAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.initialValues + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + newMutableProjection(initExpressions, Nil)() + } + + // All non-Algebraic AggregateFunctions. + private[this] val allNonAlgebraicAggregateFunctions = + allNonAlgebraicAggregateFunctionPositions.map(allAggregateFunctions) + + /////////////////////////////////////////////////////////////////////////// + // Methods and fields used by sub-classes. + /////////////////////////////////////////////////////////////////////////// + + // Initializing functions used to process a row. + protected val processRow: (MutableRow, InternalRow) => Unit = { + val rowToBeProcessed = new JoinedRow + val aggregationBufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) + aggregationMode match { + // Partial-only + case (Some(Partial), None) => + val updateExpressions = nonCompleteAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + val algebraicUpdateProjection = + newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() + + (currentBuffer: MutableRow, row: InternalRow) => { + algebraicUpdateProjection.target(currentBuffer) + // Process all algebraic aggregate functions. + algebraicUpdateProjection(rowToBeProcessed(currentBuffer, row)) + // Process all non-algebraic aggregate functions. + var i = 0 + while (i < nonCompleteNonAlgebraicAggregateFunctions.length) { + nonCompleteNonAlgebraicAggregateFunctions(i).update(currentBuffer, row) + i += 1 + } + } + + // PartialMerge-only or Final-only + case (Some(PartialMerge), None) | (Some(Final), None) => + val inputAggregationBufferSchema = if (initialInputBufferOffset == 0) { + // If initialInputBufferOffset, the input value does not contain + // grouping keys. + // This part is pretty hacky. + allAggregateFunctions.flatMap(_.cloneBufferAttributes).toSeq + } else { + groupingKeyAttributes ++ allAggregateFunctions.flatMap(_.cloneBufferAttributes) + } + // val inputAggregationBufferSchema = + // groupingKeyAttributes ++ + // allAggregateFunctions.flatMap(_.cloneBufferAttributes) + val mergeExpressions = nonCompleteAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + // This projection is used to merge buffer values for all AlgebraicAggregates. + val algebraicMergeProjection = + newMutableProjection( + mergeExpressions, + aggregationBufferSchema ++ inputAggregationBufferSchema)() + + (currentBuffer: MutableRow, row: InternalRow) => { + // Process all algebraic aggregate functions. + algebraicMergeProjection.target(currentBuffer)(rowToBeProcessed(currentBuffer, row)) + // Process all non-algebraic aggregate functions. + var i = 0 + while (i < nonCompleteNonAlgebraicAggregateFunctions.length) { + nonCompleteNonAlgebraicAggregateFunctions(i).merge(currentBuffer, row) + i += 1 + } + } + + // Final-Complete + case (Some(Final), Some(Complete)) => + val completeAggregateFunctions: Array[AggregateFunction2] = + allAggregateFunctions.takeRight(completeAggregateExpressions.length) + // All non-algebraic aggregate functions with mode Complete. + val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = + completeAggregateFunctions.collect { + case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func + } + + // The first initialInputBufferOffset values of the input aggregation buffer is + // for grouping expressions and distinct columns. + val groupingAttributesAndDistinctColumns = valueAttributes.take(initialInputBufferOffset) + + val completeOffsetExpressions = + Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) + // We do not touch buffer values of aggregate functions with the Final mode. + val finalOffsetExpressions = + Seq.fill(nonCompleteAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) + + val mergeInputSchema = + aggregationBufferSchema ++ + groupingAttributesAndDistinctColumns ++ + nonCompleteAggregateFunctions.flatMap(_.cloneBufferAttributes) + val mergeExpressions = + nonCompleteAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } ++ completeOffsetExpressions + val finalAlgebraicMergeProjection = + newMutableProjection(mergeExpressions, mergeInputSchema)() + + val updateExpressions = + finalOffsetExpressions ++ completeAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + val completeAlgebraicUpdateProjection = + newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() + + (currentBuffer: MutableRow, row: InternalRow) => { + val input = rowToBeProcessed(currentBuffer, row) + // For all aggregate functions with mode Complete, update buffers. + completeAlgebraicUpdateProjection.target(currentBuffer)(input) + var i = 0 + while (i < completeNonAlgebraicAggregateFunctions.length) { + completeNonAlgebraicAggregateFunctions(i).update(currentBuffer, row) + i += 1 + } + + // For all aggregate functions with mode Final, merge buffers. + finalAlgebraicMergeProjection.target(currentBuffer)(input) + i = 0 + while (i < nonCompleteNonAlgebraicAggregateFunctions.length) { + nonCompleteNonAlgebraicAggregateFunctions(i).merge(currentBuffer, row) + i += 1 + } + } + + // Complete-only + case (None, Some(Complete)) => + val completeAggregateFunctions: Array[AggregateFunction2] = + allAggregateFunctions.takeRight(completeAggregateExpressions.length) + // All non-algebraic aggregate functions with mode Complete. + val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = + completeAggregateFunctions.collect { + case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func + } + + val updateExpressions = + completeAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + val completeAlgebraicUpdateProjection = + newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() + + (currentBuffer: MutableRow, row: InternalRow) => { + val input = rowToBeProcessed(currentBuffer, row) + // For all aggregate functions with mode Complete, update buffers. + completeAlgebraicUpdateProjection.target(currentBuffer)(input) + var i = 0 + while (i < completeNonAlgebraicAggregateFunctions.length) { + completeNonAlgebraicAggregateFunctions(i).update(currentBuffer, row) + i += 1 + } + } + + // Grouping only. + case (None, None) => (currentBuffer: MutableRow, row: InternalRow) => {} + + case other => + sys.error( + s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " + + s"support evaluate modes $other in this iterator.") + } + } + + // Initializing the function used to generate the output row. + protected val generateOutput: (InternalRow, MutableRow) => InternalRow = { + val rowToBeEvaluated = new JoinedRow + val safeOutoutRow = new GenericMutableRow(resultExpressions.length) + val mutableOutput = if (outputsUnsafeRows) { + UnsafeProjection.create(resultExpressions.map(_.dataType).toArray).apply(safeOutoutRow) + } else { + safeOutoutRow + } + + aggregationMode match { + // Partial-only or PartialMerge-only: every output row is basically the values of + // the grouping expressions and the corresponding aggregation buffer. + case (Some(Partial), None) | (Some(PartialMerge), None) => + // Because we cannot copy a joinedRow containing a UnsafeRow (UnsafeRow does not + // support generic getter), we create a mutable projection to output the + // JoinedRow(currentGroupingKey, currentBuffer) + val bufferSchema = nonCompleteAggregateFunctions.flatMap(_.bufferAttributes) + val resultProjection = + newMutableProjection( + groupingKeyAttributes ++ bufferSchema, + groupingKeyAttributes ++ bufferSchema)() + resultProjection.target(mutableOutput) + + (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { + resultProjection(rowToBeEvaluated(currentGroupingKey, currentBuffer)) + // rowToBeEvaluated(currentGroupingKey, currentBuffer) + } + + // Final-only, Complete-only and Final-Complete: every output row contains values representing + // resultExpressions. + case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => + val bufferSchemata = + allAggregateFunctions.flatMap(_.bufferAttributes) + val evalExpressions = allAggregateFunctions.map { + case ae: AlgebraicAggregate => ae.evaluateExpression + case agg: AggregateFunction2 => NoOp + } + val algebraicEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)() + val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes + // TODO: Use unsafe row. + val aggregateResult = new GenericMutableRow(aggregateResultSchema.length) + val resultProjection = + newMutableProjection( + resultExpressions, groupingKeyAttributes ++ aggregateResultSchema)() + resultProjection.target(mutableOutput) + + (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { + // Generate results for all algebraic aggregate functions. + algebraicEvalProjection.target(aggregateResult)(currentBuffer) + // Generate results for all non-algebraic aggregate functions. + var i = 0 + while (i < allNonAlgebraicAggregateFunctions.length) { + aggregateResult.update( + allNonAlgebraicAggregateFunctionPositions(i), + allNonAlgebraicAggregateFunctions(i).eval(currentBuffer)) + i += 1 + } + resultProjection(rowToBeEvaluated(currentGroupingKey, aggregateResult)) + } + + // Grouping-only: we only output values of grouping expressions. + case (None, None) => + val resultProjection = + newMutableProjection(resultExpressions, groupingKeyAttributes)() + resultProjection.target(mutableOutput) + + (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { + resultProjection(currentGroupingKey) + } + + case other => + sys.error( + s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " + + s"support evaluate modes $other in this iterator.") + } + } + + /** Initializes buffer values for all aggregate functions. */ + protected def initializeBuffer(buffer: MutableRow): Unit = { + algebraicInitialProjection.target(buffer)(EmptyRow) + var i = 0 + while (i < allNonAlgebraicAggregateFunctions.length) { + allNonAlgebraicAggregateFunctions(i).initialize(buffer) + i += 1 + } + } + + /** + * Creates a new aggregation buffer and initializes buffer values + * for all aggregate functions. + */ + protected def newBuffer: MutableRow +} + +object AggregationIterator { + def kvIterator( + groupingExpressions: Seq[NamedExpression], + newProjection: (Seq[Expression], Seq[Attribute]) => Projection, + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]): KVIterator[InternalRow, InternalRow] = { + new KVIterator[InternalRow, InternalRow] { + private[this] val groupingKeyGenerator = newProjection(groupingExpressions, inputAttributes) + + private[this] var groupingKey: InternalRow = _ + + private[this] var value: InternalRow = _ + + override def next(): Boolean = { + if (inputIter.hasNext) { + // Read the next input row. + val inputRow = inputIter.next() + // Get groupingKey based on groupingExpressions. + groupingKey = groupingKeyGenerator(inputRow) + // The value is the inputRow. + value = inputRow + true + } else { + false + } + } + + override def getKey(): InternalRow = { + groupingKey + } + + override def getValue(): InternalRow = { + value + } + + override def close(): Unit = { + // Do nothing + } + } + } + + def unsafeKVIterator( + groupingExpressions: Seq[NamedExpression], + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]): KVIterator[UnsafeRow, InternalRow] = { + new KVIterator[UnsafeRow, InternalRow] { + private[this] val groupingKeyGenerator = + UnsafeProjection.create(groupingExpressions, inputAttributes) + + private[this] var groupingKey: UnsafeRow = _ + + private[this] var value: InternalRow = _ + + override def next(): Boolean = { + if (inputIter.hasNext) { + // Read the next input row. + val inputRow = inputIter.next() + // Get groupingKey based on groupingExpressions. + groupingKey = groupingKeyGenerator.apply(inputRow) + // The value is the inputRow. + value = inputRow + true + } else { + false + } + } + + override def getKey(): UnsafeRow = { + groupingKey + } + + override def getValue(): InternalRow = { + value + } + + override def close(): Unit = { + // Do nothing + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala new file mode 100644 index 0000000000000..78bcee16c9d00 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2} +import org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.KVIterator + +/** + * An iterator used to evaluate [[AggregateFunction2]]. It assumes the input rows have been + * sorted by values of [[groupingKeyAttributes]]. + */ +class SortBasedAggregationIterator( + groupingKeyAttributes: Seq[Attribute], + valueAttributes: Seq[Attribute], + inputKVIterator: KVIterator[InternalRow, InternalRow], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + outputsUnsafeRows: Boolean) + extends AggregationIterator( + groupingKeyAttributes, + valueAttributes, + nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + outputsUnsafeRows) { + + override protected def newBuffer: MutableRow = { + val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) + val bufferRowSize: Int = bufferSchema.length + + val genericMutableBuffer = new GenericMutableRow(bufferRowSize) + val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isFixedLength) + + val buffer = if (useUnsafeBuffer) { + val unsafeProjection = + UnsafeProjection.create(bufferSchema.map(_.dataType)) + unsafeProjection.apply(genericMutableBuffer) + } else { + genericMutableBuffer + } + initializeBuffer(buffer) + buffer + } + + /////////////////////////////////////////////////////////////////////////// + // Mutable states for sort based aggregation. + /////////////////////////////////////////////////////////////////////////// + + // The partition key of the current partition. + private[this] var currentGroupingKey: InternalRow = _ + + // The partition key of next partition. + private[this] var nextGroupingKey: InternalRow = _ + + // The first row of next partition. + private[this] var firstRowInNextGroup: InternalRow = _ + + // Indicates if we has new group of rows from the sorted input iterator + private[this] var sortedInputHasNewGroup: Boolean = false + + // The aggregation buffer used by the sort-based aggregation. + private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer + + /** Processes rows in the current group. It will stop when it find a new group. */ + protected def processCurrentSortedGroup(): Unit = { + currentGroupingKey = nextGroupingKey + // Now, we will start to find all rows belonging to this group. + // We create a variable to track if we see the next group. + var findNextPartition = false + // firstRowInNextGroup is the first row of this group. We first process it. + processRow(sortBasedAggregationBuffer, firstRowInNextGroup) + + // The search will stop when we see the next group or there is no + // input row left in the iter. + var hasNext = inputKVIterator.next() + while (!findNextPartition && hasNext) { + // Get the grouping key. + val groupingKey = inputKVIterator.getKey + val currentRow = inputKVIterator.getValue + + // Check if the current row belongs the current input row. + if (currentGroupingKey == groupingKey) { + processRow(sortBasedAggregationBuffer, currentRow) + + hasNext = inputKVIterator.next() + } else { + // We find a new group. + findNextPartition = true + nextGroupingKey = groupingKey.copy() + firstRowInNextGroup = currentRow.copy() + } + } + // We have not seen a new group. It means that there is no new row in the input + // iter. The current group is the last group of the iter. + if (!findNextPartition) { + sortedInputHasNewGroup = false + } + } + + /////////////////////////////////////////////////////////////////////////// + // Iterator's public methods + /////////////////////////////////////////////////////////////////////////// + + override final def hasNext: Boolean = sortedInputHasNewGroup + + override final def next(): InternalRow = { + if (hasNext) { + // Process the current group. + processCurrentSortedGroup() + // Generate output row for the current group. + val outputRow = generateOutput(currentGroupingKey, sortBasedAggregationBuffer) + // Initialize buffer values for the next group. + initializeBuffer(sortBasedAggregationBuffer) + + outputRow + } else { + // no more result + throw new NoSuchElementException + } + } + + protected def initialize(): Unit = { + if (inputKVIterator.next()) { + initializeBuffer(sortBasedAggregationBuffer) + + nextGroupingKey = inputKVIterator.getKey().copy() + firstRowInNextGroup = inputKVIterator.getValue().copy() + + sortedInputHasNewGroup = true + } else { + // This inputIter is empty. + sortedInputHasNewGroup = false + } + } + + initialize() + + def outputForEmptyGroupingKeyWithoutInput(): InternalRow = { + initializeBuffer(sortBasedAggregationBuffer) + generateOutput(new GenericInternalRow(0), sortBasedAggregationBuffer) + } +} + +object SortBasedAggregationIterator { + // scalastyle:off + def createFromInputIterator( + groupingExprs: Seq[NamedExpression], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + newProjection: (Seq[Expression], Seq[Attribute]) => Projection, + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow], + outputsUnsafeRows: Boolean): SortBasedAggregationIterator = { + val kvIterator = if (UnsafeProjection.canSupport(groupingExprs)) { + AggregationIterator.unsafeKVIterator( + groupingExprs, + inputAttributes, + inputIter).asInstanceOf[KVIterator[InternalRow, InternalRow]] + } else { + AggregationIterator.kvIterator(groupingExprs, newProjection, inputAttributes, inputIter) + } + + new SortBasedAggregationIterator( + groupingExprs.map(_.toAttribute), + inputAttributes, + kvIterator, + nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + outputsUnsafeRows) + } + + def createFromKVIterator( + groupingKeyAttributes: Seq[Attribute], + valueAttributes: Seq[Attribute], + inputKVIterator: KVIterator[InternalRow, InternalRow], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + outputsUnsafeRows: Boolean): SortBasedAggregationIterator = { + new SortBasedAggregationIterator( + groupingKeyAttributes, + valueAttributes, + inputKVIterator, + nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + outputsUnsafeRows) + } + // scalastyle:on +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala new file mode 100644 index 0000000000000..37d34eb7ccf09 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala @@ -0,0 +1,398 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.execution.{UnsafeKeyValueSorter, UnsafeFixedWidthAggregationMap} +import org.apache.spark.unsafe.KVIterator +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.types.StructType + +/** + * An iterator used to evaluate [[AggregateFunction2]]. + * It first tries to use in-memory hash-based aggregation. If we cannot allocate more + * space for the hash map, we spill the sorted map entries, free the map, and then + * switch to sort-based aggregation. + */ +class UnsafeHybridAggregationIterator( + groupingKeyAttributes: Seq[Attribute], + valueAttributes: Seq[Attribute], + inputKVIterator: KVIterator[UnsafeRow, InternalRow], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + outputsUnsafeRows: Boolean) + extends AggregationIterator( + groupingKeyAttributes, + valueAttributes, + nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + outputsUnsafeRows) { + + require(groupingKeyAttributes.nonEmpty) + + /////////////////////////////////////////////////////////////////////////// + // Unsafe Aggregation buffers + /////////////////////////////////////////////////////////////////////////// + + // This is the Unsafe Aggregation Map used to store all buffers. + private[this] val buffers = new UnsafeFixedWidthAggregationMap( + newBuffer, + StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)), + StructType.fromAttributes(groupingKeyAttributes), + TaskContext.get.taskMemoryManager(), + SparkEnv.get.shuffleMemoryManager, + 1024 * 16, // initial capacity + SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m"), + false // disable tracking of performance metrics + ) + + override protected def newBuffer: UnsafeRow = { + val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) + val bufferRowSize: Int = bufferSchema.length + + val genericMutableBuffer = new GenericMutableRow(bufferRowSize) + val unsafeProjection = + UnsafeProjection.create(bufferSchema.map(_.dataType)) + val buffer = unsafeProjection.apply(genericMutableBuffer) + initializeBuffer(buffer) + buffer + } + + /////////////////////////////////////////////////////////////////////////// + // Methods and variables related to switching to sort-based aggregation + /////////////////////////////////////////////////////////////////////////// + private[this] var sortBased = false + + private[this] var sortBasedAggregationIterator: SortBasedAggregationIterator = _ + + // The value part of the input KV iterator is used to store original input values of + // aggregate functions, we need to convert them to aggregation buffers. + private def processOriginalInput( + firstKey: UnsafeRow, + firstValue: InternalRow): KVIterator[UnsafeRow, UnsafeRow] = { + new KVIterator[UnsafeRow, UnsafeRow] { + private[this] var isFirstRow = true + + private[this] var groupingKey: UnsafeRow = _ + + private[this] val buffer: UnsafeRow = newBuffer + + override def next(): Boolean = { + initializeBuffer(buffer) + if (isFirstRow) { + isFirstRow = false + groupingKey = firstKey + processRow(buffer, firstValue) + + true + } else if (inputKVIterator.next()) { + groupingKey = inputKVIterator.getKey() + val value = inputKVIterator.getValue() + processRow(buffer, value) + + true + } else { + false + } + } + + override def getKey(): UnsafeRow = { + groupingKey + } + + override def getValue(): UnsafeRow = { + buffer + } + + override def close(): Unit = { + // Do nothing. + } + } + } + + // The value of the input KV Iterator has the format of groupingExprs + aggregation buffer. + // We need to project the aggregation buffer out. + private def projectInputBufferToUnsafe( + firstKey: UnsafeRow, + firstValue: InternalRow): KVIterator[UnsafeRow, UnsafeRow] = { + new KVIterator[UnsafeRow, UnsafeRow] { + private[this] var isFirstRow = true + + private[this] var groupingKey: UnsafeRow = _ + + private[this] val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) + + private[this] val value: UnsafeRow = { + val genericMutableRow = new GenericMutableRow(bufferSchema.length) + UnsafeProjection.create(bufferSchema.map(_.dataType)).apply(genericMutableRow) + } + + private[this] val projectInputBuffer = { + newMutableProjection(bufferSchema, valueAttributes)().target(value) + } + + override def next(): Boolean = { + if (isFirstRow) { + isFirstRow = false + groupingKey = firstKey + projectInputBuffer(firstValue) + + true + } else if (inputKVIterator.next()) { + groupingKey = inputKVIterator.getKey() + projectInputBuffer(inputKVIterator.getValue()) + + true + } else { + false + } + } + + override def getKey(): UnsafeRow = { + groupingKey + } + + override def getValue(): UnsafeRow = { + value + } + + override def close(): Unit = { + // Do nothing. + } + } + } + + /** + * We need to fall back to sort based aggregation because we do not have enough memory + * for our in-memory hash map (i.e. `buffers`). + */ + private def switchToSortBasedAggregation( + currentGroupingKey: UnsafeRow, + currentRow: InternalRow): Unit = { + logInfo("falling back to sort based aggregation.") + + // Step 1: Get the ExternalSorter containing entries of the map. + val externalSorter = buffers.destructAndCreateExternalSorter() + + // Step 2: Free the memory used by the map. + buffers.free() + + // Step 3: If we have aggregate function with mode Partial or Complete, + // we need to process them to get aggregation buffer. + // So, later in the sort-based aggregation iterator, we can do merge. + // If aggregate functions are with mode Final and PartialMerge, + // we just need to project the aggregation buffer from the input. + val needsProcess = aggregationMode match { + case (Some(Partial), None) => true + case (None, Some(Complete)) => true + case (Some(Final), Some(Complete)) => true + case _ => false + } + + val processedIterator = if (needsProcess) { + processOriginalInput(currentGroupingKey, currentRow) + } else { + // The input value's format is groupingExprs + buffer. + // We need to project the buffer part out. + projectInputBufferToUnsafe(currentGroupingKey, currentRow) + } + + // Step 4: Redirect processedIterator to externalSorter. + while (processedIterator.next()) { + externalSorter.insertKV(processedIterator.getKey(), processedIterator.getValue()) + } + + // Step 5: Get the sorted iterator from the externalSorter. + val sortedKVIterator: KVIterator[UnsafeRow, UnsafeRow] = externalSorter.sortedIterator() + + // Step 6: We now create a SortBasedAggregationIterator based on sortedKVIterator. + // For a aggregate function with mode Partial, its mode in the SortBasedAggregationIterator + // will be PartialMerge. For a aggregate function with mode Complete, + // its mode in the SortBasedAggregationIterator will be Final. + val newNonCompleteAggregateExpressions = allAggregateExpressions.map { + case AggregateExpression2(func, Partial, isDistinct) => + AggregateExpression2(func, PartialMerge, isDistinct) + case AggregateExpression2(func, Complete, isDistinct) => + AggregateExpression2(func, Final, isDistinct) + case other => other + } + val newNonCompleteAggregateAttributes = + nonCompleteAggregateAttributes ++ completeAggregateAttributes + + val newValueAttributes = + allAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes) + + sortBasedAggregationIterator = SortBasedAggregationIterator.createFromKVIterator( + groupingKeyAttributes = groupingKeyAttributes, + valueAttributes = newValueAttributes, + inputKVIterator = sortedKVIterator.asInstanceOf[KVIterator[InternalRow, InternalRow]], + nonCompleteAggregateExpressions = newNonCompleteAggregateExpressions, + nonCompleteAggregateAttributes = newNonCompleteAggregateAttributes, + completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, + initialInputBufferOffset = 0, + resultExpressions = resultExpressions, + newMutableProjection = newMutableProjection, + outputsUnsafeRows = outputsUnsafeRows) + } + + /////////////////////////////////////////////////////////////////////////// + // Methods used to initialize this iterator. + /////////////////////////////////////////////////////////////////////////// + + /** Starts to read input rows and falls back to sort-based aggregation if necessary. */ + protected def initialize(): Unit = { + var hasNext = inputKVIterator.next() + while (!sortBased && hasNext) { + val groupingKey = inputKVIterator.getKey() + val currentRow = inputKVIterator.getValue() + val buffer = buffers.getAggregationBuffer(groupingKey) + if (buffer == null) { + // buffer == null means that we could not allocate more memory. + // Now, we need to spill the map and switch to sort-based aggregation. + switchToSortBasedAggregation(groupingKey, currentRow) + sortBased = true + } else { + processRow(buffer, currentRow) + hasNext = inputKVIterator.next() + } + } + } + + // This is the starting point of this iterator. + initialize() + + // Creates the iterator for the Hash Aggregation Map after we have populated + // contents of that map. + private[this] val aggregationBufferMapIterator = buffers.iterator() + + private[this] var _mapIteratorHasNext = false + + // Pre-load the first key-value pair from the map to make hasNext idempotent. + if (!sortBased) { + _mapIteratorHasNext = aggregationBufferMapIterator.next() + // If the map is empty, we just free it. + if (!_mapIteratorHasNext) { + buffers.free() + } + } + + /////////////////////////////////////////////////////////////////////////// + // Iterator's public methods + /////////////////////////////////////////////////////////////////////////// + + override final def hasNext: Boolean = { + (sortBased && sortBasedAggregationIterator.hasNext) || (!sortBased && _mapIteratorHasNext) + } + + + override final def next(): InternalRow = { + if (hasNext) { + if (sortBased) { + sortBasedAggregationIterator.next() + } else { + // We did not fall back to the sort-based aggregation. + val result = + generateOutput( + aggregationBufferMapIterator.getKey, + aggregationBufferMapIterator.getValue) + // Pre-load next key-value pair form aggregationBufferMapIterator. + _mapIteratorHasNext = aggregationBufferMapIterator.next() + + if (!_mapIteratorHasNext) { + val resultCopy = result.copy() + buffers.free() + resultCopy + } else { + result + } + } + } else { + // no more result + throw new NoSuchElementException + } + } +} + +object UnsafeHybridAggregationIterator { + // scalastyle:off + def createFromInputIterator( + groupingExprs: Seq[NamedExpression], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow], + outputsUnsafeRows: Boolean): UnsafeHybridAggregationIterator = { + new UnsafeHybridAggregationIterator( + groupingExprs.map(_.toAttribute), + inputAttributes, + AggregationIterator.unsafeKVIterator(groupingExprs, inputAttributes, inputIter), + nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + outputsUnsafeRows) + } + + def createFromKVIterator( + groupingKeyAttributes: Seq[Attribute], + valueAttributes: Seq[Attribute], + inputKVIterator: KVIterator[UnsafeRow, InternalRow], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + outputsUnsafeRows: Boolean): UnsafeHybridAggregationIterator = { + new UnsafeHybridAggregationIterator( + groupingKeyAttributes, + valueAttributes, + inputKVIterator, + nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + outputsUnsafeRows) + } + // scalastyle:on +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala deleted file mode 100644 index 98538c462bc89..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.aggregate - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} - -case class Aggregate2Sort( - requiredChildDistributionExpressions: Option[Seq[Expression]], - groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression2], - aggregateAttributes: Seq[Attribute], - resultExpressions: Seq[NamedExpression], - child: SparkPlan) - extends UnaryNode { - - override def canProcessUnsafeRows: Boolean = true - - override def references: AttributeSet = { - val referencesInResults = - AttributeSet(resultExpressions.flatMap(_.references)) -- AttributeSet(aggregateAttributes) - - AttributeSet( - groupingExpressions.flatMap(_.references) ++ - aggregateExpressions.flatMap(_.references) ++ - referencesInResults) - } - - override def requiredChildDistribution: List[Distribution] = { - requiredChildDistributionExpressions match { - case Some(exprs) if exprs.length == 0 => AllTuples :: Nil - case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil - case None => UnspecifiedDistribution :: Nil - } - } - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = { - // TODO: We should not sort the input rows if they are just in reversed order. - groupingExpressions.map(SortOrder(_, Ascending)) :: Nil - } - - override def outputOrdering: Seq[SortOrder] = { - // It is possible that the child.outputOrdering starts with the required - // ordering expressions (e.g. we require [a] as the sort expression and the - // child's outputOrdering is [a, b]). We can only guarantee the output rows - // are sorted by values of groupingExpressions. - groupingExpressions.map(SortOrder(_, Ascending)) - } - - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - child.execute().mapPartitions { iter => - if (aggregateExpressions.length == 0) { - new FinalSortAggregationIterator( - groupingExpressions, - Nil, - Nil, - resultExpressions, - newMutableProjection, - child.output, - iter) - } else { - val aggregationIterator: SortAggregationIterator = { - aggregateExpressions.map(_.mode).distinct.toList match { - case Partial :: Nil => - new PartialSortAggregationIterator( - groupingExpressions, - aggregateExpressions, - newMutableProjection, - child.output, - iter) - case PartialMerge :: Nil => - new PartialMergeSortAggregationIterator( - groupingExpressions, - aggregateExpressions, - newMutableProjection, - child.output, - iter) - case Final :: Nil => - new FinalSortAggregationIterator( - groupingExpressions, - aggregateExpressions, - aggregateAttributes, - resultExpressions, - newMutableProjection, - child.output, - iter) - case other => - sys.error( - s"Could not evaluate ${aggregateExpressions} because we do not support evaluate " + - s"modes $other in this operator.") - } - } - - aggregationIterator - } - } - } -} - -case class FinalAndCompleteAggregate2Sort( - previousGroupingExpressions: Seq[NamedExpression], - groupingExpressions: Seq[NamedExpression], - finalAggregateExpressions: Seq[AggregateExpression2], - finalAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - resultExpressions: Seq[NamedExpression], - child: SparkPlan) - extends UnaryNode { - override def references: AttributeSet = { - val referencesInResults = - AttributeSet(resultExpressions.flatMap(_.references)) -- - AttributeSet(finalAggregateExpressions) -- - AttributeSet(completeAggregateExpressions) - - AttributeSet( - groupingExpressions.flatMap(_.references) ++ - finalAggregateExpressions.flatMap(_.references) ++ - completeAggregateExpressions.flatMap(_.references) ++ - referencesInResults) - } - - override def requiredChildDistribution: List[Distribution] = { - if (groupingExpressions.isEmpty) { - AllTuples :: Nil - } else { - ClusteredDistribution(groupingExpressions) :: Nil - } - } - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - groupingExpressions.map(SortOrder(_, Ascending)) :: Nil - - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - child.execute().mapPartitions { iter => - - new FinalAndCompleteSortAggregationIterator( - previousGroupingExpressions.length, - groupingExpressions, - finalAggregateExpressions, - finalAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - resultExpressions, - newMutableProjection, - child.output, - iter) - } - } - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala deleted file mode 100644 index 2ca0cb82c1aab..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala +++ /dev/null @@ -1,664 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.aggregate - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.types.NullType - -import scala.collection.mutable.ArrayBuffer - -/** - * An iterator used to evaluate aggregate functions. It assumes that input rows - * are already grouped by values of `groupingExpressions`. - */ -private[sql] abstract class SortAggregationIterator( - groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression2], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow]) - extends Iterator[InternalRow] { - - /////////////////////////////////////////////////////////////////////////// - // Static fields for this iterator - /////////////////////////////////////////////////////////////////////////// - - protected val aggregateFunctions: Array[AggregateFunction2] = { - var mutableBufferOffset = 0 - var inputBufferOffset: Int = initialInputBufferOffset - val functions = new Array[AggregateFunction2](aggregateExpressions.length) - var i = 0 - while (i < aggregateExpressions.length) { - val func = aggregateExpressions(i).aggregateFunction - val funcWithBoundReferences = aggregateExpressions(i).mode match { - case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] => - // We need to create BoundReferences if the function is not an - // AlgebraicAggregate (it does not support code-gen) and the mode of - // this function is Partial or Complete because we will call eval of this - // function's children in the update method of this aggregate function. - // Those eval calls require BoundReferences to work. - BindReferences.bindReference(func, inputAttributes) - case _ => - // We only need to set inputBufferOffset for aggregate functions with mode - // PartialMerge and Final. - func.inputBufferOffset = inputBufferOffset - inputBufferOffset += func.bufferSchema.length - func - } - // Set mutableBufferOffset for this function. It is important that setting - // mutableBufferOffset happens after all potential bindReference operations - // because bindReference will create a new instance of the function. - funcWithBoundReferences.mutableBufferOffset = mutableBufferOffset - mutableBufferOffset += funcWithBoundReferences.bufferSchema.length - functions(i) = funcWithBoundReferences - i += 1 - } - functions - } - - // Positions of those non-algebraic aggregate functions in aggregateFunctions. - // For example, we have func1, func2, func3, func4 in aggregateFunctions, and - // func2 and func3 are non-algebraic aggregate functions. - // nonAlgebraicAggregateFunctionPositions will be [1, 2]. - protected val nonAlgebraicAggregateFunctionPositions: Array[Int] = { - val positions = new ArrayBuffer[Int]() - var i = 0 - while (i < aggregateFunctions.length) { - aggregateFunctions(i) match { - case agg: AlgebraicAggregate => - case _ => positions += i - } - i += 1 - } - positions.toArray - } - - // All non-algebraic aggregate functions. - protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] = - nonAlgebraicAggregateFunctionPositions.map(aggregateFunctions) - - // This is used to project expressions for the grouping expressions. - protected val groupGenerator = - newMutableProjection(groupingExpressions, inputAttributes)() - - // The underlying buffer shared by all aggregate functions. - protected val buffer: MutableRow = { - // The number of elements of the underlying buffer of this operator. - // All aggregate functions are sharing this underlying buffer and they find their - // buffer values through bufferOffset. - // var size = 0 - // var i = 0 - // while (i < aggregateFunctions.length) { - // size += aggregateFunctions(i).bufferSchema.length - // i += 1 - // } - new GenericMutableRow(aggregateFunctions.map(_.bufferSchema.length).sum) - } - - protected val joinedRow = new JoinedRow - - // This projection is used to initialize buffer values for all AlgebraicAggregates. - protected val algebraicInitialProjection = { - val initExpressions = aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.initialValues - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) - } - - newMutableProjection(initExpressions, Nil)().target(buffer) - } - - /////////////////////////////////////////////////////////////////////////// - // Mutable states - /////////////////////////////////////////////////////////////////////////// - - // The partition key of the current partition. - protected var currentGroupingKey: InternalRow = _ - // The partition key of next partition. - protected var nextGroupingKey: InternalRow = _ - // The first row of next partition. - protected var firstRowInNextGroup: InternalRow = _ - // Indicates if we has new group of rows to process. - protected var hasNewGroup: Boolean = true - - /** Initializes buffer values for all aggregate functions. */ - protected def initializeBuffer(): Unit = { - algebraicInitialProjection(EmptyRow) - var i = 0 - while (i < nonAlgebraicAggregateFunctions.length) { - nonAlgebraicAggregateFunctions(i).initialize(buffer) - i += 1 - } - } - - protected def initialize(): Unit = { - if (inputIter.hasNext) { - initializeBuffer() - val currentRow = inputIter.next().copy() - // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey, - // we are making a copy at here. - nextGroupingKey = groupGenerator(currentRow).copy() - firstRowInNextGroup = currentRow - } else { - // This iter is an empty one. - hasNewGroup = false - } - } - - /////////////////////////////////////////////////////////////////////////// - // Private methods - /////////////////////////////////////////////////////////////////////////// - - /** Processes rows in the current group. It will stop when it find a new group. */ - private def processCurrentGroup(): Unit = { - currentGroupingKey = nextGroupingKey - // Now, we will start to find all rows belonging to this group. - // We create a variable to track if we see the next group. - var findNextPartition = false - // firstRowInNextGroup is the first row of this group. We first process it. - processRow(firstRowInNextGroup) - // The search will stop when we see the next group or there is no - // input row left in the iter. - while (inputIter.hasNext && !findNextPartition) { - val currentRow = inputIter.next() - // Get the grouping key based on the grouping expressions. - // For the below compare method, we do not need to make a copy of groupingKey. - val groupingKey = groupGenerator(currentRow) - // Check if the current row belongs the current input row. - if (currentGroupingKey == groupingKey) { - processRow(currentRow) - } else { - // We find a new group. - findNextPartition = true - nextGroupingKey = groupingKey.copy() - firstRowInNextGroup = currentRow.copy() - } - } - // We have not seen a new group. It means that there is no new row in the input - // iter. The current group is the last group of the iter. - if (!findNextPartition) { - hasNewGroup = false - } - } - - /////////////////////////////////////////////////////////////////////////// - // Public methods - /////////////////////////////////////////////////////////////////////////// - - override final def hasNext: Boolean = hasNewGroup - - override final def next(): InternalRow = { - if (hasNext) { - // Process the current group. - processCurrentGroup() - // Generate output row for the current group. - val outputRow = generateOutput() - // Initilize buffer values for the next group. - initializeBuffer() - - outputRow - } else { - // no more result - throw new NoSuchElementException - } - } - - /////////////////////////////////////////////////////////////////////////// - // Methods that need to be implemented - /////////////////////////////////////////////////////////////////////////// - - /** The initial input buffer offset for `inputBufferOffset` of an [[AggregateFunction2]]. */ - protected def initialInputBufferOffset: Int - - /** The function used to process an input row. */ - protected def processRow(row: InternalRow): Unit - - /** The function used to generate the result row. */ - protected def generateOutput(): InternalRow - - /////////////////////////////////////////////////////////////////////////// - // Initialize this iterator - /////////////////////////////////////////////////////////////////////////// - - initialize() -} - -/** - * An iterator used to do partial aggregations (for those aggregate functions with mode Partial). - * It assumes that input rows are already grouped by values of `groupingExpressions`. - * The format of its output rows is: - * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| - */ -class PartialSortAggregationIterator( - groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression2], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow]) - extends SortAggregationIterator( - groupingExpressions, - aggregateExpressions, - newMutableProjection, - inputAttributes, - inputIter) { - - // This projection is used to update buffer values for all AlgebraicAggregates. - private val algebraicUpdateProjection = { - val bufferSchema = aggregateFunctions.flatMap(_.bufferAttributes) - val updateExpressions = aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) - } - newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer) - } - - override protected def initialInputBufferOffset: Int = 0 - - override protected def processRow(row: InternalRow): Unit = { - // Process all algebraic aggregate functions. - algebraicUpdateProjection(joinedRow(buffer, row)) - // Process all non-algebraic aggregate functions. - var i = 0 - while (i < nonAlgebraicAggregateFunctions.length) { - nonAlgebraicAggregateFunctions(i).update(buffer, row) - i += 1 - } - } - - override protected def generateOutput(): InternalRow = { - // We just output the grouping expressions and the underlying buffer. - joinedRow(currentGroupingKey, buffer).copy() - } -} - -/** - * An iterator used to do partial merge aggregations (for those aggregate functions with mode - * PartialMerge). It assumes that input rows are already grouped by values of - * `groupingExpressions`. - * The format of its input rows is: - * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| - * - * The format of its internal buffer is: - * |aggregationBuffer1|...|aggregationBufferN| - * - * The format of its output rows is: - * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| - */ -class PartialMergeSortAggregationIterator( - groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression2], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow]) - extends SortAggregationIterator( - groupingExpressions, - aggregateExpressions, - newMutableProjection, - inputAttributes, - inputIter) { - - // This projection is used to merge buffer values for all AlgebraicAggregates. - private val algebraicMergeProjection = { - val mergeInputSchema = - aggregateFunctions.flatMap(_.bufferAttributes) ++ - groupingExpressions.map(_.toAttribute) ++ - aggregateFunctions.flatMap(_.cloneBufferAttributes) - val mergeExpressions = aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) - } - - newMutableProjection(mergeExpressions, mergeInputSchema)() - } - - override protected def initialInputBufferOffset: Int = groupingExpressions.length - - override protected def processRow(row: InternalRow): Unit = { - // Process all algebraic aggregate functions. - algebraicMergeProjection.target(buffer)(joinedRow(buffer, row)) - // Process all non-algebraic aggregate functions. - var i = 0 - while (i < nonAlgebraicAggregateFunctions.length) { - nonAlgebraicAggregateFunctions(i).merge(buffer, row) - i += 1 - } - } - - override protected def generateOutput(): InternalRow = { - // We output grouping expressions and aggregation buffers. - joinedRow(currentGroupingKey, buffer).copy() - } -} - -/** - * An iterator used to do final aggregations (for those aggregate functions with mode - * Final). It assumes that input rows are already grouped by values of - * `groupingExpressions`. - * The format of its input rows is: - * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| - * - * The format of its internal buffer is: - * |aggregationBuffer1|...|aggregationBufferN| - * - * The format of its output rows is represented by the schema of `resultExpressions`. - */ -class FinalSortAggregationIterator( - groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression2], - aggregateAttributes: Seq[Attribute], - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow]) - extends SortAggregationIterator( - groupingExpressions, - aggregateExpressions, - newMutableProjection, - inputAttributes, - inputIter) { - - // The result of aggregate functions. - private val aggregateResult: MutableRow = new GenericMutableRow(aggregateAttributes.length) - - // The projection used to generate the output rows of this operator. - // This is only used when we are generating final results of aggregate functions. - private val resultProjection = - newMutableProjection( - resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)() - - // This projection is used to merge buffer values for all AlgebraicAggregates. - private val algebraicMergeProjection = { - val mergeInputSchema = - aggregateFunctions.flatMap(_.bufferAttributes) ++ - groupingExpressions.map(_.toAttribute) ++ - aggregateFunctions.flatMap(_.cloneBufferAttributes) - val mergeExpressions = aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) - } - - newMutableProjection(mergeExpressions, mergeInputSchema)() - } - - // This projection is used to evaluate all AlgebraicAggregates. - private val algebraicEvalProjection = { - val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes) - val evalExpressions = aggregateFunctions.map { - case ae: AlgebraicAggregate => ae.evaluateExpression - case agg: AggregateFunction2 => NoOp - } - - newMutableProjection(evalExpressions, bufferSchemata)() - } - - override protected def initialInputBufferOffset: Int = groupingExpressions.length - - override def initialize(): Unit = { - if (inputIter.hasNext) { - initializeBuffer() - val currentRow = inputIter.next().copy() - // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey, - // we are making a copy at here. - nextGroupingKey = groupGenerator(currentRow).copy() - firstRowInNextGroup = currentRow - } else { - if (groupingExpressions.isEmpty) { - // If there is no grouping expression, we need to generate a single row as the output. - initializeBuffer() - // Right now, the buffer only contains initial buffer values. Because - // merging two buffers with initial values will generate a row that - // still store initial values. We set the currentRow as the copy of the current buffer. - // Because input aggregation buffer has initialInputBufferOffset extra values at the - // beginning, we create a dummy row for this part. - val currentRow = - joinedRow(new GenericInternalRow(initialInputBufferOffset), buffer).copy() - nextGroupingKey = groupGenerator(currentRow).copy() - firstRowInNextGroup = currentRow - } else { - // This iter is an empty one. - hasNewGroup = false - } - } - } - - override protected def processRow(row: InternalRow): Unit = { - // Process all algebraic aggregate functions. - algebraicMergeProjection.target(buffer)(joinedRow(buffer, row)) - // Process all non-algebraic aggregate functions. - var i = 0 - while (i < nonAlgebraicAggregateFunctions.length) { - nonAlgebraicAggregateFunctions(i).merge(buffer, row) - i += 1 - } - } - - override protected def generateOutput(): InternalRow = { - // Generate results for all algebraic aggregate functions. - algebraicEvalProjection.target(aggregateResult)(buffer) - // Generate results for all non-algebraic aggregate functions. - var i = 0 - while (i < nonAlgebraicAggregateFunctions.length) { - aggregateResult.update( - nonAlgebraicAggregateFunctionPositions(i), - nonAlgebraicAggregateFunctions(i).eval(buffer)) - i += 1 - } - resultProjection(joinedRow(currentGroupingKey, aggregateResult)) - } -} - -/** - * An iterator used to do both final aggregations (for those aggregate functions with mode - * Final) and complete aggregations (for those aggregate functions with mode Complete). - * It assumes that input rows are already grouped by values of `groupingExpressions`. - * The format of its input rows is: - * |groupingExpr1|...|groupingExprN|col1|...|colM|aggregationBuffer1|...|aggregationBufferN| - * col1 to colM are columns used by aggregate functions with Complete mode. - * aggregationBuffer1 to aggregationBufferN are buffers used by aggregate functions with - * Final mode. - * - * The format of its internal buffer is: - * |aggregationBuffer1|...|aggregationBuffer(N+M)| - * For aggregation buffers, first N aggregation buffers are used by N aggregate functions with - * mode Final. Then, the last M aggregation buffers are used by M aggregate functions with mode - * Complete. - * - * The format of its output rows is represented by the schema of `resultExpressions`. - */ -class FinalAndCompleteSortAggregationIterator( - override protected val initialInputBufferOffset: Int, - groupingExpressions: Seq[NamedExpression], - finalAggregateExpressions: Seq[AggregateExpression2], - finalAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow]) - extends SortAggregationIterator( - groupingExpressions, - // TODO: document the ordering - finalAggregateExpressions ++ completeAggregateExpressions, - newMutableProjection, - inputAttributes, - inputIter) { - - // The result of aggregate functions. - private val aggregateResult: MutableRow = - new GenericMutableRow(completeAggregateAttributes.length + finalAggregateAttributes.length) - - // The projection used to generate the output rows of this operator. - // This is only used when we are generating final results of aggregate functions. - private val resultProjection = { - val inputSchema = - groupingExpressions.map(_.toAttribute) ++ - finalAggregateAttributes ++ - completeAggregateAttributes - newMutableProjection(resultExpressions, inputSchema)() - } - - // All aggregate functions with mode Final. - private val finalAggregateFunctions: Array[AggregateFunction2] = { - val functions = new Array[AggregateFunction2](finalAggregateExpressions.length) - var i = 0 - while (i < finalAggregateExpressions.length) { - functions(i) = aggregateFunctions(i) - i += 1 - } - functions - } - - // All non-algebraic aggregate functions with mode Final. - private val finalNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = - finalAggregateFunctions.collect { - case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func - } - - // All aggregate functions with mode Complete. - private val completeAggregateFunctions: Array[AggregateFunction2] = { - val functions = new Array[AggregateFunction2](completeAggregateExpressions.length) - var i = 0 - while (i < completeAggregateExpressions.length) { - functions(i) = aggregateFunctions(finalAggregateFunctions.length + i) - i += 1 - } - functions - } - - // All non-algebraic aggregate functions with mode Complete. - private val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = - completeAggregateFunctions.collect { - case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func - } - - // This projection is used to merge buffer values for all AlgebraicAggregates with mode - // Final. - private val finalAlgebraicMergeProjection = { - // The first initialInputBufferOffset values of the input aggregation buffer is - // for grouping expressions and distinct columns. - val groupingAttributesAndDistinctColumns = inputAttributes.take(initialInputBufferOffset) - - val completeOffsetExpressions = - Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) - - val mergeInputSchema = - finalAggregateFunctions.flatMap(_.bufferAttributes) ++ - completeAggregateFunctions.flatMap(_.bufferAttributes) ++ - groupingAttributesAndDistinctColumns ++ - finalAggregateFunctions.flatMap(_.cloneBufferAttributes) - val mergeExpressions = - finalAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) - } ++ completeOffsetExpressions - newMutableProjection(mergeExpressions, mergeInputSchema)() - } - - // This projection is used to update buffer values for all AlgebraicAggregates with mode - // Complete. - private val completeAlgebraicUpdateProjection = { - // We do not touch buffer values of aggregate functions with the Final mode. - val finalOffsetExpressions = - Seq.fill(finalAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) - - val bufferSchema = - finalAggregateFunctions.flatMap(_.bufferAttributes) ++ - completeAggregateFunctions.flatMap(_.bufferAttributes) - val updateExpressions = - finalOffsetExpressions ++ completeAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) - } - newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer) - } - - // This projection is used to evaluate all AlgebraicAggregates. - private val algebraicEvalProjection = { - val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes) - val evalExpressions = aggregateFunctions.map { - case ae: AlgebraicAggregate => ae.evaluateExpression - case agg: AggregateFunction2 => NoOp - } - - newMutableProjection(evalExpressions, bufferSchemata)() - } - - override def initialize(): Unit = { - if (inputIter.hasNext) { - initializeBuffer() - val currentRow = inputIter.next().copy() - // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey, - // we are making a copy at here. - nextGroupingKey = groupGenerator(currentRow).copy() - firstRowInNextGroup = currentRow - } else { - if (groupingExpressions.isEmpty) { - // If there is no grouping expression, we need to generate a single row as the output. - initializeBuffer() - // Right now, the buffer only contains initial buffer values. Because - // merging two buffers with initial values will generate a row that - // still store initial values. We set the currentRow as the copy of the current buffer. - // Because input aggregation buffer has initialInputBufferOffset extra values at the - // beginning, we create a dummy row for this part. - val currentRow = - joinedRow(new GenericInternalRow(initialInputBufferOffset), buffer).copy() - nextGroupingKey = groupGenerator(currentRow).copy() - firstRowInNextGroup = currentRow - } else { - // This iter is an empty one. - hasNewGroup = false - } - } - } - - override protected def processRow(row: InternalRow): Unit = { - val input = joinedRow(buffer, row) - // For all aggregate functions with mode Complete, update buffers. - completeAlgebraicUpdateProjection(input) - var i = 0 - while (i < completeNonAlgebraicAggregateFunctions.length) { - completeNonAlgebraicAggregateFunctions(i).update(buffer, row) - i += 1 - } - - // For all aggregate functions with mode Final, merge buffers. - finalAlgebraicMergeProjection.target(buffer)(input) - i = 0 - while (i < finalNonAlgebraicAggregateFunctions.length) { - finalNonAlgebraicAggregateFunctions(i).merge(buffer, row) - i += 1 - } - } - - override protected def generateOutput(): InternalRow = { - // Generate results for all algebraic aggregate functions. - algebraicEvalProjection.target(aggregateResult)(buffer) - // Generate results for all non-algebraic aggregate functions. - var i = 0 - while (i < nonAlgebraicAggregateFunctions.length) { - aggregateResult.update( - nonAlgebraicAggregateFunctionPositions(i), - nonAlgebraicAggregateFunctions(i).eval(buffer)) - i += 1 - } - - resultProjection(joinedRow(currentGroupingKey, aggregateResult)) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index cc54319171bdb..5fafc916bfa0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -24,7 +24,154 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjecti import org.apache.spark.sql.catalyst.expressions.{MutableRow, InterpretedMutableProjection, AttributeReference, Expression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2 import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} -import org.apache.spark.sql.types.{Metadata, StructField, StructType, DataType} +import org.apache.spark.sql.types._ + +/** + * A helper trait used to create specialized setter and getter for types supported by + * [[org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap]]'s buffer. + * (see UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema). + */ +sealed trait BufferSetterGetterUtils { + + def createGetters(schema: StructType): Array[(InternalRow, Int) => Any] = { + val dataTypes = schema.fields.map(_.dataType) + val getters = new Array[(InternalRow, Int) => Any](dataTypes.length) + + var i = 0 + while (i < getters.length) { + getters(i) = dataTypes(i) match { + case BooleanType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getBoolean(ordinal) + + case ByteType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getByte(ordinal) + + case ShortType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getShort(ordinal) + + case IntegerType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getInt(ordinal) + + case LongType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getLong(ordinal) + + case FloatType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getFloat(ordinal) + + case DoubleType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getDouble(ordinal) + + case dt: DecimalType => + val precision = dt.precision + val scale = dt.scale + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getDecimal(ordinal, precision, scale) + + case other => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.get(ordinal, other) + } + + i += 1 + } + + getters + } + + def createSetters(schema: StructType): Array[((MutableRow, Int, Any) => Unit)] = { + val dataTypes = schema.fields.map(_.dataType) + val setters = new Array[(MutableRow, Int, Any) => Unit](dataTypes.length) + + var i = 0 + while (i < setters.length) { + setters(i) = dataTypes(i) match { + case b: BooleanType => + (row: MutableRow, ordinal: Int, value: Any) => + if (value != null) { + row.setBoolean(ordinal, value.asInstanceOf[Boolean]) + } else { + row.setNullAt(ordinal) + } + + case ByteType => + (row: MutableRow, ordinal: Int, value: Any) => + if (value != null) { + row.setByte(ordinal, value.asInstanceOf[Byte]) + } else { + row.setNullAt(ordinal) + } + + case ShortType => + (row: MutableRow, ordinal: Int, value: Any) => + if (value != null) { + row.setShort(ordinal, value.asInstanceOf[Short]) + } else { + row.setNullAt(ordinal) + } + + case IntegerType => + (row: MutableRow, ordinal: Int, value: Any) => + if (value != null) { + row.setInt(ordinal, value.asInstanceOf[Int]) + } else { + row.setNullAt(ordinal) + } + + case LongType => + (row: MutableRow, ordinal: Int, value: Any) => + if (value != null) { + row.setLong(ordinal, value.asInstanceOf[Long]) + } else { + row.setNullAt(ordinal) + } + + case FloatType => + (row: MutableRow, ordinal: Int, value: Any) => + if (value != null) { + row.setFloat(ordinal, value.asInstanceOf[Float]) + } else { + row.setNullAt(ordinal) + } + + case DoubleType => + (row: MutableRow, ordinal: Int, value: Any) => + if (value != null) { + row.setDouble(ordinal, value.asInstanceOf[Double]) + } else { + row.setNullAt(ordinal) + } + + case dt: DecimalType => + val precision = dt.precision + (row: MutableRow, ordinal: Int, value: Any) => + if (value != null) { + row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision) + } else { + row.setNullAt(ordinal) + } + + case other => + (row: MutableRow, ordinal: Int, value: Any) => + if (value != null) { + row.update(ordinal, value) + } else { + row.setNullAt(ordinal) + } + } + + i += 1 + } + + setters + } +} /** * A Mutable [[Row]] representing an mutable aggregation buffer. @@ -35,7 +182,7 @@ private[sql] class MutableAggregationBufferImpl ( toScalaConverters: Array[Any => Any], bufferOffset: Int, var underlyingBuffer: MutableRow) - extends MutableAggregationBuffer { + extends MutableAggregationBuffer with BufferSetterGetterUtils { private[this] val offsets: Array[Int] = { val newOffsets = new Array[Int](length) @@ -47,6 +194,10 @@ private[sql] class MutableAggregationBufferImpl ( newOffsets } + private[this] val bufferValueGetters = createGetters(schema) + + private[this] val bufferValueSetters = createSetters(schema) + override def length: Int = toCatalystConverters.length override def get(i: Int): Any = { @@ -54,7 +205,7 @@ private[sql] class MutableAggregationBufferImpl ( throw new IllegalArgumentException( s"Could not access ${i}th value in this buffer because it only has $length values.") } - toScalaConverters(i)(underlyingBuffer.get(offsets(i), schema(i).dataType)) + toScalaConverters(i)(bufferValueGetters(i)(underlyingBuffer, offsets(i))) } def update(i: Int, value: Any): Unit = { @@ -62,7 +213,15 @@ private[sql] class MutableAggregationBufferImpl ( throw new IllegalArgumentException( s"Could not update ${i}th value in this buffer because it only has $length values.") } - underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value)) + + bufferValueSetters(i)(underlyingBuffer, offsets(i), toCatalystConverters(i)(value)) + } + + // Because get method call specialized getter based on the schema, we cannot use the + // default implementation of the isNullAt (which is get(i) == null). + // We have to override it to call isNullAt of the underlyingBuffer. + override def isNullAt(i: Int): Boolean = { + underlyingBuffer.isNullAt(offsets(i)) } override def copy(): MutableAggregationBufferImpl = { @@ -84,7 +243,7 @@ private[sql] class InputAggregationBuffer private[sql] ( toScalaConverters: Array[Any => Any], bufferOffset: Int, var underlyingInputBuffer: InternalRow) - extends Row { + extends Row with BufferSetterGetterUtils { private[this] val offsets: Array[Int] = { val newOffsets = new Array[Int](length) @@ -96,6 +255,10 @@ private[sql] class InputAggregationBuffer private[sql] ( newOffsets } + private[this] val bufferValueGetters = createGetters(schema) + + def getBufferOffset: Int = bufferOffset + override def length: Int = toCatalystConverters.length override def get(i: Int): Any = { @@ -103,8 +266,14 @@ private[sql] class InputAggregationBuffer private[sql] ( throw new IllegalArgumentException( s"Could not access ${i}th value in this buffer because it only has $length values.") } - // TODO: Use buffer schema to avoid using generic getter. - toScalaConverters(i)(underlyingInputBuffer.get(offsets(i), schema(i).dataType)) + toScalaConverters(i)(bufferValueGetters(i)(underlyingInputBuffer, offsets(i))) + } + + // Because get method call specialized getter based on the schema, we cannot use the + // default implementation of the isNullAt (which is get(i) == null). + // We have to override it to call isNullAt of the underlyingInputBuffer. + override def isNullAt(i: Int): Boolean = { + underlyingInputBuffer.isNullAt(offsets(i)) } override def copy(): InputAggregationBuffer = { @@ -147,7 +316,7 @@ private[sql] case class ScalaUDAF( override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance()) - val childrenSchema: StructType = { + private[this] val childrenSchema: StructType = { val inputFields = children.zipWithIndex.map { case (child, index) => StructField(s"input$index", child.dataType, child.nullable, Metadata.empty) @@ -155,7 +324,7 @@ private[sql] case class ScalaUDAF( StructType(inputFields) } - lazy val inputProjection = { + private lazy val inputProjection = { val inputAttributes = childrenSchema.toAttributes log.debug( s"Creating MutableProj: $children, inputSchema: $inputAttributes.") @@ -168,40 +337,68 @@ private[sql] case class ScalaUDAF( } } - val inputToScalaConverters: Any => Any = + private[this] val inputToScalaConverters: Any => Any = CatalystTypeConverters.createToScalaConverter(childrenSchema) - val bufferValuesToCatalystConverters: Array[Any => Any] = bufferSchema.fields.map { field => - CatalystTypeConverters.createToCatalystConverter(field.dataType) + private[this] val bufferValuesToCatalystConverters: Array[Any => Any] = { + bufferSchema.fields.map { field => + CatalystTypeConverters.createToCatalystConverter(field.dataType) + } } - val bufferValuesToScalaConverters: Array[Any => Any] = bufferSchema.fields.map { field => - CatalystTypeConverters.createToScalaConverter(field.dataType) + private[this] val bufferValuesToScalaConverters: Array[Any => Any] = { + bufferSchema.fields.map { field => + CatalystTypeConverters.createToScalaConverter(field.dataType) + } } - lazy val inputAggregateBuffer: InputAggregationBuffer = - new InputAggregationBuffer( - bufferSchema, - bufferValuesToCatalystConverters, - bufferValuesToScalaConverters, - inputBufferOffset, - null) - - lazy val mutableAggregateBuffer: MutableAggregationBufferImpl = - new MutableAggregationBufferImpl( - bufferSchema, - bufferValuesToCatalystConverters, - bufferValuesToScalaConverters, - mutableBufferOffset, - null) + // This buffer is only used at executor side. + private[this] var inputAggregateBuffer: InputAggregationBuffer = null + + // This buffer is only used at executor side. + private[this] var mutableAggregateBuffer: MutableAggregationBufferImpl = null + + // This buffer is only used at executor side. + private[this] var evalAggregateBuffer: InputAggregationBuffer = null + + /** + * Sets the inputBufferOffset to newInputBufferOffset and then create a new instance of + * `inputAggregateBuffer` based on this new inputBufferOffset. + */ + override def withNewInputBufferOffset(newInputBufferOffset: Int): Unit = { + super.withNewInputBufferOffset(newInputBufferOffset) + // inputBufferOffset has been updated. + inputAggregateBuffer = + new InputAggregationBuffer( + bufferSchema, + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + inputBufferOffset, + null) + } - lazy val evalAggregateBuffer: InputAggregationBuffer = - new InputAggregationBuffer( - bufferSchema, - bufferValuesToCatalystConverters, - bufferValuesToScalaConverters, - mutableBufferOffset, - null) + /** + * Sets the mutableBufferOffset to newMutableBufferOffset and then create a new instance of + * `mutableAggregateBuffer` and `evalAggregateBuffer` based on this new mutableBufferOffset. + */ + override def withNewMutableBufferOffset(newMutableBufferOffset: Int): Unit = { + super.withNewMutableBufferOffset(newMutableBufferOffset) + // mutableBufferOffset has been updated. + mutableAggregateBuffer = + new MutableAggregationBufferImpl( + bufferSchema, + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + mutableBufferOffset, + null) + evalAggregateBuffer = + new InputAggregationBuffer( + bufferSchema, + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + mutableBufferOffset, + null) + } override def initialize(buffer: MutableRow): Unit = { mutableAggregateBuffer.underlyingBuffer = buffer diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 03635baae4a5f..960be08f84d94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -17,13 +17,9 @@ package org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.types.{StructType, MapType, ArrayType} /** * Utility functions used by the query planner to convert our plan to new aggregation code path. @@ -52,13 +48,16 @@ object Utils { agg.aggregateFunction.bufferAttributes } val partialAggregate = - Aggregate2Sort( - None: Option[Seq[Expression]], - namedGroupingExpressions.map(_._2), - partialAggregateExpressions, - partialAggregateAttributes, - namedGroupingAttributes ++ partialAggregateAttributes, - child) + Aggregate( + requiredChildDistributionExpressions = None: Option[Seq[Expression]], + groupingExpressions = namedGroupingExpressions.map(_._2), + nonCompleteAggregateExpressions = partialAggregateExpressions, + nonCompleteAggregateAttributes = partialAggregateAttributes, + completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, + initialInputBufferOffset = 0, + resultExpressions = namedGroupingAttributes ++ partialAggregateAttributes, + child = child) // 2. Create an Aggregate Operator for final aggregations. val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) @@ -78,13 +77,17 @@ object Utils { }.getOrElse(expression) }.asInstanceOf[NamedExpression] } - val finalAggregate = Aggregate2Sort( - Some(namedGroupingAttributes), - namedGroupingAttributes, - finalAggregateExpressions, - finalAggregateAttributes, - rewrittenResultExpressions, - partialAggregate) + val finalAggregate = + Aggregate( + requiredChildDistributionExpressions = Some(namedGroupingAttributes), + groupingExpressions = namedGroupingAttributes, + nonCompleteAggregateExpressions = finalAggregateExpressions, + nonCompleteAggregateAttributes = finalAggregateAttributes, + completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, + initialInputBufferOffset = namedGroupingAttributes.length, + resultExpressions = rewrittenResultExpressions, + child = partialAggregate) finalAggregate :: Nil } @@ -133,14 +136,21 @@ object Utils { val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => agg.aggregateFunction.bufferAttributes } + val partialAggregateGroupingExpressions = + (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2) + val partialAggregateResult = + namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes val partialAggregate = - Aggregate2Sort( - None: Option[Seq[Expression]], - (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2), - partialAggregateExpressions, - partialAggregateAttributes, - namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes, - child) + Aggregate( + requiredChildDistributionExpressions = None: Option[Seq[Expression]], + groupingExpressions = partialAggregateGroupingExpressions, + nonCompleteAggregateExpressions = partialAggregateExpressions, + nonCompleteAggregateAttributes = partialAggregateAttributes, + completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, + initialInputBufferOffset = 0, + resultExpressions = partialAggregateResult, + child = child) // 2. Create an Aggregate Operator for partial merge aggregations. val partialMergeAggregateExpressions = functionsWithoutDistinct.map { @@ -151,14 +161,19 @@ object Utils { partialMergeAggregateExpressions.flatMap { agg => agg.aggregateFunction.bufferAttributes } + val partialMergeAggregateResult = + namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes val partialMergeAggregate = - Aggregate2Sort( - Some(namedGroupingAttributes), - namedGroupingAttributes ++ distinctColumnAttributes, - partialMergeAggregateExpressions, - partialMergeAggregateAttributes, - namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes, - partialAggregate) + Aggregate( + requiredChildDistributionExpressions = Some(namedGroupingAttributes), + groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes, + nonCompleteAggregateExpressions = partialMergeAggregateExpressions, + nonCompleteAggregateAttributes = partialMergeAggregateAttributes, + completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, + initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, + resultExpressions = partialMergeAggregateResult, + child = partialAggregate) // 3. Create an Aggregate Operator for partial merge aggregations. val finalAggregateExpressions = functionsWithoutDistinct.map { @@ -199,15 +214,17 @@ object Utils { }.getOrElse(expression) }.asInstanceOf[NamedExpression] } - val finalAndCompleteAggregate = FinalAndCompleteAggregate2Sort( - namedGroupingAttributes ++ distinctColumnAttributes, - namedGroupingAttributes, - finalAggregateExpressions, - finalAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - rewrittenResultExpressions, - partialMergeAggregate) + val finalAndCompleteAggregate = + Aggregate( + requiredChildDistributionExpressions = Some(namedGroupingAttributes), + groupingExpressions = namedGroupingAttributes, + nonCompleteAggregateExpressions = finalAggregateExpressions, + nonCompleteAggregateAttributes = finalAggregateAttributes, + completeAggregateExpressions = completeAggregateExpressions, + completeAggregateAttributes = completeAggregateAttributes, + initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, + resultExpressions = rewrittenResultExpressions, + child = partialMergeAggregate) finalAndCompleteAggregate :: Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 2294a670c735f..5a1b000e89875 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -220,7 +220,6 @@ case class TakeOrderedAndProject( override def outputOrdering: Seq[SortOrder] = sortOrder } - /** * :: DeveloperApi :: * Return a new RDD that has exactly `numPartitions` partitions. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 51fe9d9d98bf3..bbadc202a4f06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.scalatest.BeforeAndAfterAll - import java.sql.Timestamp +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.errors.DialectException -import org.apache.spark.sql.execution.aggregate.Aggregate2Sort +import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ @@ -273,7 +273,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { var hasGeneratedAgg = false df.queryExecution.executedPlan.foreach { case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true - case newAggregate: Aggregate2Sort => hasGeneratedAgg = true + case newAggregate: aggregate.Aggregate => hasGeneratedAgg = true case _ => } if (!hasGeneratedAgg) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index 54f82f89ed18a..7978ed57a937e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -138,7 +138,14 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll s"Expected $expectedSerializerClass as the serializer of Exchange. " + s"However, the serializer was not set." val serializer = dependency.serializer.getOrElse(fail(serializerNotSetMessage)) - assert(serializer.getClass === expectedSerializerClass) + val isExpectedSerializer = + serializer.getClass == expectedSerializerClass || + serializer.getClass == classOf[UnsafeRowSerializer] + val wrongSerializerErrorMessage = + s"Expected ${expectedSerializerClass.getCanonicalName} or " + + s"${classOf[UnsafeRowSerializer].getCanonicalName}. But " + + s"${serializer.getClass.getCanonicalName} is used." + assert(isExpectedSerializer, wrongSerializerErrorMessage) case _ => // Ignore other nodes. } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 0375eb79add95..6f0db27775e4d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -17,15 +17,15 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.execution.aggregate.Aggregate2Sort +import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.{SQLConf, AnalysisException, QueryTest, Row} import org.scalatest.BeforeAndAfterAll import test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} -class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { +abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { override val sqlContext = TestHive import sqlContext.implicits._ @@ -34,7 +34,7 @@ class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAf override def beforeAll(): Unit = { originalUseAggregate2 = sqlContext.conf.useSqlAggregate2 - sqlContext.sql("set spark.sql.useAggregate2=true") + sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, "true") val data1 = Seq[(Integer, Integer)]( (1, 10), (null, -60), @@ -81,7 +81,7 @@ class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAf sqlContext.sql("DROP TABLE IF EXISTS agg1") sqlContext.sql("DROP TABLE IF EXISTS agg2") sqlContext.dropTempTable("emptyTable") - sqlContext.sql(s"set spark.sql.useAggregate2=$originalUseAggregate2") + sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, originalUseAggregate2.toString) } test("empty table") { @@ -454,54 +454,86 @@ class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAf } test("error handling") { - sqlContext.sql(s"set spark.sql.useAggregate2=false") - var errorMessage = intercept[AnalysisException] { - sqlContext.sql( - """ - |SELECT - | key, - | sum(value + 1.5 * key), - | mydoublesum(value), - | mydoubleavg(value) - |FROM agg1 - |GROUP BY key - """.stripMargin).collect() - }.getMessage - assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) + withSQLConf("spark.sql.useAggregate2" -> "false") { + val errorMessage = intercept[AnalysisException] { + sqlContext.sql( + """ + |SELECT + | key, + | sum(value + 1.5 * key), + | mydoublesum(value), + | mydoubleavg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin).collect() + }.getMessage + assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) + } // TODO: once we support Hive UDAF in the new interface, // we can remove the following two tests. - sqlContext.sql(s"set spark.sql.useAggregate2=true") - errorMessage = intercept[AnalysisException] { - sqlContext.sql( + withSQLConf("spark.sql.useAggregate2" -> "true") { + val errorMessage = intercept[AnalysisException] { + sqlContext.sql( + """ + |SELECT + | key, + | mydoublesum(value + 1.5 * key), + | stddev_samp(value) + |FROM agg1 + |GROUP BY key + """.stripMargin).collect() + }.getMessage + assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) + + // This will fall back to the old aggregate + val newAggregateOperators = sqlContext.sql( """ |SELECT | key, - | mydoublesum(value + 1.5 * key), + | sum(value + 1.5 * key), | stddev_samp(value) |FROM agg1 |GROUP BY key - """.stripMargin).collect() - }.getMessage - assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) - - // This will fall back to the old aggregate - val newAggregateOperators = sqlContext.sql( - """ - |SELECT - | key, - | sum(value + 1.5 * key), - | stddev_samp(value) - |FROM agg1 - |GROUP BY key - """.stripMargin).queryExecution.executedPlan.collect { - case agg: Aggregate2Sort => agg + """.stripMargin).queryExecution.executedPlan.collect { + case agg: aggregate.Aggregate => agg + } + val message = + "We should fallback to the old aggregation code path if " + + "there is any aggregate function that cannot be converted to the new interface." + assert(newAggregateOperators.isEmpty, message) } - val message = - "We should fallback to the old aggregation code path if there is any aggregate function " + - "that cannot be converted to the new interface." - assert(newAggregateOperators.isEmpty, message) + } +} + +class SortBasedAggregationQuerySuite extends AggregationQuerySuite { - sqlContext.sql(s"set spark.sql.useAggregate2=true") + var originalUnsafeEnabled: Boolean = _ + + override def beforeAll(): Unit = { + originalUnsafeEnabled = sqlContext.conf.unsafeEnabled + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "false") + super.beforeAll() + } + + override def afterAll(): Unit = { + super.afterAll() + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) + } +} + +class TungstenAggregationQuerySuite extends AggregationQuerySuite { + + var originalUnsafeEnabled: Boolean = _ + + override def beforeAll(): Unit = { + originalUnsafeEnabled = sqlContext.conf.unsafeEnabled + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true") + super.beforeAll() + } + + override def afterAll(): Unit = { + super.afterAll() + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) } } From 95dccc63350c45045f038bab9f8a5080b4e1f8cc Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Mon, 3 Aug 2015 01:55:58 -0700 Subject: [PATCH 095/340] [SPARK-8873] [MESOS] Clean up shuffle files if external shuffle service is used This patch builds directly on #7820, which is largely written by tnachen. The only addition is one commit for cleaning up the code. There should be no functional differences between this and #7820. Author: Timothy Chen Author: Andrew Or Closes #7881 from andrewor14/tim-cleanup-mesos-shuffle and squashes the following commits: 8894f7d [Andrew Or] Clean up code 2a5fa10 [Andrew Or] Merge branch 'mesos_shuffle_clean' of github.com:tnachen/spark into tim-cleanup-mesos-shuffle fadff89 [Timothy Chen] Address comments. e4d0f1d [Timothy Chen] Clean up external shuffle data on driver exit with Mesos. --- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../spark/deploy/ExternalShuffleService.scala | 17 ++- .../mesos/MesosExternalShuffleService.scala | 107 ++++++++++++++++++ .../org/apache/spark/rpc/RpcEndpoint.scala | 6 +- .../mesos/CoarseMesosSchedulerBackend.scala | 52 ++++++++- .../CoarseMesosSchedulerBackendSuite.scala | 5 +- .../launcher/SparkClassCommandBuilder.java | 3 +- .../spark/network/client/TransportClient.java | 5 + .../shuffle/ExternalShuffleBlockHandler.java | 6 + .../shuffle/ExternalShuffleClient.java | 12 +- .../mesos/MesosExternalShuffleClient.java | 72 ++++++++++++ .../protocol/BlockTransferMessage.java | 4 +- .../protocol/mesos/RegisterDriver.java | 60 ++++++++++ sbin/start-mesos-shuffle-service.sh | 35 ++++++ sbin/stop-mesos-shuffle-service.sh | 25 ++++ 15 files changed, 394 insertions(+), 17 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java create mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java create mode 100755 sbin/start-mesos-shuffle-service.sh create mode 100755 sbin/stop-mesos-shuffle-service.sh diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a1c66ef4fc5ea..6f336a7c299ab 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2658,7 +2658,7 @@ object SparkContext extends Logging { val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", false) val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs val backend = if (coarseGrained) { - new CoarseMesosSchedulerBackend(scheduler, sc, url) + new CoarseMesosSchedulerBackend(scheduler, sc, url, sc.env.securityManager) } else { new MesosSchedulerBackend(scheduler, sc, url) } diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index 4089c3e771fa8..20a9faa1784b7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -27,6 +27,7 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.sasl.SaslServerBootstrap import org.apache.spark.network.server.TransportServer import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler +import org.apache.spark.network.util.TransportConf import org.apache.spark.util.Utils /** @@ -45,11 +46,16 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana private val useSasl: Boolean = securityManager.isAuthenticationEnabled() private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0) - private val blockHandler = new ExternalShuffleBlockHandler(transportConf) + private val blockHandler = newShuffleBlockHandler(transportConf) private val transportContext: TransportContext = new TransportContext(transportConf, blockHandler) private var server: TransportServer = _ + /** Create a new shuffle block handler. Factored out for subclasses to override. */ + protected def newShuffleBlockHandler(conf: TransportConf): ExternalShuffleBlockHandler = { + new ExternalShuffleBlockHandler(conf) + } + /** Starts the external shuffle service if the user has configured us to. */ def startIfEnabled() { if (enabled) { @@ -93,6 +99,13 @@ object ExternalShuffleService extends Logging { private val barrier = new CountDownLatch(1) def main(args: Array[String]): Unit = { + main(args, (conf: SparkConf, sm: SecurityManager) => new ExternalShuffleService(conf, sm)) + } + + /** A helper main method that allows the caller to call this with a custom shuffle service. */ + private[spark] def main( + args: Array[String], + newShuffleService: (SparkConf, SecurityManager) => ExternalShuffleService): Unit = { val sparkConf = new SparkConf Utils.loadDefaultSparkProperties(sparkConf) val securityManager = new SecurityManager(sparkConf) @@ -100,7 +113,7 @@ object ExternalShuffleService extends Logging { // we override this value since this service is started from the command line // and we assume the user really wants it to be running sparkConf.set("spark.shuffle.service.enabled", "true") - server = new ExternalShuffleService(sparkConf, securityManager) + server = newShuffleService(sparkConf, securityManager) server.start() installShutdownHook() diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala new file mode 100644 index 0000000000000..061857476a8a0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import java.net.SocketAddress + +import scala.collection.mutable + +import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.deploy.ExternalShuffleService +import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage +import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver +import org.apache.spark.network.util.TransportConf + +/** + * An RPC endpoint that receives registration requests from Spark drivers running on Mesos. + * It detects driver termination and calls the cleanup callback to [[ExternalShuffleService]]. + */ +private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportConf) + extends ExternalShuffleBlockHandler(transportConf) with Logging { + + // Stores a map of driver socket addresses to app ids + private val connectedApps = new mutable.HashMap[SocketAddress, String] + + protected override def handleMessage( + message: BlockTransferMessage, + client: TransportClient, + callback: RpcResponseCallback): Unit = { + message match { + case RegisterDriverParam(appId) => + val address = client.getSocketAddress + logDebug(s"Received registration request from app $appId (remote address $address).") + if (connectedApps.contains(address)) { + val existingAppId = connectedApps(address) + if (!existingAppId.equals(appId)) { + logError(s"A new app '$appId' has connected to existing address $address, " + + s"removing previously registered app '$existingAppId'.") + applicationRemoved(existingAppId, true) + } + } + connectedApps(address) = appId + callback.onSuccess(new Array[Byte](0)) + case _ => super.handleMessage(message, client, callback) + } + } + + /** + * On connection termination, clean up shuffle files written by the associated application. + */ + override def connectionTerminated(client: TransportClient): Unit = { + val address = client.getSocketAddress + if (connectedApps.contains(address)) { + val appId = connectedApps(address) + logInfo(s"Application $appId disconnected (address was $address).") + applicationRemoved(appId, true /* cleanupLocalDirs */) + connectedApps.remove(address) + } else { + logWarning(s"Unknown $address disconnected.") + } + } + + /** An extractor object for matching [[RegisterDriver]] message. */ + private object RegisterDriverParam { + def unapply(r: RegisterDriver): Option[String] = Some(r.getAppId) + } +} + +/** + * A wrapper of [[ExternalShuffleService]] that provides an additional endpoint for drivers + * to associate with. This allows the shuffle service to detect when a driver is terminated + * and can clean up the associated shuffle files. + */ +private[mesos] class MesosExternalShuffleService(conf: SparkConf, securityManager: SecurityManager) + extends ExternalShuffleService(conf, securityManager) { + + protected override def newShuffleBlockHandler( + conf: TransportConf): ExternalShuffleBlockHandler = { + new MesosExternalShuffleBlockHandler(conf) + } +} + +private[spark] object MesosExternalShuffleService extends Logging { + + def main(args: Array[String]): Unit = { + ExternalShuffleService.main(args, + (conf: SparkConf, sm: SecurityManager) => new MesosExternalShuffleService(conf, sm)) + } +} + + diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala index d2b2baef1d8c4..dfcbc51cdf616 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -47,11 +47,11 @@ private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint * * It is guaranteed that `onStart`, `receive` and `onStop` will be called in sequence. * - * The lift-cycle will be: + * The life-cycle of an endpoint is: * - * constructor onStart receive* onStop + * constructor -> onStart -> receive* -> onStop * - * Note: `receive` can be called concurrently. If you want `receive` is thread-safe, please use + * Note: `receive` can be called concurrently. If you want `receive` to be thread-safe, please use * [[ThreadSafeRpcEndpoint]] * * If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index b7fde0d9b3265..15a0915708c7c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -26,12 +26,15 @@ import scala.collection.mutable.{HashMap, HashSet} import com.google.common.collect.HashBiMap import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} -import org.apache.mesos.{Scheduler => MScheduler, _} +import org.apache.mesos.{Scheduler => MScheduler, SchedulerDriver} + +import org.apache.spark.{SecurityManager, SparkContext, SparkEnv, SparkException, TaskState} +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient import org.apache.spark.rpc.RpcAddress import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils -import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds @@ -46,7 +49,8 @@ import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} private[spark] class CoarseMesosSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, - master: String) + master: String, + securityManager: SecurityManager) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with MScheduler with MesosSchedulerUtils { @@ -56,12 +60,19 @@ private[spark] class CoarseMesosSchedulerBackend( // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt + // If shuffle service is enabled, the Spark driver will register with the shuffle service. + // This is for cleaning up shuffle files reliably. + private val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) + // Cores we have acquired with each Mesos task ID val coresByTaskId = new HashMap[Int, Int] var totalCoresAcquired = 0 val slaveIdsWithExecutors = new HashSet[String] + // Maping from slave Id to hostname + private val slaveIdToHost = new HashMap[String, String] + val taskIdToSlaveId: HashBiMap[Int, String] = HashBiMap.create[Int, String] // How many times tasks on each slave failed val failuresBySlaveId: HashMap[String, Int] = new HashMap[String, Int] @@ -90,6 +101,19 @@ private[spark] class CoarseMesosSchedulerBackend( private val slaveOfferConstraints = parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + // A client for talking to the external shuffle service, if it is a + private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = { + if (shuffleServiceEnabled) { + Some(new MesosExternalShuffleClient( + SparkTransportConf.fromSparkConf(conf), + securityManager, + securityManager.isAuthenticationEnabled(), + securityManager.isSaslEncryptionEnabled())) + } else { + None + } + } + var nextMesosTaskId = 0 @volatile var appId: String = _ @@ -188,6 +212,7 @@ private[spark] class CoarseMesosSchedulerBackend( override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { appId = frameworkId.getValue + mesosExternalShuffleClient.foreach(_.init(appId)) logInfo("Registered as framework ID " + appId) markRegistered() } @@ -244,6 +269,7 @@ private[spark] class CoarseMesosSchedulerBackend( // accept the offer and launch the task logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + slaveIdToHost(offer.getSlaveId.getValue) = offer.getHostname d.launchTasks( Collections.singleton(offer.getId), Collections.singleton(taskBuilder.build()), filters) @@ -261,7 +287,27 @@ private[spark] class CoarseMesosSchedulerBackend( val taskId = status.getTaskId.getValue.toInt val state = status.getState logInfo(s"Mesos task $taskId is now $state") + val slaveId: String = status.getSlaveId.getValue stateLock.synchronized { + // If the shuffle service is enabled, have the driver register with each one of the + // shuffle services. This allows the shuffle services to clean up state associated with + // this application when the driver exits. There is currently not a great way to detect + // this through Mesos, since the shuffle services are set up independently. + if (TaskState.fromMesos(state).equals(TaskState.RUNNING) && + slaveIdToHost.contains(slaveId) && + shuffleServiceEnabled) { + assume(mesosExternalShuffleClient.isDefined, + "External shuffle client was not instantiated even though shuffle service is enabled.") + // TODO: Remove this and allow the MesosExternalShuffleService to detect + // framework termination when new Mesos Framework HTTP API is available. + val externalShufflePort = conf.getInt("spark.shuffle.service.port", 7337) + val hostname = slaveIdToHost.remove(slaveId).get + logDebug(s"Connecting to shuffle service on slave $slaveId, " + + s"host $hostname, port $externalShufflePort for app ${conf.getAppId}") + mesosExternalShuffleClient.get + .registerDriverWithShuffleService(hostname, externalShufflePort) + } + if (TaskState.isFinished(TaskState.fromMesos(state))) { val slaveId = taskIdToSlaveId(taskId) slaveIdsWithExecutors -= slaveId diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala index 4b504df7b8851..525ee0d3bdc5a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -30,7 +30,7 @@ import org.scalatest.mock.MockitoSugar import org.scalatest.BeforeAndAfter import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SecurityManager, SparkFunSuite} class CoarseMesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext @@ -59,7 +59,8 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite private def createSchedulerBackend( taskScheduler: TaskSchedulerImpl, driver: SchedulerDriver): CoarseMesosSchedulerBackend = { - val backend = new CoarseMesosSchedulerBackend(taskScheduler, sc, "master") { + val securityManager = mock[SecurityManager] + val backend = new CoarseMesosSchedulerBackend(taskScheduler, sc, "master", securityManager) { override protected def createSchedulerDriver( masterUrl: String, scheduler: Scheduler, diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index de85720febf23..5f95e2c74f902 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -69,7 +69,8 @@ public List buildCommand(Map env) throws IOException { } else if (className.equals("org.apache.spark.executor.MesosExecutorBackend")) { javaOptsKeys.add("SPARK_EXECUTOR_OPTS"); memKey = "SPARK_EXECUTOR_MEMORY"; - } else if (className.equals("org.apache.spark.deploy.ExternalShuffleService")) { + } else if (className.equals("org.apache.spark.deploy.ExternalShuffleService") || + className.equals("org.apache.spark.deploy.mesos.MesosExternalShuffleService")) { javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); javaOptsKeys.add("SPARK_SHUFFLE_OPTS"); memKey = "SPARK_DAEMON_MEMORY"; diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 37f2e34ceb24d..e8e7f06247d3e 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -19,6 +19,7 @@ import java.io.Closeable; import java.io.IOException; +import java.net.SocketAddress; import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -79,6 +80,10 @@ public boolean isActive() { return channel.isOpen() || channel.isActive(); } + public SocketAddress getSocketAddress() { + return channel.remoteAddress(); + } + /** * Requests a single chunk from the remote side, from the pre-negotiated streamId. * diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index e4faaf8854fc7..db9dc4f17cee9 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -65,7 +65,13 @@ public ExternalShuffleBlockHandler(TransportConf conf) { @Override public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteArray(message); + handleMessage(msgObj, client, callback); + } + protected void handleMessage( + BlockTransferMessage msgObj, + TransportClient client, + RpcResponseCallback callback) { if (msgObj instanceof OpenBlocks) { OpenBlocks msg = (OpenBlocks) msgObj; List blocks = Lists.newArrayList(); diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 612bce571a493..ea6d248d66be3 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -50,8 +50,8 @@ public class ExternalShuffleClient extends ShuffleClient { private final boolean saslEncryptionEnabled; private final SecretKeyHolder secretKeyHolder; - private TransportClientFactory clientFactory; - private String appId; + protected TransportClientFactory clientFactory; + protected String appId; /** * Creates an external shuffle client, with SASL optionally enabled. If SASL is not enabled, @@ -71,6 +71,10 @@ public ExternalShuffleClient( this.saslEncryptionEnabled = saslEncryptionEnabled; } + protected void checkInit() { + assert appId != null : "Called before init()"; + } + @Override public void init(String appId) { this.appId = appId; @@ -89,7 +93,7 @@ public void fetchBlocks( final String execId, String[] blockIds, BlockFetchingListener listener) { - assert appId != null : "Called before init()"; + checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { RetryingBlockFetcher.BlockFetchStarter blockFetchStarter = @@ -132,7 +136,7 @@ public void registerWithShuffleServer( int port, String execId, ExecutorShuffleInfo executorInfo) throws IOException { - assert appId != null : "Called before init()"; + checkInit(); TransportClient client = clientFactory.createClient(host, port); byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray(); client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java new file mode 100644 index 0000000000000..7543b6be4f2a1 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.mesos; + +import java.io.IOException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.shuffle.ExternalShuffleClient; +import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver; +import org.apache.spark.network.util.TransportConf; + +/** + * A client for talking to the external shuffle service in Mesos coarse-grained mode. + * + * This is used by the Spark driver to register with each external shuffle service on the cluster. + * The reason why the driver has to talk to the service is for cleaning up shuffle files reliably + * after the application exits. Mesos does not provide a great alternative to do this, so Spark + * has to detect this itself. + */ +public class MesosExternalShuffleClient extends ExternalShuffleClient { + private final Logger logger = LoggerFactory.getLogger(MesosExternalShuffleClient.class); + + /** + * Creates an Mesos external shuffle client that wraps the {@link ExternalShuffleClient}. + * Please refer to docs on {@link ExternalShuffleClient} for more information. + */ + public MesosExternalShuffleClient( + TransportConf conf, + SecretKeyHolder secretKeyHolder, + boolean saslEnabled, + boolean saslEncryptionEnabled) { + super(conf, secretKeyHolder, saslEnabled, saslEncryptionEnabled); + } + + public void registerDriverWithShuffleService(String host, int port) throws IOException { + checkInit(); + byte[] registerDriver = new RegisterDriver(appId).toByteArray(); + TransportClient client = clientFactory.createClient(host, port); + client.sendRpc(registerDriver, new RpcResponseCallback() { + @Override + public void onSuccess(byte[] response) { + logger.info("Successfully registered app " + appId + " with external shuffle service."); + } + + @Override + public void onFailure(Throwable e) { + logger.warn("Unable to register app " + appId + " with external shuffle service. " + + "Please manually remove shuffle data after driver exit. Error: " + e); + } + }); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index 6c1210b33268a..fcb52363e632c 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -21,6 +21,7 @@ import io.netty.buffer.Unpooled; import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver; /** * Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or @@ -37,7 +38,7 @@ public abstract class BlockTransferMessage implements Encodable { /** Preceding every serialized message is its type, which allows us to deserialize it. */ public static enum Type { - OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3); + OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4); private final byte id; @@ -60,6 +61,7 @@ public static BlockTransferMessage fromByteArray(byte[] msg) { case 1: return UploadBlock.decode(buf); case 2: return RegisterExecutor.decode(buf); case 3: return StreamHandle.decode(buf); + case 4: return RegisterDriver.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java new file mode 100644 index 0000000000000..1c28fc1dff246 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol.mesos; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; + +/** + * A message sent from the driver to register with the MesosExternalShuffleService. + */ +public class RegisterDriver extends BlockTransferMessage { + private final String appId; + + public RegisterDriver(String appId) { + this.appId = appId; + } + + public String getAppId() { return appId; } + + @Override + protected Type type() { return Type.REGISTER_DRIVER; } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + } + + @Override + public int hashCode() { + return Objects.hashCode(appId); + } + + public static RegisterDriver decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + return new RegisterDriver(appId); + } +} diff --git a/sbin/start-mesos-shuffle-service.sh b/sbin/start-mesos-shuffle-service.sh new file mode 100755 index 0000000000000..64580762c5dc4 --- /dev/null +++ b/sbin/start-mesos-shuffle-service.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Starts the Mesos external shuffle server on the machine this script is executed on. +# The Mesos external shuffle service detects when an application exits and automatically +# cleans up its shuffle files. +# +# Usage: start-mesos-shuffle-server.sh +# +# Use the SPARK_SHUFFLE_OPTS environment variable to set shuffle service configuration. +# + +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +. "$sbin/spark-config.sh" +. "$SPARK_PREFIX/bin/load-spark-env.sh" + +exec "$sbin"/spark-daemon.sh start org.apache.spark.deploy.mesos.MesosExternalShuffleService 1 diff --git a/sbin/stop-mesos-shuffle-service.sh b/sbin/stop-mesos-shuffle-service.sh new file mode 100755 index 0000000000000..0e965d5ec5886 --- /dev/null +++ b/sbin/stop-mesos-shuffle-service.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Stops the Mesos external shuffle service on the machine this script is executed on. + +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +"$sbin"/spark-daemon.sh stop org.apache.spark.deploy.mesos.MesosExternalShuffleService 1 From 137f47865df6e98ab70ae5ba30dc4d441fb41166 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 3 Aug 2015 04:21:15 -0700 Subject: [PATCH 096/340] [SPARK-9551][SQL] add a cheap version of copy for UnsafeRow to reuse a copy buffer Author: Wenchen Fan Closes #7885 from cloud-fan/cheap-copy and squashes the following commits: 0900ca1 [Wenchen Fan] replace == with === 73f4ada [Wenchen Fan] add tests 07b865a [Wenchen Fan] add a cheap version of copy --- .../sql/catalyst/expressions/UnsafeRow.java | 32 ++++++++++++++++ .../org/apache/spark/sql/UnsafeRowSuite.scala | 38 +++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index c5d42d73a43a4..f4230cfaba375 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -463,6 +463,38 @@ public UnsafeRow copy() { return rowCopy; } + /** + * Creates an empty UnsafeRow from a byte array with specified numBytes and numFields. + * The returned row is invalid until we call copyFrom on it. + */ + public static UnsafeRow createFromByteArray(int numBytes, int numFields) { + final UnsafeRow row = new UnsafeRow(); + row.pointTo(new byte[numBytes], numFields, numBytes); + return row; + } + + /** + * Copies the input UnsafeRow to this UnsafeRow, and resize the underlying byte[] when the + * input row is larger than this row. + */ + public void copyFrom(UnsafeRow row) { + // copyFrom is only available for UnsafeRow created from byte array. + assert (baseObject instanceof byte[]) && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET; + if (row.sizeInBytes > this.sizeInBytes) { + // resize the underlying byte[] if it's not large enough. + this.baseObject = new byte[row.sizeInBytes]; + } + PlatformDependent.copyMemory( + row.baseObject, + row.baseOffset, + this.baseObject, + this.baseOffset, + row.sizeInBytes + ); + // update the sizeInBytes. + this.sizeInBytes = row.sizeInBytes; + } + /** * Write this UnsafeRow's underlying bytes to the given OutputStream. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index e72a1bc6c4e20..c5faaa663e749 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -82,4 +82,42 @@ class UnsafeRowSuite extends SparkFunSuite { assert(unsafeRow.get(0, dataType) === null) } } + + test("createFromByteArray and copyFrom") { + val row = InternalRow(1, UTF8String.fromString("abc")) + val converter = UnsafeProjection.create(Array[DataType](IntegerType, StringType)) + val unsafeRow = converter.apply(row) + + val emptyRow = UnsafeRow.createFromByteArray(64, 2) + val buffer = emptyRow.getBaseObject + + emptyRow.copyFrom(unsafeRow) + assert(emptyRow.getSizeInBytes() === unsafeRow.getSizeInBytes) + assert(emptyRow.getInt(0) === unsafeRow.getInt(0)) + assert(emptyRow.getUTF8String(1) === unsafeRow.getUTF8String(1)) + // make sure we reuse the buffer. + assert(emptyRow.getBaseObject === buffer) + + // make sure we really copied the input row. + unsafeRow.setInt(0, 2) + assert(emptyRow.getInt(0) === 1) + + val longString = UTF8String.fromString((1 to 100).map(_ => "abc").reduce(_ + _)) + val row2 = InternalRow(3, longString) + val unsafeRow2 = converter.apply(row2) + + // make sure we can resize. + emptyRow.copyFrom(unsafeRow2) + assert(emptyRow.getSizeInBytes() === unsafeRow2.getSizeInBytes) + assert(emptyRow.getInt(0) === 3) + assert(emptyRow.getUTF8String(1) === longString) + // make sure we really resized. + assert(emptyRow.getBaseObject != buffer) + + // make sure we can still handle small rows after resize. + emptyRow.copyFrom(unsafeRow) + assert(emptyRow.getSizeInBytes() === unsafeRow.getSizeInBytes) + assert(emptyRow.getInt(0) === unsafeRow.getInt(0)) + assert(emptyRow.getUTF8String(1) === unsafeRow.getUTF8String(1)) + } } From 191bf2689d127a9dd328b9cc517362fd51eaed3d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 3 Aug 2015 04:23:26 -0700 Subject: [PATCH 097/340] [SPARK-9518] [SQL] cleanup generated UnsafeRowJoiner and fix bug Currently, when copy the bitsets, we didn't consider that the row1 may not sit in the beginning of byte array. cc rxin Author: Davies Liu Closes #7892 from davies/clean_join and squashes the following commits: 14cce9e [Davies Liu] cleanup generated UnsafeRowJoiner and fix bug --- .../codegen/GenerateUnsafeRowJoiner.scala | 102 ++++++------------ .../GenerateUnsafeRowJoinerBitsetSuite.scala | 7 +- 2 files changed, 37 insertions(+), 72 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index 645eb48d5a51b..5f8a6f8871722 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -40,10 +40,6 @@ abstract class UnsafeRowJoiner { */ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), UnsafeRowJoiner] { - def dump(word: Long): String = { - Seq.tabulate(64) { i => if ((word >> i) % 2 == 0) "0" else "1" }.reverse.mkString - } - override protected def create(in: (StructType, StructType)): UnsafeRowJoiner = { create(in._1, in._2) } @@ -56,76 +52,45 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U } def create(schema1: StructType, schema2: StructType): UnsafeRowJoiner = { - val ctx = newCodeGenContext() val offset = PlatformDependent.BYTE_ARRAY_OFFSET + val getLong = "PlatformDependent.UNSAFE.getLong" + val putLong = "PlatformDependent.UNSAFE.putLong" val bitset1Words = (schema1.size + 63) / 64 val bitset2Words = (schema2.size + 63) / 64 val outputBitsetWords = (schema1.size + schema2.size + 63) / 64 val bitset1Remainder = schema1.size % 64 - val bitset2Remainder = schema2.size % 64 // The number of words we can reduce when we concat two rows together. // The only reduction comes from merging the bitset portion of the two rows, saving 1 word. val sizeReduction = bitset1Words + bitset2Words - outputBitsetWords - // --------------------- copy bitset from row 1 ----------------------- // - val copyBitset1 = Seq.tabulate(bitset1Words) { i => - s""" - |PlatformDependent.UNSAFE.putLong(buf, ${offset + i * 8}, - | PlatformDependent.UNSAFE.getLong(obj1, ${offset + i * 8})); - """.stripMargin - }.mkString - - - // --------------------- copy bitset from row 2 ----------------------- // - var copyBitset2 = "" - if (bitset1Remainder == 0) { - copyBitset2 += Seq.tabulate(bitset2Words) { i => - s""" - |PlatformDependent.UNSAFE.putLong(buf, ${offset + (bitset1Words + i) * 8}, - | PlatformDependent.UNSAFE.getLong(obj2, ${offset + i * 8})); - """.stripMargin - }.mkString - } else { - copyBitset2 = Seq.tabulate(bitset2Words) { i => - s""" - |long bs2w$i = PlatformDependent.UNSAFE.getLong(obj2, ${offset + i * 8}); - |long bs2w${i}p1 = (bs2w$i << $bitset1Remainder) & ~((1L << $bitset1Remainder) - 1); - |long bs2w${i}p2 = (bs2w$i >>> ${64 - bitset1Remainder}); - """.stripMargin - }.mkString - - copyBitset2 += Seq.tabulate(bitset2Words) { i => - val currentOffset = offset + (bitset1Words + i - 1) * 8 - if (i == 0) { - if (bitset1Words > 0) { - s""" - |PlatformDependent.UNSAFE.putLong(buf, $currentOffset, - | bs2w${i}p1 | PlatformDependent.UNSAFE.getLong(obj1, $currentOffset)); - """.stripMargin - } else { - s""" - |PlatformDependent.UNSAFE.putLong(buf, $currentOffset + 8, bs2w${i}p1); - """.stripMargin - } + // --------------------- copy bitset from row 1 and row 2 --------------------------- // + val copyBitset = Seq.tabulate(outputBitsetWords) { i => + val bits = if (bitset1Remainder > 0) { + if (i < bitset1Words - 1) { + s"$getLong(obj1, offset1 + ${i * 8})" + } else if (i == bitset1Words - 1) { + // combine last work of bitset1 and first word of bitset2 + s"$getLong(obj1, offset1 + ${i * 8}) | ($getLong(obj2, offset2) << $bitset1Remainder)" + } else if (i - bitset1Words < bitset2Words - 1) { + // combine next two words of bitset2 + s"($getLong(obj2, offset2 + ${(i - bitset1Words) * 8}) >>> (64 - $bitset1Remainder))" + + s"| ($getLong(obj2, offset2 + ${(i - bitset1Words + 1) * 8}) << $bitset1Remainder)" + } else { + // last word of bitset2 + s"$getLong(obj2, offset2 + ${(i - bitset1Words) * 8}) >>> (64 - $bitset1Remainder)" + } + } else { + // they are aligned by word + if (i < bitset1Words) { + s"$getLong(obj1, offset1 + ${i * 8})" } else { - s""" - |PlatformDependent.UNSAFE.putLong(buf, $currentOffset, bs2w${i}p1 | bs2w${i - 1}p2); - """.stripMargin + s"$getLong(obj2, offset2 + ${(i - bitset1Words) * 8})" } - }.mkString("\n") - - if (bitset2Words > 0 && - (bitset2Remainder == 0 || bitset2Remainder > (64 - bitset1Remainder))) { - val lastWord = bitset2Words - 1 - copyBitset2 += - s""" - |PlatformDependent.UNSAFE.putLong(buf, ${offset + (outputBitsetWords - 1) * 8}, - | bs2w${lastWord}p2); - """.stripMargin } - } + s"$putLong(buf, ${offset + i * 8}, $bits);" + }.mkString("\n") // --------------------- copy fixed length portion from row 1 ----------------------- // var cursor = offset + outputBitsetWords * 8 @@ -149,10 +114,10 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U cursor += schema2.size * 8 // --------------------- copy variable length portion from row 1 ----------------------- // + val numBytesBitsetAndFixedRow1 = (bitset1Words + schema1.size) * 8 val copyVariableLengthRow1 = s""" |// Copy variable length data for row1 - |long numBytesBitsetAndFixedRow1 = ${(bitset1Words + schema1.size) * 8}; - |long numBytesVariableRow1 = row1.getSizeInBytes() - numBytesBitsetAndFixedRow1; + |long numBytesVariableRow1 = row1.getSizeInBytes() - $numBytesBitsetAndFixedRow1; |PlatformDependent.copyMemory( | obj1, offset1 + ${(bitset1Words + schema1.size) * 8}, | buf, $cursor, @@ -160,10 +125,10 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U """.stripMargin // --------------------- copy variable length portion from row 2 ----------------------- // + val numBytesBitsetAndFixedRow2 = (bitset2Words + schema2.size) * 8 val copyVariableLengthRow2 = s""" |// Copy variable length data for row2 - |long numBytesBitsetAndFixedRow2 = ${(bitset2Words + schema2.size) * 8}; - |long numBytesVariableRow2 = row2.getSizeInBytes() - numBytesBitsetAndFixedRow2; + |long numBytesVariableRow2 = row2.getSizeInBytes() - $numBytesBitsetAndFixedRow2; |PlatformDependent.copyMemory( | obj2, offset2 + ${(bitset2Words + schema2.size) * 8}, | buf, $cursor + numBytesVariableRow1, @@ -183,12 +148,11 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U if (i < schema1.size) { s"${(outputBitsetWords - bitset1Words + schema2.size) * 8}L" } else { - s"${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1" + s"(${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1)" } val cursor = offset + outputBitsetWords * 8 + i * 8 s""" - |PlatformDependent.UNSAFE.putLong(buf, $cursor, - | PlatformDependent.UNSAFE.getLong(buf, $cursor) + ($shift << 32)); + |$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32)); """.stripMargin } }.mkString @@ -217,8 +181,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | final Object obj2 = row2.getBaseObject(); | final long offset2 = row2.getBaseOffset(); | - | $copyBitset1 - | $copyBitset2 + | $copyBitset | $copyFixedLengthRow1 | $copyFixedLengthRow2 | $copyVariableLengthRow1 @@ -233,7 +196,6 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U """.stripMargin logDebug(s"SpecificUnsafeRowJoiner($schema1, $schema2):\n${CodeFormatter.format(code)}") - // println(CodeFormatter.format(code)) val c = compile(code) c.generate(Array.empty).asInstanceOf[UnsafeRowJoiner] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala index 76d9d991ed0dc..718a2acc8281d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala @@ -22,6 +22,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.PlatformDependent /** * A test suite for the bitset portion of the row concatenation. @@ -91,8 +92,9 @@ class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite { private def createUnsafeRow(numFields: Int): UnsafeRow = { val row = new UnsafeRow val sizeInBytes = numFields * 8 + ((numFields + 63) / 64) * 8 - val buf = new Array[Byte](sizeInBytes) - row.pointTo(buf, numFields, sizeInBytes) + val offset = numFields * 8 + val buf = new Array[Byte](sizeInBytes + offset) + row.pointTo(buf, PlatformDependent.BYTE_ARRAY_OFFSET + offset, numFields, sizeInBytes) row } @@ -133,6 +135,7 @@ class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite { |input1: ${set1.mkString} |input2: ${set2.mkString} |output: ${out.mkString} + |expect: ${set1.mkString}${set2.mkString} """.stripMargin } From 8be198c86935001907727fd16577231ff776125b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 3 Aug 2015 04:26:18 -0700 Subject: [PATCH 098/340] Two minor comments from code review on 191bf2689. --- .../catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala | 2 +- .../codegen/GenerateUnsafeRowJoinerBitsetSuite.scala | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index 5f8a6f8871722..30b51dd83fa9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -76,7 +76,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U } else if (i - bitset1Words < bitset2Words - 1) { // combine next two words of bitset2 s"($getLong(obj2, offset2 + ${(i - bitset1Words) * 8}) >>> (64 - $bitset1Remainder))" + - s"| ($getLong(obj2, offset2 + ${(i - bitset1Words + 1) * 8}) << $bitset1Remainder)" + s" | ($getLong(obj2, offset2 + ${(i - bitset1Words + 1) * 8}) << $bitset1Remainder)" } else { // last word of bitset2 s"$getLong(obj2, offset2 + ${(i - bitset1Words) * 8}) >>> (64 - $bitset1Remainder)" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala index 718a2acc8281d..aff1bee99faad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala @@ -92,6 +92,8 @@ class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite { private def createUnsafeRow(numFields: Int): UnsafeRow = { val row = new UnsafeRow val sizeInBytes = numFields * 8 + ((numFields + 63) / 64) * 8 + // Allocate a larger buffer than needed and point the UnsafeRow to somewhere in the middle. + // This way we can test the joiner when the input UnsafeRows are not the entire arrays. val offset = numFields * 8 val buf = new Array[Byte](sizeInBytes + offset) row.pointTo(buf, PlatformDependent.BYTE_ARRAY_OFFSET + offset, numFields, sizeInBytes) From 69f5a7c934ac553ed52c00679b800bcffe83c1d6 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 3 Aug 2015 10:46:34 -0700 Subject: [PATCH 099/340] [SPARK-9528] [ML] Changed RandomForestClassifier to extend ProbabilisticClassifier RandomForestClassifier now outputs rawPrediction based on tree probabilities, plus probability column computed from normalized rawPrediction. CC: holdenk Author: Joseph K. Bradley Closes #7859 from jkbradley/rf-prob and squashes the following commits: 6c28f51 [Joseph K. Bradley] Changed RandomForestClassifier to extend ProbabilisticClassifier --- .../DecisionTreeClassifier.scala | 8 +--- .../ProbabilisticClassifier.scala | 27 +++++++++++++- .../RandomForestClassifier.scala | 37 +++++++++++++------ .../RandomForestClassifierSuite.scala | 36 ++++++++++++++---- 4 files changed, 81 insertions(+), 27 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index f27cfd0331419..f2b992f8ba249 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -131,13 +131,7 @@ final class DecisionTreeClassificationModel private[ml] ( override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { rawPrediction match { case dv: DenseVector => - var i = 0 - val size = dv.size - val sum = dv.values.sum - while (i < size) { - dv.values(i) = if (sum != 0) dv.values(i) / sum else 0.0 - i += 1 - } + ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(dv) dv case sv: SparseVector => throw new RuntimeException("Unexpected error in DecisionTreeClassificationModel:" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index dad451108626d..f9c9c2371f5cd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, DataType, StructType} @@ -175,3 +175,28 @@ private[spark] abstract class ProbabilisticClassificationModel[ */ protected def probability2prediction(probability: Vector): Double = probability.argmax } + +private[ml] object ProbabilisticClassificationModel { + + /** + * Normalize a vector of raw predictions to be a multinomial probability vector, in place. + * + * The input raw predictions should be >= 0. + * The output vector sums to 1, unless the input vector is all-0 (in which case the output is + * all-0 too). + * + * NOTE: This is NOT applicable to all models, only ones which effectively use class + * instance counts for raw predictions. + */ + def normalizeToProbabilitiesInPlace(v: DenseVector): Unit = { + val sum = v.values.sum + if (sum != 0) { + var i = 0 + val size = v.size + while (i < size) { + v.values(i) /= sum + i += 1 + } + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 0c7eb4a662fdb..56e80cc8fe6e1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -17,22 +17,19 @@ package org.apache.spark.ml.classification -import scala.collection.mutable - import org.apache.spark.annotation.Experimental import org.apache.spark.ml.tree.impl.RandomForest -import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DoubleType + /** * :: Experimental :: @@ -43,7 +40,7 @@ import org.apache.spark.sql.types.DoubleType */ @Experimental final class RandomForestClassifier(override val uid: String) - extends Classifier[Vector, RandomForestClassifier, RandomForestClassificationModel] + extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel] with RandomForestParams with TreeClassifierParams { def this() = this(Identifiable.randomUID("rfc")) @@ -127,7 +124,7 @@ final class RandomForestClassificationModel private[ml] ( override val uid: String, private val _trees: Array[DecisionTreeClassificationModel], override val numClasses: Int) - extends ClassificationModel[Vector, RandomForestClassificationModel] + extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") @@ -157,15 +154,33 @@ final class RandomForestClassificationModel private[ml] ( override protected def predictRaw(features: Vector): Vector = { // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 // Classifies using majority votes. - // Ignore the weights since all are 1.0 for now. - val votes = new Array[Double](numClasses) + // Ignore the tree weights since all are 1.0 for now. + val votes = Array.fill[Double](numClasses)(0.0) _trees.view.foreach { tree => - val prediction = tree.rootNode.predictImpl(features).prediction.toInt - votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight + val classCounts: Array[Double] = tree.rootNode.predictImpl(features).impurityStats.stats + val total = classCounts.sum + if (total != 0) { + var i = 0 + while (i < numClasses) { + votes(i) += classCounts(i) / total + i += 1 + } + } } Vectors.dense(votes) } + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction match { + case dv: DenseVector => + ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(dv) + dv + case sv: SparseVector => + throw new RuntimeException("Unexpected error in RandomForestClassificationModel:" + + " raw2probabilityInPlace encountered SparseVector") + } + } + override def copy(extra: ParamMap): RandomForestClassificationModel = { copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index dbb2577c6204d..edf848b21a905 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} @@ -121,6 +122,33 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte compareAPIs(rdd, rf2, categoricalFeatures, numClasses) } + test("predictRaw and predictProbability") { + val rdd = orderedLabeledPoints5_20 + val rf = new RandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(3) + .setNumTrees(3) + .setSeed(123) + val categoricalFeatures = Map.empty[Int, Int] + val numClasses = 2 + + val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val model = rf.fit(df) + + val predictions = model.transform(df) + .select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol) + .collect() + + predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) => + assert(pred === rawPred.argmax, + s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") + val sum = rawPred.toArray.sum + assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred, + "probability prediction mismatch") + assert(probPred.toArray.sum ~== 1.0 relTol 1E-5) + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// @@ -173,13 +201,5 @@ private object RandomForestClassifierSuite { assert(newModel.hasParent) assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent) assert(newModel.numClasses == numClasses) - val results = newModel.transform(newData) - results.select("rawPrediction", "prediction").collect().foreach { - case Row(raw: Vector, prediction: Double) => { - assert(raw.size == numClasses) - val predFromRaw = raw.toArray.zipWithIndex.maxBy(_._1)._2 - assert(predFromRaw == prediction) - } - } } } From b41a32718d615b304efba146bf97be0229779b01 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 3 Aug 2015 10:58:37 -0700 Subject: [PATCH 100/340] [SPARK-1855] Local checkpointing Certain use cases of Spark involve RDDs with long lineages that must be truncated periodically (e.g. GraphX). The existing way of doing it is through `rdd.checkpoint()`, which is expensive because it writes to HDFS. This patch provides an alternative to truncate lineages cheaply *without providing the same level of fault tolerance*. **Local checkpointing** writes checkpointed data to the local file system through the block manager. It is much faster than replicating to a reliable storage and provides the same semantics as long as executors do not fail. It is accessible through a new operator `rdd.localCheckpoint()` and leaves the old one unchanged. Users may even decide to combine the two and call the reliable one less frequently. The bulk of this patch involves refactoring the checkpointing interface to accept custom implementations of checkpointing. [Design doc](https://issues.apache.org/jira/secure/attachment/12741708/SPARK-7292-design.pdf). Author: Andrew Or Closes #7279 from andrewor14/local-checkpoint and squashes the following commits: 729600f [Andrew Or] Oops, fix tests 34bc059 [Andrew Or] Avoid computing all partitions in local checkpoint e43bbb6 [Andrew Or] Merge branch 'master' of github.com:apache/spark into local-checkpoint 3be5aea [Andrew Or] Address comments bf846a6 [Andrew Or] Merge branch 'master' of github.com:apache/spark into local-checkpoint ab003a3 [Andrew Or] Fix compile c2e111b [Andrew Or] Address comments 33f167a [Andrew Or] Merge branch 'master' of github.com:apache/spark into local-checkpoint e908a42 [Andrew Or] Fix tests f5be0f3 [Andrew Or] Use MEMORY_AND_DISK as the default local checkpoint level a92657d [Andrew Or] Update a few comments e58e3e3 [Andrew Or] Merge branch 'master' of github.com:apache/spark into local-checkpoint 4eb6eb1 [Andrew Or] Merge branch 'master' of github.com:apache/spark into local-checkpoint 1bbe154 [Andrew Or] Simplify LocalCheckpointRDD 48a9996 [Andrew Or] Avoid traversing dependency tree + rewrite tests 62aba3f [Andrew Or] Merge branch 'master' of github.com:apache/spark into local-checkpoint db70dc2 [Andrew Or] Express local checkpointing through caching the original RDD 87d43c6 [Andrew Or] Merge branch 'master' of github.com:apache/spark into local-checkpoint c449b38 [Andrew Or] Fix style 4a182f3 [Andrew Or] Add fine-grained tests for local checkpointing 53b363b [Andrew Or] Rename a few more awkwardly named methods (minor) e4cf071 [Andrew Or] Simplify LocalCheckpointRDD + docs + clean ups 4880deb [Andrew Or] Fix style d096c67 [Andrew Or] Fix mima 172cb66 [Andrew Or] Fix mima? e53d964 [Andrew Or] Fix style 56831c5 [Andrew Or] Add a few warnings and clear exception messages 2e59646 [Andrew Or] Add local checkpoint clean up tests 4dbbab1 [Andrew Or] Refactor CheckpointSuite to test local checkpointing 4514dc9 [Andrew Or] Clean local checkpoint files through RDD cleanups 0477eec [Andrew Or] Rename a few methods with awkward names (minor) 2e902e5 [Andrew Or] First implementation of local checkpointing 8447454 [Andrew Or] Fix tests 4ac1896 [Andrew Or] Refactor checkpoint interface for modularity --- .../org/apache/spark/ContextCleaner.scala | 9 +- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../scala/org/apache/spark/TaskContext.scala | 8 + .../org/apache/spark/rdd/CheckpointRDD.scala | 153 +------- .../apache/spark/rdd/LocalCheckpointRDD.scala | 67 ++++ .../spark/rdd/LocalRDDCheckpointData.scala | 83 +++++ .../main/scala/org/apache/spark/rdd/RDD.scala | 128 +++++-- .../apache/spark/rdd/RDDCheckpointData.scala | 106 ++---- .../spark/rdd/ReliableCheckpointRDD.scala | 172 +++++++++ .../spark/rdd/ReliableRDDCheckpointData.scala | 108 ++++++ .../org/apache/spark/CheckpointSuite.scala | 164 +++++---- .../apache/spark/ContextCleanerSuite.scala | 61 +++- .../spark/rdd/LocalCheckpointSuite.scala | 330 ++++++++++++++++++ project/MimaExcludes.scala | 9 +- 14 files changed, 1085 insertions(+), 315 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala create mode 100644 core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 37198d887b07b..d23c1533db758 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -22,7 +22,7 @@ import java.lang.ref.{ReferenceQueue, WeakReference} import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.{RDDCheckpointData, RDD} +import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData} import org.apache.spark.util.Utils /** @@ -231,11 +231,14 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } } - /** Perform checkpoint cleanup. */ + /** + * Clean up checkpoint files written to a reliable storage. + * Locally checkpointed files are cleaned up separately through RDD cleanups. + */ def doCleanCheckpoint(rddId: Int): Unit = { try { logDebug("Cleaning rdd checkpoint data " + rddId) - RDDCheckpointData.clearRDDCheckpointData(sc, rddId) + ReliableRDDCheckpointData.cleanCheckpoint(sc, rddId) listeners.foreach(_.checkpointCleaned(rddId)) logInfo("Cleaned rdd checkpoint data " + rddId) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6f336a7c299ab..4380cf45cc1b0 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1192,7 +1192,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } protected[spark] def checkpointFile[T: ClassTag](path: String): RDD[T] = withScope { - new CheckpointRDD[T](this, path) + new ReliableCheckpointRDD[T](this, path) } /** Build the union of a list of RDDs. */ diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index b48836d5c8897..5d2c551d58514 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -59,6 +59,14 @@ object TaskContext { * Unset the thread local TaskContext. Internal to Spark. */ protected[spark] def unset(): Unit = taskContext.remove() + + /** + * Return an empty task context that is not actually used. + * Internal use only. + */ + private[spark] def empty(): TaskContext = { + new TaskContextImpl(0, 0, 0, 0, null, null) + } } diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index e17bd47905d7a..72fe215dae73e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -17,156 +17,31 @@ package org.apache.spark.rdd -import java.io.IOException - import scala.reflect.ClassTag -import org.apache.hadoop.fs.Path - -import org.apache.spark._ -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.{Partition, SparkContext, TaskContext} +/** + * An RDD partition used to recover checkpointed data. + */ private[spark] class CheckpointRDDPartition(val index: Int) extends Partition /** - * This RDD represents a RDD checkpoint file (similar to HadoopRDD). + * An RDD that recovers checkpointed data from storage. */ -private[spark] -class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) +private[spark] abstract class CheckpointRDD[T: ClassTag](@transient sc: SparkContext) extends RDD[T](sc, Nil) { - private val broadcastedConf = sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration)) - - @transient private val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) - - override def getCheckpointFile: Option[String] = Some(checkpointPath) - - override def getPartitions: Array[Partition] = { - val cpath = new Path(checkpointPath) - val numPartitions = - // listStatus can throw exception if path does not exist. - if (fs.exists(cpath)) { - val dirContents = fs.listStatus(cpath).map(_.getPath) - val partitionFiles = dirContents.filter(_.getName.startsWith("part-")).map(_.toString).sorted - val numPart = partitionFiles.length - if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || - ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) { - throw new SparkException("Invalid checkpoint directory: " + checkpointPath) - } - numPart - } else 0 - - Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i)) - } - - override def getPreferredLocations(split: Partition): Seq[String] = { - val status = fs.getFileStatus(new Path(checkpointPath, - CheckpointRDD.splitIdToFile(split.index))) - val locations = fs.getFileBlockLocations(status, 0, status.getLen) - locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost") - } - - override def compute(split: Partition, context: TaskContext): Iterator[T] = { - val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index)) - CheckpointRDD.readFromFile(file, broadcastedConf, context) - } - // CheckpointRDD should not be checkpointed again - override def checkpoint(): Unit = { } override def doCheckpoint(): Unit = { } -} - -private[spark] object CheckpointRDD extends Logging { - def splitIdToFile(splitId: Int): String = { - "part-%05d".format(splitId) - } - - def writeToFile[T: ClassTag]( - path: String, - broadcastedConf: Broadcast[SerializableConfiguration], - blockSize: Int = -1 - )(ctx: TaskContext, iterator: Iterator[T]) { - val env = SparkEnv.get - val outputDir = new Path(path) - val fs = outputDir.getFileSystem(broadcastedConf.value.value) - - val finalOutputName = splitIdToFile(ctx.partitionId) - val finalOutputPath = new Path(outputDir, finalOutputName) - val tempOutputPath = - new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptNumber) - - if (fs.exists(tempOutputPath)) { - throw new IOException("Checkpoint failed: temporary path " + - tempOutputPath + " already exists") - } - val bufferSize = env.conf.getInt("spark.buffer.size", 65536) - - val fileOutputStream = if (blockSize < 0) { - fs.create(tempOutputPath, false, bufferSize) - } else { - // This is mainly for testing purpose - fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize) - } - val serializer = env.serializer.newInstance() - val serializeStream = serializer.serializeStream(fileOutputStream) - Utils.tryWithSafeFinally { - serializeStream.writeAll(iterator) - } { - serializeStream.close() - } - - if (!fs.rename(tempOutputPath, finalOutputPath)) { - if (!fs.exists(finalOutputPath)) { - logInfo("Deleting tempOutputPath " + tempOutputPath) - fs.delete(tempOutputPath, false) - throw new IOException("Checkpoint failed: failed to save output of task: " - + ctx.attemptNumber + " and final output path does not exist") - } else { - // Some other copy of this task must've finished before us and renamed it - logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it") - fs.delete(tempOutputPath, false) - } - } - } - - def readFromFile[T]( - path: Path, - broadcastedConf: Broadcast[SerializableConfiguration], - context: TaskContext - ): Iterator[T] = { - val env = SparkEnv.get - val fs = path.getFileSystem(broadcastedConf.value.value) - val bufferSize = env.conf.getInt("spark.buffer.size", 65536) - val fileInputStream = fs.open(path, bufferSize) - val serializer = env.serializer.newInstance() - val deserializeStream = serializer.deserializeStream(fileInputStream) - - // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener(context => deserializeStream.close()) - - deserializeStream.asIterator.asInstanceOf[Iterator[T]] - } + override def checkpoint(): Unit = { } + override def localCheckpoint(): this.type = this - // Test whether CheckpointRDD generate expected number of partitions despite - // each split file having multiple blocks. This needs to be run on a - // cluster (mesos or standalone) using HDFS. - def main(args: Array[String]) { - import org.apache.spark._ + // Note: There is a bug in MiMa that complains about `AbstractMethodProblem`s in the + // base [[org.apache.spark.rdd.RDD]] class if we do not override the following methods. + // scalastyle:off + protected override def getPartitions: Array[Partition] = ??? + override def compute(p: Partition, tc: TaskContext): Iterator[T] = ??? + // scalastyle:on - val Array(cluster, hdfsPath) = args - val env = SparkEnv.get - val sc = new SparkContext(cluster, "CheckpointRDD Test") - val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000) - val path = new Path(hdfsPath, "temp") - val conf = SparkHadoopUtil.get.newConfiguration(new SparkConf()) - val fs = path.getFileSystem(conf) - val broadcastedConf = sc.broadcast(new SerializableConfiguration(conf)) - sc.runJob(rdd, CheckpointRDD.writeToFile[Int](path.toString, broadcastedConf, 1024) _) - val cpRDD = new CheckpointRDD[Int](sc, path.toString) - assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same") - assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same") - fs.delete(path, true) - } } diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala new file mode 100644 index 0000000000000..daa5779d688cc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag + +import org.apache.spark.{Partition, SparkContext, SparkEnv, SparkException, TaskContext} +import org.apache.spark.storage.RDDBlockId + +/** + * A dummy CheckpointRDD that exists to provide informative error messages during failures. + * + * This is simply a placeholder because the original checkpointed RDD is expected to be + * fully cached. Only if an executor fails or if the user explicitly unpersists the original + * RDD will Spark ever attempt to compute this CheckpointRDD. When this happens, however, + * we must provide an informative error message. + * + * @param sc the active SparkContext + * @param rddId the ID of the checkpointed RDD + * @param numPartitions the number of partitions in the checkpointed RDD + */ +private[spark] class LocalCheckpointRDD[T: ClassTag]( + @transient sc: SparkContext, + rddId: Int, + numPartitions: Int) + extends CheckpointRDD[T](sc) { + + def this(rdd: RDD[T]) { + this(rdd.context, rdd.id, rdd.partitions.size) + } + + protected override def getPartitions: Array[Partition] = { + (0 until numPartitions).toArray.map { i => new CheckpointRDDPartition(i) } + } + + /** + * Throw an exception indicating that the relevant block is not found. + * + * This should only be called if the original RDD is explicitly unpersisted or if an + * executor is lost. Under normal circumstances, however, the original RDD (our child) + * is expected to be fully cached and so all partitions should already be computed and + * available in the block storage. + */ + override def compute(partition: Partition, context: TaskContext): Iterator[T] = { + throw new SparkException( + s"Checkpoint block ${RDDBlockId(rddId, partition.index)} not found! Either the executor " + + s"that originally checkpointed this partition is no longer alive, or the original RDD is " + + s"unpersisted. If this problem persists, you may consider using `rdd.checkpoint()` " + + s"instead, which is slower than local checkpointing but more fault-tolerant.") + } + +} diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala new file mode 100644 index 0000000000000..d6fad896845f6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag + +import org.apache.spark.{Logging, SparkEnv, SparkException, TaskContext} +import org.apache.spark.storage.{RDDBlockId, StorageLevel} +import org.apache.spark.util.Utils + +/** + * An implementation of checkpointing implemented on top of Spark's caching layer. + * + * Local checkpointing trades off fault tolerance for performance by skipping the expensive + * step of saving the RDD data to a reliable and fault-tolerant storage. Instead, the data + * is written to the local, ephemeral block storage that lives in each executor. This is useful + * for use cases where RDDs build up long lineages that need to be truncated often (e.g. GraphX). + */ +private[spark] class LocalRDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) + extends RDDCheckpointData[T](rdd) with Logging { + + /** + * Ensure the RDD is fully cached so the partitions can be recovered later. + */ + protected override def doCheckpoint(): CheckpointRDD[T] = { + val level = rdd.getStorageLevel + + // Assume storage level uses disk; otherwise memory eviction may cause data loss + assume(level.useDisk, s"Storage level $level is not appropriate for local checkpointing") + + // Not all actions compute all partitions of the RDD (e.g. take). For correctness, we + // must cache any missing partitions. TODO: avoid running another job here (SPARK-8582). + val action = (tc: TaskContext, iterator: Iterator[T]) => Utils.getIteratorSize(iterator) + val missingPartitionIndices = rdd.partitions.map(_.index).filter { i => + !SparkEnv.get.blockManager.master.contains(RDDBlockId(rdd.id, i)) + } + if (missingPartitionIndices.nonEmpty) { + rdd.sparkContext.runJob(rdd, action, missingPartitionIndices) + } + + new LocalCheckpointRDD[T](rdd) + } + +} + +private[spark] object LocalRDDCheckpointData { + + val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK + + /** + * Transform the specified storage level to one that uses disk. + * + * This guarantees that the RDD can be recomputed multiple times correctly as long as + * executors do not fail. Otherwise, if the RDD is cached in memory only, for instance, + * the checkpoint data will be lost if the relevant block is evicted from memory. + * + * This method is idempotent. + */ + def transformStorageLevel(level: StorageLevel): StorageLevel = { + // If this RDD is to be cached off-heap, fail fast since we cannot provide any + // correctness guarantees about subsequent computations after the first one + if (level.useOffHeap) { + throw new SparkException("Local checkpointing is not compatible with off-heap caching.") + } + + StorageLevel(useDisk = true, level.useMemory, level.deserialized, level.replication) + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 6d61d227382d7..081c721f23687 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -149,23 +149,43 @@ abstract class RDD[T: ClassTag]( } /** - * Set this RDD's storage level to persist its values across operations after the first time - * it is computed. This can only be used to assign a new storage level if the RDD does not - * have a storage level set yet.. + * Mark this RDD for persisting using the specified level. + * + * @param newLevel the target storage level + * @param allowOverride whether to override any existing level with the new one */ - def persist(newLevel: StorageLevel): this.type = { + private def persist(newLevel: StorageLevel, allowOverride: Boolean): this.type = { // TODO: Handle changes of StorageLevel - if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) { + if (storageLevel != StorageLevel.NONE && newLevel != storageLevel && !allowOverride) { throw new UnsupportedOperationException( "Cannot change storage level of an RDD after it was already assigned a level") } - sc.persistRDD(this) - // Register the RDD with the ContextCleaner for automatic GC-based cleanup - sc.cleaner.foreach(_.registerRDDForCleanup(this)) + // If this is the first time this RDD is marked for persisting, register it + // with the SparkContext for cleanups and accounting. Do this only once. + if (storageLevel == StorageLevel.NONE) { + sc.cleaner.foreach(_.registerRDDForCleanup(this)) + sc.persistRDD(this) + } storageLevel = newLevel this } + /** + * Set this RDD's storage level to persist its values across operations after the first time + * it is computed. This can only be used to assign a new storage level if the RDD does not + * have a storage level set yet. Local checkpointing is an exception. + */ + def persist(newLevel: StorageLevel): this.type = { + if (isLocallyCheckpointed) { + // This means the user previously called localCheckpoint(), which should have already + // marked this RDD for persisting. Here we should override the old storage level with + // one that is explicitly requested by the user (after adapting it to use disk). + persist(LocalRDDCheckpointData.transformStorageLevel(newLevel), allowOverride = true) + } else { + persist(newLevel, allowOverride = false) + } + } + /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ def persist(): this.type = persist(StorageLevel.MEMORY_ONLY) @@ -1448,33 +1468,99 @@ abstract class RDD[T: ClassTag]( /** * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint - * directory set with SparkContext.setCheckpointDir() and all references to its parent + * directory set with `SparkContext#setCheckpointDir` and all references to its parent * RDDs will be removed. This function must be called before any job has been * executed on this RDD. It is strongly recommended that this RDD is persisted in * memory, otherwise saving it on a file will require recomputation. */ - def checkpoint(): Unit = { + def checkpoint(): Unit = RDDCheckpointData.synchronized { + // NOTE: we use a global lock here due to complexities downstream with ensuring + // children RDD partitions point to the correct parent partitions. In the future + // we should revisit this consideration. if (context.checkpointDir.isEmpty) { throw new SparkException("Checkpoint directory has not been set in the SparkContext") } else if (checkpointData.isEmpty) { - // NOTE: we use a global lock here due to complexities downstream with ensuring - // children RDD partitions point to the correct parent partitions. In the future - // we should revisit this consideration. - RDDCheckpointData.synchronized { - checkpointData = Some(new RDDCheckpointData(this)) - } + checkpointData = Some(new ReliableRDDCheckpointData(this)) + } + } + + /** + * Mark this RDD for local checkpointing using Spark's existing caching layer. + * + * This method is for users who wish to truncate RDD lineages while skipping the expensive + * step of replicating the materialized data in a reliable distributed file system. This is + * useful for RDDs with long lineages that need to be truncated periodically (e.g. GraphX). + * + * Local checkpointing sacrifices fault-tolerance for performance. In particular, checkpointed + * data is written to ephemeral local storage in the executors instead of to a reliable, + * fault-tolerant storage. The effect is that if an executor fails during the computation, + * the checkpointed data may no longer be accessible, causing an irrecoverable job failure. + * + * This is NOT safe to use with dynamic allocation, which removes executors along + * with their cached blocks. If you must use both features, you are advised to set + * `spark.dynamicAllocation.cachedExecutorIdleTimeout` to a high value. + * + * The checkpoint directory set through `SparkContext#setCheckpointDir` is not used. + */ + def localCheckpoint(): this.type = RDDCheckpointData.synchronized { + if (conf.getBoolean("spark.dynamicAllocation.enabled", false) && + conf.contains("spark.dynamicAllocation.cachedExecutorIdleTimeout")) { + logWarning("Local checkpointing is NOT safe to use with dynamic allocation, " + + "which removes executors along with their cached blocks. If you must use both " + + "features, you are advised to set `spark.dynamicAllocation.cachedExecutorIdleTimeout` " + + "to a high value. E.g. If you plan to use the RDD for 1 hour, set the timeout to " + + "at least 1 hour.") + } + + // Note: At this point we do not actually know whether the user will call persist() on + // this RDD later, so we must explicitly call it here ourselves to ensure the cached + // blocks are registered for cleanup later in the SparkContext. + // + // If, however, the user has already called persist() on this RDD, then we must adapt + // the storage level he/she specified to one that is appropriate for local checkpointing + // (i.e. uses disk) to guarantee correctness. + + if (storageLevel == StorageLevel.NONE) { + persist(LocalRDDCheckpointData.DEFAULT_STORAGE_LEVEL) + } else { + persist(LocalRDDCheckpointData.transformStorageLevel(storageLevel), allowOverride = true) } + + checkpointData match { + case Some(reliable: ReliableRDDCheckpointData[_]) => logWarning( + "RDD was already marked for reliable checkpointing: overriding with local checkpoint.") + case _ => + } + checkpointData = Some(new LocalRDDCheckpointData(this)) + this } /** - * Return whether this RDD has been checkpointed or not + * Return whether this RDD is marked for checkpointing, either reliably or locally. */ def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed) /** - * Gets the name of the file to which this RDD was checkpointed + * Return whether this RDD is marked for local checkpointing. + * Exposed for testing. */ - def getCheckpointFile: Option[String] = checkpointData.flatMap(_.getCheckpointFile) + private[rdd] def isLocallyCheckpointed: Boolean = { + checkpointData match { + case Some(_: LocalRDDCheckpointData[T]) => true + case _ => false + } + } + + /** + * Gets the name of the directory to which this RDD was checkpointed. + * This is not defined if the RDD is checkpointed locally. + */ + def getCheckpointFile: Option[String] = { + checkpointData match { + case Some(reliable: ReliableRDDCheckpointData[T]) => reliable.getCheckpointDir + case _ => None + } + } // ======================================================================= // Other internal methods and fields @@ -1545,7 +1631,7 @@ abstract class RDD[T: ClassTag]( if (!doCheckpointCalled) { doCheckpointCalled = true if (checkpointData.isDefined) { - checkpointData.get.doCheckpoint() + checkpointData.get.checkpoint() } else { dependencies.foreach(_.rdd.doCheckpoint()) } @@ -1557,7 +1643,7 @@ abstract class RDD[T: ClassTag]( * Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`) * created from the checkpoint file, and forget its old dependencies and partitions. */ - private[spark] def markCheckpointed(checkpointRDD: RDD[_]) { + private[spark] def markCheckpointed(): Unit = { clearDependencies() partitions_ = null deps = null // Forget the constructor argument for dependencies too diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 4f954363bed8e..0e43520870c0a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -19,10 +19,7 @@ package org.apache.spark.rdd import scala.reflect.ClassTag -import org.apache.hadoop.fs.Path - -import org.apache.spark._ -import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.Partition /** * Enumeration to manage state transitions of an RDD through checkpointing @@ -39,39 +36,31 @@ private[spark] object CheckpointState extends Enumeration { * as well as, manages the post-checkpoint state by providing the updated partitions, * iterator and preferred locations of the checkpointed RDD. */ -private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) - extends Logging with Serializable { +private[spark] abstract class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) + extends Serializable { import CheckpointState._ // The checkpoint state of the associated RDD. - private var cpState = Initialized - - // The file to which the associated RDD has been checkpointed to - private var cpFile: Option[String] = None + protected var cpState = Initialized - // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD. - // This is defined if and only if `cpState` is `Checkpointed`. + // The RDD that contains our checkpointed data private var cpRDD: Option[CheckpointRDD[T]] = None // TODO: are we sure we need to use a global lock in the following methods? - // Is the RDD already checkpointed + /** + * Return whether the checkpoint data for this RDD is already persisted. + */ def isCheckpointed: Boolean = RDDCheckpointData.synchronized { cpState == Checkpointed } - // Get the file to which this RDD was checkpointed to as an Option - def getCheckpointFile: Option[String] = RDDCheckpointData.synchronized { - cpFile - } - /** - * Materialize this RDD and write its content to a reliable DFS. + * Materialize this RDD and persist its content. * This is called immediately after the first action invoked on this RDD has completed. */ - def doCheckpoint(): Unit = { - + final def checkpoint(): Unit = { // Guard against multiple threads checkpointing the same RDD by // atomically flipping the state of this RDDCheckpointData RDDCheckpointData.synchronized { @@ -82,64 +71,41 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) } } - // Create the output path for the checkpoint - val path = RDDCheckpointData.rddCheckpointDataPath(rdd.context, rdd.id).get - val fs = path.getFileSystem(rdd.context.hadoopConfiguration) - if (!fs.mkdirs(path)) { - throw new SparkException(s"Failed to create checkpoint path $path") - } - - // Save to file, and reload it as an RDD - val broadcastedConf = rdd.context.broadcast( - new SerializableConfiguration(rdd.context.hadoopConfiguration)) - val newRDD = new CheckpointRDD[T](rdd.context, path.toString) - if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) { - rdd.context.cleaner.foreach { cleaner => - cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id) - } - } - - // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582) - rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) - if (newRDD.partitions.length != rdd.partitions.length) { - throw new SparkException( - "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.length + ") has different " + - "number of partitions than original RDD " + rdd + "(" + rdd.partitions.length + ")") - } + val newRDD = doCheckpoint() - // Change the dependencies and partitions of the RDD + // Update our state and truncate the RDD lineage RDDCheckpointData.synchronized { - cpFile = Some(path.toString) cpRDD = Some(newRDD) - rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions cpState = Checkpointed + rdd.markCheckpointed() } - logInfo(s"Done checkpointing RDD ${rdd.id} to $path, new parent is RDD ${newRDD.id}") - } - - def getPartitions: Array[Partition] = RDDCheckpointData.synchronized { - cpRDD.get.partitions } - def checkpointRDD: Option[CheckpointRDD[T]] = RDDCheckpointData.synchronized { - cpRDD - } -} + /** + * Materialize this RDD and persist its content. + * + * Subclasses should override this method to define custom checkpointing behavior. + * @return the checkpoint RDD created in the process. + */ + protected def doCheckpoint(): CheckpointRDD[T] -private[spark] object RDDCheckpointData { + /** + * Return the RDD that contains our checkpointed data. + * This is only defined if the checkpoint state is `Checkpointed`. + */ + def checkpointRDD: Option[CheckpointRDD[T]] = RDDCheckpointData.synchronized { cpRDD } - /** Return the path of the directory to which this RDD's checkpoint data is written. */ - def rddCheckpointDataPath(sc: SparkContext, rddId: Int): Option[Path] = { - sc.checkpointDir.map { dir => new Path(dir, s"rdd-$rddId") } + /** + * Return the partitions of the resulting checkpoint RDD. + * For tests only. + */ + def getPartitions: Array[Partition] = RDDCheckpointData.synchronized { + cpRDD.map(_.partitions).getOrElse { Array.empty } } - /** Clean up the files associated with the checkpoint data for this RDD. */ - def clearRDDCheckpointData(sc: SparkContext, rddId: Int): Unit = { - rddCheckpointDataPath(sc, rddId).foreach { path => - val fs = path.getFileSystem(sc.hadoopConfiguration) - if (fs.exists(path)) { - fs.delete(path, true) - } - } - } } + +/** + * Global lock for synchronizing checkpoint operations. + */ +private[spark] object RDDCheckpointData diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala new file mode 100644 index 0000000000000..35d8b0bfd18c5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.io.IOException + +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.Path + +import org.apache.spark._ +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.util.{SerializableConfiguration, Utils} + +/** + * An RDD that reads from checkpoint files previously written to reliable storage. + */ +private[spark] class ReliableCheckpointRDD[T: ClassTag]( + @transient sc: SparkContext, + val checkpointPath: String) + extends CheckpointRDD[T](sc) { + + @transient private val hadoopConf = sc.hadoopConfiguration + @transient private val cpath = new Path(checkpointPath) + @transient private val fs = cpath.getFileSystem(hadoopConf) + private val broadcastedConf = sc.broadcast(new SerializableConfiguration(hadoopConf)) + + // Fail fast if checkpoint directory does not exist + require(fs.exists(cpath), s"Checkpoint directory does not exist: $checkpointPath") + + /** + * Return the path of the checkpoint directory this RDD reads data from. + */ + override def getCheckpointFile: Option[String] = Some(checkpointPath) + + /** + * Return partitions described by the files in the checkpoint directory. + * + * Since the original RDD may belong to a prior application, there is no way to know a + * priori the number of partitions to expect. This method assumes that the original set of + * checkpoint files are fully preserved in a reliable storage across application lifespans. + */ + protected override def getPartitions: Array[Partition] = { + // listStatus can throw exception if path does not exist. + val inputFiles = fs.listStatus(cpath) + .map(_.getPath) + .filter(_.getName.startsWith("part-")) + .sortBy(_.toString) + // Fail fast if input files are invalid + inputFiles.zipWithIndex.foreach { case (path, i) => + if (!path.toString.endsWith(ReliableCheckpointRDD.checkpointFileName(i))) { + throw new SparkException(s"Invalid checkpoint file: $path") + } + } + Array.tabulate(inputFiles.length)(i => new CheckpointRDDPartition(i)) + } + + /** + * Return the locations of the checkpoint file associated with the given partition. + */ + protected override def getPreferredLocations(split: Partition): Seq[String] = { + val status = fs.getFileStatus( + new Path(checkpointPath, ReliableCheckpointRDD.checkpointFileName(split.index))) + val locations = fs.getFileBlockLocations(status, 0, status.getLen) + locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost") + } + + /** + * Read the content of the checkpoint file associated with the given partition. + */ + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + val file = new Path(checkpointPath, ReliableCheckpointRDD.checkpointFileName(split.index)) + ReliableCheckpointRDD.readCheckpointFile(file, broadcastedConf, context) + } + +} + +private[spark] object ReliableCheckpointRDD extends Logging { + + /** + * Return the checkpoint file name for the given partition. + */ + private def checkpointFileName(partitionIndex: Int): String = { + "part-%05d".format(partitionIndex) + } + + /** + * Write this partition's values to a checkpoint file. + */ + def writeCheckpointFile[T: ClassTag]( + path: String, + broadcastedConf: Broadcast[SerializableConfiguration], + blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) { + val env = SparkEnv.get + val outputDir = new Path(path) + val fs = outputDir.getFileSystem(broadcastedConf.value.value) + + val finalOutputName = ReliableCheckpointRDD.checkpointFileName(ctx.partitionId()) + val finalOutputPath = new Path(outputDir, finalOutputName) + val tempOutputPath = + new Path(outputDir, s".$finalOutputName-attempt-${ctx.attemptNumber()}") + + if (fs.exists(tempOutputPath)) { + throw new IOException(s"Checkpoint failed: temporary path $tempOutputPath already exists") + } + val bufferSize = env.conf.getInt("spark.buffer.size", 65536) + + val fileOutputStream = if (blockSize < 0) { + fs.create(tempOutputPath, false, bufferSize) + } else { + // This is mainly for testing purpose + fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize) + } + val serializer = env.serializer.newInstance() + val serializeStream = serializer.serializeStream(fileOutputStream) + Utils.tryWithSafeFinally { + serializeStream.writeAll(iterator) + } { + serializeStream.close() + } + + if (!fs.rename(tempOutputPath, finalOutputPath)) { + if (!fs.exists(finalOutputPath)) { + logInfo(s"Deleting tempOutputPath $tempOutputPath") + fs.delete(tempOutputPath, false) + throw new IOException("Checkpoint failed: failed to save output of task: " + + s"${ctx.attemptNumber()} and final output path does not exist: $finalOutputPath") + } else { + // Some other copy of this task must've finished before us and renamed it + logInfo(s"Final output path $finalOutputPath already exists; not overwriting it") + fs.delete(tempOutputPath, false) + } + } + } + + /** + * Read the content of the specified checkpoint file. + */ + def readCheckpointFile[T]( + path: Path, + broadcastedConf: Broadcast[SerializableConfiguration], + context: TaskContext): Iterator[T] = { + val env = SparkEnv.get + val fs = path.getFileSystem(broadcastedConf.value.value) + val bufferSize = env.conf.getInt("spark.buffer.size", 65536) + val fileInputStream = fs.open(path, bufferSize) + val serializer = env.serializer.newInstance() + val deserializeStream = serializer.deserializeStream(fileInputStream) + + // Register an on-task-completion callback to close the input stream. + context.addTaskCompletionListener(context => deserializeStream.close()) + + deserializeStream.asIterator.asInstanceOf[Iterator[T]] + } + +} diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala new file mode 100644 index 0000000000000..1df8eef5ff2b9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.Path + +import org.apache.spark._ +import org.apache.spark.util.SerializableConfiguration + +/** + * An implementation of checkpointing that writes the RDD data to reliable storage. + * This allows drivers to be restarted on failure with previously computed state. + */ +private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) + extends RDDCheckpointData[T](rdd) with Logging { + + // The directory to which the associated RDD has been checkpointed to + // This is assumed to be a non-local path that points to some reliable storage + private val cpDir: String = + ReliableRDDCheckpointData.checkpointPath(rdd.context, rdd.id) + .map(_.toString) + .getOrElse { throw new SparkException("Checkpoint dir must be specified.") } + + /** + * Return the directory to which this RDD was checkpointed. + * If the RDD is not checkpointed yet, return None. + */ + def getCheckpointDir: Option[String] = RDDCheckpointData.synchronized { + if (isCheckpointed) { + Some(cpDir.toString) + } else { + None + } + } + + /** + * Materialize this RDD and write its content to a reliable DFS. + * This is called immediately after the first action invoked on this RDD has completed. + */ + protected override def doCheckpoint(): CheckpointRDD[T] = { + + // Create the output path for the checkpoint + val path = new Path(cpDir) + val fs = path.getFileSystem(rdd.context.hadoopConfiguration) + if (!fs.mkdirs(path)) { + throw new SparkException(s"Failed to create checkpoint path $cpDir") + } + + // Save to file, and reload it as an RDD + val broadcastedConf = rdd.context.broadcast( + new SerializableConfiguration(rdd.context.hadoopConfiguration)) + // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582) + rdd.context.runJob(rdd, ReliableCheckpointRDD.writeCheckpointFile[T](cpDir, broadcastedConf) _) + val newRDD = new ReliableCheckpointRDD[T](rdd.context, cpDir) + if (newRDD.partitions.length != rdd.partitions.length) { + throw new SparkException( + s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " + + s"number of partitions from original RDD $rdd(${rdd.partitions.length})") + } + + // Optionally clean our checkpoint files if the reference is out of scope + if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) { + rdd.context.cleaner.foreach { cleaner => + cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id) + } + } + + logInfo(s"Done checkpointing RDD ${rdd.id} to $cpDir, new parent is RDD ${newRDD.id}") + + newRDD + } + +} + +private[spark] object ReliableRDDCheckpointData { + + /** Return the path of the directory to which this RDD's checkpoint data is written. */ + def checkpointPath(sc: SparkContext, rddId: Int): Option[Path] = { + sc.checkpointDir.map { dir => new Path(dir, s"rdd-$rddId") } + } + + /** Clean up the files associated with the checkpoint data for this RDD. */ + def cleanCheckpoint(sc: SparkContext, rddId: Int): Unit = { + checkpointPath(sc, rddId).foreach { path => + val fs = path.getFileSystem(sc.hadoopConfiguration) + if (fs.exists(path)) { + fs.delete(path, true) + } + } + } +} diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index cc50e6d79a3e2..d343bb95cb68c 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -25,11 +25,15 @@ import org.apache.spark.rdd._ import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.util.Utils +/** + * Test suite for end-to-end checkpointing functionality. + * This tests both reliable checkpoints and local checkpoints. + */ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging { - var checkpointDir: File = _ - val partitioner = new HashPartitioner(2) + private var checkpointDir: File = _ + private val partitioner = new HashPartitioner(2) - override def beforeEach() { + override def beforeEach(): Unit = { super.beforeEach() checkpointDir = File.createTempFile("temp", "", Utils.createTempDir()) checkpointDir.delete() @@ -37,40 +41,43 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging sc.setCheckpointDir(checkpointDir.toString) } - override def afterEach() { + override def afterEach(): Unit = { super.afterEach() Utils.deleteRecursively(checkpointDir) } - test("basic checkpointing") { + runTest("basic checkpointing") { reliableCheckpoint: Boolean => val parCollection = sc.makeRDD(1 to 4) val flatMappedRDD = parCollection.flatMap(x => 1 to x) - flatMappedRDD.checkpoint() + checkpoint(flatMappedRDD, reliableCheckpoint) assert(flatMappedRDD.dependencies.head.rdd === parCollection) val result = flatMappedRDD.collect() assert(flatMappedRDD.dependencies.head.rdd != parCollection) assert(flatMappedRDD.collect() === result) } - test("RDDs with one-to-one dependencies") { - testRDD(_.map(x => x.toString)) - testRDD(_.flatMap(x => 1 to x)) - testRDD(_.filter(_ % 2 == 0)) - testRDD(_.sample(false, 0.5, 0)) - testRDD(_.glom()) - testRDD(_.mapPartitions(_.map(_.toString))) - testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) - testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) - testRDD(_.pipe(Seq("cat"))) + runTest("RDDs with one-to-one dependencies") { reliableCheckpoint: Boolean => + testRDD(_.map(x => x.toString), reliableCheckpoint) + testRDD(_.flatMap(x => 1 to x), reliableCheckpoint) + testRDD(_.filter(_ % 2 == 0), reliableCheckpoint) + testRDD(_.sample(false, 0.5, 0), reliableCheckpoint) + testRDD(_.glom(), reliableCheckpoint) + testRDD(_.mapPartitions(_.map(_.toString)), reliableCheckpoint) + testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString), reliableCheckpoint) + testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x), + reliableCheckpoint) + testRDD(_.pipe(Seq("cat")), reliableCheckpoint) } - test("ParallelCollection") { + runTest("ParallelCollectionRDD") { reliableCheckpoint: Boolean => val parCollection = sc.makeRDD(1 to 4, 2) val numPartitions = parCollection.partitions.size - parCollection.checkpoint() + checkpoint(parCollection, reliableCheckpoint) assert(parCollection.dependencies === Nil) val result = parCollection.collect() - assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result) + if (reliableCheckpoint) { + assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result) + } assert(parCollection.dependencies != Nil) assert(parCollection.partitions.length === numPartitions) assert(parCollection.partitions.toList === @@ -78,44 +85,46 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging assert(parCollection.collect() === result) } - test("BlockRDD") { + runTest("BlockRDD") { reliableCheckpoint: Boolean => val blockId = TestBlockId("id") val blockManager = SparkEnv.get.blockManager blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY) val blockRDD = new BlockRDD[String](sc, Array(blockId)) val numPartitions = blockRDD.partitions.size - blockRDD.checkpoint() + checkpoint(blockRDD, reliableCheckpoint) val result = blockRDD.collect() - assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result) + if (reliableCheckpoint) { + assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result) + } assert(blockRDD.dependencies != Nil) assert(blockRDD.partitions.length === numPartitions) assert(blockRDD.partitions.toList === blockRDD.checkpointData.get.getPartitions.toList) assert(blockRDD.collect() === result) } - test("ShuffledRDD") { + runTest("ShuffleRDD") { reliableCheckpoint: Boolean => testRDD(rdd => { // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD new ShuffledRDD[Int, Int, Int](rdd.map(x => (x % 2, 1)), partitioner) - }) + }, reliableCheckpoint) } - test("UnionRDD") { + runTest("UnionRDD") { reliableCheckpoint: Boolean => def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1) - testRDD(_.union(otherRDD)) - testRDDPartitions(_.union(otherRDD)) + testRDD(_.union(otherRDD), reliableCheckpoint) + testRDDPartitions(_.union(otherRDD), reliableCheckpoint) } - test("CartesianRDD") { + runTest("CartesianRDD") { reliableCheckpoint: Boolean => def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1) - testRDD(new CartesianRDD(sc, _, otherRDD)) - testRDDPartitions(new CartesianRDD(sc, _, otherRDD)) + testRDD(new CartesianRDD(sc, _, otherRDD), reliableCheckpoint) + testRDDPartitions(new CartesianRDD(sc, _, otherRDD), reliableCheckpoint) // Test that the CartesianRDD updates parent partitions (CartesianRDD.s1/s2) after // the parent RDD has been checkpointed and parent partitions have been changed. // Note that this test is very specific to the current implementation of CartesianRDD. val ones = sc.makeRDD(1 to 100, 10).map(x => x) - ones.checkpoint() // checkpoint that MappedRDD + checkpoint(ones, reliableCheckpoint) // checkpoint that MappedRDD val cartesian = new CartesianRDD(sc, ones, ones) val splitBeforeCheckpoint = serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition]) @@ -129,16 +138,16 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging ) } - test("CoalescedRDD") { - testRDD(_.coalesce(2)) - testRDDPartitions(_.coalesce(2)) + runTest("CoalescedRDD") { reliableCheckpoint: Boolean => + testRDD(_.coalesce(2), reliableCheckpoint) + testRDDPartitions(_.coalesce(2), reliableCheckpoint) // Test that the CoalescedRDDPartition updates parent partitions (CoalescedRDDPartition.parents) // after the parent RDD has been checkpointed and parent partitions have been changed. // Note that this test is very specific to the current implementation of // CoalescedRDDPartitions. val ones = sc.makeRDD(1 to 100, 10).map(x => x) - ones.checkpoint() // checkpoint that MappedRDD + checkpoint(ones, reliableCheckpoint) // checkpoint that MappedRDD val coalesced = new CoalescedRDD(ones, 2) val splitBeforeCheckpoint = serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition]) @@ -151,7 +160,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging ) } - test("CoGroupedRDD") { + runTest("CoGroupedRDD") { reliableCheckpoint: Boolean => val longLineageRDD1 = generateFatPairRDD() // Collect the RDD as sequences instead of arrays to enable equality tests in testRDD @@ -160,26 +169,26 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging testRDD(rdd => { CheckpointSuite.cogroup(longLineageRDD1, rdd.map(x => (x % 2, 1)), partitioner) - }, seqCollectFunc) + }, reliableCheckpoint, seqCollectFunc) val longLineageRDD2 = generateFatPairRDD() testRDDPartitions(rdd => { CheckpointSuite.cogroup( longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner) - }, seqCollectFunc) + }, reliableCheckpoint, seqCollectFunc) } - test("ZippedPartitionsRDD") { - testRDD(rdd => rdd.zip(rdd.map(x => x))) - testRDDPartitions(rdd => rdd.zip(rdd.map(x => x))) + runTest("ZippedPartitionsRDD") { reliableCheckpoint: Boolean => + testRDD(rdd => rdd.zip(rdd.map(x => x)), reliableCheckpoint) + testRDDPartitions(rdd => rdd.zip(rdd.map(x => x)), reliableCheckpoint) // Test that ZippedPartitionsRDD updates parent partitions after parent RDDs have // been checkpointed and parent partitions have been changed. // Note that this test is very specific to the implementation of ZippedPartitionsRDD. val rdd = generateFatRDD() val zippedRDD = rdd.zip(rdd.map(x => x)).asInstanceOf[ZippedPartitionsRDD2[_, _, _]] - zippedRDD.rdd1.checkpoint() - zippedRDD.rdd2.checkpoint() + checkpoint(zippedRDD.rdd1, reliableCheckpoint) + checkpoint(zippedRDD.rdd2, reliableCheckpoint) val partitionBeforeCheckpoint = serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartitionsPartition]) zippedRDD.count() @@ -194,27 +203,27 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging ) } - test("PartitionerAwareUnionRDD") { + runTest("PartitionerAwareUnionRDD") { reliableCheckpoint: Boolean => testRDD(rdd => { new PartitionerAwareUnionRDD[(Int, Int)](sc, Array( generateFatPairRDD(), rdd.map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _) )) - }) + }, reliableCheckpoint) testRDDPartitions(rdd => { new PartitionerAwareUnionRDD[(Int, Int)](sc, Array( generateFatPairRDD(), rdd.map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _) )) - }) + }, reliableCheckpoint) // Test that the PartitionerAwareUnionRDD updates parent partitions // (PartitionerAwareUnionRDD.parents) after the parent RDD has been checkpointed and parent // partitions have been changed. Note that this test is very specific to the current // implementation of PartitionerAwareUnionRDD. val pairRDD = generateFatPairRDD() - pairRDD.checkpoint() + checkpoint(pairRDD, reliableCheckpoint) val unionRDD = new PartitionerAwareUnionRDD(sc, Array(pairRDD)) val partitionBeforeCheckpoint = serializeDeserialize( unionRDD.partitions.head.asInstanceOf[PartitionerAwareUnionRDDPartition]) @@ -228,17 +237,34 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging ) } - test("CheckpointRDD with zero partitions") { + runTest("CheckpointRDD with zero partitions") { reliableCheckpoint: Boolean => val rdd = new BlockRDD[Int](sc, Array[BlockId]()) assert(rdd.partitions.size === 0) assert(rdd.isCheckpointed === false) - rdd.checkpoint() + checkpoint(rdd, reliableCheckpoint) assert(rdd.count() === 0) assert(rdd.isCheckpointed === true) assert(rdd.partitions.size === 0) } - def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect() + // Utility test methods + + /** Checkpoint the RDD either locally or reliably. */ + private def checkpoint(rdd: RDD[_], reliableCheckpoint: Boolean): Unit = { + if (reliableCheckpoint) { + rdd.checkpoint() + } else { + rdd.localCheckpoint() + } + } + + /** Run a test twice, once for local checkpointing and once for reliable checkpointing. */ + private def runTest(name: String)(body: Boolean => Unit): Unit = { + test(name + " [reliable checkpoint]")(body(true)) + test(name + " [local checkpoint]")(body(false)) + } + + private def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect() /** * Test checkpointing of the RDD generated by the given operation. It tests whether the @@ -246,11 +272,14 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging * on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.). * * @param op an operation to run on the RDD + * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints * @param collectFunc a function for collecting the values in the RDD, in case there are * non-comparable types like arrays that we want to convert to something that supports == */ - def testRDD[U: ClassTag](op: (RDD[Int]) => RDD[U], - collectFunc: RDD[U] => Any = defaultCollectFunc[U] _) { + private def testRDD[U: ClassTag]( + op: (RDD[Int]) => RDD[U], + reliableCheckpoint: Boolean, + collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { // Generate the final RDD using given RDD operation val baseRDD = generateFatRDD() val operatedRDD = op(baseRDD) @@ -267,14 +296,16 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging // Find serialized sizes before and after the checkpoint logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) - operatedRDD.checkpoint() + checkpoint(operatedRDD, reliableCheckpoint) val result = collectFunc(operatedRDD) operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) // Test whether the checkpoint file has been created - assert(collectFunc(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result) + if (reliableCheckpoint) { + assert(collectFunc(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result) + } // Test whether dependencies have been changed from its earlier parent RDD assert(operatedRDD.dependencies.head.rdd != parentRDD) @@ -310,11 +341,14 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging * partitions (i.e., do not call it on simple RDD like MappedRDD). * * @param op an operation to run on the RDD + * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints * @param collectFunc a function for collecting the values in the RDD, in case there are * non-comparable types like arrays that we want to convert to something that supports == */ - def testRDDPartitions[U: ClassTag](op: (RDD[Int]) => RDD[U], - collectFunc: RDD[U] => Any = defaultCollectFunc[U] _) { + private def testRDDPartitions[U: ClassTag]( + op: (RDD[Int]) => RDD[U], + reliableCheckpoint: Boolean, + collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { // Generate the final RDD using given RDD operation val baseRDD = generateFatRDD() val operatedRDD = op(baseRDD) @@ -328,7 +362,10 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging // Find serialized sizes before and after the checkpoint logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) - parentRDDs.foreach(_.checkpoint()) // checkpoint the parent RDD, not the generated one + // checkpoint the parent RDD, not the generated one + parentRDDs.foreach { rdd => + checkpoint(rdd, reliableCheckpoint) + } val result = collectFunc(operatedRDD) // force checkpointing operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) @@ -350,7 +387,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging /** * Generate an RDD such that both the RDD and its partitions have large size. */ - def generateFatRDD(): RDD[Int] = { + private def generateFatRDD(): RDD[Int] = { new FatRDD(sc.makeRDD(1 to 100, 4)).map(x => x) } @@ -358,7 +395,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging * Generate an pair RDD (with partitioner) such that both the RDD and its partitions * have large size. */ - def generateFatPairRDD(): RDD[(Int, Int)] = { + private def generateFatPairRDD(): RDD[(Int, Int)] = { new FatPairRDD(sc.makeRDD(1 to 100, 4), partitioner).mapValues(x => x) } @@ -366,7 +403,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint. */ - def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { + private def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { val rddSize = Utils.serialize(rdd).size val rddCpDataSize = Utils.serialize(rdd.checkpointData).size val rddPartitionSize = Utils.serialize(rdd.partitions).size @@ -394,7 +431,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging * contents after deserialization (e.g., the contents of an RDD split after * it is sent to a slave along with a task) */ - def serializeDeserialize[T](obj: T): T = { + private def serializeDeserialize[T](obj: T): T = { val bytes = Utils.serialize(obj) Utils.deserialize[T](bytes) } @@ -402,10 +439,11 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging /** * Recursively force the initialization of the all members of an RDD and it parents. */ - def initializeRdd(rdd: RDD[_]) { + private def initializeRdd(rdd: RDD[_]): Unit = { rdd.partitions // forces the - rdd.dependencies.map(_.rdd).foreach(initializeRdd(_)) + rdd.dependencies.map(_.rdd).foreach(initializeRdd) } + } /** RDD partition that has large serialized size. */ diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 26858ef2774fc..0c14bef7befd8 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -24,12 +24,11 @@ import scala.language.existentials import scala.util.Random import org.scalatest.BeforeAndAfter -import org.scalatest.concurrent.{PatienceConfiguration, Eventually} +import org.scalatest.concurrent.PatienceConfiguration import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.{RDDCheckpointData, RDD} +import org.apache.spark.rdd.{ReliableRDDCheckpointData, RDD} import org.apache.spark.storage._ import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager @@ -52,6 +51,7 @@ abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[Ha .setAppName("ContextCleanerSuite") .set("spark.cleaner.referenceTracking.blocking", "true") .set("spark.cleaner.referenceTracking.blocking.shuffle", "true") + .set("spark.cleaner.referenceTracking.cleanCheckpoints", "true") .set("spark.shuffle.manager", shuffleManager.getName) before { @@ -209,11 +209,11 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { postGCTester.assertCleanup() } - test("automatically cleanup checkpoint") { + test("automatically cleanup normal checkpoint") { val checkpointDir = java.io.File.createTempFile("temp", "") checkpointDir.deleteOnExit() checkpointDir.delete() - var rdd = newPairRDD + var rdd = newPairRDD() sc.setCheckpointDir(checkpointDir.toString) rdd.checkpoint() rdd.cache() @@ -221,23 +221,26 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { var rddId = rdd.id // Confirm the checkpoint directory exists - assert(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).isDefined) - val path = RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get + assert(ReliableRDDCheckpointData.checkpointPath(sc, rddId).isDefined) + val path = ReliableRDDCheckpointData.checkpointPath(sc, rddId).get val fs = path.getFileSystem(sc.hadoopConfiguration) assert(fs.exists(path)) // the checkpoint is not cleaned by default (without the configuration set) - var postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, Nil) + var postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, Seq(rddId)) rdd = null // Make RDD out of scope, ok if collected earlier runGC() postGCTester.assertCleanup() - assert(fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get)) + assert(!fs.exists(ReliableRDDCheckpointData.checkpointPath(sc, rddId).get)) + // Verify that checkpoints are NOT cleaned up if the config is not enabled sc.stop() - val conf = new SparkConf().setMaster("local[2]").setAppName("cleanupCheckpoint"). - set("spark.cleaner.referenceTracking.cleanCheckpoints", "true") + val conf = new SparkConf() + .setMaster("local[2]") + .setAppName("cleanupCheckpoint") + .set("spark.cleaner.referenceTracking.cleanCheckpoints", "false") sc = new SparkContext(conf) - rdd = newPairRDD + rdd = newPairRDD() sc.setCheckpointDir(checkpointDir.toString) rdd.checkpoint() rdd.cache() @@ -245,17 +248,40 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { rddId = rdd.id // Confirm the checkpoint directory exists - assert(fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get)) + assert(fs.exists(ReliableRDDCheckpointData.checkpointPath(sc, rddId).get)) // Reference rdd to defeat any early collection by the JVM rdd.count() // Test that GC causes checkpoint data cleanup after dereferencing the RDD - postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, Seq(rddId)) + postGCTester = new CleanerTester(sc, Seq(rddId)) rdd = null // Make RDD out of scope runGC() postGCTester.assertCleanup() - assert(!fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get)) + assert(fs.exists(ReliableRDDCheckpointData.checkpointPath(sc, rddId).get)) + } + + test("automatically clean up local checkpoint") { + // Note that this test is similar to the RDD cleanup + // test because the same underlying mechanism is used! + var rdd = newPairRDD().localCheckpoint() + assert(rdd.checkpointData.isDefined) + assert(rdd.checkpointData.get.checkpointRDD.isEmpty) + rdd.count() + assert(rdd.checkpointData.get.checkpointRDD.isDefined) + + // Test that GC does not cause checkpoint cleanup due to a strong reference + val preGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) + runGC() + intercept[Exception] { + preGCTester.assertCleanup()(timeout(1000 millis)) + } + + // Test that RDD going out of scope does cause the checkpoint blocks to be cleaned up + val postGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) + rdd = null + runGC() + postGCTester.assertCleanup() } test("automatically cleanup RDD + shuffle + broadcast") { @@ -408,7 +434,10 @@ class SortShuffleContextCleanerSuite extends ContextCleanerSuiteBase(classOf[Sor } -/** Class to test whether RDDs, shuffles, etc. have been successfully cleaned. */ +/** + * Class to test whether RDDs, shuffles, etc. have been successfully cleaned. + * The checkpoint here refers only to normal (reliable) checkpoints, not local checkpoints. + */ class CleanerTester( sc: SparkContext, rddIds: Seq[Int] = Seq.empty, diff --git a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala new file mode 100644 index 0000000000000..5103eb74b2457 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala @@ -0,0 +1,330 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.{SparkException, SparkContext, LocalSparkContext, SparkFunSuite} + +import org.mockito.Mockito.spy +import org.apache.spark.storage.{RDDBlockId, StorageLevel} + +/** + * Fine-grained tests for local checkpointing. + * For end-to-end tests, see CheckpointSuite. + */ +class LocalCheckpointSuite extends SparkFunSuite with LocalSparkContext { + + override def beforeEach(): Unit = { + sc = new SparkContext("local[2]", "test") + } + + test("transform storage level") { + val transform = LocalRDDCheckpointData.transformStorageLevel _ + assert(transform(StorageLevel.NONE) === StorageLevel.DISK_ONLY) + assert(transform(StorageLevel.MEMORY_ONLY) === StorageLevel.MEMORY_AND_DISK) + assert(transform(StorageLevel.MEMORY_ONLY_SER) === StorageLevel.MEMORY_AND_DISK_SER) + assert(transform(StorageLevel.MEMORY_ONLY_2) === StorageLevel.MEMORY_AND_DISK_2) + assert(transform(StorageLevel.MEMORY_ONLY_SER_2) === StorageLevel.MEMORY_AND_DISK_SER_2) + assert(transform(StorageLevel.DISK_ONLY) === StorageLevel.DISK_ONLY) + assert(transform(StorageLevel.DISK_ONLY_2) === StorageLevel.DISK_ONLY_2) + assert(transform(StorageLevel.MEMORY_AND_DISK) === StorageLevel.MEMORY_AND_DISK) + assert(transform(StorageLevel.MEMORY_AND_DISK_SER) === StorageLevel.MEMORY_AND_DISK_SER) + assert(transform(StorageLevel.MEMORY_AND_DISK_2) === StorageLevel.MEMORY_AND_DISK_2) + assert(transform(StorageLevel.MEMORY_AND_DISK_SER_2) === StorageLevel.MEMORY_AND_DISK_SER_2) + // Off-heap is not supported and Spark should fail fast + intercept[SparkException] { + transform(StorageLevel.OFF_HEAP) + } + } + + test("basic lineage truncation") { + val numPartitions = 4 + val parallelRdd = sc.parallelize(1 to 100, numPartitions) + val mappedRdd = parallelRdd.map { i => i + 1 } + val filteredRdd = mappedRdd.filter { i => i % 2 == 0 } + val expectedPartitionIndices = (0 until numPartitions).toArray + assert(filteredRdd.checkpointData.isEmpty) + assert(filteredRdd.getStorageLevel === StorageLevel.NONE) + assert(filteredRdd.partitions.map(_.index) === expectedPartitionIndices) + assert(filteredRdd.dependencies.size === 1) + assert(filteredRdd.dependencies.head.rdd === mappedRdd) + assert(mappedRdd.dependencies.size === 1) + assert(mappedRdd.dependencies.head.rdd === parallelRdd) + assert(parallelRdd.dependencies.size === 0) + + // Mark the RDD for local checkpointing + filteredRdd.localCheckpoint() + assert(filteredRdd.checkpointData.isDefined) + assert(!filteredRdd.checkpointData.get.isCheckpointed) + assert(!filteredRdd.checkpointData.get.checkpointRDD.isDefined) + assert(filteredRdd.getStorageLevel === LocalRDDCheckpointData.DEFAULT_STORAGE_LEVEL) + + // After an action, the lineage is truncated + val result = filteredRdd.collect() + assert(filteredRdd.checkpointData.get.isCheckpointed) + assert(filteredRdd.checkpointData.get.checkpointRDD.isDefined) + val checkpointRdd = filteredRdd.checkpointData.flatMap(_.checkpointRDD).get + assert(filteredRdd.dependencies.size === 1) + assert(filteredRdd.dependencies.head.rdd === checkpointRdd) + assert(filteredRdd.partitions.map(_.index) === expectedPartitionIndices) + assert(checkpointRdd.partitions.map(_.index) === expectedPartitionIndices) + + // Recomputation should yield the same result + assert(filteredRdd.collect() === result) + assert(filteredRdd.collect() === result) + } + + test("basic lineage truncation - caching before checkpointing") { + testBasicLineageTruncationWithCaching( + newRdd.persist(StorageLevel.MEMORY_ONLY).localCheckpoint(), + StorageLevel.MEMORY_AND_DISK) + } + + test("basic lineage truncation - caching after checkpointing") { + testBasicLineageTruncationWithCaching( + newRdd.localCheckpoint().persist(StorageLevel.MEMORY_ONLY), + StorageLevel.MEMORY_AND_DISK) + } + + test("indirect lineage truncation") { + testIndirectLineageTruncation( + newRdd.localCheckpoint(), + LocalRDDCheckpointData.DEFAULT_STORAGE_LEVEL) + } + + test("indirect lineage truncation - caching before checkpointing") { + testIndirectLineageTruncation( + newRdd.persist(StorageLevel.MEMORY_ONLY).localCheckpoint(), + StorageLevel.MEMORY_AND_DISK) + } + + test("indirect lineage truncation - caching after checkpointing") { + testIndirectLineageTruncation( + newRdd.localCheckpoint().persist(StorageLevel.MEMORY_ONLY), + StorageLevel.MEMORY_AND_DISK) + } + + test("checkpoint without draining iterator") { + testWithoutDrainingIterator( + newSortedRdd.localCheckpoint(), + LocalRDDCheckpointData.DEFAULT_STORAGE_LEVEL, + 50) + } + + test("checkpoint without draining iterator - caching before checkpointing") { + testWithoutDrainingIterator( + newSortedRdd.persist(StorageLevel.MEMORY_ONLY).localCheckpoint(), + StorageLevel.MEMORY_AND_DISK, + 50) + } + + test("checkpoint without draining iterator - caching after checkpointing") { + testWithoutDrainingIterator( + newSortedRdd.localCheckpoint().persist(StorageLevel.MEMORY_ONLY), + StorageLevel.MEMORY_AND_DISK, + 50) + } + + test("checkpoint blocks exist") { + testCheckpointBlocksExist( + newRdd.localCheckpoint(), + LocalRDDCheckpointData.DEFAULT_STORAGE_LEVEL) + } + + test("checkpoint blocks exist - caching before checkpointing") { + testCheckpointBlocksExist( + newRdd.persist(StorageLevel.MEMORY_ONLY).localCheckpoint(), + StorageLevel.MEMORY_AND_DISK) + } + + test("checkpoint blocks exist - caching after checkpointing") { + testCheckpointBlocksExist( + newRdd.localCheckpoint().persist(StorageLevel.MEMORY_ONLY), + StorageLevel.MEMORY_AND_DISK) + } + + test("missing checkpoint block fails with informative message") { + val rdd = newRdd.localCheckpoint() + val numPartitions = rdd.partitions.size + val partitionIndices = rdd.partitions.map(_.index) + val bmm = sc.env.blockManager.master + + // After an action, the blocks should be found somewhere in the cache + rdd.collect() + partitionIndices.foreach { i => + assert(bmm.contains(RDDBlockId(rdd.id, i))) + } + + // Remove one of the blocks to simulate executor failure + // Collecting the RDD should now fail with an informative exception + val blockId = RDDBlockId(rdd.id, numPartitions - 1) + bmm.removeBlock(blockId) + try { + rdd.collect() + fail("Collect should have failed if local checkpoint block is removed...") + } catch { + case se: SparkException => + assert(se.getMessage.contains(s"Checkpoint block $blockId not found")) + assert(se.getMessage.contains("rdd.checkpoint()")) // suggest an alternative + assert(se.getMessage.contains("fault-tolerant")) // justify the alternative + } + } + + /** + * Helper method to create a simple RDD. + */ + private def newRdd: RDD[Int] = { + sc.parallelize(1 to 100, 4) + .map { i => i + 1 } + .filter { i => i % 2 == 0 } + } + + /** + * Helper method to create a simple sorted RDD. + */ + private def newSortedRdd: RDD[Int] = newRdd.sortBy(identity) + + /** + * Helper method to test basic lineage truncation with caching. + * + * @param rdd an RDD that is both marked for caching and local checkpointing + */ + private def testBasicLineageTruncationWithCaching[T]( + rdd: RDD[T], + targetStorageLevel: StorageLevel): Unit = { + require(targetStorageLevel !== StorageLevel.NONE) + require(rdd.getStorageLevel !== StorageLevel.NONE) + require(rdd.isLocallyCheckpointed) + val result = rdd.collect() + assert(rdd.getStorageLevel === targetStorageLevel) + assert(rdd.checkpointData.isDefined) + assert(rdd.checkpointData.get.isCheckpointed) + assert(rdd.checkpointData.get.checkpointRDD.isDefined) + assert(rdd.dependencies.head.rdd === rdd.checkpointData.get.checkpointRDD.get) + assert(rdd.collect() === result) + assert(rdd.collect() === result) + } + + /** + * Helper method to test indirect lineage truncation. + * + * Indirect lineage truncation here means the action is called on one of the + * checkpointed RDD's descendants, but not on the checkpointed RDD itself. + * + * @param rdd a locally checkpointed RDD + */ + private def testIndirectLineageTruncation[T]( + rdd: RDD[T], + targetStorageLevel: StorageLevel): Unit = { + require(targetStorageLevel !== StorageLevel.NONE) + require(rdd.isLocallyCheckpointed) + val rdd1 = rdd.map { i => i + "1" } + val rdd2 = rdd1.map { i => i + "2" } + val rdd3 = rdd2.map { i => i + "3" } + val rddDependencies = rdd.dependencies + val rdd1Dependencies = rdd1.dependencies + val rdd2Dependencies = rdd2.dependencies + val rdd3Dependencies = rdd3.dependencies + assert(rdd1Dependencies.size === 1) + assert(rdd1Dependencies.head.rdd === rdd) + assert(rdd2Dependencies.size === 1) + assert(rdd2Dependencies.head.rdd === rdd1) + assert(rdd3Dependencies.size === 1) + assert(rdd3Dependencies.head.rdd === rdd2) + + // Only the locally checkpointed RDD should have special storage level + assert(rdd.getStorageLevel === targetStorageLevel) + assert(rdd1.getStorageLevel === StorageLevel.NONE) + assert(rdd2.getStorageLevel === StorageLevel.NONE) + assert(rdd3.getStorageLevel === StorageLevel.NONE) + + // After an action, only the dependencies of the checkpointed RDD changes + val result = rdd3.collect() + assert(rdd.dependencies !== rddDependencies) + assert(rdd1.dependencies === rdd1Dependencies) + assert(rdd2.dependencies === rdd2Dependencies) + assert(rdd3.dependencies === rdd3Dependencies) + assert(rdd3.collect() === result) + assert(rdd3.collect() === result) + } + + /** + * Helper method to test checkpointing without fully draining the iterator. + * + * Not all RDD actions fully consume the iterator. As a result, a subset of the partitions + * may not be cached. However, since we want to truncate the lineage safely, we explicitly + * ensure that *all* partitions are fully cached. This method asserts this behavior. + * + * @param rdd a locally checkpointed RDD + */ + private def testWithoutDrainingIterator[T]( + rdd: RDD[T], + targetStorageLevel: StorageLevel, + targetCount: Int): Unit = { + require(targetCount > 0) + require(targetStorageLevel !== StorageLevel.NONE) + require(rdd.isLocallyCheckpointed) + + // This does not drain the iterator, but checkpointing should still work + val first = rdd.first() + assert(rdd.count() === targetCount) + assert(rdd.count() === targetCount) + assert(rdd.first() === first) + assert(rdd.first() === first) + + // Test the same thing by calling actions on a descendant instead + val rdd1 = rdd.repartition(10) + val rdd2 = rdd1.repartition(100) + val rdd3 = rdd2.repartition(1000) + val first2 = rdd3.first() + assert(rdd3.count() === targetCount) + assert(rdd3.count() === targetCount) + assert(rdd3.first() === first2) + assert(rdd3.first() === first2) + assert(rdd.getStorageLevel === targetStorageLevel) + assert(rdd1.getStorageLevel === StorageLevel.NONE) + assert(rdd2.getStorageLevel === StorageLevel.NONE) + assert(rdd3.getStorageLevel === StorageLevel.NONE) + } + + /** + * Helper method to test whether the checkpoint blocks are found in the cache. + * + * @param rdd a locally checkpointed RDD + */ + private def testCheckpointBlocksExist[T]( + rdd: RDD[T], + targetStorageLevel: StorageLevel): Unit = { + val bmm = sc.env.blockManager.master + val partitionIndices = rdd.partitions.map(_.index) + + // The blocks should not exist before the action + partitionIndices.foreach { i => + assert(!bmm.contains(RDDBlockId(rdd.id, i))) + } + + // After an action, the blocks should be found in the cache with the expected level + rdd.collect() + partitionIndices.foreach { i => + val blockId = RDDBlockId(rdd.id, i) + val status = bmm.getBlockStatus(blockId) + assert(status.nonEmpty) + assert(status.values.head.storageLevel === targetStorageLevel) + } + } + +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index f9384c4c3c9d6..280aac931915d 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -80,8 +80,13 @@ object MimaExcludes { "org.apache.spark.mllib.linalg.Matrix.numActives") ) ++ Seq( // SPARK-8914 Remove RDDApi - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.RDDApi") + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RDDApi") + ) ++ Seq( + // SPARK-7292 Provide operator to truncate lineage cheaply + ProblemFilters.exclude[AbstractClassProblem]( + "org.apache.spark.rdd.RDDCheckpointData"), + ProblemFilters.exclude[AbstractClassProblem]( + "org.apache.spark.rdd.CheckpointRDD") ) ++ Seq( // SPARK-8701 Add input metadata in the batch page. ProblemFilters.exclude[MissingClassProblem]( From dfe7bd168d9bcf8c53f993f459ab473d893457b0 Mon Sep 17 00:00:00 2001 From: Joseph Batchik Date: Mon, 3 Aug 2015 11:17:38 -0700 Subject: [PATCH 101/340] [SPARK-9511] [SQL] Fixed Table Name Parsing The issue was that the tokenizer was parsing "1one" into the numeric 1 using the code on line 110. I added another case to accept strings that start with a number and then have a letter somewhere else in it as well. Author: Joseph Batchik Closes #7844 from JDrit/parse_error and squashes the following commits: b8ca12f [Joseph Batchik] fixed parsing issue by adding another case --- .../spark/sql/catalyst/AbstractSparkSQLParser.scala | 2 ++ .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 10 ++++++++++ 2 files changed, 12 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index d494ae7b71d16..5898a5f93f381 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -104,6 +104,8 @@ class SqlLexical extends StdLexical { override lazy val token: Parser[Token] = ( identChar ~ (identChar | digit).* ^^ { case first ~ rest => processIdent((first :: rest).mkString) } + | digit.* ~ identChar ~ (identChar | digit).* ^^ + { case first ~ middle ~ rest => processIdent((first ++ (middle :: rest)).mkString) } | rep1(digit) ~ ('.' ~> digit.*).? ^^ { case i ~ None => NumericLit(i.mkString) case i ~ Some(d) => FloatLit(i.mkString + "." + d.mkString) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index bbadc202a4f06..f1abae0720058 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1604,4 +1604,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { checkAnswer(df.select(-df("i")), Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123)))) } + + test("SPARK-9511: error with table starting with number") { + val df = sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)) + .toDF("num", "str") + df.registerTempTable("1one") + + checkAnswer(sqlContext.sql("select count(num) from 1one"), Row(10)) + + sqlContext.dropTempTable("1one") + } } From 7a9d09f0bb472a1671d3457e1f7108f4c2eb4121 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 3 Aug 2015 11:22:02 -0700 Subject: [PATCH 102/340] [SQL][minor] Simplify UnsafeRow.calculateBitSetWidthInBytes. Author: Reynold Xin Closes #7897 from rxin/calculateBitSetWidthInBytes and squashes the following commits: 2e73b3a [Reynold Xin] [SQL][minor] Simplify UnsafeRow.calculateBitSetWidthInBytes. --- .../spark/sql/catalyst/expressions/UnsafeRow.java | 2 +- .../scala/org/apache/spark/sql/UnsafeRowSuite.scala | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index f4230cfaba375..e6750fce4fa80 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -59,7 +59,7 @@ public final class UnsafeRow extends MutableRow { ////////////////////////////////////////////////////////////////////////////// public static int calculateBitSetWidthInBytes(int numFields) { - return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8; + return ((numFields + 63)/ 64) * 8; } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index c5faaa663e749..89bad1bfdab0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -28,6 +28,16 @@ import org.apache.spark.unsafe.memory.MemoryAllocator import org.apache.spark.unsafe.types.UTF8String class UnsafeRowSuite extends SparkFunSuite { + + test("bitset width calculation") { + assert(UnsafeRow.calculateBitSetWidthInBytes(0) === 0) + assert(UnsafeRow.calculateBitSetWidthInBytes(1) === 8) + assert(UnsafeRow.calculateBitSetWidthInBytes(32) === 8) + assert(UnsafeRow.calculateBitSetWidthInBytes(64) === 8) + assert(UnsafeRow.calculateBitSetWidthInBytes(65) === 16) + assert(UnsafeRow.calculateBitSetWidthInBytes(128) === 16) + } + test("writeToStream") { val row = InternalRow.apply(UTF8String.fromString("hello"), UTF8String.fromString("world"), 123) val arrayBackedUnsafeRow: UnsafeRow = From 703e44bff19f4c394f6f9bff1ce9152cdc68c51e Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 3 Aug 2015 12:06:58 -0700 Subject: [PATCH 103/340] [SPARK-9554] [SQL] Enables in-memory partition pruning by default Author: Cheng Lian Closes #7895 from liancheng/spark-9554/enable-in-memory-partition-pruning and squashes the following commits: 67c403e [Cheng Lian] Enables in-memory partition pruning by default --- sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 387960c4b482b..41ba1c7fe0574 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -200,7 +200,7 @@ private[spark] object SQLConf { val IN_MEMORY_PARTITION_PRUNING = booleanConf("spark.sql.inMemoryColumnarStorage.partitionPruning", - defaultValue = Some(false), + defaultValue = Some(true), doc = "When true, enable partition pruning for in-memory columnar tables.", isPublic = false) From ff9169a002f1b75231fd25b7d04157a912503038 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 3 Aug 2015 12:17:46 -0700 Subject: [PATCH 104/340] [SPARK-5133] [ML] Added featureImportance to RandomForestClassifier and Regressor Added featureImportance to RandomForestClassifier and Regressor. This follows the scikit-learn implementation here: [https://github.com/scikit-learn/scikit-learn/blob/a95203b249c1cf392f86d001ad999e29b2392739/sklearn/tree/_tree.pyx#L3341] CC: yanboliang Would you mind taking a look? Thanks! Author: Joseph K. Bradley Author: Feynman Liang Closes #7838 from jkbradley/dt-feature-importance and squashes the following commits: 72a167a [Joseph K. Bradley] fixed unit test 86cea5f [Joseph K. Bradley] Modified RF featuresImportances to return Vector instead of Map 5aa74f0 [Joseph K. Bradley] finally fixed unit test for real 33df5db [Joseph K. Bradley] fix unit test 42a2d3b [Joseph K. Bradley] fix unit test fe94e72 [Joseph K. Bradley] modified feature importance unit tests cc693ee [Feynman Liang] Add classifier tests 79a6f87 [Feynman Liang] Compare dense vectors in test 21d01fc [Feynman Liang] Added failing SKLearn test ac0b254 [Joseph K. Bradley] Added featureImportance to RandomForestClassifier/Regressor. Need to add unit tests --- .../RandomForestClassifier.scala | 30 ++++- .../ml/regression/RandomForestRegressor.scala | 33 +++++- .../scala/org/apache/spark/ml/tree/Node.scala | 19 +++- .../spark/ml/tree/impl/RandomForest.scala | 92 +++++++++++++++ .../org/apache/spark/ml/tree/treeModels.scala | 6 + .../JavaRandomForestClassifierSuite.java | 2 + .../JavaRandomForestRegressorSuite.java | 2 + .../RandomForestClassifierSuite.scala | 31 ++++- .../org/apache/spark/ml/impl/TreeTests.scala | 18 +++ .../RandomForestRegressorSuite.scala | 27 ++++- .../ml/tree/impl/RandomForestSuite.scala | 107 ++++++++++++++++++ 11 files changed, 351 insertions(+), 16 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 56e80cc8fe6e1..b59826a59499a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -95,7 +95,8 @@ final class RandomForestClassifier(override val uid: String) val trees = RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed) .map(_.asInstanceOf[DecisionTreeClassificationModel]) - new RandomForestClassificationModel(trees, numClasses) + val numFeatures = oldDataset.first().features.size + new RandomForestClassificationModel(trees, numFeatures, numClasses) } override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra) @@ -118,11 +119,13 @@ object RandomForestClassifier { * features. * @param _trees Decision trees in the ensemble. * Warning: These have null parents. + * @param numFeatures Number of features used by this model */ @Experimental final class RandomForestClassificationModel private[ml] ( override val uid: String, private val _trees: Array[DecisionTreeClassificationModel], + val numFeatures: Int, override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { @@ -133,8 +136,8 @@ final class RandomForestClassificationModel private[ml] ( * Construct a random forest classification model, with all trees weighted equally. * @param trees Component trees */ - def this(trees: Array[DecisionTreeClassificationModel], numClasses: Int) = - this(Identifiable.randomUID("rfc"), trees, numClasses) + def this(trees: Array[DecisionTreeClassificationModel], numFeatures: Int, numClasses: Int) = + this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses) override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] @@ -182,13 +185,30 @@ final class RandomForestClassificationModel private[ml] ( } override def copy(extra: ParamMap): RandomForestClassificationModel = { - copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra) + copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra) } override def toString: String = { s"RandomForestClassificationModel with $numTrees trees" } + /** + * Estimate of the importance of each feature. + * + * This generalizes the idea of "Gini" importance to other losses, + * following the explanation of Gini importance from "Random Forests" documentation + * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * + * This feature importance is calculated as follows: + * - Average over trees: + * - importance(feature j) = sum (over nodes which split on feature j) of the gain, + * where gain is scaled by the number of instances passing through node + * - Normalize importances for tree based on total number of training instances used + * to build tree. + * - Normalize feature importance vector to sum to 1. + */ + lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures) + /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldRandomForestModel = { new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld)) @@ -210,6 +230,6 @@ private[ml] object RandomForestClassificationModel { DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc") - new RandomForestClassificationModel(uid, newTrees, numClasses) + new RandomForestClassificationModel(uid, newTrees, -1, numClasses) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 17fb1ad5e15d4..1ee43c8725732 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -30,7 +30,7 @@ import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestMo import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DoubleType + /** * :: Experimental :: @@ -87,7 +87,8 @@ final class RandomForestRegressor(override val uid: String) val trees = RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed) .map(_.asInstanceOf[DecisionTreeRegressionModel]) - new RandomForestRegressionModel(trees) + val numFeatures = oldDataset.first().features.size + new RandomForestRegressionModel(trees, numFeatures) } override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra) @@ -108,11 +109,13 @@ object RandomForestRegressor { * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression. * It supports both continuous and categorical features. * @param _trees Decision trees in the ensemble. + * @param numFeatures Number of features used by this model */ @Experimental final class RandomForestRegressionModel private[ml] ( override val uid: String, - private val _trees: Array[DecisionTreeRegressionModel]) + private val _trees: Array[DecisionTreeRegressionModel], + val numFeatures: Int) extends PredictionModel[Vector, RandomForestRegressionModel] with TreeEnsembleModel with Serializable { @@ -122,7 +125,8 @@ final class RandomForestRegressionModel private[ml] ( * Construct a random forest regression model, with all trees weighted equally. * @param trees Component trees */ - def this(trees: Array[DecisionTreeRegressionModel]) = this(Identifiable.randomUID("rfr"), trees) + def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) = + this(Identifiable.randomUID("rfr"), trees, numFeatures) override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] @@ -147,13 +151,30 @@ final class RandomForestRegressionModel private[ml] ( } override def copy(extra: ParamMap): RandomForestRegressionModel = { - copyValues(new RandomForestRegressionModel(uid, _trees), extra) + copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra) } override def toString: String = { s"RandomForestRegressionModel with $numTrees trees" } + /** + * Estimate of the importance of each feature. + * + * This generalizes the idea of "Gini" importance to other losses, + * following the explanation of Gini importance from "Random Forests" documentation + * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * + * This feature importance is calculated as follows: + * - Average over trees: + * - importance(feature j) = sum (over nodes which split on feature j) of the gain, + * where gain is scaled by the number of instances passing through node + * - Normalize importances for tree based on total number of training instances used + * to build tree. + * - Normalize feature importance vector to sum to 1. + */ + lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures) + /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldRandomForestModel = { new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld)) @@ -173,6 +194,6 @@ private[ml] object RandomForestRegressionModel { // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } - new RandomForestRegressionModel(parent.uid, newTrees) + new RandomForestRegressionModel(parent.uid, newTrees, -1) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index 8879352a600a9..cd24931293903 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -44,7 +44,7 @@ sealed abstract class Node extends Serializable { * and probabilities. * For classification, the array of class counts must be normalized to a probability distribution. */ - private[tree] def impurityStats: ImpurityCalculator + private[ml] def impurityStats: ImpurityCalculator /** Recursive prediction helper method */ private[ml] def predictImpl(features: Vector): LeafNode @@ -72,6 +72,12 @@ sealed abstract class Node extends Serializable { * @param id Node ID using old format IDs */ private[ml] def toOld(id: Int): OldNode + + /** + * Trace down the tree, and return the largest feature index used in any split. + * @return Max feature index used in a split, or -1 if there are no splits (single leaf node). + */ + private[ml] def maxSplitFeatureIndex(): Int } private[ml] object Node { @@ -109,7 +115,7 @@ private[ml] object Node { final class LeafNode private[ml] ( override val prediction: Double, override val impurity: Double, - override val impurityStats: ImpurityCalculator) extends Node { + override private[ml] val impurityStats: ImpurityCalculator) extends Node { override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)" @@ -129,6 +135,8 @@ final class LeafNode private[ml] ( new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), impurity, isLeaf = true, None, None, None, None) } + + override private[ml] def maxSplitFeatureIndex(): Int = -1 } /** @@ -150,7 +158,7 @@ final class InternalNode private[ml] ( val leftChild: Node, val rightChild: Node, val split: Split, - override val impurityStats: ImpurityCalculator) extends Node { + override private[ml] val impurityStats: ImpurityCalculator) extends Node { override def toString: String = { s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" @@ -190,6 +198,11 @@ final class InternalNode private[ml] ( new OldPredict(leftChild.prediction, prob = 0.0), new OldPredict(rightChild.prediction, prob = 0.0)))) } + + override private[ml] def maxSplitFeatureIndex(): Int = { + math.max(split.featureIndex, + math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex())) + } } private object InternalNode { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index a8b90d9d266a1..4ac51a475474a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -26,6 +26,7 @@ import org.apache.spark.Logging import org.apache.spark.ml.classification.DecisionTreeClassificationModel import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ +import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata, @@ -34,6 +35,7 @@ import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} @@ -1113,4 +1115,94 @@ private[ml] object RandomForest extends Logging { } } + /** + * Given a Random Forest model, compute the importance of each feature. + * This generalizes the idea of "Gini" importance to other losses, + * following the explanation of Gini importance from "Random Forests" documentation + * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * + * This feature importance is calculated as follows: + * - Average over trees: + * - importance(feature j) = sum (over nodes which split on feature j) of the gain, + * where gain is scaled by the number of instances passing through node + * - Normalize importances for tree based on total number of training instances used + * to build tree. + * - Normalize feature importance vector to sum to 1. + * + * Note: This should not be used with Gradient-Boosted Trees. It only makes sense for + * independently trained trees. + * @param trees Unweighted forest of trees + * @param numFeatures Number of features in model (even if not all are explicitly used by + * the model). + * If -1, then numFeatures is set based on the max feature index in all trees. + * @return Feature importance values, of length numFeatures. + */ + private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = { + val totalImportances = new OpenHashMap[Int, Double]() + trees.foreach { tree => + // Aggregate feature importance vector for this tree + val importances = new OpenHashMap[Int, Double]() + computeFeatureImportance(tree.rootNode, importances) + // Normalize importance vector for this tree, and add it to total. + // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count? + val treeNorm = importances.map(_._2).sum + if (treeNorm != 0) { + importances.foreach { case (idx, impt) => + val normImpt = impt / treeNorm + totalImportances.changeValue(idx, normImpt, _ + normImpt) + } + } + } + // Normalize importances + normalizeMapValues(totalImportances) + // Construct vector + val d = if (numFeatures != -1) { + numFeatures + } else { + // Find max feature index used in trees + val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max + maxFeatureIndex + 1 + } + if (d == 0) { + assert(totalImportances.size == 0, s"Unknown error in computing RandomForest feature" + + s" importance: No splits in forest, but some non-zero importances.") + } + val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip + Vectors.sparse(d, indices.toArray, values.toArray) + } + + /** + * Recursive method for computing feature importances for one tree. + * This walks down the tree, adding to the importance of 1 feature at each node. + * @param node Current node in recursion + * @param importances Aggregate feature importances, modified by this method + */ + private[impl] def computeFeatureImportance( + node: Node, + importances: OpenHashMap[Int, Double]): Unit = { + node match { + case n: InternalNode => + val feature = n.split.featureIndex + val scaledGain = n.gain * n.impurityStats.count + importances.changeValue(feature, scaledGain, _ + scaledGain) + computeFeatureImportance(n.leftChild, importances) + computeFeatureImportance(n.rightChild, importances) + case n: LeafNode => + // do nothing + } + } + + /** + * Normalize the values of this map to sum to 1, in place. + * If all values are 0, this method does nothing. + * @param map Map with non-negative values. + */ + private[impl] def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = { + val total = map.map(_._2).sum + if (total != 0) { + val keys = map.iterator.map(_._1).toArray + keys.foreach { key => map.changeValue(key, 0.0, _ / total) } + } + } + } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 22873909c33fa..b77191156f68f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -53,6 +53,12 @@ private[ml] trait DecisionTreeModel { val header = toString + "\n" header + rootNode.subtreeToString(2) } + + /** + * Trace down the tree, and return the largest feature index used in any split. + * @return Max feature index used in a split, or -1 if there are no splits (single leaf node). + */ + private[ml] def maxSplitFeatureIndex(): Int = rootNode.maxSplitFeatureIndex() } /** diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java index 32d0b3856b7e2..a66a1e12927be 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java @@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; @@ -85,6 +86,7 @@ public void runDT() { model.toDebugString(); model.trees(); model.treeWeights(); + Vector importances = model.featureImportances(); /* // TODO: Add test once save/load are implemented. SPARK-6725 diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java index e306ebadfe7cf..a00ce5e249c34 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java @@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; @@ -85,6 +86,7 @@ public void runDT() { model.toDebugString(); model.trees(); model.treeWeights(); + Vector importances = model.featureImportances(); /* // TODO: Add test once save/load are implemented. SPARK-6725 diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index edf848b21a905..6ca4b5aa5fde8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -67,7 +67,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte test("params") { ParamsSuite.checkParams(new RandomForestClassifier) val model = new RandomForestClassificationModel("rfc", - Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2) + Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2, 2) ParamsSuite.checkParams(model) } @@ -149,6 +149,35 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte } } + ///////////////////////////////////////////////////////////////////////////// + // Tests of feature importance + ///////////////////////////////////////////////////////////////////////////// + test("Feature importance with toy data") { + val numClasses = 2 + val rf = new RandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(3) + .setNumTrees(3) + .setFeatureSubsetStrategy("all") + .setSubsamplingRate(1.0) + .setSeed(123) + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = sc.parallelize(Seq( + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)), + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)) + )) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) + + val importances = rf.fit(df).featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala index 778abcba22c10..460849c79f04f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala @@ -124,4 +124,22 @@ private[ml] object TreeTests extends SparkFunSuite { "checkEqual failed since the two tree ensembles were not identical") } } + + /** + * Helper method for constructing a tree for testing. + * Given left, right children, construct a parent node. + * @param split Split for parent node + * @return Parent node with children attached + */ + def buildParentNode(left: Node, right: Node, split: Split): Node = { + val leftImp = left.impurityStats + val rightImp = right.impurityStats + val parentImp = leftImp.copy.add(rightImp) + val leftWeight = leftImp.count / parentImp.count.toDouble + val rightWeight = rightImp.count / parentImp.count.toDouble + val gain = parentImp.calculate() - + (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate()) + val pred = parentImp.predict + new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index b24ecaa57c89b..992ce9562434e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -26,7 +27,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame - /** * Test suite for [[RandomForestRegressor]]. */ @@ -71,6 +71,31 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex regressionTestWithContinuousFeatures(rf) } + test("Feature importance with toy data") { + val rf = new RandomForestRegressor() + .setImpurity("variance") + .setMaxDepth(3) + .setNumTrees(3) + .setFeatureSubsetStrategy("all") + .setSubsamplingRate(1.0) + .setSeed(123) + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = sc.parallelize(Seq( + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)), + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)) + )) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0) + + val importances = rf.fit(df).featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala new file mode 100644 index 0000000000000..dc852795c7f62 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.classification.DecisionTreeClassificationModel +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.tree.{ContinuousSplit, DecisionTreeModel, LeafNode, Node} +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.tree.impurity.GiniCalculator +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.collection.OpenHashMap + +/** + * Test suite for [[RandomForest]]. + */ +class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { + + import RandomForestSuite.mapToVec + + test("computeFeatureImportance, featureImportances") { + /* Build tree for testing, with this structure: + grandParent + left2 parent + left right + */ + val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0)) + val left = new LeafNode(0.0, leftImp.calculate(), leftImp) + + val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0)) + val right = new LeafNode(2.0, rightImp.calculate(), rightImp) + + val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5)) + val parentImp = parent.impurityStats + + val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0)) + val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp) + + val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0)) + val grandImp = grandParent.impurityStats + + // Test feature importance computed at different subtrees. + def testNode(node: Node, expected: Map[Int, Double]): Unit = { + val map = new OpenHashMap[Int, Double]() + RandomForest.computeFeatureImportance(node, map) + assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) + } + + // Leaf node + testNode(left, Map.empty[Int, Double]) + + // Internal node with 2 leaf children + val feature0importance = parentImp.calculate() * parentImp.count - + (leftImp.calculate() * leftImp.count + rightImp.calculate() * rightImp.count) + testNode(parent, Map(0 -> feature0importance)) + + // Full tree + val feature1importance = grandImp.calculate() * grandImp.count - + (left2Imp.calculate() * left2Imp.count + parentImp.calculate() * parentImp.count) + testNode(grandParent, Map(0 -> feature0importance, 1 -> feature1importance)) + + // Forest consisting of (full tree) + (internal node with 2 leafs) + val trees = Array(parent, grandParent).map { root => + new DecisionTreeClassificationModel(root, numClasses = 3).asInstanceOf[DecisionTreeModel] + } + val importances: Vector = RandomForest.featureImportances(trees, 2) + val tree2norm = feature0importance + feature1importance + val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0, + (feature1importance / tree2norm) / 2.0) + assert(importances ~== expected relTol 0.01) + } + + test("normalizeMapValues") { + val map = new OpenHashMap[Int, Double]() + map(0) = 1.0 + map(2) = 2.0 + RandomForest.normalizeMapValues(map) + val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0) + assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) + } + +} + +private object RandomForestSuite { + + def mapToVec(map: Map[Int, Double]): Vector = { + val size = (map.keys.toSeq :+ 0).max + 1 + val (indices, values) = map.toSeq.sortBy(_._1).unzip + Vectors.sparse(size, indices.toArray, values.toArray) + } +} From ba1c4e138de2ea84b55def4eed2bd363e60aea4d Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Mon, 3 Aug 2015 12:53:44 -0700 Subject: [PATCH 105/340] [SPARK-9558][DOCS]Update docs to follow the increase of memory defaults. Now the memory defaults of master and slave in Standalone mode and History Server is 1g, not 512m. So let's update docs. Author: Kousuke Saruta Closes #7896 from sarutak/update-doc-for-daemon-memory and squashes the following commits: a77626c [Kousuke Saruta] Fix docs to follow the update of increase of memory defaults --- conf/spark-env.sh.template | 1 + docs/monitoring.md | 2 +- docs/spark-standalone.md | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 192d3ae091134..c05fe381a36a7 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -38,6 +38,7 @@ # - SPARK_WORKER_INSTANCES, to set the number of worker processes per node # - SPARK_WORKER_DIR, to set the working directory of worker processes # - SPARK_WORKER_OPTS, to set config properties only for the worker (e.g. "-Dx=y") +# - SPARK_DAEMON_MEMORY, to allocate to the master, worker and history server themselves (default: 1g). # - SPARK_HISTORY_OPTS, to set config properties only for the history server (e.g. "-Dx=y") # - SPARK_SHUFFLE_OPTS, to set config properties only for the external shuffle service (e.g. "-Dx=y") # - SPARK_DAEMON_JAVA_OPTS, to set config properties for all daemons (e.g. "-Dx=y") diff --git a/docs/monitoring.md b/docs/monitoring.md index bcf885fe4e681..cedceb2958023 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -48,7 +48,7 @@ follows: Environment VariableMeaning SPARK_DAEMON_MEMORY - Memory to allocate to the history server (default: 512m). + Memory to allocate to the history server (default: 1g). SPARK_DAEMON_JAVA_OPTS diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 4f71fbc086cd0..2fe9ec3542b28 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -152,7 +152,7 @@ You can optionally configure the cluster further by setting environment variable SPARK_DAEMON_MEMORY - Memory to allocate to the Spark master and worker daemons themselves (default: 512m). + Memory to allocate to the Spark master and worker daemons themselves (default: 1g). SPARK_DAEMON_JAVA_OPTS From 8ca287ebbd58985a568341b08040d0efa9d3641a Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 3 Aug 2015 13:58:00 -0700 Subject: [PATCH 106/340] [SPARK-9191] [ML] [Doc] Add ml.PCA user guide and code examples Add ml.PCA user guide document and code examples for Scala/Java/Python. Author: Yanbo Liang Closes #7522 from yanboliang/ml-pca-md and squashes the following commits: 60dec05 [Yanbo Liang] address comments f992abe [Yanbo Liang] Add ml.PCA doc and examples --- docs/ml-features.md | 86 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/docs/ml-features.md b/docs/ml-features.md index 54068debe2159..fa0ad1f00ab12 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -461,6 +461,92 @@ for binarized_feature, in binarizedFeatures.collect(): +## PCA + +[PCA](http://en.wikipedia.org/wiki/Principal_component_analysis) is a statistical procedure that uses an orthogonal transformation to convert a set of observations of possibly correlated variables into a set of values of linearly uncorrelated variables called principal components. A [PCA](api/scala/index.html#org.apache.spark.ml.feature.PCA) class trains a model to project vectors to a low-dimensional space using PCA. The example below shows how to project 5-dimensional feature vectors into 3-dimensional principal components. + +

+
+See the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.feature.PCA) for API details. +{% highlight scala %} +import org.apache.spark.ml.feature.PCA +import org.apache.spark.mllib.linalg.Vectors + +val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) +) +val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") +val pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df) +val pcaDF = pca.transform(df) +val result = pcaDF.select("pcaFeatures") +result.show() +{% endhighlight %} +
+ +
+See the [Java API documentation](api/java/org/apache/spark/ml/feature/PCA.html) for API details. +{% highlight java %} +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.PCA +import org.apache.spark.ml.feature.PCAModel +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaSparkContext jsc = ... +SQLContext jsql = ... +JavaRDD data = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), + RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), + RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) +)); +StructType schema = new StructType(new StructField[] { + new StructField("features", new VectorUDT(), false, Metadata.empty()), +}); +DataFrame df = jsql.createDataFrame(data, schema); +PCAModel pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df); +DataFrame result = pca.transform(df).select("pcaFeatures"); +result.show(); +{% endhighlight %} +
+ +
+See the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.feature.PCA) for API details. +{% highlight python %} +from pyspark.ml.feature import PCA +from pyspark.mllib.linalg import Vectors + +data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), + (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), + (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] +df = sqlContext.createDataFrame(data,["features"]) +pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") +model = pca.fit(df) +result = model.transform(df).select("pcaFeatures") +result.show(truncate=False) +{% endhighlight %} +
+
+ ## PolynomialExpansion [Polynomial expansion](http://en.wikipedia.org/wiki/Polynomial_expansion) is the process of expanding your features into a polynomial space, which is formulated by an n-degree combination of original dimensions. A [PolynomialExpansion](api/scala/index.html#org.apache.spark.ml.feature.PolynomialExpansion) class provides this functionality. The example below shows how to expand your features into a 3-degree polynomial space. From e4765a46833baff1dd7465c4cf50e947de7e8f21 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 3 Aug 2015 13:59:35 -0700 Subject: [PATCH 107/340] [SPARK-9544] [MLLIB] add Python API for RFormula Add Python API for RFormula. Similar to other feature transformers in Python. This is just a thin wrapper over the Scala implementation. ericl MechCoder Author: Xiangrui Meng Closes #7879 from mengxr/SPARK-9544 and squashes the following commits: 3d5ff03 [Xiangrui Meng] add an doctest for . and - 5e969a5 [Xiangrui Meng] fix pydoc 1cd41f8 [Xiangrui Meng] organize imports 3c18b10 [Xiangrui Meng] add Python API for RFormula --- .../apache/spark/ml/feature/RFormula.scala | 21 ++--- python/pyspark/ml/feature.py | 85 ++++++++++++++++++- 2 files changed, 91 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index d1726917e4517..d5360c9217ea9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -19,16 +19,14 @@ package org.apache.spark.ml.feature import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.util.parsing.combinator.RegexParsers import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.{Estimator, Model, Transformer, Pipeline, PipelineModel, PipelineStage} +import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.VectorUDT import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ /** @@ -63,31 +61,26 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R */ val formula: Param[String] = new Param(this, "formula", "R model formula") - private var parsedFormula: Option[ParsedRFormula] = None - /** * Sets the formula to use for this transformer. Must be called before use. * @group setParam * @param value an R formula in string form (e.g. "y ~ x + z") */ - def setFormula(value: String): this.type = { - parsedFormula = Some(RFormulaParser.parse(value)) - set(formula, value) - this - } + def setFormula(value: String): this.type = set(formula, value) /** @group getParam */ def getFormula: String = $(formula) /** Whether the formula specifies fitting an intercept. */ private[ml] def hasIntercept: Boolean = { - require(parsedFormula.isDefined, "Must call setFormula() first.") - parsedFormula.get.hasIntercept + require(isDefined(formula), "Formula must be defined first.") + RFormulaParser.parse($(formula)).hasIntercept } override def fit(dataset: DataFrame): RFormulaModel = { - require(parsedFormula.isDefined, "Must call setFormula() first.") - val resolvedFormula = parsedFormula.get.resolve(dataset.schema) + require(isDefined(formula), "Formula must be defined first.") + val parsedFormula = RFormulaParser.parse($(formula)) + val resolvedFormula = parsedFormula.resolve(dataset.schema) // StringType terms and terms representing interactions need to be encoded before assembly. // TODO(ekl) add support for feature interactions val encoderStages = ArrayBuffer[PipelineStage]() diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 015e7a9d4900a..3f04c41ac5ab6 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -24,7 +24,7 @@ __all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', - 'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel'] + 'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel'] @inherit_doc @@ -1110,6 +1110,89 @@ class PCAModel(JavaModel): """ +@inherit_doc +class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol): + """ + .. note:: Experimental + + Implements the transforms required for fitting a dataset against an + R model formula. Currently we support a limited subset of the R + operators, including '~', '+', '-', and '.'. Also see the R formula + docs: + http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + + >>> df = sqlContext.createDataFrame([ + ... (1.0, 1.0, "a"), + ... (0.0, 2.0, "b"), + ... (0.0, 0.0, "a") + ... ], ["y", "x", "s"]) + >>> rf = RFormula(formula="y ~ x + s") + >>> rf.fit(df).transform(df).show() + +---+---+---+---------+-----+ + | y| x| s| features|label| + +---+---+---+---------+-----+ + |1.0|1.0| a|[1.0,1.0]| 1.0| + |0.0|2.0| b|[2.0,0.0]| 0.0| + |0.0|0.0| a|[0.0,1.0]| 0.0| + +---+---+---+---------+-----+ + ... + >>> rf.fit(df, {rf.formula: "y ~ . - s"}).transform(df).show() + +---+---+---+--------+-----+ + | y| x| s|features|label| + +---+---+---+--------+-----+ + |1.0|1.0| a| [1.0]| 1.0| + |0.0|2.0| b| [2.0]| 0.0| + |0.0|0.0| a| [0.0]| 0.0| + +---+---+---+--------+-----+ + ... + """ + + # a placeholder to make it appear in the generated doc + formula = Param(Params._dummy(), "formula", "R model formula") + + @keyword_only + def __init__(self, formula=None, featuresCol="features", labelCol="label"): + """ + __init__(self, formula=None, featuresCol="features", labelCol="label") + """ + super(RFormula, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid) + self.formula = Param(self, "formula", "R model formula") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, formula=None, featuresCol="features", labelCol="label"): + """ + setParams(self, formula=None, featuresCol="features", labelCol="label") + Sets params for RFormula. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setFormula(self, value): + """ + Sets the value of :py:attr:`formula`. + """ + self._paramMap[self.formula] = value + return self + + def getFormula(self): + """ + Gets the value of :py:attr:`formula`. + """ + return self.getOrDefault(self.formula) + + def _create_model(self, java_model): + return RFormulaModel(java_model) + + +class RFormulaModel(JavaModel): + """ + Model fitted by :py:class:`RFormula`. + """ + + if __name__ == "__main__": import doctest from pyspark.context import SparkContext From 702aa9d7fb16c98a50e046edfd76b8a7861d0391 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 3 Aug 2015 14:22:07 -0700 Subject: [PATCH 108/340] [SPARK-8735] [SQL] Expose memory usage for shuffles, joins and aggregations This patch exposes the memory used by internal data structures on the SparkUI. This tracks memory used by all spilling operations and SQL operators backed by Tungsten, e.g. `BroadcastHashJoin`, `ExternalSort`, `GeneratedAggregate` etc. The metric exposed is "peak execution memory", which broadly refers to the peak in-memory sizes of each of these data structure. A separate patch will extend this by linking the new information to the SQL operators themselves. screen shot 2015-07-29 at 7 43 17 pm screen shot 2015-07-29 at 7 43 05 pm [Review on Reviewable](https://reviewable.io/reviews/apache/spark/7770) Author: Andrew Or Closes #7770 from andrewor14/expose-memory-metrics and squashes the following commits: 9abecb9 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics f5b0d68 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics d7df332 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics 8eefbc5 [Andrew Or] Fix non-failing tests 9de2a12 [Andrew Or] Fix tests due to another logical merge conflict 876bfa4 [Andrew Or] Fix failing test after logical merge conflict 361a359 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics 40b4802 [Andrew Or] Fix style? d0fef87 [Andrew Or] Fix tests? b3b92f6 [Andrew Or] Address comments 0625d73 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics c00a197 [Andrew Or] Fix potential NPEs 10da1cd [Andrew Or] Fix compile 17f4c2d [Andrew Or] Fix compile? a87b4d0 [Andrew Or] Fix compile? d70874d [Andrew Or] Fix test compile + address comments 2840b7d [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics 6aa2f7a [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics b889a68 [Andrew Or] Minor changes: comments, spacing, style 663a303 [Andrew Or] UnsafeShuffleWriter: update peak memory before close d090a94 [Andrew Or] Fix style 2480d84 [Andrew Or] Expand test coverage 5f1235b [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics 1ecf678 [Andrew Or] Minor changes: comments, style, unused imports 0b6926c [Andrew Or] Oops 111a05e [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics a7a39a5 [Andrew Or] Strengthen presence check for accumulator a919eb7 [Andrew Or] Add tests for unsafe shuffle writer 23c845d [Andrew Or] Add tests for SQL operators a757550 [Andrew Or] Address comments b5c51c1 [Andrew Or] Re-enable test in JavaAPISuite 5107691 [Andrew Or] Add tests for internal accumulators 59231e4 [Andrew Or] Fix tests 9528d09 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics 5b5e6f3 [Andrew Or] Add peak execution memory to summary table + tooltip 92b4b6b [Andrew Or] Display peak execution memory on the UI eee5437 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics d9b9015 [Andrew Or] Track execution memory in unsafe shuffles 770ee54 [Andrew Or] Track execution memory in broadcast joins 9c605a4 [Andrew Or] Track execution memory in GeneratedAggregate 9e824f2 [Andrew Or] Add back execution memory tracking for *ExternalSort 4ef4cb1 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics e6c3e2f [Andrew Or] Move internal accumulators creation to Stage a417592 [Andrew Or] Expose memory metrics in UnsafeExternalSorter 3c4f042 [Andrew Or] Track memory usage in ExternalAppendOnlyMap / ExternalSorter bd7ab3f [Andrew Or] Add internal accumulators to TaskContext --- .../unsafe/UnsafeShuffleExternalSorter.java | 27 ++- .../shuffle/unsafe/UnsafeShuffleWriter.java | 38 +++- .../spark/unsafe/map/BytesToBytesMap.java | 8 +- .../unsafe/sort/UnsafeExternalSorter.java | 29 ++- .../org/apache/spark/ui/static/webui.css | 2 +- .../scala/org/apache/spark/Accumulators.scala | 60 +++++- .../scala/org/apache/spark/Aggregator.scala | 24 +-- .../scala/org/apache/spark/TaskContext.scala | 13 +- .../org/apache/spark/TaskContextImpl.scala | 8 + .../org/apache/spark/rdd/CoGroupedRDD.scala | 9 +- .../spark/scheduler/AccumulableInfo.scala | 9 +- .../apache/spark/scheduler/DAGScheduler.scala | 28 ++- .../apache/spark/scheduler/ResultTask.scala | 6 +- .../spark/scheduler/ShuffleMapTask.scala | 10 +- .../org/apache/spark/scheduler/Stage.scala | 16 ++ .../org/apache/spark/scheduler/Task.scala | 18 +- .../shuffle/hash/HashShuffleReader.scala | 8 +- .../scala/org/apache/spark/ui/ToolTips.scala | 7 + .../org/apache/spark/ui/jobs/StagePage.scala | 140 +++++++++---- .../spark/ui/jobs/TaskDetailsClassNames.scala | 1 + .../collection/ExternalAppendOnlyMap.scala | 13 +- .../util/collection/ExternalSorter.scala | 20 +- .../java/org/apache/spark/JavaAPISuite.java | 3 +- .../unsafe/UnsafeShuffleWriterSuite.java | 54 +++++ .../map/AbstractBytesToBytesMapSuite.java | 39 ++++ .../sort/UnsafeExternalSorterSuite.java | 46 +++++ .../org/apache/spark/AccumulatorSuite.scala | 193 +++++++++++++++++- .../org/apache/spark/CacheManagerSuite.scala | 10 +- .../org/apache/spark/rdd/PipedRDDSuite.scala | 2 +- .../org/apache/spark/scheduler/FakeTask.scala | 6 +- .../scheduler/NotSerializableFakeTask.scala | 2 +- .../spark/scheduler/TaskContextSuite.scala | 7 +- .../spark/scheduler/TaskSetManagerSuite.scala | 2 +- .../shuffle/hash/HashShuffleReaderSuite.scala | 2 +- .../ShuffleBlockFetcherIteratorSuite.scala | 8 +- .../org/apache/spark/ui/StagePageSuite.scala | 76 +++++++ .../ExternalAppendOnlyMapSuite.scala | 15 ++ .../util/collection/ExternalSorterSuite.scala | 14 +- .../execution/UnsafeExternalRowSorter.java | 7 + .../UnsafeFixedWidthAggregationMap.java | 8 + .../sql/execution/GeneratedAggregate.scala | 11 +- .../execution/joins/BroadcastHashJoin.scala | 10 +- .../joins/BroadcastHashOuterJoin.scala | 8 + .../joins/BroadcastLeftSemiJoinHash.scala | 10 +- .../sql/execution/joins/HashedRelation.scala | 22 +- .../org/apache/spark/sql/execution/sort.scala | 12 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 60 ++++-- .../sql/execution/TungstenSortSuite.scala | 12 ++ .../UnsafeFixedWidthAggregationMapSuite.scala | 3 +- .../UnsafeKVExternalSorterSuite.scala | 3 +- .../execution/joins/BroadcastJoinSuite.scala | 94 +++++++++ 51 files changed, 1070 insertions(+), 163 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 1aa6ba4201261..bf4eaa59ff589 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -20,6 +20,7 @@ import java.io.File; import java.io.IOException; import java.util.LinkedList; +import javax.annotation.Nullable; import scala.Tuple2; @@ -86,9 +87,12 @@ final class UnsafeShuffleExternalSorter { private final LinkedList spills = new LinkedList(); + /** Peak memory used by this sorter so far, in bytes. **/ + private long peakMemoryUsedBytes; + // These variables are reset after spilling: - private UnsafeShuffleInMemorySorter sorter; - private MemoryBlock currentPage = null; + @Nullable private UnsafeShuffleInMemorySorter sorter; + @Nullable private MemoryBlock currentPage = null; private long currentPagePosition = -1; private long freeSpaceInCurrentPage = 0; @@ -106,6 +110,7 @@ public UnsafeShuffleExternalSorter( this.blockManager = blockManager; this.taskContext = taskContext; this.initialSize = initialSize; + this.peakMemoryUsedBytes = initialSize; this.numPartitions = numPartitions; // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; @@ -279,10 +284,26 @@ private long getMemoryUsage() { for (MemoryBlock page : allocatedPages) { totalPageSize += page.size(); } - return sorter.getMemoryUsage() + totalPageSize; + return ((sorter == null) ? 0 : sorter.getMemoryUsage()) + totalPageSize; + } + + private void updatePeakMemoryUsed() { + long mem = getMemoryUsage(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; } private long freeMemory() { + updatePeakMemoryUsed(); long memoryFreed = 0; for (MemoryBlock block : allocatedPages) { memoryManager.freePage(block); diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index d47d6fc9c2ac4..6e2eeb37c86f1 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -27,6 +27,7 @@ import scala.collection.JavaConversions; import scala.reflect.ClassTag; import scala.reflect.ClassTag$; +import scala.collection.immutable.Map; import com.google.common.annotations.VisibleForTesting; import com.google.common.io.ByteStreams; @@ -78,8 +79,9 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final SparkConf sparkConf; private final boolean transferToEnabled; - private MapStatus mapStatus = null; - private UnsafeShuffleExternalSorter sorter = null; + @Nullable private MapStatus mapStatus; + @Nullable private UnsafeShuffleExternalSorter sorter; + private long peakMemoryUsedBytes = 0; /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream { @@ -131,9 +133,28 @@ public UnsafeShuffleWriter( @VisibleForTesting public int maxRecordSizeBytes() { + assert(sorter != null); return sorter.maxRecordSizeBytes; } + private void updatePeakMemoryUsed() { + // sorter can be null if this writer is closed + if (sorter != null) { + long mem = sorter.getPeakMemoryUsedBytes(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; + } + /** * This convenience method should only be called in test code. */ @@ -144,7 +165,7 @@ public void write(Iterator> records) throws IOException { @Override public void write(scala.collection.Iterator> records) throws IOException { - // Keep track of success so we know if we ecountered an exception + // Keep track of success so we know if we encountered an exception // We do this rather than a standard try/catch/re-throw to handle // generic throwables. boolean success = false; @@ -189,6 +210,8 @@ private void open() throws IOException { @VisibleForTesting void closeAndWriteOutput() throws IOException { + assert(sorter != null); + updatePeakMemoryUsed(); serBuffer = null; serOutputStream = null; final SpillInfo[] spills = sorter.closeAndGetSpills(); @@ -209,6 +232,7 @@ void closeAndWriteOutput() throws IOException { @VisibleForTesting void insertRecordIntoSorter(Product2 record) throws IOException { + assert(sorter != null); final K key = record._1(); final int partitionId = partitioner.getPartition(key); serBuffer.reset(); @@ -431,6 +455,14 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th @Override public Option stop(boolean success) { try { + // Update task metrics from accumulators (null in UnsafeShuffleWriterSuite) + Map> internalAccumulators = + taskContext.internalMetricsToAccumulators(); + if (internalAccumulators != null) { + internalAccumulators.apply(InternalAccumulator.PEAK_EXECUTION_MEMORY()) + .add(getPeakMemoryUsedBytes()); + } + if (stopping) { return Option.apply(null); } else { diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 01a66084e918e..20347433e16b2 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -505,7 +505,7 @@ public boolean putNewKey( // Here, we'll copy the data into our data pages. Because we only store a relative offset from // the key address instead of storing the absolute address of the value, the key and value // must be stored in the same memory page. - // (8 byte key length) (key) (8 byte value length) (value) + // (8 byte key length) (key) (value) final long requiredSize = 8 + keyLengthBytes + valueLengthBytes; // --- Figure out where to insert the new record --------------------------------------------- @@ -655,7 +655,10 @@ public long getPageSizeBytes() { return pageSizeBytes; } - /** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */ + /** + * Returns the total amount of memory, in bytes, consumed by this map's managed structures. + * Note that this is also the peak memory used by this map, since the map is append-only. + */ public long getTotalMemoryConsumption() { long totalDataPagesSize = 0L; for (MemoryBlock dataPage : dataPages) { @@ -674,7 +677,6 @@ public long getTimeSpentResizingNs() { return timeSpentResizingNs; } - /** * Returns the average number of probes per key lookup. */ diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index b984301cbbf2b..bf5f965a9d8dc 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -70,13 +70,14 @@ public final class UnsafeExternalSorter { private final LinkedList spillWriters = new LinkedList<>(); // These variables are reset after spilling: - private UnsafeInMemorySorter inMemSorter; + @Nullable private UnsafeInMemorySorter inMemSorter; // Whether the in-mem sorter is created internally, or passed in from outside. // If it is passed in from outside, we shouldn't release the in-mem sorter's memory. private boolean isInMemSorterExternal = false; private MemoryBlock currentPage = null; private long currentPagePosition = -1; private long freeSpaceInCurrentPage = 0; + private long peakMemoryUsedBytes = 0; public static UnsafeExternalSorter createWithExistingInMemorySorter( TaskMemoryManager taskMemoryManager, @@ -183,6 +184,7 @@ public void closeCurrentPage() { * Sort and spill the current records in response to memory pressure. */ public void spill() throws IOException { + assert(inMemSorter != null); logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", Thread.currentThread().getId(), Utils.bytesToString(getMemoryUsage()), @@ -219,7 +221,22 @@ private long getMemoryUsage() { for (MemoryBlock page : allocatedPages) { totalPageSize += page.size(); } - return inMemSorter.getMemoryUsage() + totalPageSize; + return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize; + } + + private void updatePeakMemoryUsed() { + long mem = getMemoryUsage(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; } @VisibleForTesting @@ -233,6 +250,7 @@ public int getNumberOfAllocatedPages() { * @return the number of bytes freed. */ public long freeMemory() { + updatePeakMemoryUsed(); long memoryFreed = 0; for (MemoryBlock block : allocatedPages) { taskMemoryManager.freePage(block); @@ -277,7 +295,8 @@ public void deleteSpillFiles() { * @return true if the record can be inserted without requiring more allocations, false otherwise. */ private boolean haveSpaceForRecord(int requiredSpace) { - assert (requiredSpace > 0); + assert(requiredSpace > 0); + assert(inMemSorter != null); return (inMemSorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage)); } @@ -290,6 +309,7 @@ private boolean haveSpaceForRecord(int requiredSpace) { * the record size. */ private void allocateSpaceForRecord(int requiredSpace) throws IOException { + assert(inMemSorter != null); // TODO: merge these steps to first calculate total memory requirements for this insert, // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the // data page. @@ -350,6 +370,7 @@ public void insertRecord( if (!haveSpaceForRecord(totalSpaceRequired)) { allocateSpaceForRecord(totalSpaceRequired); } + assert(inMemSorter != null); final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); @@ -382,6 +403,7 @@ public void insertKVRecord( if (!haveSpaceForRecord(totalSpaceRequired)) { allocateSpaceForRecord(totalSpaceRequired); } + assert(inMemSorter != null); final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); @@ -405,6 +427,7 @@ public void insertKVRecord( } public UnsafeSorterIterator getSortedIterator() throws IOException { + assert(inMemSorter != null); final UnsafeSorterIterator inMemoryIterator = inMemSorter.getSortedIterator(); int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0); if (spillWriters.isEmpty()) { diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index b1cef47042247..648cd1b104802 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -207,7 +207,7 @@ span.additional-metric-title { /* Hide all additional metrics by default. This is done here rather than using JavaScript to * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */ .scheduler_delay, .deserialization_time, .fetch_wait_time, .shuffle_read_remote, -.serialization_time, .getting_result_time { +.serialization_time, .getting_result_time, .peak_execution_memory { display: none; } diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index eb75f26718e19..b6a0119c696fd 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -152,8 +152,14 @@ class Accumulable[R, T] private[spark] ( in.defaultReadObject() value_ = zero deserialized = true - val taskContext = TaskContext.get() - taskContext.registerAccumulator(this) + // Automatically register the accumulator when it is deserialized with the task closure. + // Note that internal accumulators are deserialized before the TaskContext is created and + // are registered in the TaskContext constructor. + if (!isInternal) { + val taskContext = TaskContext.get() + assume(taskContext != null, "Task context was null when deserializing user accumulators") + taskContext.registerAccumulator(this) + } } override def toString: String = if (value_ == null) "null" else value_.toString @@ -248,10 +254,20 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa * @param param helper object defining how to add elements of type `T` * @tparam T result type */ -class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String]) - extends Accumulable[T, T](initialValue, param, name) { +class Accumulator[T] private[spark] ( + @transient initialValue: T, + param: AccumulatorParam[T], + name: Option[String], + internal: Boolean) + extends Accumulable[T, T](initialValue, param, name, internal) { + + def this(initialValue: T, param: AccumulatorParam[T], name: Option[String]) = { + this(initialValue, param, name, false) + } - def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None) + def this(initialValue: T, param: AccumulatorParam[T]) = { + this(initialValue, param, None, false) + } } /** @@ -342,3 +358,37 @@ private[spark] object Accumulators extends Logging { } } + +private[spark] object InternalAccumulator { + val PEAK_EXECUTION_MEMORY = "peakExecutionMemory" + val TEST_ACCUMULATOR = "testAccumulator" + + // For testing only. + // This needs to be a def since we don't want to reuse the same accumulator across stages. + private def maybeTestAccumulator: Option[Accumulator[Long]] = { + if (sys.props.contains("spark.testing")) { + Some(new Accumulator( + 0L, AccumulatorParam.LongAccumulatorParam, Some(TEST_ACCUMULATOR), internal = true)) + } else { + None + } + } + + /** + * Accumulators for tracking internal metrics. + * + * These accumulators are created with the stage such that all tasks in the stage will + * add to the same set of accumulators. We do this to report the distribution of accumulator + * values across all tasks within each stage. + */ + def create(): Seq[Accumulator[Long]] = { + Seq( + // Execution memory refers to the memory used by internal data structures created + // during shuffles, aggregations and joins. The value of this accumulator should be + // approximately the sum of the peak sizes across all such data structures created + // in this task. For SQL jobs, this only tracks all unsafe operators and ExternalSort. + new Accumulator( + 0L, AccumulatorParam.LongAccumulatorParam, Some(PEAK_EXECUTION_MEMORY), internal = true) + ) ++ maybeTestAccumulator.toSeq + } +} diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index ceeb58075d345..289aab9bd9e51 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -58,12 +58,7 @@ case class Aggregator[K, V, C] ( } else { val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) combiners.insertAll(iter) - // Update task metrics if context is not null - // TODO: Make context non optional in a future release - Option(context).foreach { c => - c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled) - c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled) - } + updateMetrics(context, combiners) combiners.iterator } } @@ -89,13 +84,18 @@ case class Aggregator[K, V, C] ( } else { val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) combiners.insertAll(iter) - // Update task metrics if context is not null - // TODO: Make context non-optional in a future release - Option(context).foreach { c => - c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled) - c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled) - } + updateMetrics(context, combiners) combiners.iterator } } + + /** Update task metrics after populating the external map. */ + private def updateMetrics(context: TaskContext, map: ExternalAppendOnlyMap[_, _, _]): Unit = { + Option(context).foreach { c => + c.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) + c.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) + c.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) + } + } } diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 5d2c551d58514..63cca80b2d734 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -61,12 +61,12 @@ object TaskContext { protected[spark] def unset(): Unit = taskContext.remove() /** - * Return an empty task context that is not actually used. - * Internal use only. + * An empty task context that does not represent an actual task. */ - private[spark] def empty(): TaskContext = { - new TaskContextImpl(0, 0, 0, 0, null, null) + private[spark] def empty(): TaskContextImpl = { + new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty) } + } @@ -187,4 +187,9 @@ abstract class TaskContext extends Serializable { * accumulator id and the value of the Map is the latest accumulator local value. */ private[spark] def collectAccumulators(): Map[Long, Any] + + /** + * Accumulators for tracking internal metrics indexed by the name. + */ + private[spark] val internalMetricsToAccumulators: Map[String, Accumulator[Long]] } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 9ee168ae016f8..5df94c6d3a103 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -32,6 +32,7 @@ private[spark] class TaskContextImpl( override val attemptNumber: Int, override val taskMemoryManager: TaskMemoryManager, @transient private val metricsSystem: MetricsSystem, + internalAccumulators: Seq[Accumulator[Long]], val runningLocally: Boolean = false, val taskMetrics: TaskMetrics = TaskMetrics.empty) extends TaskContext @@ -114,4 +115,11 @@ private[spark] class TaskContextImpl( private[spark] override def collectAccumulators(): Map[Long, Any] = synchronized { accumulators.mapValues(_.localValue).toMap } + + private[spark] override val internalMetricsToAccumulators: Map[String, Accumulator[Long]] = { + // Explicitly register internal accumulators here because these are + // not captured in the task closure and are already deserialized + internalAccumulators.foreach(registerAccumulator) + internalAccumulators.map { a => (a.name.get, a) }.toMap + } } diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 130b58882d8ee..9c617fc719cb5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -23,8 +23,7 @@ import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext} -import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} +import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer} import org.apache.spark.util.Utils @@ -169,8 +168,10 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: for ((it, depNum) <- rddIterators) { map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) } - context.taskMetrics.incMemoryBytesSpilled(map.memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(map.diskBytesSpilled) + context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) new InterruptibleIterator(context, map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index e0edd7d4ae968..11d123eec43ca 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -24,11 +24,12 @@ import org.apache.spark.annotation.DeveloperApi * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage. */ @DeveloperApi -class AccumulableInfo ( +class AccumulableInfo private[spark] ( val id: Long, val name: String, val update: Option[String], // represents a partial update within a task - val value: String) { + val value: String, + val internal: Boolean) { override def equals(other: Any): Boolean = other match { case acc: AccumulableInfo => @@ -40,10 +41,10 @@ class AccumulableInfo ( object AccumulableInfo { def apply(id: Long, name: String, update: Option[String], value: String): AccumulableInfo = { - new AccumulableInfo(id, name, update, value) + new AccumulableInfo(id, name, update, value, internal = false) } def apply(id: Long, name: String, value: String): AccumulableInfo = { - new AccumulableInfo(id, name, None, value) + new AccumulableInfo(id, name, None, value, internal = false) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c4fa277c21254..bb489c6b6e98f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -773,16 +773,26 @@ class DAGScheduler( stage.pendingTasks.clear() // First figure out the indexes of partition ids to compute. - val partitionsToCompute: Seq[Int] = { + val (allPartitions: Seq[Int], partitionsToCompute: Seq[Int]) = { stage match { case stage: ShuffleMapStage => - (0 until stage.numPartitions).filter(id => stage.outputLocs(id).isEmpty) + val allPartitions = 0 until stage.numPartitions + val filteredPartitions = allPartitions.filter { id => stage.outputLocs(id).isEmpty } + (allPartitions, filteredPartitions) case stage: ResultStage => val job = stage.resultOfJob.get - (0 until job.numPartitions).filter(id => !job.finished(id)) + val allPartitions = 0 until job.numPartitions + val filteredPartitions = allPartitions.filter { id => !job.finished(id) } + (allPartitions, filteredPartitions) } } + // Reset internal accumulators only if this stage is not partially submitted + // Otherwise, we may override existing accumulator values from some tasks + if (allPartitions == partitionsToCompute) { + stage.resetInternalAccumulators() + } + val properties = jobIdToActiveJob.get(stage.firstJobId).map(_.properties).orNull runningStages += stage @@ -852,7 +862,8 @@ class DAGScheduler( partitionsToCompute.map { id => val locs = taskIdToLocations(id) val part = stage.rdd.partitions(id) - new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs) + new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, + taskBinary, part, locs, stage.internalAccumulators) } case stage: ResultStage => @@ -861,7 +872,8 @@ class DAGScheduler( val p: Int = job.partitions(id) val part = stage.rdd.partitions(p) val locs = taskIdToLocations(id) - new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id) + new ResultTask(stage.id, stage.latestInfo.attemptId, + taskBinary, part, locs, id, stage.internalAccumulators) } } } catch { @@ -916,9 +928,11 @@ class DAGScheduler( // To avoid UI cruft, ignore cases where value wasn't updated if (acc.name.isDefined && partialValue != acc.zero) { val name = acc.name.get - stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, s"${acc.value}") + val value = s"${acc.value}" + stage.latestInfo.accumulables(id) = + new AccumulableInfo(id, name, None, value, acc.isInternal) event.taskInfo.accumulables += - AccumulableInfo(id, name, Some(s"$partialValue"), s"${acc.value}") + new AccumulableInfo(id, name, Some(s"$partialValue"), value, acc.isInternal) } } } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 9c2606e278c54..c4dc080e2b22b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -45,8 +45,10 @@ private[spark] class ResultTask[T, U]( taskBinary: Broadcast[Array[Byte]], partition: Partition, @transient locs: Seq[TaskLocation], - val outputId: Int) - extends Task[U](stageId, stageAttemptId, partition.index) with Serializable { + val outputId: Int, + internalAccumulators: Seq[Accumulator[Long]]) + extends Task[U](stageId, stageAttemptId, partition.index, internalAccumulators) + with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 14c8c00961487..f478f9982afef 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -43,12 +43,14 @@ private[spark] class ShuffleMapTask( stageAttemptId: Int, taskBinary: Broadcast[Array[Byte]], partition: Partition, - @transient private var locs: Seq[TaskLocation]) - extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging { + @transient private var locs: Seq[TaskLocation], + internalAccumulators: Seq[Accumulator[Long]]) + extends Task[MapStatus](stageId, stageAttemptId, partition.index, internalAccumulators) + with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ def this(partitionId: Int) { - this(0, 0, null, new Partition { override def index: Int = 0 }, null) + this(0, 0, null, new Partition { override def index: Int = 0 }, null, null) } @transient private val preferredLocs: Seq[TaskLocation] = { @@ -69,7 +71,7 @@ private[spark] class ShuffleMapTask( val manager = SparkEnv.get.shuffleManager writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) - return writer.stop(success = true).get + writer.stop(success = true).get } catch { case e: Exception => try { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 40a333a3e06b2..de05ee256dbfc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -68,6 +68,22 @@ private[spark] abstract class Stage( val name = callSite.shortForm val details = callSite.longForm + private var _internalAccumulators: Seq[Accumulator[Long]] = Seq.empty + + /** Internal accumulators shared across all tasks in this stage. */ + def internalAccumulators: Seq[Accumulator[Long]] = _internalAccumulators + + /** + * Re-initialize the internal accumulators associated with this stage. + * + * This is called every time the stage is submitted, *except* when a subset of tasks + * belonging to this stage has already finished. Otherwise, reinitializing the internal + * accumulators here again will override partial values from the finished tasks. + */ + def resetInternalAccumulators(): Unit = { + _internalAccumulators = InternalAccumulator.create() + } + /** * Pointer to the [StageInfo] object for the most recent attempt. This needs to be initialized * here, before any attempts have actually been created, because the DAGScheduler uses this diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 1978305cfefbd..9edf9f048f9fd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -23,7 +23,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext} +import org.apache.spark.{Accumulator, SparkEnv, TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.unsafe.memory.TaskMemoryManager @@ -47,7 +47,8 @@ import org.apache.spark.util.Utils private[spark] abstract class Task[T]( val stageId: Int, val stageAttemptId: Int, - var partitionId: Int) extends Serializable { + val partitionId: Int, + internalAccumulators: Seq[Accumulator[Long]]) extends Serializable { /** * The key of the Map is the accumulator id and the value of the Map is the latest accumulator @@ -68,12 +69,13 @@ private[spark] abstract class Task[T]( metricsSystem: MetricsSystem) : (T, AccumulatorUpdates) = { context = new TaskContextImpl( - stageId = stageId, - partitionId = partitionId, - taskAttemptId = taskAttemptId, - attemptNumber = attemptNumber, - taskMemoryManager = taskMemoryManager, - metricsSystem = metricsSystem, + stageId, + partitionId, + taskAttemptId, + attemptNumber, + taskMemoryManager, + metricsSystem, + internalAccumulators, runningLocally = false) TaskContext.setTaskContext(context) context.taskMetrics.setHostname(Utils.localHostName()) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index de79fa56f017b..0c8f08f0f3b1b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle.hash -import org.apache.spark.{InterruptibleIterator, Logging, MapOutputTracker, SparkEnv, TaskContext} +import org.apache.spark._ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} @@ -100,8 +100,10 @@ private[spark] class HashShuffleReader[K, C]( // the ExternalSorter won't spill to disk. val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser)) sorter.insertAll(aggregatedIter) - context.taskMetrics.incMemoryBytesSpilled(sorter.memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) sorter.iterator case None => aggregatedIter diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index e2d25e36365fa..cb122eaed83d1 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -62,6 +62,13 @@ private[spark] object ToolTips { """Time that the executor spent paused for Java garbage collection while the task was running.""" + val PEAK_EXECUTION_MEMORY = + """Execution memory refers to the memory used by internal data structures created during + shuffles, aggregations and joins when Tungsten is enabled. The value of this accumulator + should be approximately the sum of the peak sizes across all such data structures created + in this task. For SQL jobs, this only tracks all unsafe operators, broadcast joins, and + external sort.""" + val JOB_TIMELINE = """Shows when jobs started and ended and when executors joined or left. Drag to scroll. Click Enable Zooming and use mouse wheel to zoom in/out.""" diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index cf04b5e59239b..3954c3d1ef894 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -26,6 +26,7 @@ import scala.xml.{Elem, Node, Unparsed} import org.apache.commons.lang3.StringEscapeUtils +import org.apache.spark.{InternalAccumulator, SparkConf} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} import org.apache.spark.ui._ @@ -67,6 +68,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { // if we find that it's okay. private val MAX_TIMELINE_TASKS = parent.conf.getInt("spark.ui.timeline.tasks.maximum", 1000) + private val displayPeakExecutionMemory = + parent.conf.getOption("spark.sql.unsafe.enabled").exists(_.toBoolean) def render(request: HttpServletRequest): Seq[Node] = { progressListener.synchronized { @@ -114,10 +117,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val stageData = stageDataOption.get val tasks = stageData.taskData.values.toSeq.sortBy(_.taskInfo.launchTime) - val numCompleted = tasks.count(_.taskInfo.finished) - val accumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables - val hasAccumulators = accumulables.size > 0 + + val allAccumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables + val externalAccumulables = allAccumulables.values.filter { acc => !acc.internal } + val hasAccumulators = externalAccumulables.size > 0 val summary =
@@ -221,6 +225,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { Getting Result Time + {if (displayPeakExecutionMemory) { +
  • + + + Peak Execution Memory + +
  • + }}
    @@ -241,11 +254,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val accumulableTable = UIUtils.listingTable( accumulableHeaders, accumulableRow, - accumulables.values.toSeq) + externalAccumulables.toSeq) val currentTime = System.currentTimeMillis() val (taskTable, taskTableHTML) = try { val _taskTable = new TaskPagedTable( + parent.conf, UIUtils.prependBaseUri(parent.basePath) + s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}", tasks, @@ -294,12 +308,14 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { else { def getDistributionQuantiles(data: Seq[Double]): IndexedSeq[Double] = Distribution(data).get.getQuantiles() - def getFormattedTimeQuantiles(times: Seq[Double]): Seq[Node] = { getDistributionQuantiles(times).map { millis => {UIUtils.formatDuration(millis.toLong)} } } + def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = { + getDistributionQuantiles(data).map(d => {Utils.bytesToString(d.toLong)}) + } val deserializationTimes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.executorDeserializeTime.toDouble @@ -349,6 +365,23 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +: getFormattedTimeQuantiles(gettingResultTimes) + + val peakExecutionMemory = validTasks.map { case TaskUIData(info, _, _) => + info.accumulables + .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY } + .map { acc => acc.value.toLong } + .getOrElse(0L) + .toDouble + } + val peakExecutionMemoryQuantiles = { + + + Peak Execution Memory + + +: getFormattedSizeQuantiles(peakExecutionMemory) + } + // The scheduler delay includes the network delay to send the task to the worker // machine and to send back the result (but not the time to fetch the task result, // if it needed to be fetched from the block manager on the worker). @@ -359,10 +392,6 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { title={ToolTips.SCHEDULER_DELAY} data-placement="right">Scheduler Delay val schedulerDelayQuantiles = schedulerDelayTitle +: getFormattedTimeQuantiles(schedulerDelays) - - def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = - getDistributionQuantiles(data).map(d => {Utils.bytesToString(d.toLong)}) - def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double]) : Seq[Elem] = { val recordDist = getDistributionQuantiles(records).iterator @@ -466,6 +495,13 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {serializationQuantiles} , {gettingResultQuantiles}, + if (displayPeakExecutionMemory) { + + {peakExecutionMemoryQuantiles} + + } else { + Nil + }, if (stageData.hasInput) {inputQuantiles} else Nil, if (stageData.hasOutput) {outputQuantiles} else Nil, if (stageData.hasShuffleRead) { @@ -499,7 +535,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val executorTable = new ExecutorTable(stageId, stageAttemptId, parent) val maybeAccumulableTable: Seq[Node] = - if (accumulables.size > 0) {

    Accumulators

    ++ accumulableTable } else Seq() + if (hasAccumulators) {

    Accumulators

    ++ accumulableTable } else Seq() val content = summary ++ @@ -750,29 +786,30 @@ private[ui] case class TaskTableRowBytesSpilledData( * Contains all data that needs for sorting and generating HTML. Using this one rather than * TaskUIData to avoid creating duplicate contents during sorting the data. */ -private[ui] case class TaskTableRowData( - index: Int, - taskId: Long, - attempt: Int, - speculative: Boolean, - status: String, - taskLocality: String, - executorIdAndHost: String, - launchTime: Long, - duration: Long, - formatDuration: String, - schedulerDelay: Long, - taskDeserializationTime: Long, - gcTime: Long, - serializationTime: Long, - gettingResultTime: Long, - accumulators: Option[String], // HTML - input: Option[TaskTableRowInputData], - output: Option[TaskTableRowOutputData], - shuffleRead: Option[TaskTableRowShuffleReadData], - shuffleWrite: Option[TaskTableRowShuffleWriteData], - bytesSpilled: Option[TaskTableRowBytesSpilledData], - error: String) +private[ui] class TaskTableRowData( + val index: Int, + val taskId: Long, + val attempt: Int, + val speculative: Boolean, + val status: String, + val taskLocality: String, + val executorIdAndHost: String, + val launchTime: Long, + val duration: Long, + val formatDuration: String, + val schedulerDelay: Long, + val taskDeserializationTime: Long, + val gcTime: Long, + val serializationTime: Long, + val gettingResultTime: Long, + val peakExecutionMemoryUsed: Long, + val accumulators: Option[String], // HTML + val input: Option[TaskTableRowInputData], + val output: Option[TaskTableRowOutputData], + val shuffleRead: Option[TaskTableRowShuffleReadData], + val shuffleWrite: Option[TaskTableRowShuffleWriteData], + val bytesSpilled: Option[TaskTableRowBytesSpilledData], + val error: String) private[ui] class TaskDataSource( tasks: Seq[TaskUIData], @@ -816,10 +853,15 @@ private[ui] class TaskDataSource( val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) val gettingResultTime = getGettingResultTime(info, currentTime) - val maybeAccumulators = info.accumulables - val accumulatorsReadable = maybeAccumulators.map { acc => + val (taskInternalAccumulables, taskExternalAccumulables) = + info.accumulables.partition(_.internal) + val externalAccumulableReadable = taskExternalAccumulables.map { acc => StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}") } + val peakExecutionMemoryUsed = taskInternalAccumulables + .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY } + .map { acc => acc.value.toLong } + .getOrElse(0L) val maybeInput = metrics.flatMap(_.inputMetrics) val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L) @@ -923,7 +965,7 @@ private[ui] class TaskDataSource( None } - TaskTableRowData( + new TaskTableRowData( info.index, info.taskId, info.attempt, @@ -939,7 +981,8 @@ private[ui] class TaskDataSource( gcTime, serializationTime, gettingResultTime, - if (hasAccumulators) Some(accumulatorsReadable.mkString("
    ")) else None, + peakExecutionMemoryUsed, + if (hasAccumulators) Some(externalAccumulableReadable.mkString("
    ")) else None, input, output, shuffleRead, @@ -1006,6 +1049,10 @@ private[ui] class TaskDataSource( override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = Ordering.Long.compare(x.gettingResultTime, y.gettingResultTime) } + case "Peak Execution Memory" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.peakExecutionMemoryUsed, y.peakExecutionMemoryUsed) + } case "Accumulators" => if (hasAccumulators) { new Ordering[TaskTableRowData] { @@ -1132,6 +1179,7 @@ private[ui] class TaskDataSource( } private[ui] class TaskPagedTable( + conf: SparkConf, basePath: String, data: Seq[TaskUIData], hasAccumulators: Boolean, @@ -1143,7 +1191,11 @@ private[ui] class TaskPagedTable( currentTime: Long, pageSize: Int, sortColumn: String, - desc: Boolean) extends PagedTable[TaskTableRowData]{ + desc: Boolean) extends PagedTable[TaskTableRowData] { + + // We only track peak memory used for unsafe operators + private val displayPeakExecutionMemory = + conf.getOption("spark.sql.unsafe.enabled").exists(_.toBoolean) override def tableId: String = "" @@ -1195,6 +1247,13 @@ private[ui] class TaskPagedTable( ("GC Time", ""), ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++ + { + if (displayPeakExecutionMemory) { + Seq(("Peak Execution Memory", TaskDetailsClassNames.PEAK_EXECUTION_MEMORY)) + } else { + Nil + } + } ++ {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ {if (hasInput) Seq(("Input Size / Records", "")) else Nil} ++ {if (hasOutput) Seq(("Output Size / Records", "")) else Nil} ++ @@ -1271,6 +1330,11 @@ private[ui] class TaskPagedTable( {UIUtils.formatDuration(task.gettingResultTime)} + {if (displayPeakExecutionMemory) { + + {Utils.bytesToString(task.peakExecutionMemoryUsed)} + + }} {if (task.accumulators.nonEmpty) { {Unparsed(task.accumulators.get)} }} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala index 9bf67db8acde1..d2dfc5a32915c 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala @@ -31,4 +31,5 @@ private[spark] object TaskDetailsClassNames { val SHUFFLE_READ_REMOTE_SIZE = "shuffle_read_remote" val RESULT_SERIALIZATION_TIME = "serialization_time" val GETTING_RESULT_TIME = "getting_result_time" + val PEAK_EXECUTION_MEMORY = "peak_execution_memory" } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index d166037351c31..f929b12606f0a 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -89,6 +89,7 @@ class ExternalAppendOnlyMap[K, V, C]( // Number of bytes spilled in total private var _diskBytesSpilled = 0L + def diskBytesSpilled: Long = _diskBytesSpilled // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val fileBufferSize = @@ -97,6 +98,10 @@ class ExternalAppendOnlyMap[K, V, C]( // Write metrics for current spill private var curWriteMetrics: ShuffleWriteMetrics = _ + // Peak size of the in-memory map observed so far, in bytes + private var _peakMemoryUsedBytes: Long = 0L + def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes + private val keyComparator = new HashComparator[K] private val ser = serializer.newInstance() @@ -126,7 +131,11 @@ class ExternalAppendOnlyMap[K, V, C]( while (entries.hasNext) { curEntry = entries.next() - if (maybeSpill(currentMap, currentMap.estimateSize())) { + val estimatedSize = currentMap.estimateSize() + if (estimatedSize > _peakMemoryUsedBytes) { + _peakMemoryUsedBytes = estimatedSize + } + if (maybeSpill(currentMap, estimatedSize)) { currentMap = new SizeTrackingAppendOnlyMap[K, C] } currentMap.changeValue(curEntry._1, update) @@ -207,8 +216,6 @@ class ExternalAppendOnlyMap[K, V, C]( spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) } - def diskBytesSpilled: Long = _diskBytesSpilled - /** * Return an iterator that merges the in-memory map with the spilled maps. * If no spill has occurred, simply return the in-memory map's iterator. diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index ba7ec834d622d..19287edbaf166 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -152,6 +152,9 @@ private[spark] class ExternalSorter[K, V, C]( private var _diskBytesSpilled = 0L def diskBytesSpilled: Long = _diskBytesSpilled + // Peak size of the in-memory data structure observed so far, in bytes + private var _peakMemoryUsedBytes: Long = 0L + def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes // A comparator for keys K that orders them within a partition to allow aggregation or sorting. // Can be a partial ordering by hash code if a total ordering is not provided through by the @@ -224,15 +227,22 @@ private[spark] class ExternalSorter[K, V, C]( return } + var estimatedSize = 0L if (usingMap) { - if (maybeSpill(map, map.estimateSize())) { + estimatedSize = map.estimateSize() + if (maybeSpill(map, estimatedSize)) { map = new PartitionedAppendOnlyMap[K, C] } } else { - if (maybeSpill(buffer, buffer.estimateSize())) { + estimatedSize = buffer.estimateSize() + if (maybeSpill(buffer, estimatedSize)) { buffer = newBuffer() } } + + if (estimatedSize > _peakMemoryUsedBytes) { + _peakMemoryUsedBytes = estimatedSize + } } /** @@ -684,8 +694,10 @@ private[spark] class ExternalSorter[K, V, C]( } } - context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) + context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemoryUsedBytes) lengths } diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index e948ca33471a4..ffe4b4baffb2a 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -51,7 +51,6 @@ import org.apache.spark.api.java.*; import org.apache.spark.api.java.function.*; -import org.apache.spark.executor.TaskMetrics; import org.apache.spark.input.PortableDataStream; import org.apache.spark.partial.BoundedDouble; import org.apache.spark.partial.PartialResult; @@ -1011,7 +1010,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, null, false, new TaskMetrics()); + TaskContext context = TaskContext$.MODULE$.empty(); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 04fc09b323dbb..98c32bbc298d7 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -190,6 +190,7 @@ public Tuple2 answer( }); when(taskContext.taskMetrics()).thenReturn(taskMetrics); + when(taskContext.internalMetricsToAccumulators()).thenReturn(null); when(shuffleDep.serializer()).thenReturn(Option.apply(serializer)); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); @@ -542,4 +543,57 @@ public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException { writer.stop(false); assertSpillFilesWereCleanedUp(); } + + @Test + public void testPeakMemoryUsed() throws Exception { + final long recordLengthBytes = 8; + final long pageSizeBytes = 256; + final long numRecordsPerPage = pageSizeBytes / recordLengthBytes; + final SparkConf conf = new SparkConf().set("spark.buffer.pageSize", pageSizeBytes + "b"); + final UnsafeShuffleWriter writer = + new UnsafeShuffleWriter( + blockManager, + shuffleBlockResolver, + taskMemoryManager, + shuffleMemoryManager, + new UnsafeShuffleHandle(0, 1, shuffleDep), + 0, // map id + taskContext, + conf); + + // Peak memory should be monotonically increasing. More specifically, every time + // we allocate a new page it should increase by exactly the size of the page. + long previousPeakMemory = writer.getPeakMemoryUsedBytes(); + long newPeakMemory; + try { + for (int i = 0; i < numRecordsPerPage * 10; i++) { + writer.insertRecordIntoSorter(new Tuple2(1, 1)); + newPeakMemory = writer.getPeakMemoryUsedBytes(); + if (i % numRecordsPerPage == 0) { + // We allocated a new page for this record, so peak memory should change + assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); + } else { + assertEquals(previousPeakMemory, newPeakMemory); + } + previousPeakMemory = newPeakMemory; + } + + // Spilling should not change peak memory + writer.forceSorterToSpill(); + newPeakMemory = writer.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + for (int i = 0; i < numRecordsPerPage; i++) { + writer.insertRecordIntoSorter(new Tuple2(1, 1)); + } + newPeakMemory = writer.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + + // Closing the writer should not change peak memory + writer.closeAndWriteOutput(); + newPeakMemory = writer.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + } finally { + writer.stop(false); + } + } } diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index dbb7c662d7871..0e23a64fb74bb 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -25,6 +25,7 @@ import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import static org.hamcrest.Matchers.greaterThan; +import static org.junit.Assert.*; import static org.mockito.AdditionalMatchers.geq; import static org.mockito.Mockito.*; @@ -495,4 +496,42 @@ public void resizingLargeMap() { map.growAndRehash(); map.free(); } + + @Test + public void testTotalMemoryConsumption() { + final long recordLengthBytes = 24; + final long pageSizeBytes = 256 + 8; // 8 bytes for end-of-page marker + final long numRecordsPerPage = (pageSizeBytes - 8) / recordLengthBytes; + final BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, shuffleMemoryManager, 1024, pageSizeBytes); + + // Since BytesToBytesMap is append-only, we expect the total memory consumption to be + // monotonically increasing. More specifically, every time we allocate a new page it + // should increase by exactly the size of the page. In this regard, the memory usage + // at any given time is also the peak memory used. + long previousMemory = map.getTotalMemoryConsumption(); + long newMemory; + try { + for (long i = 0; i < numRecordsPerPage * 10; i++) { + final long[] value = new long[]{i}; + map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8).putNewKey( + value, + PlatformDependent.LONG_ARRAY_OFFSET, + 8, + value, + PlatformDependent.LONG_ARRAY_OFFSET, + 8); + newMemory = map.getTotalMemoryConsumption(); + if (i % numRecordsPerPage == 0) { + // We allocated a new page for this record, so peak memory should change + assertEquals(previousMemory + pageSizeBytes, newMemory); + } else { + assertEquals(previousMemory, newMemory); + } + previousMemory = newMemory; + } + } finally { + map.free(); + } + } } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 52fa8bcd57e79..c11949d57a0ea 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -247,4 +247,50 @@ public void testFillingPage() throws Exception { assertSpillFilesWereCleanedUp(); } + @Test + public void testPeakMemoryUsed() throws Exception { + final long recordLengthBytes = 8; + final long pageSizeBytes = 256; + final long numRecordsPerPage = pageSizeBytes / recordLengthBytes; + final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( + taskMemoryManager, + shuffleMemoryManager, + blockManager, + taskContext, + recordComparator, + prefixComparator, + 1024, + pageSizeBytes); + + // Peak memory should be monotonically increasing. More specifically, every time + // we allocate a new page it should increase by exactly the size of the page. + long previousPeakMemory = sorter.getPeakMemoryUsedBytes(); + long newPeakMemory; + try { + for (int i = 0; i < numRecordsPerPage * 10; i++) { + insertNumber(sorter, i); + newPeakMemory = sorter.getPeakMemoryUsedBytes(); + if (i % numRecordsPerPage == 0) { + // We allocated a new page for this record, so peak memory should change + assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); + } else { + assertEquals(previousPeakMemory, newPeakMemory); + } + previousPeakMemory = newPeakMemory; + } + + // Spilling should not change peak memory + sorter.spill(); + newPeakMemory = sorter.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + for (int i = 0; i < numRecordsPerPage; i++) { + insertNumber(sorter, i); + } + newPeakMemory = sorter.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + } finally { + sorter.freeMemory(); + } + } + } diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index e942d6579b2fd..48f549575f4d1 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -18,13 +18,17 @@ package org.apache.spark import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import scala.ref.WeakReference import org.scalatest.Matchers +import org.scalatest.exceptions.TestFailedException +import org.apache.spark.scheduler._ -class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext { +class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext { + import InternalAccumulator._ implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] = new AccumulableParam[mutable.Set[A], A] { @@ -155,4 +159,191 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex assert(!Accumulators.originals.get(accId).isDefined) } + test("internal accumulators in TaskContext") { + val accums = InternalAccumulator.create() + val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null, accums) + val internalMetricsToAccums = taskContext.internalMetricsToAccumulators + val collectedInternalAccums = taskContext.collectInternalAccumulators() + val collectedAccums = taskContext.collectAccumulators() + assert(internalMetricsToAccums.size > 0) + assert(internalMetricsToAccums.values.forall(_.isInternal)) + assert(internalMetricsToAccums.contains(TEST_ACCUMULATOR)) + val testAccum = internalMetricsToAccums(TEST_ACCUMULATOR) + assert(collectedInternalAccums.size === internalMetricsToAccums.size) + assert(collectedInternalAccums.size === collectedAccums.size) + assert(collectedInternalAccums.contains(testAccum.id)) + assert(collectedAccums.contains(testAccum.id)) + } + + test("internal accumulators in a stage") { + val listener = new SaveInfoListener + val numPartitions = 10 + sc = new SparkContext("local", "test") + sc.addSparkListener(listener) + // Have each task add 1 to the internal accumulator + sc.parallelize(1 to 100, numPartitions).mapPartitions { iter => + TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 + iter + }.count() + val stageInfos = listener.getCompletedStageInfos + val taskInfos = listener.getCompletedTaskInfos + assert(stageInfos.size === 1) + assert(taskInfos.size === numPartitions) + // The accumulator values should be merged in the stage + val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) + assert(stageAccum.value.toLong === numPartitions) + // The accumulator should be updated locally on each task + val taskAccumValues = taskInfos.map { taskInfo => + val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) + assert(taskAccum.update.isDefined) + assert(taskAccum.update.get.toLong === 1) + taskAccum.value.toLong + } + // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions + assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + } + + test("internal accumulators in multiple stages") { + val listener = new SaveInfoListener + val numPartitions = 10 + sc = new SparkContext("local", "test") + sc.addSparkListener(listener) + // Each stage creates its own set of internal accumulators so the + // values for the same metric should not be mixed up across stages + sc.parallelize(1 to 100, numPartitions) + .map { i => (i, i) } + .mapPartitions { iter => + TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 + iter + } + .reduceByKey { case (x, y) => x + y } + .mapPartitions { iter => + TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 10 + iter + } + .repartition(numPartitions * 2) + .mapPartitions { iter => + TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 100 + iter + } + .count() + // We ran 3 stages, and the accumulator values should be distinct + val stageInfos = listener.getCompletedStageInfos + assert(stageInfos.size === 3) + val firstStageAccum = findAccumulableInfo(stageInfos(0).accumulables.values, TEST_ACCUMULATOR) + val secondStageAccum = findAccumulableInfo(stageInfos(1).accumulables.values, TEST_ACCUMULATOR) + val thirdStageAccum = findAccumulableInfo(stageInfos(2).accumulables.values, TEST_ACCUMULATOR) + assert(firstStageAccum.value.toLong === numPartitions) + assert(secondStageAccum.value.toLong === numPartitions * 10) + assert(thirdStageAccum.value.toLong === numPartitions * 2 * 100) + } + + test("internal accumulators in fully resubmitted stages") { + testInternalAccumulatorsWithFailedTasks((i: Int) => true) // fail all tasks + } + + test("internal accumulators in partially resubmitted stages") { + testInternalAccumulatorsWithFailedTasks((i: Int) => i % 2 == 0) // fail a subset + } + + /** + * Return the accumulable info that matches the specified name. + */ + private def findAccumulableInfo( + accums: Iterable[AccumulableInfo], + name: String): AccumulableInfo = { + accums.find { a => a.name == name }.getOrElse { + throw new TestFailedException(s"internal accumulator '$name' not found", 0) + } + } + + /** + * Test whether internal accumulators are merged properly if some tasks fail. + */ + private def testInternalAccumulatorsWithFailedTasks(failCondition: (Int => Boolean)): Unit = { + val listener = new SaveInfoListener + val numPartitions = 10 + val numFailedPartitions = (0 until numPartitions).count(failCondition) + // This says use 1 core and retry tasks up to 2 times + sc = new SparkContext("local[1, 2]", "test") + sc.addSparkListener(listener) + sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) => + val taskContext = TaskContext.get() + taskContext.internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 + // Fail the first attempts of a subset of the tasks + if (failCondition(i) && taskContext.attemptNumber() == 0) { + throw new Exception("Failing a task intentionally.") + } + iter + }.count() + val stageInfos = listener.getCompletedStageInfos + val taskInfos = listener.getCompletedTaskInfos + assert(stageInfos.size === 1) + assert(taskInfos.size === numPartitions + numFailedPartitions) + val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) + // We should not double count values in the merged accumulator + assert(stageAccum.value.toLong === numPartitions) + val taskAccumValues = taskInfos.flatMap { taskInfo => + if (!taskInfo.failed) { + // If a task succeeded, its update value should always be 1 + val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) + assert(taskAccum.update.isDefined) + assert(taskAccum.update.get.toLong === 1) + Some(taskAccum.value.toLong) + } else { + // If a task failed, we should not get its accumulator values + assert(taskInfo.accumulables.isEmpty) + None + } + } + assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + } + +} + +private[spark] object AccumulatorSuite { + + /** + * Run one or more Spark jobs and verify that the peak execution memory accumulator + * is updated afterwards. + */ + def verifyPeakExecutionMemorySet( + sc: SparkContext, + testName: String)(testBody: => Unit): Unit = { + val listener = new SaveInfoListener + sc.addSparkListener(listener) + // Verify that the accumulator does not already exist + sc.parallelize(1 to 10).count() + val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values) + assert(!accums.exists(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)) + testBody + // Verify that peak execution memory is updated + val accum = listener.getCompletedStageInfos + .flatMap(_.accumulables.values) + .find(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY) + .getOrElse { + throw new TestFailedException( + s"peak execution memory accumulator not set in '$testName'", 0) + } + assert(accum.value.toLong > 0) + } +} + +/** + * A simple listener that keeps track of the TaskInfos and StageInfos of all completed jobs. + */ +private class SaveInfoListener extends SparkListener { + private val completedStageInfos: ArrayBuffer[StageInfo] = new ArrayBuffer[StageInfo] + private val completedTaskInfos: ArrayBuffer[TaskInfo] = new ArrayBuffer[TaskInfo] + + def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq + def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.toArray.toSeq + + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + completedStageInfos += stageCompleted.stageInfo + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + completedTaskInfos += taskEnd.taskInfo + } } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 618a5fb24710f..cb8bd04e496a7 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -21,7 +21,7 @@ import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter import org.scalatest.mock.MockitoSugar -import org.apache.spark.executor.DataReadMethod +import org.apache.spark.executor.{DataReadMethod, TaskMetrics} import org.apache.spark.rdd.RDD import org.apache.spark.storage._ @@ -65,7 +65,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before // in blockManager.put is a losing battle. You have been warned. blockManager = sc.env.blockManager cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0, null, null) + val context = TaskContext.empty() val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) val getValue = blockManager.get(RDDBlockId(rdd.id, split.index)) assert(computeValue.toList === List(1, 2, 3, 4)) @@ -77,7 +77,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12) when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result)) - val context = new TaskContextImpl(0, 0, 0, 0, null, null) + val context = TaskContext.empty() val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(5, 6, 7)) } @@ -86,14 +86,14 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before // Local computation should not persist the resulting value, so don't expect a put(). when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None) - val context = new TaskContextImpl(0, 0, 0, 0, null, null, true) + val context = new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty, runningLocally = true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } test("verify task metrics updated correctly") { cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0, null, null) + val context = TaskContext.empty() cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 3e8816a4c65be..5f73ec8675966 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -175,7 +175,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { } val hadoopPart1 = generateFakeHadoopPartition() val pipedRdd = new PipedRDD(nums, "printenv " + varName) - val tContext = new TaskContextImpl(0, 0, 0, 0, null, null) + val tContext = TaskContext.empty() val rddIter = pipedRdd.compute(hadoopPart1, tContext) val arr = rddIter.toArray assert(arr(0) == "/some/path") diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index b3ca150195a5f..f7e16af9d3a92 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -19,9 +19,11 @@ package org.apache.spark.scheduler import org.apache.spark.TaskContext -class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, 0) { +class FakeTask( + stageId: Int, + prefLocs: Seq[TaskLocation] = Nil) + extends Task[Int](stageId, 0, 0, Seq.empty) { override def runTask(context: TaskContext): Int = 0 - override def preferredLocations: Seq[TaskLocation] = prefLocs } diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala index 383855caefa2f..f33324792495b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala @@ -25,7 +25,7 @@ import org.apache.spark.TaskContext * A Task implementation that fails to serialize. */ private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) - extends Task[Array[Byte]](stageId, 0, 0) { + extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) { override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte] override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]() diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 9201d1e1f328b..450ab7b9fe92b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -57,8 +57,9 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark } val closureSerializer = SparkEnv.get.closureSerializer.newInstance() val func = (c: TaskContext, i: Iterator[String]) => i.next() - val task = new ResultTask[String, String](0, 0, - sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0) + val taskBinary = sc.broadcast(closureSerializer.serialize((rdd, func)).array) + val task = new ResultTask[String, String]( + 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty) intercept[RuntimeException] { task.run(0, 0, null) } @@ -66,7 +67,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark } test("all TaskCompletionListeners should be called even if some fail") { - val context = new TaskContextImpl(0, 0, 0, 0, null, null) + val context = TaskContext.empty() val listener = mock(classOf[TaskCompletionListener]) context.addTaskCompletionListener(_ => throw new Exception("blah")) context.addTaskCompletionListener(listener) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 3abb99c4b2b54..f7cc4bb61d574 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -136,7 +136,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex /** * A Task implementation that results in a large serialized task. */ -class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0) { +class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) { val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024) val random = new Random(0) random.nextBytes(randomBuffer) diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala index db718ecabbdb9..05b3afef5b839 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala @@ -138,7 +138,7 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { shuffleHandle, reduceId, reduceId + 1, - new TaskContextImpl(0, 0, 0, 0, null, null), + TaskContext.empty(), blockManager, mapOutputTracker) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index cf8bd8ae69625..828153bdbfc44 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -29,7 +29,7 @@ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.PrivateMethodTester -import org.apache.spark.{SparkFunSuite, TaskContextImpl} +import org.apache.spark.{SparkFunSuite, TaskContext} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.BlockFetchingListener @@ -95,7 +95,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ) val iterator = new ShuffleBlockFetcherIterator( - new TaskContextImpl(0, 0, 0, 0, null, null), + TaskContext.empty(), transfer, blockManager, blocksByAddress, @@ -165,7 +165,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null) + val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, @@ -227,7 +227,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null) + val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala new file mode 100644 index 0000000000000..98f9314f31dff --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite, Success} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler._ +import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab} +import org.apache.spark.ui.scope.RDDOperationGraphListener + +class StagePageSuite extends SparkFunSuite with LocalSparkContext { + + test("peak execution memory only displayed if unsafe is enabled") { + val unsafeConf = "spark.sql.unsafe.enabled" + val conf = new SparkConf().set(unsafeConf, "true") + val html = renderStagePage(conf).toString().toLowerCase + val targetString = "peak execution memory" + assert(html.contains(targetString)) + // Disable unsafe and make sure it's not there + val conf2 = new SparkConf().set(unsafeConf, "false") + val html2 = renderStagePage(conf2).toString().toLowerCase + assert(!html2.contains(targetString)) + } + + /** + * Render a stage page started with the given conf and return the HTML. + * This also runs a dummy stage to populate the page with useful content. + */ + private def renderStagePage(conf: SparkConf): Seq[Node] = { + val jobListener = new JobProgressListener(conf) + val graphListener = new RDDOperationGraphListener(conf) + val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS) + val request = mock(classOf[HttpServletRequest]) + when(tab.conf).thenReturn(conf) + when(tab.progressListener).thenReturn(jobListener) + when(tab.operationGraphListener).thenReturn(graphListener) + when(tab.appName).thenReturn("testing") + when(tab.headerTabs).thenReturn(Seq.empty) + when(request.getParameter("id")).thenReturn("0") + when(request.getParameter("attempt")).thenReturn("0") + val page = new StagePage(tab) + + // Simulate a stage in job progress listener + val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details") + val taskInfo = new TaskInfo(0, 0, 0, 0, "0", "localhost", TaskLocality.ANY, false) + jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) + jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) + taskInfo.markSuccessful() + jobListener.onTaskEnd( + SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty)) + jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo)) + page.render(request) + } + +} diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 9c362f0de7076..12e9bafcc92c1 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -399,4 +399,19 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { sc.stop() } + test("external aggregation updates peak execution memory") { + val conf = createSparkConf(loadDefaults = false) + .set("spark.shuffle.memoryFraction", "0.001") + .set("spark.shuffle.manager", "hash") // make sure we're not also using ExternalSorter + sc = new SparkContext("local", "test", conf) + // No spilling + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map without spilling") { + sc.parallelize(1 to 10, 2).map { i => (i, i) }.reduceByKey(_ + _).count() + } + // With spilling + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map with spilling") { + sc.parallelize(1 to 1000 * 1000, 2).map { i => (i, i) }.reduceByKey(_ + _).count() + } + } + } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 986cd8623d145..bdb0f4d507a7e 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -692,7 +692,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { sortWithoutBreakingSortingContracts(createSparkConf(true, false)) } - def sortWithoutBreakingSortingContracts(conf: SparkConf) { + private def sortWithoutBreakingSortingContracts(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.01") conf.set("spark.shuffle.manager", "sort") sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) @@ -743,5 +743,15 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } sorter2.stop() - } + } + + test("sorting updates peak execution memory") { + val conf = createSparkConf(loadDefaults = false, kryo = false) + .set("spark.shuffle.manager", "sort") + sc = new SparkContext("local", "test", conf) + // Avoid aggregating here to make sure we're not also using ExternalAppendOnlyMap + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sorter") { + sc.parallelize(1 to 1000, 2).repartition(100).count() + } + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 5e4c6232c9471..193906d24790e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -106,6 +106,13 @@ void spill() throws IOException { sorter.spill(); } + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsage() { + return sorter.getPeakMemoryUsedBytes(); + } + private void cleanupResources() { sorter.freeMemory(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 9e2c9334a7bee..43d06ce9bdfa3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -208,6 +208,14 @@ public void close() { }; } + /** + * The memory used by this map's managed structures, in bytes. + * Note that this is also the peak memory used by this map, since the map is append-only. + */ + public long getMemoryUsage() { + return map.getTotalMemoryConsumption(); + } + /** * Free the memory associated with this map. This is idempotent and can be called multiple times. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index cd87b8deba0c2..bf4905dc1eef9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import java.io.IOException -import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -263,11 +263,12 @@ case class GeneratedAggregate( assert(iter.hasNext, "There should be at least one row for this path") log.info("Using Unsafe-based aggregator") val pageSizeBytes = SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m") + val taskContext = TaskContext.get() val aggregationMap = new UnsafeFixedWidthAggregationMap( newAggregationBuffer(EmptyRow), aggregationBufferSchema, groupKeySchema, - TaskContext.get.taskMemoryManager(), + taskContext.taskMemoryManager(), SparkEnv.get.shuffleMemoryManager, 1024 * 16, // initial capacity pageSizeBytes, @@ -284,6 +285,10 @@ case class GeneratedAggregate( updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow)) } + // Record memory used in the process + taskContext.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(aggregationMap.getMemoryUsage) + new Iterator[InternalRow] { private[this] val mapIterator = aggregationMap.iterator() private[this] val resultProjection = resultProjectionBuilder() @@ -300,7 +305,7 @@ case class GeneratedAggregate( } else { // This is the last element in the iterator, so let's free the buffer. Before we do, // though, we need to make a defensive copy of the result so that we don't return an - // object that might contain dangling pointers to the freed memory + // object that might contain dangling pointers to the freed memory. val resultCopy = result.copy() aggregationMap.free() resultCopy diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 624efc1b1d734..e73e2523a777f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ +import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -70,7 +71,14 @@ case class BroadcastHashJoin( val broadcastRelation = Await.result(broadcastFuture, timeout) streamedPlan.execute().mapPartitions { streamedIter => - hashJoin(streamedIter, broadcastRelation.value) + val hashedRelation = broadcastRelation.value + hashedRelation match { + case unsafe: UnsafeHashedRelation => + TaskContext.get().internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) + case _ => + } + hashJoin(streamedIter, hashedRelation) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index 309716a0efcc0..c35e439cc9deb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ +import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -75,6 +76,13 @@ case class BroadcastHashOuterJoin( val hashTable = broadcastRelation.value val keyGenerator = streamedKeyGenerator + hashTable match { + case unsafe: UnsafeHashedRelation => + TaskContext.get().internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) + case _ => + } + joinType match { case LeftOuter => streamedIter.flatMap(currentRow => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index a60593911f94f..5bd06fbdca605 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -51,7 +52,14 @@ case class BroadcastLeftSemiJoinHash( val broadcastedRelation = sparkContext.broadcast(hashRelation) left.execute().mapPartitions { streamIter => - hashSemiJoin(streamIter, broadcastedRelation.value) + val hashedRelation = broadcastedRelation.value + hashedRelation match { + case unsafe: UnsafeHashedRelation => + TaskContext.get().internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) + case _ => + } + hashSemiJoin(streamIter, hashedRelation) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index cc8bbfd2f8943..58b4236f7b5b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -183,8 +183,27 @@ private[joins] final class UnsafeHashedRelation( private[joins] def this() = this(null) // Needed for serialization // Use BytesToBytesMap in executor for better performance (it's created when deserialization) + // This is used in broadcast joins and distributed mode only @transient private[this] var binaryMap: BytesToBytesMap = _ + /** + * Return the size of the unsafe map on the executors. + * + * For broadcast joins, this hashed relation is bigger on the driver because it is + * represented as a Java hash map there. While serializing the map to the executors, + * however, we rehash the contents in a binary map to reduce the memory footprint on + * the executors. + * + * For non-broadcast joins or in local mode, return 0. + */ + def getUnsafeSize: Long = { + if (binaryMap != null) { + binaryMap.getTotalMemoryConsumption + } else { + 0 + } + } + override def get(key: InternalRow): Seq[InternalRow] = { val unsafeKey = key.asInstanceOf[UnsafeRow] @@ -214,7 +233,7 @@ private[joins] final class UnsafeHashedRelation( } } else { - // Use the JavaHashMap in Local mode or ShuffleHashJoin + // Use the Java HashMap in local mode or for non-broadcast joins (e.g. ShuffleHashJoin) hashTable.get(unsafeKey) } } @@ -316,6 +335,7 @@ private[joins] object UnsafeHashedRelation { keyGenerator: UnsafeProjection, sizeEstimate: Int): HashedRelation = { + // Use a Java hash table here because unsafe maps expect fixed size records val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate) // Create a mapping of buildKeys -> rows diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index 92cf328c76cbc..3192b6ebe9075 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ @@ -76,6 +77,11 @@ case class ExternalSort( val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering = Some(ordering)) sorter.insertAll(iterator.map(r => (r.copy(), null))) val baseIterator = sorter.iterator.map(_._1) + val context = TaskContext.get() + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) // TODO(marmbrus): The complex type signature below thwarts inference for no reason. CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop()) }, preservesPartitioning = true) @@ -137,7 +143,11 @@ case class TungstenSort( if (testSpillFrequency > 0) { sorter.setTestSpillFrequency(testSpillFrequency) } - sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + val taskContext = TaskContext.get() + taskContext.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage) + sortedIterator }, preservesPartitioning = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f1abae0720058..29dfcf2575227 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -21,6 +21,7 @@ import java.sql.Timestamp import org.scalatest.BeforeAndAfterAll +import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.errors.DialectException @@ -258,6 +259,23 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } } + private def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { + val df = sql(sqlText) + // First, check if we have GeneratedAggregate. + val hasGeneratedAgg = df.queryExecution.executedPlan + .collect { case _: GeneratedAggregate | _: aggregate.Aggregate => true } + .nonEmpty + if (!hasGeneratedAgg) { + fail( + s""" + |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan. + |${df.queryExecution.simpleString} + """.stripMargin) + } + // Then, check results. + checkAnswer(df, expectedResults) + } + test("aggregation with codegen") { val originalValue = sqlContext.conf.codegenEnabled sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) @@ -267,26 +285,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { .unionAll(sqlContext.table("testData")) .registerTempTable("testData3x") - def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { - val df = sql(sqlText) - // First, check if we have GeneratedAggregate. - var hasGeneratedAgg = false - df.queryExecution.executedPlan.foreach { - case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true - case newAggregate: aggregate.Aggregate => hasGeneratedAgg = true - case _ => - } - if (!hasGeneratedAgg) { - fail( - s""" - |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan. - |${df.queryExecution.simpleString} - """.stripMargin) - } - // Then, check results. - checkAnswer(df, expectedResults) - } - try { // Just to group rows. testCodeGen( @@ -1605,6 +1603,28 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123)))) } + test("aggregation with codegen updates peak execution memory") { + withSQLConf( + (SQLConf.CODEGEN_ENABLED.key, "true"), + (SQLConf.USE_SQL_AGGREGATE2.key, "false")) { + val sc = sqlContext.sparkContext + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "aggregation with codegen") { + testCodeGen( + "SELECT key, count(value) FROM testData GROUP BY key", + (1 to 100).map(i => Row(i, 1))) + } + } + } + + test("external sorting updates peak execution memory") { + withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) { + val sc = sqlContext.sparkContext + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sort") { + sortTest() + } + } + } + test("SPARK-9511: error with table starting with number") { val df = sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)) .toDF("num", "str") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala index c7949848513cf..88bce0e319f9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.scalatest.BeforeAndAfterAll +import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.test.TestSQLContext @@ -59,6 +60,17 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { ) } + test("sorting updates peak execution memory") { + val sc = TestSQLContext.sparkContext + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "unsafe external sort") { + checkThatPlansAgree( + (1 to 100).map(v => Tuple1(v)).toDF("a"), + (child: SparkPlan) => TungstenSort('a.asc :: Nil, true, child), + (child: SparkPlan) => Sort('a.asc :: Nil, global = true, child), + sortAnswers = false) + } + } + // Test sorting on different data types for ( dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 7c591f6143b9e..ef827b0fe9b5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -69,7 +69,8 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { taskAttemptId = Random.nextInt(10000), attemptNumber = 0, taskMemoryManager = taskMemoryManager, - metricsSystem = null)) + metricsSystem = null, + internalAccumulators = Seq.empty)) try { f diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 0282b25b9dd50..601a5a07ad002 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -76,7 +76,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite { taskAttemptId = 98456, attemptNumber = 0, taskMemoryManager = taskMemMgr, - metricsSystem = null)) + metricsSystem = null, + internalAccumulators = Seq.empty)) // Create the data converters val kExternalConverter = CatalystTypeConverters.createToCatalystConverter(keySchema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala new file mode 100644 index 0000000000000..0554e11d252ba --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -0,0 +1,94 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +// TODO: uncomment the test here! It is currently failing due to +// bad interaction with org.apache.spark.sql.test.TestSQLContext. + +// scalastyle:off +//package org.apache.spark.sql.execution.joins +// +//import scala.reflect.ClassTag +// +//import org.scalatest.BeforeAndAfterAll +// +//import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} +//import org.apache.spark.sql.functions._ +//import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest} +// +///** +// * Test various broadcast join operators with unsafe enabled. +// * +// * This needs to be its own suite because [[org.apache.spark.sql.test.TestSQLContext]] runs +// * in local mode, but for tests in this suite we need to run Spark in local-cluster mode. +// * In particular, the use of [[org.apache.spark.unsafe.map.BytesToBytesMap]] in +// * [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered without +// * serializing the hashed relation, which does not happen in local mode. +// */ +//class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { +// private var sc: SparkContext = null +// private var sqlContext: SQLContext = null +// +// /** +// * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled. +// */ +// override def beforeAll(): Unit = { +// super.beforeAll() +// val conf = new SparkConf() +// .setMaster("local-cluster[2,1,1024]") +// .setAppName("testing") +// sc = new SparkContext(conf) +// sqlContext = new SQLContext(sc) +// sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true) +// sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) +// } +// +// override def afterAll(): Unit = { +// sc.stop() +// sc = null +// sqlContext = null +// } +// +// /** +// * Test whether the specified broadcast join updates the peak execution memory accumulator. +// */ +// private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = { +// AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) { +// val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") +// val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") +// // Comparison at the end is for broadcast left semi join +// val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") +// val df3 = df1.join(broadcast(df2), joinExpression, joinType) +// val plan = df3.queryExecution.executedPlan +// assert(plan.collect { case p: T => p }.size === 1) +// plan.executeCollect() +// } +// } +// +// test("unsafe broadcast hash join updates peak execution memory") { +// testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash join", "inner") +// } +// +// test("unsafe broadcast hash outer join updates peak execution memory") { +// testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer") +// } +// +// test("unsafe broadcast left semi join updates peak execution memory") { +// testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi") +// } +// +//} +// scalastyle:on From b2e4b85d2db0320e9cbfaf5a5542f749f1f11cf4 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 3 Aug 2015 14:51:15 -0700 Subject: [PATCH 109/340] Revert "[SPARK-9372] [SQL] Filter nulls in join keys" This reverts commit 687c8c37150f4c93f8e57d86bb56321a4891286b. --- .../catalyst/expressions/nullFunctions.scala | 48 +--- .../sql/catalyst/optimizer/Optimizer.scala | 64 ++--- .../plans/logical/basicOperators.scala | 32 +-- .../expressions/ExpressionEvalHelper.scala | 4 +- .../expressions/MathFunctionsSuite.scala | 3 +- .../expressions/NullFunctionsSuite.scala | 49 +--- .../spark/sql/DataFrameNaFunctions.scala | 2 +- .../scala/org/apache/spark/sql/SQLConf.scala | 6 - .../org/apache/spark/sql/SQLContext.scala | 5 +- .../extendedOperatorOptimizations.scala | 160 ------------ .../optimizer/FilterNullsInJoinKeySuite.scala | 236 ------------------ 11 files changed, 37 insertions(+), 572 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index d58c4756938c7..287718fab7f0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -210,58 +210,14 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { } } -/** - * A predicate that is evaluated to be true if there are at least `n` null values. - */ -case class AtLeastNNulls(n: Int, children: Seq[Expression]) extends Predicate { - override def nullable: Boolean = false - override def foldable: Boolean = children.forall(_.foldable) - override def toString: String = s"AtLeastNNulls($n, ${children.mkString(",")})" - - private[this] val childrenArray = children.toArray - - override def eval(input: InternalRow): Boolean = { - var numNulls = 0 - var i = 0 - while (i < childrenArray.length && numNulls < n) { - val evalC = childrenArray(i).eval(input) - if (evalC == null) { - numNulls += 1 - } - i += 1 - } - numNulls >= n - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val numNulls = ctx.freshName("numNulls") - val code = children.map { e => - val eval = e.gen(ctx) - s""" - if ($numNulls < $n) { - ${eval.code} - if (${eval.isNull}) { - $numNulls += 1; - } - } - """ - }.mkString("\n") - s""" - int $numNulls = 0; - $code - boolean ${ev.isNull} = false; - boolean ${ev.primitive} = $numNulls >= $n; - """ - } -} /** * A predicate that is evaluated to be true if there are at least `n` non-null and non-NaN values. */ -case class AtLeastNNonNullNans(n: Int, children: Seq[Expression]) extends Predicate { +case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate { override def nullable: Boolean = false override def foldable: Boolean = children.forall(_.foldable) - override def toString: String = s"AtLeastNNonNullNans($n, ${children.mkString(",")})" + override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})" private[this] val childrenArray = children.toArray diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e4b6294dc7b8e..29d706dcb39a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -31,14 +31,8 @@ import org.apache.spark.sql.types._ abstract class Optimizer extends RuleExecutor[LogicalPlan] -class DefaultOptimizer extends Optimizer { - - /** - * Override to provide additional rules for the "Operator Optimizations" batch. - */ - val extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil - - lazy val batches = +object DefaultOptimizer extends Optimizer { + val batches = // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: @@ -47,27 +41,26 @@ class DefaultOptimizer extends Optimizer { RemoveLiteralFromGroupExpressions) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down - SetOperationPushDown :: - SamplePushDown :: - PushPredicateThroughJoin :: - PushPredicateThroughProject :: - PushPredicateThroughGenerate :: - ColumnPruning :: + SetOperationPushDown, + SamplePushDown, + PushPredicateThroughJoin, + PushPredicateThroughProject, + PushPredicateThroughGenerate, + ColumnPruning, // Operator combine - ProjectCollapsing :: - CombineFilters :: - CombineLimits :: + ProjectCollapsing, + CombineFilters, + CombineLimits, // Constant folding - NullPropagation :: - OptimizeIn :: - ConstantFolding :: - LikeSimplification :: - BooleanSimplification :: - RemovePositive :: - SimplifyFilters :: - SimplifyCasts :: - SimplifyCaseConversionExpressions :: - extendedOperatorOptimizationRules.toList : _*) :: + NullPropagation, + OptimizeIn, + ConstantFolding, + LikeSimplification, + BooleanSimplification, + RemovePositive, + SimplifyFilters, + SimplifyCasts, + SimplifyCaseConversionExpressions) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), @@ -229,18 +222,12 @@ object ColumnPruning extends Rule[LogicalPlan] { } /** Applies a projection only when the child is producing unnecessary attributes */ - private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = { + private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) { - // We need to preserve the nullability of c's output. - // So, we first create a outputMap and if a reference is from the output of - // c, we use that output attribute from c. - val outputMap = AttributeMap(c.output.map(attr => (attr, attr))) - val projectList = allReferences.filter(outputMap.contains).map(outputMap).toSeq - Project(projectList, c) + Project(allReferences.filter(c.outputSet.contains).toSeq, c) } else { c } - } } /** @@ -530,13 +517,6 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { */ object CombineFilters extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Filter(Not(AtLeastNNulls(1, e1)), Filter(Not(AtLeastNNulls(1, e2)), grandChild)) => - // If we are combining two expressions Not(AtLeastNNulls(1, e1)) and - // Not(AtLeastNNulls(1, e2)) - // (this is used to make sure there is no null in the result of e1 and e2 and - // they are added by FilterNullsInJoinKey optimziation rule), we can - // just create a Not(AtLeastNNulls(1, (e1 ++ e2).distinct)). - Filter(Not(AtLeastNNulls(1, (e1 ++ e2).distinct)), grandChild) case ff @ Filter(fc, nf @ Filter(nc, grandChild)) => Filter(And(nc, fc), grandChild) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 54b5f49772664..aacfc86ab0e49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -86,37 +86,7 @@ case class Generate( } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { - /** - * Indicates if `atLeastNNulls` is used to check if atLeastNNulls.children - * have at least one null value and atLeastNNulls.children are all attributes. - */ - private def isAtLeastOneNullOutputAttributes(atLeastNNulls: AtLeastNNulls): Boolean = { - val expressions = atLeastNNulls.children - val n = atLeastNNulls.n - if (n != 1) { - // AtLeastNNulls is not used to check if atLeastNNulls.children have - // at least one null value. - false - } else { - // AtLeastNNulls is used to check if atLeastNNulls.children have - // at least one null value. We need to make sure all atLeastNNulls.children - // are attributes. - expressions.forall(_.isInstanceOf[Attribute]) - } - } - - override def output: Seq[Attribute] = condition match { - case Not(a: AtLeastNNulls) if isAtLeastOneNullOutputAttributes(a) => - // The condition is used to make sure that there is no null value in - // a.children. - val nonNullableAttributes = AttributeSet(a.children.asInstanceOf[Seq[Attribute]]) - child.output.map { - case attr if nonNullableAttributes.contains(attr) => - attr.withNullability(false) - case attr => attr - } - case _ => child.output - } + override def output: Seq[Attribute] = child.output } case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 3e55151298741..a41185b4d8754 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -31,8 +31,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} trait ExpressionEvalHelper { self: SparkFunSuite => - protected val defaultOptimizer = new DefaultOptimizer - protected def create_row(values: Any*): InternalRow = { InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst)) } @@ -188,7 +186,7 @@ trait ExpressionEvalHelper { expected: Any, inputRow: InternalRow = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = defaultOptimizer.execute(plan) + val optimizedPlan = DefaultOptimizer.execute(plan) checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 649a5b44dc036..9fcb548af6bbb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types._ @@ -148,7 +149,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = defaultOptimizer.execute(plan) + val optimizedPlan = DefaultOptimizer.execute(plan) checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala index bf197124d8dbc..ace6c15dc8418 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala @@ -77,7 +77,7 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } - test("AtLeastNNonNullNans") { + test("AtLeastNNonNulls") { val mix = Seq(Literal("x"), Literal.create(null, StringType), Literal.create(null, DoubleType), @@ -96,46 +96,11 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal(Float.MaxValue), Literal(false)) - checkEvaluation(AtLeastNNonNullNans(0, mix), true, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(2, mix), true, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(3, mix), false, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(0, nanOnly), true, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(3, nanOnly), true, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(4, nanOnly), false, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(0, nullOnly), true, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(3, nullOnly), true, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(4, nullOnly), false, EmptyRow) - } - - test("AtLeastNNull") { - val mix = Seq(Literal("x"), - Literal.create(null, StringType), - Literal.create(null, DoubleType), - Literal(Double.NaN), - Literal(5f)) - - val nanOnly = Seq(Literal("x"), - Literal(10.0), - Literal(Float.NaN), - Literal(math.log(-2)), - Literal(Double.MaxValue)) - - val nullOnly = Seq(Literal("x"), - Literal.create(null, DoubleType), - Literal.create(null, DecimalType.USER_DEFAULT), - Literal(Float.MaxValue), - Literal(false)) - - checkEvaluation(AtLeastNNulls(0, mix), true, EmptyRow) - checkEvaluation(AtLeastNNulls(1, mix), true, EmptyRow) - checkEvaluation(AtLeastNNulls(2, mix), true, EmptyRow) - checkEvaluation(AtLeastNNulls(3, mix), false, EmptyRow) - checkEvaluation(AtLeastNNulls(0, nanOnly), true, EmptyRow) - checkEvaluation(AtLeastNNulls(1, nanOnly), false, EmptyRow) - checkEvaluation(AtLeastNNulls(2, nanOnly), false, EmptyRow) - checkEvaluation(AtLeastNNulls(0, nullOnly), true, EmptyRow) - checkEvaluation(AtLeastNNulls(1, nullOnly), true, EmptyRow) - checkEvaluation(AtLeastNNulls(2, nullOnly), true, EmptyRow) - checkEvaluation(AtLeastNNulls(3, nullOnly), false, EmptyRow) + checkEvaluation(AtLeastNNonNulls(2, mix), true, EmptyRow) + checkEvaluation(AtLeastNNonNulls(3, mix), false, EmptyRow) + checkEvaluation(AtLeastNNonNulls(3, nanOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNulls(4, nanOnly), false, EmptyRow) + checkEvaluation(AtLeastNNonNulls(3, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index ea85f0657a726..a4fd4cf3b330b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -122,7 +122,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = { // Filtering condition: // only keep the row if it has at least `minNonNulls` non-null and non-NaN values. - val predicate = AtLeastNNonNullNans(minNonNulls, cols.map(name => df.resolve(name))) + val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name))) df.filter(Column(predicate)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 41ba1c7fe0574..f836122b3e0e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -413,10 +413,6 @@ private[spark] object SQLConf { "spark.sql.useSerializer2", defaultValue = Some(true), isPublic = false) - val ADVANCED_SQL_OPTIMIZATION = booleanConf( - "spark.sql.advancedOptimization", - defaultValue = Some(true), isPublic = false) - object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -488,8 +484,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2) - private[spark] def advancedSqlOptimizations: Boolean = getConf(ADVANCED_SQL_OPTIMIZATION) - private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) private[spark] def defaultSizeInBytes: Long = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 31e2b508d485e..dbb2a09846548 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -41,7 +41,6 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.optimizer.FilterNullsInJoinKey import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -157,9 +156,7 @@ class SQLContext(@transient val sparkContext: SparkContext) } @transient - protected[sql] lazy val optimizer: Optimizer = new DefaultOptimizer { - override val extendedOperatorOptimizationRules = FilterNullsInJoinKey(self) :: Nil - } + protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer @transient protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala deleted file mode 100644 index 5a4dde5756964..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.optimizer - -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys -import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter, LeftSemi} -import org.apache.spark.sql.catalyst.plans.logical.{Project, Filter, Join, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.Rule - -/** - * An optimization rule used to insert Filters to filter out rows whose equal join keys - * have at least one null values. For this kind of rows, they will not contribute to - * the join results of equal joins because a null does not equal another null. We can - * filter them out before shuffling join input rows. For example, we have two tables - * - * table1(key String, value Int) - * "str1"|1 - * null |2 - * - * table2(key String, value Int) - * "str1"|3 - * null |4 - * - * For a inner equal join, the result will be - * "str1"|1|"str1"|3 - * - * those two rows having null as the value of key will not contribute to the result. - * So, we can filter them out early. - * - * This optimization rule can be disabled by setting spark.sql.advancedOptimization to false. - * - */ -case class FilterNullsInJoinKey( - sqlContext: SQLContext) - extends Rule[LogicalPlan] { - - /** - * Checks if we need to add a Filter operator. We will add a Filter when - * there is any attribute in `keys` whose corresponding attribute of `keys` - * in `plan.output` is still nullable (`nullable` field is `true`). - */ - private def needsFilter(keys: Seq[Expression], plan: LogicalPlan): Boolean = { - val keyAttributeSet = AttributeSet(keys.filter(_.isInstanceOf[Attribute])) - plan.output.filter(keyAttributeSet.contains).exists(_.nullable) - } - - /** - * Adds a Filter operator to make sure that every attribute in `keys` is non-nullable. - */ - private def addFilterIfNecessary( - keys: Seq[Expression], - child: LogicalPlan): LogicalPlan = { - // We get all attributes from keys. - val attributes = keys.filter(_.isInstanceOf[Attribute]) - - // Then, we create a Filter to make sure these attributes are non-nullable. - val filter = - if (attributes.nonEmpty) { - Filter(Not(AtLeastNNulls(1, attributes)), child) - } else { - child - } - - filter - } - - /** - * We reconstruct the join condition. - */ - private def reconstructJoinCondition( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - otherPredicate: Option[Expression]): Expression = { - // First, we rewrite the equal condition part. When we extract those keys, - // we use splitConjunctivePredicates. So, it is safe to use .reduce(And). - val rewrittenEqualJoinCondition = leftKeys.zip(rightKeys).map { - case (l, r) => EqualTo(l, r) - }.reduce(And) - - // Then, we add otherPredicate. When we extract those equal condition part, - // we use splitConjunctivePredicates. So, it is safe to use - // And(rewrittenEqualJoinCondition, c). - val rewrittenJoinCondition = otherPredicate - .map(c => And(rewrittenEqualJoinCondition, c)) - .getOrElse(rewrittenEqualJoinCondition) - - rewrittenJoinCondition - } - - def apply(plan: LogicalPlan): LogicalPlan = { - if (!sqlContext.conf.advancedSqlOptimizations) { - plan - } else { - plan transform { - case join: Join => join match { - // For a inner join having equal join condition part, we can add filters - // to both sides of the join operator. - case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) => - val withLeftFilter = addFilterIfNecessary(leftKeys, left) - val withRightFilter = addFilterIfNecessary(rightKeys, right) - val rewrittenJoinCondition = - reconstructJoinCondition(leftKeys, rightKeys, condition) - - Join(withLeftFilter, withRightFilter, Inner, Some(rewrittenJoinCondition)) - - // For a left outer join having equal join condition part, we can add a filter - // to the right side of the join operator. - case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) - if needsFilter(rightKeys, right) => - val withRightFilter = addFilterIfNecessary(rightKeys, right) - val rewrittenJoinCondition = - reconstructJoinCondition(leftKeys, rightKeys, condition) - - Join(left, withRightFilter, LeftOuter, Some(rewrittenJoinCondition)) - - // For a right outer join having equal join condition part, we can add a filter - // to the left side of the join operator. - case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) - if needsFilter(leftKeys, left) => - val withLeftFilter = addFilterIfNecessary(leftKeys, left) - val rewrittenJoinCondition = - reconstructJoinCondition(leftKeys, rightKeys, condition) - - Join(withLeftFilter, right, RightOuter, Some(rewrittenJoinCondition)) - - // For a left semi join having equal join condition part, we can add filters - // to both sides of the join operator. - case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) - if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) => - val withLeftFilter = addFilterIfNecessary(leftKeys, left) - val withRightFilter = addFilterIfNecessary(rightKeys, right) - val rewrittenJoinCondition = - reconstructJoinCondition(leftKeys, rightKeys, condition) - - Join(withLeftFilter, withRightFilter, LeftSemi, Some(rewrittenJoinCondition)) - - case other => other - } - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala deleted file mode 100644 index f98e4acafbf2c..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala +++ /dev/null @@ -1,236 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.optimizer - -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Not, AtLeastNNulls} -import org.apache.spark.sql.catalyst.optimizer._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.test.TestSQLContext - -/** This is the test suite for FilterNullsInJoinKey optimization rule. */ -class FilterNullsInJoinKeySuite extends PlanTest { - - // We add predicate pushdown rules at here to make sure we do not - // create redundant Filter operators. Also, because the attribute ordering of - // the Project operator added by ColumnPruning may be not deterministic - // (the ordering may depend on the testing environment), - // we first construct the plan with expected Filter operators and then - // run the optimizer to add the the Project for column pruning. - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = - Batch("Subqueries", Once, - EliminateSubQueries) :: - Batch("Operator Optimizations", FixedPoint(100), - FilterNullsInJoinKey(TestSQLContext), // This is the rule we test in this suite. - CombineFilters, - PushPredicateThroughProject, - BooleanSimplification, - PushPredicateThroughJoin, - PushPredicateThroughGenerate, - ColumnPruning, - ProjectCollapsing) :: Nil - } - - val leftRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.int) - - val rightRelation = LocalRelation('e.int, 'f.int, 'g.int, 'h.int) - - test("inner join") { - val joinCondition = - ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) - - val joinedPlan = - leftRelation - .join(rightRelation, Inner, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - val optimized = Optimize.execute(joinedPlan.analyze) - - // For an inner join, FilterNullsInJoinKey add filter to both side. - val correctLeft = - leftRelation - .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) - - val correctRight = - rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) - - val correctAnswer = - correctLeft - .join(correctRight, Inner, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) - } - - test("make sure we do not keep adding filters") { - val thirdRelation = LocalRelation('i.int, 'j.int, 'k.int, 'l.int) - val joinedPlan = - leftRelation - .join(rightRelation, Inner, Some('a === 'e)) - .join(thirdRelation, Inner, Some('b === 'i && 'a === 'j)) - - val optimized = Optimize.execute(joinedPlan.analyze) - val conditions = optimized.collect { - case Filter(condition @ Not(AtLeastNNulls(1, exprs)), _) => exprs - } - - // Make sure that we have three Not(AtLeastNNulls(1, exprs)) for those three tables. - assert(conditions.length === 3) - - // Make sure attribtues are indeed a, b, e, i, and j. - assert( - conditions.flatMap(exprs => exprs).toSet === - joinedPlan.select('a, 'b, 'e, 'i, 'j).analyze.output.toSet) - } - - test("inner join (partially optimized)") { - val joinCondition = - ('a + 2 === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) - - val joinedPlan = - leftRelation - .join(rightRelation, Inner, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - val optimized = Optimize.execute(joinedPlan.analyze) - - // We cannot extract attribute from the left join key. - val correctRight = - rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) - - val correctAnswer = - leftRelation - .join(correctRight, Inner, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) - } - - test("inner join (not optimized)") { - val nonOptimizedJoinConditions = - Some('c - 100 + 'd === 'g + 1 - 'h) :: - Some('d > 'h || 'c === 'g) :: - Some('d + 'g + 'c > 'd - 'h) :: Nil - - nonOptimizedJoinConditions.foreach { joinCondition => - val joinedPlan = - leftRelation - .join(rightRelation.select('f, 'g, 'h), Inner, joinCondition) - .select('a, 'c, 'f, 'd, 'h, 'g) - - val optimized = Optimize.execute(joinedPlan.analyze) - - comparePlans(optimized, Optimize.execute(joinedPlan.analyze)) - } - } - - test("left outer join") { - val joinCondition = - ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) - - val joinedPlan = - leftRelation - .join(rightRelation, LeftOuter, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - val optimized = Optimize.execute(joinedPlan.analyze) - - // For a left outer join, FilterNullsInJoinKey add filter to the right side. - val correctRight = - rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) - - val correctAnswer = - leftRelation - .join(correctRight, LeftOuter, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) - } - - test("right outer join") { - val joinCondition = - ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) - - val joinedPlan = - leftRelation - .join(rightRelation, RightOuter, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - val optimized = Optimize.execute(joinedPlan.analyze) - - // For a right outer join, FilterNullsInJoinKey add filter to the left side. - val correctLeft = - leftRelation - .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) - - val correctAnswer = - correctLeft - .join(rightRelation, RightOuter, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - - comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) - } - - test("full outer join") { - val joinCondition = - ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) - - val joinedPlan = - leftRelation - .join(rightRelation, FullOuter, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - // FilterNullsInJoinKey does not fire for a full outer join. - val optimized = Optimize.execute(joinedPlan.analyze) - - comparePlans(optimized, Optimize.execute(joinedPlan.analyze)) - } - - test("left semi join") { - val joinCondition = - ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) - - val joinedPlan = - leftRelation - .join(rightRelation, LeftSemi, Some(joinCondition)) - .select('a, 'd) - - val optimized = Optimize.execute(joinedPlan.analyze) - - // For a left semi join, FilterNullsInJoinKey add filter to both side. - val correctLeft = - leftRelation - .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) - - val correctRight = - rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) - - val correctAnswer = - correctLeft - .join(correctRight, LeftSemi, Some(joinCondition)) - .select('a, 'd) - - comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) - } -} From a2409d1c8e8ddec04b529ac6f6a12b5993f0eeda Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Mon, 3 Aug 2015 15:24:34 -0700 Subject: [PATCH 110/340] [SPARK-8064] [SQL] Build against Hive 1.2.1 Cherry picked the parts of the initial SPARK-8064 WiP branch needed to get sql/hive to compile against hive 1.2.1. That's the ASF release packaged under org.apache.hive, not any fork. Tests not run yet: that's what the machines are for Author: Steve Loughran Author: Cheng Lian Author: Michael Armbrust Author: Patrick Wendell Closes #7191 from steveloughran/stevel/feature/SPARK-8064-hive-1.2-002 and squashes the following commits: 7556d85 [Cheng Lian] Updates .q files and corresponding golden files ef4af62 [Steve Loughran] Merge commit '6a92bb09f46a04d6cd8c41bdba3ecb727ebb9030' into stevel/feature/SPARK-8064-hive-1.2-002 6a92bb0 [Cheng Lian] Overrides HiveConf time vars dcbb391 [Cheng Lian] Adds com.twitter:parquet-hadoop-bundle:1.6.0 for Hive Parquet SerDe 0bbe475 [Steve Loughran] SPARK-8064 scalastyle rejects the standard Hadoop ASF license header... fdf759b [Steve Loughran] SPARK-8064 classpath dependency suite to be in sync with shading in final (?) hive-exec spark 7a6c727 [Steve Loughran] SPARK-8064 switch to second staging repo of the spark-hive artifacts. This one has the protobuf-shaded hive-exec jar 376c003 [Steve Loughran] SPARK-8064 purge duplicate protobuf declaration 2c74697 [Steve Loughran] SPARK-8064 switch to the protobuf shaded hive-exec jar with tests to chase it down cc44020 [Steve Loughran] SPARK-8064 remove hadoop.version from runtest.py, as profile will fix that automatically. 6901fa9 [Steve Loughran] SPARK-8064 explicit protobuf import da310dc [Michael Armbrust] Fixes for Hive tests. a775a75 [Steve Loughran] SPARK-8064 cherry-pick-incomplete 7404f34 [Patrick Wendell] Add spark-hive staging repo 832c164 [Steve Loughran] SPARK-8064 try to supress compiler warnings on Complex.java pasted-thrift-code 312c0d4 [Steve Loughran] SPARK-8064 maven/ivy dependency purge; calcite declaration needed fa5ae7b [Steve Loughran] HIVE-8064 fix up hive-thriftserver dependencies and cut back on evicted references in the hive- packages; this keeps mvn and ivy resolution compatible, as the reconciliation policy is "by hand" c188048 [Steve Loughran] SPARK-8064 manage the Hive depencencies to that -things that aren't needed are excluded -sql/hive built with ivy is in sync with the maven reconciliation policy, rather than latest-first 4c8be8d [Cheng Lian] WIP: Partial fix for Thrift server and CLI tests 314eb3c [Steve Loughran] SPARK-8064 deprecation warning noise in one of the tests 17b0341 [Steve Loughran] SPARK-8064 IDE-hinted cleanups of Complex.java to reduce compiler warnings. It's all autogenerated code, so still ugly. d029b92 [Steve Loughran] SPARK-8064 rely on unescaping to have already taken place, so go straight to map of serde options 23eca7e [Steve Loughran] HIVE-8064 handle raw and escaped property tokens 54d9b06 [Steve Loughran] SPARK-8064 fix compilation regression surfacing from rebase 0b12d5f [Steve Loughran] HIVE-8064 use subset of hive complex type whose types deserialize fce73b6 [Steve Loughran] SPARK-8064 poms rely implicitly on the version of kryo chill provides fd3aa5d [Steve Loughran] SPARK-8064 version of hive to d/l from ivy is 1.2.1 dc73ece [Steve Loughran] SPARK-8064 revert to master's determinstic pushdown strategy d3c1e4a [Steve Loughran] SPARK-8064 purge UnionType 051cc21 [Steve Loughran] SPARK-8064 switch to an unshaded version of hive-exec-core, which must have been built with Kryo 2.21. This currently looks for a (locally built) version 1.2.1.spark 6684c60 [Steve Loughran] SPARK-8064 ignore RTE raised in blocking process.exitValue() call e6121e5 [Steve Loughran] SPARK-8064 address review comments aa43dc6 [Steve Loughran] SPARK-8064 more robust teardown on JavaMetastoreDatasourcesSuite f2bff01 [Steve Loughran] SPARK-8064 better takeup of asynchronously caught error text 8b1ef38 [Steve Loughran] SPARK-8064: on failures executing spark-submit in HiveSparkSubmitSuite, print command line and all logged output. 5a9ce6b [Steve Loughran] SPARK-8064 add explicit reason for kv split failure, rather than array OOB. *does not address the issue* 642b63a [Steve Loughran] SPARK-8064 reinstate something cut briefly during rebasing 97194dc [Steve Loughran] SPARK-8064 add extra logging to the YarnClusterSuite classpath test. There should be no reason why this is failing on jenkins, but as it is (and presumably its CP-related), improve the logging including any exception raised. 335357f [Steve Loughran] SPARK-8064 fail fast on thrive process spawning tests on exit codes and/or error string patterns seen in log. 3ed872f [Steve Loughran] SPARK-8064 rename field double to dbl bca55e5 [Steve Loughran] SPARK-8064 missed one of the `date` escapes 41d6479 [Steve Loughran] SPARK-8064 wrap tests with withTable() calls to avoid table-exists exceptions 2bc29a4 [Steve Loughran] SPARK-8064 ParquetSuites to escape `date` field name 1ab9bc4 [Steve Loughran] SPARK-8064 TestHive to use sered2.thrift.test.Complex bf3a249 [Steve Loughran] SPARK-8064: more resubmit than fix; tighten startup timeout to 60s. Still no obvious reason why jersey server code in spark-assembly isn't being picked up -it hasn't been shaded c829b8f [Steve Loughran] SPARK-8064: reinstate yarn-rm-server dependencies to hive-exec to ensure that jersey server is on classpath on hadoop versions < 2.6 0b0f738 [Steve Loughran] SPARK-8064: thrift server startup to fail fast on any exception in the main thread 13abaf1 [Steve Loughran] SPARK-8064 Hive compatibilty tests sin sync with explain/show output from Hive 1.2.1 d14d5ea [Steve Loughran] SPARK-8064: DATE is now a predicate; you can't use it as a field in select ops 26eef1c [Steve Loughran] SPARK-8064: HIVE-9039 renamed TOK_UNION => TOK_UNIONALL while adding TOK_UNIONDISTINCT 3d64523 [Steve Loughran] SPARK-8064 improve diagns on uknown token; fix scalastyle failure d0360f6 [Steve Loughran] SPARK-8064: delicate merge in of the branch vanzin/hive-1.1 1126e5a [Steve Loughran] SPARK-8064: name of unrecognized file format wasn't appearing in error text 8cb09c4 [Steve Loughran] SPARK-8064: test resilience/assertion improvements. Independent of the rest of the work; can be backported to earlier versions dec12cb [Steve Loughran] SPARK-8064: when a CLI suite test fails include the full output text in the raised exception; this ensures that the stdout/stderr is included in jenkins reports, so it becomes possible to diagnose the cause. 463a670 [Steve Loughran] SPARK-8064 run-tests.py adds a hadoop-2.6 profile, and changes info messages to say "w/Hive 1.2.1" in console output 2531099 [Steve Loughran] SPARK-8064 successful attempt to get rid of pentaho as a transitive dependency of hive-exec 1d59100 [Steve Loughran] SPARK-8064 (unsuccessful) attempt to get rid of pentaho as a transitive dependency of hive-exec 75733fc [Steve Loughran] SPARK-8064 change thrift binary startup message to "Starting ThriftBinaryCLIService on port" 3ebc279 [Steve Loughran] SPARK-8064 move strings used to check for http/bin thrift services up into constants c80979d [Steve Loughran] SPARK-8064: SparkSQLCLIDriver drops remote mode support. CLISuite Tests pass instead of timing out: undetected regression? 27e8370 [Steve Loughran] SPARK-8064 fix some style & IDE warnings 00e50d6 [Steve Loughran] SPARK-8064 stop excluding hive shims from dependency (commented out , for now) cb4f142 [Steve Loughran] SPARK-8054 cut pentaho dependency from calcite f7aa9cb [Steve Loughran] SPARK-8064 everything compiles with some commenting and moving of classes into a hive package 6c310b4 [Steve Loughran] SPARK-8064 subclass Hive ServerOptionsProcessor to make it public again f61a675 [Steve Loughran] SPARK-8064 thrift server switched to Hive 1.2.1, though it doesn't compile everywhere 4890b9d [Steve Loughran] SPARK-8064, build against Hive 1.2.1 --- core/pom.xml | 20 - dev/run-tests.py | 7 +- pom.xml | 654 +++++++++- sbin/spark-daemon.sh | 2 +- sql/catalyst/pom.xml | 1 - .../parquet/ParquetCompatibilityTest.scala | 13 +- sql/hive-thriftserver/pom.xml | 22 +- .../HiveServerServerOptionsProcessor.scala | 37 + .../hive/thriftserver/HiveThriftServer2.scala | 27 +- .../SparkExecuteStatementOperation.scala | 9 +- .../hive/thriftserver/SparkSQLCLIDriver.scala | 56 +- .../thriftserver/SparkSQLCLIService.scala | 13 +- .../thriftserver/SparkSQLSessionManager.scala | 11 +- .../sql/hive/thriftserver/CliSuite.scala | 75 +- .../HiveThriftServer2Suites.scala | 40 +- .../execution/HiveCompatibilitySuite.scala | 29 +- sql/hive/pom.xml | 92 +- .../apache/spark/sql/hive/HiveContext.scala | 114 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 5 +- .../org/apache/spark/sql/hive/HiveQl.scala | 97 +- .../org/apache/spark/sql/hive/HiveShim.scala | 15 +- .../sql/hive/client/ClientInterface.scala | 4 + .../spark/sql/hive/client/ClientWrapper.scala | 5 +- .../spark/sql/hive/client/HiveShim.scala | 2 +- .../hive/client/IsolatedClientLoader.scala | 2 +- .../spark/sql/hive/client/package.scala | 2 +- .../hive/execution/InsertIntoHiveTable.scala | 2 +- .../hive/execution/ScriptTransformation.scala | 6 +- .../org/apache/spark/sql/hive/hiveUDFs.scala | 2 +- .../spark/sql/hive/hiveWriterContainers.scala | 2 +- .../spark/sql/hive/orc/OrcFilters.scala | 6 +- .../apache/spark/sql/hive/test/TestHive.scala | 36 +- .../apache/spark/sql/hive/test/Complex.java | 1139 +++++++++++++++++ .../hive/JavaMetastoreDataSourcesSuite.java | 6 +- ...perator-0-ee7f6a60a9792041b85b18cda56429bf | 1 + ..._string-1-db089ff46f9826c7883198adacdfad59 | 6 +- ...tar_by-5-41d474f5e6d7c61c36f74b4bec4e9e44} | 0 ...e_alter-3-2a91d52719cf4552ebeb867204552a26 | 2 +- ...b_table-4-b585371b624cbab2616a49f553a870a0 | 2 +- ...limited-1-2a91d52719cf4552ebeb867204552a26 | 2 +- ...e_serde-1-2a91d52719cf4552ebeb867204552a26 | 2 +- ...nctions-0-45a7762c39f1b0f26f076220e2764043 | 21 + ...perties-1-be4adb893c7f946ebd76a648ce3cc1ae | 2 +- ...date_add-1-efb60fcbd6d78ad35257fb1ec39ace2 | 4 +- ...ate_sub-1-7efeb74367835ade71e5e42b22f8ced4 | 4 +- ...atediff-1-34ae7a68b13c2bc9a89f61acf2edd4c5 | 2 +- ...udf_day-0-c4c503756384ff1220222d84fd25e756 | 2 +- .../udf_day-1-87168babe1110fe4c38269843414ca4 | 11 +- ...ofmonth-0-7b2caf942528656555cf19c261a18502 | 2 +- ...ofmonth-1-ca24d07102ad264d79ff30c64a73a7e8 | 11 +- .../udf_if-0-b7ffa85b5785cccef2af1b285348cc2c | 2 +- .../udf_if-1-30cf7f51f92b5684e556deff3032d49a | 2 +- .../udf_if-1-b7ffa85b5785cccef2af1b285348cc2c | 2 +- .../udf_if-2-30cf7f51f92b5684e556deff3032d49a | 2 +- ..._minute-0-9a38997c1f41f4afe00faa0abc471aee | 2 +- ..._minute-1-16995573ac4f4a1b047ad6ee88699e48 | 8 +- ...f_month-0-9a38997c1f41f4afe00faa0abc471aee | 2 +- ...f_month-1-16995573ac4f4a1b047ad6ee88699e48 | 8 +- ...udf_std-1-6759bde0e50a3607b7c3fd5a93cbd027 | 2 +- ..._stddev-1-18e1d598820013453fad45852e1a303d | 2 +- ...union3-0-99620f72f0282904846a596ca5b3e46c} | 0 ...union3-2-90ca96ea59fd45cf0af8c020ae77c908} | 0 ...union3-3-72b149ccaef751bcfe55d5ca37cb5fd7} | 0 .../clientpositive/parenthesis_star_by.q | 2 +- .../src/test/queries/clientpositive/union3.q | 11 +- .../sql/hive/ClasspathDependenciesSuite.scala | 110 ++ .../spark/sql/hive/HiveSparkSubmitSuite.scala | 29 +- .../sql/hive/InsertIntoHiveTableSuite.scala | 7 +- .../hive/ParquetHiveCompatibilitySuite.scala | 9 + .../spark/sql/hive/StatisticsSuite.scala | 3 + .../spark/sql/hive/client/VersionsSuite.scala | 6 +- .../sql/hive/execution/HiveQuerySuite.scala | 89 +- .../sql/hive/execution/PruningSuite.scala | 8 +- .../sql/hive/execution/SQLQuerySuite.scala | 140 +- .../hive/orc/OrcHadoopFsRelationSuite.scala | 8 +- .../hive/orc/OrcPartitionDiscoverySuite.scala | 3 +- .../apache/spark/sql/hive/parquetSuites.scala | 327 ++--- yarn/pom.xml | 10 - .../spark/deploy/yarn/YarnClusterSuite.scala | 24 +- 79 files changed, 2861 insertions(+), 584 deletions(-) create mode 100644 sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java create mode 100644 sql/hive/src/test/resources/golden/! operator-0-ee7f6a60a9792041b85b18cda56429bf rename sql/hive/src/test/resources/golden/{parenthesis_star_by-5-6888c7f7894910538d82eefa23443189 => parenthesis_star_by-5-41d474f5e6d7c61c36f74b4bec4e9e44} (100%) rename sql/hive/src/test/resources/golden/{union3-0-6a8a35102de1b0b88c6721a704eb174d => union3-0-99620f72f0282904846a596ca5b3e46c} (100%) rename sql/hive/src/test/resources/golden/{union3-2-2a1dcd937f117f1955a169592b96d5f9 => union3-2-90ca96ea59fd45cf0af8c020ae77c908} (100%) rename sql/hive/src/test/resources/golden/{union3-3-8fc63f8edb2969a63cd4485f1867ba97 => union3-3-72b149ccaef751bcfe55d5ca37cb5fd7} (100%) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala diff --git a/core/pom.xml b/core/pom.xml index 202678779150b..0e53a79fd2235 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -46,30 +46,10 @@ com.twitter chill_${scala.binary.version} - - - org.ow2.asm - asm - - - org.ow2.asm - asm-commons - - com.twitter chill-java - - - org.ow2.asm - asm - - - org.ow2.asm - asm-commons - - org.apache.hadoop diff --git a/dev/run-tests.py b/dev/run-tests.py index b6d181418f027..d1852b95bb292 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -273,6 +273,7 @@ def get_hadoop_profiles(hadoop_version): "hadoop2.0": ["-Phadoop-1", "-Dhadoop.version=2.0.0-mr1-cdh4.1.1"], "hadoop2.2": ["-Pyarn", "-Phadoop-2.2"], "hadoop2.3": ["-Pyarn", "-Phadoop-2.3", "-Dhadoop.version=2.3.0"], + "hadoop2.6": ["-Pyarn", "-Phadoop-2.6"], } if hadoop_version in sbt_maven_hadoop_profiles: @@ -289,7 +290,7 @@ def build_spark_maven(hadoop_version): mvn_goals = ["clean", "package", "-DskipTests"] profiles_and_goals = build_profiles + mvn_goals - print("[info] Building Spark (w/Hive 0.13.1) using Maven with these arguments: ", + print("[info] Building Spark (w/Hive 1.2.1) using Maven with these arguments: ", " ".join(profiles_and_goals)) exec_maven(profiles_and_goals) @@ -305,14 +306,14 @@ def build_spark_sbt(hadoop_version): "streaming-kinesis-asl-assembly/assembly"] profiles_and_goals = build_profiles + sbt_goals - print("[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: ", + print("[info] Building Spark (w/Hive 1.2.1) using SBT with these arguments: ", " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) def build_apache_spark(build_tool, hadoop_version): - """Will build Spark against Hive v0.13.1 given the passed in build tool (either `sbt` or + """Will build Spark against Hive v1.2.1 given the passed in build tool (either `sbt` or `maven`). Defaults to using `sbt`.""" set_title_and_block("Building Spark", "BLOCK_BUILD") diff --git a/pom.xml b/pom.xml index be0dac953abf7..a958cec867eae 100644 --- a/pom.xml +++ b/pom.xml @@ -134,11 +134,12 @@ 2.4.0 org.spark-project.hive - 0.13.1a + 1.2.1.spark - 0.13.1 + 1.2.1 10.10.1.1 1.7.0 + 1.6.0 1.2.4 8.1.14.v20131031 3.0.0.v201112011016 @@ -151,7 +152,10 @@ 0.7.1 1.9.16 1.2.1 + 4.3.2 + + 3.1 3.4.1 2.10.4 2.10 @@ -161,6 +165,23 @@ 2.4.4 1.1.1.7 1.1.2 + 1.2.0-incubating + 1.10 + + 2.6 + + 3.3.2 + 3.2.10 + 2.7.8 + 1.9 + 2.5 + 3.5.2 + 1.3.9 + 0.9.2 + + + false + ${java.home} spring-releases Spring Release Repository https://repo.spring.io/libs-release - true + false false @@ -402,12 +431,17 @@ org.apache.commons commons-lang3 - 3.3.2 + ${commons-lang3.version} + + + org.apache.commons + commons-lang + ${commons-lang2.version} commons-codec commons-codec - 1.10 + ${commons-codec.version} org.apache.commons @@ -422,7 +456,12 @@ com.google.code.findbugs jsr305 - 1.3.9 + ${jsr305.version} + + + commons-httpclient + commons-httpclient + ${httpclient.classic.version} org.apache.httpcomponents @@ -439,6 +478,16 @@ selenium-java 2.42.2 test + + + com.google.guava + guava + + + io.netty + netty + + @@ -624,15 +673,26 @@ com.sun.jersey jersey-server - 1.9 + ${jersey.version} ${hadoop.deps.scope} com.sun.jersey jersey-core - 1.9 + ${jersey.version} ${hadoop.deps.scope} + + com.sun.jersey + jersey-json + ${jersey.version} + + + stax + stax-api + + + org.scala-lang scala-compiler @@ -1022,58 +1082,499 @@ hive-beeline ${hive.version} ${hive.deps.scope} + + + ${hive.group} + hive-common + + + ${hive.group} + hive-exec + + + ${hive.group} + hive-jdbc + + + ${hive.group} + hive-metastore + + + ${hive.group} + hive-service + + + ${hive.group} + hive-shims + + + org.apache.thrift + libthrift + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + + + commons-logging + commons-logging + + ${hive.group} hive-cli ${hive.version} ${hive.deps.scope} + + + ${hive.group} + hive-common + + + ${hive.group} + hive-exec + + + ${hive.group} + hive-jdbc + + + ${hive.group} + hive-metastore + + + ${hive.group} + hive-serde + + + ${hive.group} + hive-service + + + ${hive.group} + hive-shims + + + org.apache.thrift + libthrift + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + + + commons-logging + commons-logging + + ${hive.group} - hive-exec + hive-common ${hive.version} ${hive.deps.scope} + + ${hive.group} + hive-shims + + + org.apache.ant + ant + + + org.apache.zookeeper + zookeeper + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + commons-logging commons-logging + + + + + ${hive.group} + hive-exec + + ${hive.version} + ${hive.deps.scope} + + + + + ${hive.group} + hive-metastore + + + ${hive.group} + hive-shims + + + ${hive.group} + hive-ant + + + + ${hive.group} + spark-client + + + + + ant + ant + + + org.apache.ant + ant + com.esotericsoftware.kryo kryo + + commons-codec + commons-codec + + + commons-httpclient + commons-httpclient + org.apache.avro avro-mapred + + + org.apache.calcite + calcite-core + + + org.apache.curator + apache-curator + + + org.apache.curator + curator-client + + + org.apache.curator + curator-framework + + + org.apache.thrift + libthrift + + + org.apache.thrift + libfb303 + + + org.apache.zookeeper + zookeeper + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + + + commons-logging + commons-logging + ${hive.group} hive-jdbc ${hive.version} - ${hive.deps.scope} + + + ${hive.group} + hive-common + + + ${hive.group} + hive-common + + + ${hive.group} + hive-metastore + + + ${hive.group} + hive-serde + + + ${hive.group} + hive-service + + + ${hive.group} + hive-shims + + + org.apache.httpcomponents + httpclient + + + org.apache.httpcomponents + httpcore + + + org.apache.curator + curator-framework + + + org.apache.thrift + libthrift + + + org.apache.thrift + libfb303 + + + org.apache.zookeeper + zookeeper + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + + + commons-logging + commons-logging + + + ${hive.group} hive-metastore ${hive.version} ${hive.deps.scope} + + + ${hive.group} + hive-serde + + + ${hive.group} + hive-shims + + + org.apache.thrift + libfb303 + + + org.apache.thrift + libthrift + + + com.google.guava + guava + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + ${hive.group} hive-serde ${hive.version} ${hive.deps.scope} + + ${hive.group} + hive-common + + + ${hive.group} + hive-shims + + + commons-codec + commons-codec + + + com.google.code.findbugs + jsr305 + + + org.apache.avro + avro + + + org.apache.thrift + libthrift + + + org.apache.thrift + libfb303 + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + commons-logging commons-logging + + + + + ${hive.group} + hive-service + ${hive.version} + ${hive.deps.scope} + + + ${hive.group} + hive-common + + + ${hive.group} + hive-exec + + + ${hive.group} + hive-metastore + + + ${hive.group} + hive-shims + + + commons-codec + commons-codec + + + org.apache.curator + curator-framework + + + org.apache.curator + curator-recipes + + + org.apache.thrift + libfb303 + + + org.apache.thrift + libthrift + + + + + + + ${hive.group} + hive-shims + ${hive.version} + ${hive.deps.scope} + + + com.google.guava + guava + + + org.apache.hadoop + hadoop-yarn-server-resourcemanager + + + org.apache.curator + curator-framework + + + org.apache.thrift + libthrift + + + org.apache.zookeeper + zookeeper + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + commons-logging - commons-logging-api + commons-logging @@ -1095,6 +1596,12 @@ ${parquet.version} ${parquet.test.deps.scope} + + com.twitter + parquet-hadoop-bundle + ${hive.parquet.version} + runtime + org.apache.flume flume-ng-core @@ -1135,6 +1642,125 @@ + + org.apache.calcite + calcite-core + ${calcite.version} + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + com.google.guava + guava + + + com.google.code.findbugs + jsr305 + + + org.codehaus.janino + janino + + + + org.hsqldb + hsqldb + + + org.pentaho + pentaho-aggdesigner-algorithm + + + + + org.apache.calcite + calcite-avatica + ${calcite.version} + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + + + org.codehaus.janino + janino + ${janino.version} + + + joda-time + joda-time + ${joda.version} + + + org.jodd + jodd-core + ${jodd.version} + + + org.datanucleus + datanucleus-core + ${datanucleus-core.version} + + + org.apache.thrift + libthrift + ${libthrift.version} + + + org.apache.httpcomponents + httpclient + + + org.apache.httpcomponents + httpcore + + + org.slf4j + slf4j-api + + + + + org.apache.thrift + libfb303 + ${libthrift.version} + + + org.apache.httpcomponents + httpclient + + + org.apache.httpcomponents + httpcore + + + org.slf4j + slf4j-api + + + @@ -1271,6 +1897,8 @@ false true true + + src false @@ -1305,6 +1933,8 @@ false true true + + __not_used__ diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index de762acc8fa0e..0fbe795822fbf 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -29,7 +29,7 @@ # SPARK_NICENESS The scheduling priority for daemons. Defaults to 0. ## -usage="Usage: spark-daemon.sh [--config ] (start|stop|status) " +usage="Usage: spark-daemon.sh [--config ] (start|stop|submit|status) " # if no args specified, show usage if [ $# -le 1 ]; then diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index f4b1cc3a4ffe7..75ab575dfde83 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -66,7 +66,6 @@ org.codehaus.janino janino - 2.7.8 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala index b4cdfd9e98f6f..57478931cd509 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala @@ -31,6 +31,14 @@ import org.apache.spark.util.Utils abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest with BeforeAndAfterAll { protected var parquetStore: File = _ + /** + * Optional path to a staging subdirectory which may be created during query processing + * (Hive does this). + * Parquet files under this directory will be ignored in [[readParquetSchema()]] + * @return an optional staging directory to ignore when scanning for parquet files. + */ + protected def stagingDir: Option[String] = None + override protected def beforeAll(): Unit = { parquetStore = Utils.createTempDir(namePrefix = "parquet-compat_") parquetStore.delete() @@ -43,7 +51,10 @@ abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest with def readParquetSchema(path: String): MessageType = { val fsPath = new Path(path) val fs = fsPath.getFileSystem(configuration) - val parquetFiles = fs.listStatus(fsPath).toSeq.filterNot(_.getPath.getName.startsWith("_")) + val parquetFiles = fs.listStatus(fsPath).toSeq.filterNot { status => + status.getPath.getName.startsWith("_") || + stagingDir.map(status.getPath.getName.startsWith).getOrElse(false) + } val footers = ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles, true) footers.head.getParquetMetadata.getFileMetaData.getSchema } diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 73e6ccdb1eaf8..2dfbcb2425a37 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -60,21 +60,31 @@ ${hive.group} hive-jdbc + + ${hive.group} + hive-service + ${hive.group} hive-beeline + + com.sun.jersey + jersey-core + + + com.sun.jersey + jersey-json + + + com.sun.jersey + jersey-server + org.seleniumhq.selenium selenium-java test - - - io.netty - netty - - diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala b/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala new file mode 100644 index 0000000000000..2228f651e2387 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hive.service.server + +import org.apache.hive.service.server.HiveServer2.{StartOptionExecutor, ServerOptionsProcessor} + +/** + * Class to upgrade a package-private class to public, and + * implement a `process()` operation consistent with + * the behavior of older Hive versions + * @param serverName name of the hive server + */ +private[apache] class HiveServerServerOptionsProcessor(serverName: String) + extends ServerOptionsProcessor(serverName) { + + def process(args: Array[String]): Boolean = { + // A parse failure automatically triggers a system exit + val response = super.parse(args) + val executor = response.getServerOptionsExecutor() + // return true if the parsed option was to start the service + executor.isInstanceOf[StartOptionExecutor] + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index b7db80d93f852..9c047347cb58d 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.hive.thriftserver +import java.util.Locale +import java.util.concurrent.atomic.AtomicBoolean + import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -24,7 +27,7 @@ import org.apache.commons.logging.LogFactory import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.service.cli.thrift.{ThriftBinaryCLIService, ThriftHttpCLIService} -import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} +import org.apache.hive.service.server.{HiveServerServerOptionsProcessor, HiveServer2} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd, SparkListenerJobStart} @@ -65,7 +68,7 @@ object HiveThriftServer2 extends Logging { } def main(args: Array[String]) { - val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2") + val optionsProcessor = new HiveServerServerOptionsProcessor("HiveThriftServer2") if (!optionsProcessor.process(args)) { System.exit(-1) } @@ -241,9 +244,12 @@ object HiveThriftServer2 extends Logging { private[hive] class HiveThriftServer2(hiveContext: HiveContext) extends HiveServer2 with ReflectedCompositeService { + // state is tracked internally so that the server only attempts to shut down if it successfully + // started, and then once only. + private val started = new AtomicBoolean(false) override def init(hiveConf: HiveConf) { - val sparkSqlCliService = new SparkSQLCLIService(hiveContext) + val sparkSqlCliService = new SparkSQLCLIService(this, hiveContext) setSuperField(this, "cliService", sparkSqlCliService) addService(sparkSqlCliService) @@ -259,8 +265,19 @@ private[hive] class HiveThriftServer2(hiveContext: HiveContext) } private def isHTTPTransportMode(hiveConf: HiveConf): Boolean = { - val transportMode: String = hiveConf.getVar(ConfVars.HIVE_SERVER2_TRANSPORT_MODE) - transportMode.equalsIgnoreCase("http") + val transportMode = hiveConf.getVar(ConfVars.HIVE_SERVER2_TRANSPORT_MODE) + transportMode.toLowerCase(Locale.ENGLISH).equals("http") + } + + + override def start(): Unit = { + super.start() + started.set(true) } + override def stop(): Unit = { + if (started.getAndSet(false)) { + super.stop() + } + } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index e8758887ff3a2..833bf62d47d07 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -32,8 +32,7 @@ import org.apache.hive.service.cli._ import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.hive.ql.metadata.HiveException import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.shims.ShimLoader -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.hive.shims.Utils import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession @@ -146,7 +145,7 @@ private[hive] class SparkExecuteStatementOperation( } else { val parentSessionState = SessionState.get() val hiveConf = getConfigForOperation() - val sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf) + val sparkServiceUGI = Utils.getUGI() val sessionHive = getCurrentHive() val currentSqlSession = hiveContext.currentSession @@ -174,7 +173,7 @@ private[hive] class SparkExecuteStatementOperation( } try { - ShimLoader.getHadoopShims().doAs(sparkServiceUGI, doAsAction) + sparkServiceUGI.doAs(doAsAction) } catch { case e: Exception => setOperationException(new HiveSQLException(e)) @@ -201,7 +200,7 @@ private[hive] class SparkExecuteStatementOperation( } } - private def runInternal(): Unit = { + override def runInternal(): Unit = { statementId = UUID.randomUUID().toString logInfo(s"Running query '$statement' with $statementId") setState(OperationState.RUNNING) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index f66a17b20915f..d3886142b388d 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -20,9 +20,10 @@ package org.apache.spark.sql.hive.thriftserver import scala.collection.JavaConversions._ import java.io._ -import java.util.{ArrayList => JArrayList} +import java.util.{ArrayList => JArrayList, Locale} -import jline.{ConsoleReader, History} +import jline.console.ConsoleReader +import jline.console.history.FileHistory import org.apache.commons.lang3.StringUtils import org.apache.commons.logging.LogFactory @@ -40,6 +41,10 @@ import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveContext import org.apache.spark.util.Utils +/** + * This code doesn't support remote connections in Hive 1.2+, as the underlying CliDriver + * has dropped its support. + */ private[hive] object SparkSQLCLIDriver extends Logging { private var prompt = "spark-sql" private var continuedPrompt = "".padTo(prompt.length, ' ') @@ -111,16 +116,9 @@ private[hive] object SparkSQLCLIDriver extends Logging { // Clean up after we exit Utils.addShutdownHook { () => SparkSQLEnv.stop() } + val remoteMode = isRemoteMode(sessionState) // "-h" option has been passed, so connect to Hive thrift server. - if (sessionState.getHost != null) { - sessionState.connect() - if (sessionState.isRemoteMode) { - prompt = s"[${sessionState.getHost}:${sessionState.getPort}]" + prompt - continuedPrompt = "".padTo(prompt.length, ' ') - } - } - - if (!sessionState.isRemoteMode) { + if (!remoteMode) { // Hadoop-20 and above - we need to augment classpath using hiveconf // components. // See also: code in ExecDriver.java @@ -131,6 +129,9 @@ private[hive] object SparkSQLCLIDriver extends Logging { } conf.setClassLoader(loader) Thread.currentThread().setContextClassLoader(loader) + } else { + // Hive 1.2 + not supported in CLI + throw new RuntimeException("Remote operations not supported") } val cli = new SparkSQLCLIDriver @@ -171,14 +172,14 @@ private[hive] object SparkSQLCLIDriver extends Logging { val reader = new ConsoleReader() reader.setBellEnabled(false) // reader.setDebug(new PrintWriter(new FileWriter("writer.debug", true))) - CliDriver.getCommandCompletor.foreach((e) => reader.addCompletor(e)) + CliDriver.getCommandCompleter.foreach((e) => reader.addCompleter(e)) val historyDirectory = System.getProperty("user.home") try { if (new File(historyDirectory).exists()) { val historyFile = historyDirectory + File.separator + ".hivehistory" - reader.setHistory(new History(new File(historyFile))) + reader.setHistory(new FileHistory(new File(historyFile))) } else { logWarning("WARNING: Directory for Hive history file: " + historyDirectory + " does not exist. History will not be available during this session.") @@ -190,10 +191,14 @@ private[hive] object SparkSQLCLIDriver extends Logging { logWarning(e.getMessage) } + // TODO: missing +/* val clientTransportTSocketField = classOf[CliSessionState].getDeclaredField("transport") clientTransportTSocketField.setAccessible(true) transport = clientTransportTSocketField.get(sessionState).asInstanceOf[TSocket] +*/ + transport = null var ret = 0 var prefix = "" @@ -230,6 +235,13 @@ private[hive] object SparkSQLCLIDriver extends Logging { System.exit(ret) } + + + def isRemoteMode(state: CliSessionState): Boolean = { + // sessionState.isRemoteMode + state.isHiveServerQuery + } + } private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { @@ -239,25 +251,33 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { private val console = new SessionState.LogHelper(LOG) + private val isRemoteMode = { + SparkSQLCLIDriver.isRemoteMode(sessionState) + } + private val conf: Configuration = if (sessionState != null) sessionState.getConf else new Configuration() // Force initializing SparkSQLEnv. This is put here but not object SparkSQLCliDriver // because the Hive unit tests do not go through the main() code path. - if (!sessionState.isRemoteMode) { + if (!isRemoteMode) { SparkSQLEnv.init() + } else { + // Hive 1.2 + not supported in CLI + throw new RuntimeException("Remote operations not supported") } override def processCmd(cmd: String): Int = { val cmd_trimmed: String = cmd.trim() + val cmd_lower = cmd_trimmed.toLowerCase(Locale.ENGLISH) val tokens: Array[String] = cmd_trimmed.split("\\s+") val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() - if (cmd_trimmed.toLowerCase.equals("quit") || - cmd_trimmed.toLowerCase.equals("exit") || - tokens(0).equalsIgnoreCase("source") || + if (cmd_lower.equals("quit") || + cmd_lower.equals("exit") || + tokens(0).toLowerCase(Locale.ENGLISH).equals("source") || cmd_trimmed.startsWith("!") || tokens(0).toLowerCase.equals("list") || - sessionState.isRemoteMode) { + isRemoteMode) { val start = System.currentTimeMillis() super.processCmd(cmd) val end = System.currentTimeMillis() diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index 41f647d5f8c5a..644165acf70a7 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -23,11 +23,12 @@ import javax.security.auth.login.LoginException import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.hadoop.hive.shims.Utils import org.apache.hadoop.security.UserGroupInformation import org.apache.hive.service.Service.STATE import org.apache.hive.service.auth.HiveAuthFactory import org.apache.hive.service.cli._ +import org.apache.hive.service.server.HiveServer2 import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.spark.sql.hive.HiveContext @@ -35,22 +36,22 @@ import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import scala.collection.JavaConversions._ -private[hive] class SparkSQLCLIService(hiveContext: HiveContext) - extends CLIService +private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, hiveContext: HiveContext) + extends CLIService(hiveServer) with ReflectedCompositeService { override def init(hiveConf: HiveConf) { setSuperField(this, "hiveConf", hiveConf) - val sparkSqlSessionManager = new SparkSQLSessionManager(hiveContext) + val sparkSqlSessionManager = new SparkSQLSessionManager(hiveServer, hiveContext) setSuperField(this, "sessionManager", sparkSqlSessionManager) addService(sparkSqlSessionManager) var sparkServiceUGI: UserGroupInformation = null - if (ShimLoader.getHadoopShims.isSecurityEnabled) { + if (UserGroupInformation.isSecurityEnabled) { try { HiveAuthFactory.loginFromKeytab(hiveConf) - sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf) + sparkServiceUGI = Utils.getUGI() setSuperField(this, "serviceUGI", sparkServiceUGI) } catch { case e @ (_: IOException | _: LoginException) => diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index 2d5ee68002286..92ac0ec3fca29 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -25,14 +25,15 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.service.cli.SessionHandle import org.apache.hive.service.cli.session.SessionManager import org.apache.hive.service.cli.thrift.TProtocolVersion +import org.apache.hive.service.server.HiveServer2 import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager -private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) - extends SessionManager +private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext: HiveContext) + extends SessionManager(hiveServer) with ReflectedCompositeService { private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) @@ -55,12 +56,14 @@ private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) protocol: TProtocolVersion, username: String, passwd: String, + ipAddress: String, sessionConf: java.util.Map[String, String], withImpersonation: Boolean, delegationToken: String): SessionHandle = { hiveContext.openSession() - val sessionHandle = super.openSession( - protocol, username, passwd, sessionConf, withImpersonation, delegationToken) + val sessionHandle = + super.openSession(protocol, username, passwd, ipAddress, sessionConf, withImpersonation, + delegationToken) val session = super.getSession(sessionHandle) HiveThriftServer2.listener.onSessionCreated( session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index df80d04b40801..121b3e077f71f 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -23,6 +23,7 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.concurrent.{Await, Promise} import scala.sys.process.{Process, ProcessLogger} +import scala.util.Failure import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.BeforeAndAfter @@ -37,31 +38,46 @@ import org.apache.spark.util.Utils class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { val warehousePath = Utils.createTempDir() val metastorePath = Utils.createTempDir() + val scratchDirPath = Utils.createTempDir() before { - warehousePath.delete() - metastorePath.delete() + warehousePath.delete() + metastorePath.delete() + scratchDirPath.delete() } after { - warehousePath.delete() - metastorePath.delete() + warehousePath.delete() + metastorePath.delete() + scratchDirPath.delete() } + /** + * Run a CLI operation and expect all the queries and expected answers to be returned. + * @param timeout maximum time for the commands to complete + * @param extraArgs any extra arguments + * @param errorResponses a sequence of strings whose presence in the stdout of the forked process + * is taken as an immediate error condition. That is: if a line beginning + * with one of these strings is found, fail the test immediately. + * The default value is `Seq("Error:")` + * + * @param queriesAndExpectedAnswers one or more tupes of query + answer + */ def runCliWithin( timeout: FiniteDuration, - extraArgs: Seq[String] = Seq.empty)( + extraArgs: Seq[String] = Seq.empty, + errorResponses: Seq[String] = Seq("Error:"))( queriesAndExpectedAnswers: (String, String)*): Unit = { val (queries, expectedAnswers) = queriesAndExpectedAnswers.unzip - val cliScript = "../../bin/spark-sql".split("/").mkString(File.separator) - val command = { + val cliScript = "../../bin/spark-sql".split("/").mkString(File.separator) val jdbcUrl = s"jdbc:derby:;databaseName=$metastorePath;create=true" s"""$cliScript | --master local | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$jdbcUrl | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath + | --hiveconf ${ConfVars.SCRATCHDIR}=$scratchDirPath """.stripMargin.split("\\s+").toSeq ++ extraArgs } @@ -81,6 +97,12 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { if (next == expectedAnswers.size) { foundAllExpectedAnswers.trySuccess(()) } + } else { + errorResponses.foreach( r => { + if (line.startsWith(r)) { + foundAllExpectedAnswers.tryFailure( + new RuntimeException(s"Failed with error line '$line'")) + }}) } } @@ -88,16 +110,44 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { val process = (Process(command, None) #< queryStream).run( ProcessLogger(captureOutput("stdout"), captureOutput("stderr"))) + // catch the output value + class exitCodeCatcher extends Runnable { + var exitValue = 0 + + override def run(): Unit = { + try { + exitValue = process.exitValue() + } catch { + case rte: RuntimeException => + // ignored as it will get triggered when the process gets destroyed + logDebug("Ignoring exception while waiting for exit code", rte) + } + if (exitValue != 0) { + // process exited: fail fast + foundAllExpectedAnswers.tryFailure( + new RuntimeException(s"Failed with exit code $exitValue")) + } + } + } + // spin off the code catche thread. No attempt is made to kill this + // as it will exit once the launched process terminates. + val codeCatcherThread = new Thread(new exitCodeCatcher()) + codeCatcherThread.start() + try { - Await.result(foundAllExpectedAnswers.future, timeout) + Await.ready(foundAllExpectedAnswers.future, timeout) + foundAllExpectedAnswers.future.value match { + case Some(Failure(t)) => throw t + case _ => + } } catch { case cause: Throwable => - logError( + val message = s""" |======================= |CliSuite failure output |======================= |Spark SQL CLI command line: ${command.mkString(" ")} - | + |Exception: $cause |Executed query $next "${queries(next)}", |But failed to capture expected output "${expectedAnswers(next)}" within $timeout. | @@ -105,8 +155,9 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { |=========================== |End CliSuite failure output |=========================== - """.stripMargin, cause) - throw cause + """.stripMargin + logError(message, cause) + fail(message, cause) } finally { process.destroy() } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 39b31523e07cb..8374629b5d45a 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive.thriftserver import java.io.File import java.net.URL -import java.nio.charset.StandardCharsets import java.sql.{Date, DriverManager, SQLException, Statement} import scala.collection.mutable.ArrayBuffer @@ -492,7 +491,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl new File(s"$tempLog4jConf/log4j.properties"), UTF_8) - tempLog4jConf + File.pathSeparator + sys.props("java.class.path") + tempLog4jConf // + File.pathSeparator + sys.props("java.class.path") } s"""$startScript @@ -508,6 +507,20 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl """.stripMargin.split("\\s+").toSeq } + /** + * String to scan for when looking for the the thrift binary endpoint running. + * This can change across Hive versions. + */ + val THRIFT_BINARY_SERVICE_LIVE = "Starting ThriftBinaryCLIService on port" + + /** + * String to scan for when looking for the the thrift HTTP endpoint running. + * This can change across Hive versions. + */ + val THRIFT_HTTP_SERVICE_LIVE = "Started ThriftHttpCLIService in http" + + val SERVER_STARTUP_TIMEOUT = 1.minute + private def startThriftServer(port: Int, attempt: Int) = { warehousePath = Utils.createTempDir() warehousePath.delete() @@ -545,23 +558,26 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl // Ensures that the following "tail" command won't fail. logPath.createNewFile() + val successLines = Seq(THRIFT_BINARY_SERVICE_LIVE, THRIFT_HTTP_SERVICE_LIVE) + val failureLines = Seq("HiveServer2 is stopped", "Exception in thread", "Error:") logTailingProcess = // Using "-n +0" to make sure all lines in the log file are checked. Process(s"/usr/bin/env tail -n +0 -f ${logPath.getCanonicalPath}").run(ProcessLogger( (line: String) => { diagnosisBuffer += line - - if (line.contains("ThriftBinaryCLIService listening on") || - line.contains("Started ThriftHttpCLIService in http")) { - serverStarted.trySuccess(()) - } else if (line.contains("HiveServer2 is stopped")) { - // This log line appears when the server fails to start and terminates gracefully (e.g. - // because of port contention). - serverStarted.tryFailure(new RuntimeException("Failed to start HiveThriftServer2")) - } + successLines.foreach(r => { + if (line.contains(r)) { + serverStarted.trySuccess(()) + } + }) + failureLines.foreach(r => { + if (line.contains(r)) { + serverStarted.tryFailure(new RuntimeException(s"Failed with output '$line'")) + } + }) })) - Await.result(serverStarted.future, 2.minute) + Await.result(serverStarted.future, SERVER_STARTUP_TIMEOUT) } private def stopThriftServer(): Unit = { diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 53d5b22b527b2..c46a4a4b0be54 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -267,7 +267,34 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "date_udf", // Unlike Hive, we do support log base in (0, 1.0], therefore disable this - "udf7" + "udf7", + + // Trivial changes to DDL output + "compute_stats_empty_table", + "compute_stats_long", + "create_view_translate", + "show_create_table_serde", + "show_tblproperties", + + // Odd changes to output + "merge4", + + // Thift is broken... + "inputddl8", + + // Hive changed ordering of ddl: + "varchar_union1", + + // Parser changes in Hive 1.2 + "input25", + "input26", + + // Uses invalid table name + "innerjoin", + + // classpath problems + "compute_stats.*", + "udf_bitmap_.*" ) /** diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index b00f320318be0..be1607476e254 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -36,6 +36,11 @@ + + + com.twitter + parquet-hadoop-bundle + org.apache.spark spark-core_${scala.binary.version} @@ -53,32 +58,42 @@ spark-sql_${scala.binary.version} ${project.version} + - org.codehaus.jackson - jackson-mapper-asl + ${hive.group} + hive-exec + ${hive.group} - hive-serde + hive-metastore + org.apache.avro @@ -91,6 +106,55 @@ avro-mapred ${avro.mapred.classifier} + + commons-httpclient + commons-httpclient + + + org.apache.calcite + calcite-avatica + + + org.apache.calcite + calcite-core + + + org.apache.httpcomponents + httpclient + + + org.codehaus.jackson + jackson-mapper-asl + + + + commons-codec + commons-codec + + + joda-time + joda-time + + + org.jodd + jodd-core + + + com.google.code.findbugs + jsr305 + + + org.datanucleus + datanucleus-core + + + org.apache.thrift + libthrift + + + org.apache.thrift + libfb303 + org.scalacheck scalacheck_${scala.binary.version} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 110f51a305861..567d7fa12ff14 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -20,15 +20,18 @@ package org.apache.spark.sql.hive import java.io.File import java.net.{URL, URLClassLoader} import java.sql.Timestamp +import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ import scala.collection.mutable.HashMap import scala.language.implicitConversions +import scala.concurrent.duration._ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.ql.session.SessionState @@ -164,6 +167,16 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { } SessionState.setCurrentSessionState(executionHive.state) + /** + * Overrides default Hive configurations to avoid breaking changes to Spark SQL users. + * - allow SQL11 keywords to be used as identifiers + */ + private[sql] def defaultOverides() = { + setConf(ConfVars.HIVE_SUPPORT_SQL11_RESERVED_KEYWORDS.varname, "false") + } + + defaultOverides() + /** * The copy of the Hive client that is used to retrieve metadata from the Hive MetaStore. * The version of the Hive client that is used here must match the metastore that is configured @@ -252,6 +265,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { } protected[sql] override def parseSql(sql: String): LogicalPlan = { + var state = SessionState.get() + if (state == null) { + SessionState.setCurrentSessionState(tlSession.get().asInstanceOf[SQLSession].sessionState) + } super.parseSql(substitutor.substitute(hiveconf, sql)) } @@ -298,10 +315,21 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { // Can we use fs.getContentSummary in future? // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use // countFileSize to count the table size. + val stagingDir = metadataHive.getConf(HiveConf.ConfVars.STAGINGDIR.varname, + HiveConf.ConfVars.STAGINGDIR.defaultStrVal) + def calculateTableSize(fs: FileSystem, path: Path): Long = { val fileStatus = fs.getFileStatus(path) val size = if (fileStatus.isDir) { - fs.listStatus(path).map(status => calculateTableSize(fs, status.getPath)).sum + fs.listStatus(path) + .map { status => + if (!status.getPath().getName().startsWith(stagingDir)) { + calculateTableSize(fs, status.getPath) + } else { + 0L + } + } + .sum } else { fileStatus.getLen } @@ -398,7 +426,58 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { } /** Overridden by child classes that need to set configuration before the client init. */ - protected def configure(): Map[String, String] = Map.empty + protected def configure(): Map[String, String] = { + // Hive 0.14.0 introduces timeout operations in HiveConf, and changes default values of a bunch + // of time `ConfVar`s by adding time suffixes (`s`, `ms`, and `d` etc.). This breaks backwards- + // compatibility when users are trying to connecting to a Hive metastore of lower version, + // because these options are expected to be integral values in lower versions of Hive. + // + // Here we enumerate all time `ConfVar`s and convert their values to numeric strings according + // to their output time units. + Seq( + ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY -> TimeUnit.SECONDS, + ConfVars.METASTORE_CLIENT_SOCKET_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.METASTORE_CLIENT_SOCKET_LIFETIME -> TimeUnit.SECONDS, + ConfVars.HMSHANDLERINTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.METASTORE_EVENT_DB_LISTENER_TTL -> TimeUnit.SECONDS, + ConfVars.METASTORE_EVENT_CLEAN_FREQ -> TimeUnit.SECONDS, + ConfVars.METASTORE_EVENT_EXPIRY_DURATION -> TimeUnit.SECONDS, + ConfVars.METASTORE_AGGREGATE_STATS_CACHE_TTL -> TimeUnit.SECONDS, + ConfVars.METASTORE_AGGREGATE_STATS_CACHE_MAX_WRITER_WAIT -> TimeUnit.MILLISECONDS, + ConfVars.METASTORE_AGGREGATE_STATS_CACHE_MAX_READER_WAIT -> TimeUnit.MILLISECONDS, + ConfVars.HIVES_AUTO_PROGRESS_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_LOG_INCREMENTAL_PLAN_PROGRESS_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_STATS_JDBC_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_STATS_RETRIES_WAIT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_LOCK_SLEEP_BETWEEN_RETRIES -> TimeUnit.SECONDS, + ConfVars.HIVE_ZOOKEEPER_SESSION_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_ZOOKEEPER_CONNECTION_BASESLEEPTIME -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_TXN_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_COMPACTOR_WORKER_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_COMPACTOR_CHECK_INTERVAL -> TimeUnit.SECONDS, + ConfVars.HIVE_COMPACTOR_CLEANER_RUN_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_THRIFT_HTTP_MAX_IDLE_TIME -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_THRIFT_HTTP_WORKER_KEEPALIVE_TIME -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_THRIFT_HTTP_COOKIE_MAX_AGE -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_THRIFT_LOGIN_BEBACKOFF_SLOT_LENGTH -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_THRIFT_LOGIN_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_THRIFT_WORKER_KEEPALIVE_TIME -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_ASYNC_EXEC_SHUTDOWN_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_ASYNC_EXEC_KEEPALIVE_TIME -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_LONG_POLLING_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_SESSION_CHECK_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_IDLE_SESSION_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_IDLE_OPERATION_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.SERVER_READ_SOCKET_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_LOCALIZE_RESOURCE_WAIT_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.SPARK_CLIENT_FUTURE_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.SPARK_JOB_MONITOR_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.SPARK_RPC_CLIENT_CONNECT_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.SPARK_RPC_CLIENT_HANDSHAKE_TIMEOUT -> TimeUnit.MILLISECONDS + ).map { case (confVar, unit) => + confVar.varname -> hiveconf.getTimeVar(confVar, unit).toString + }.toMap + } protected[hive] class SQLSession extends super.SQLSession { protected[sql] override lazy val conf: SQLConf = new SQLConf { @@ -515,19 +594,23 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { private[hive] object HiveContext { /** The version of hive used internally by Spark SQL. */ - val hiveExecutionVersion: String = "0.13.1" + val hiveExecutionVersion: String = "1.2.1" val HIVE_METASTORE_VERSION: String = "spark.sql.hive.metastore.version" val HIVE_METASTORE_JARS = stringConf("spark.sql.hive.metastore.jars", defaultValue = Some("builtin"), - doc = "Location of the jars that should be used to instantiate the HiveMetastoreClient. This" + - " property can be one of three options: " + - "1. \"builtin\" Use Hive 0.13.1, which is bundled with the Spark assembly jar when " + - "-Phive is enabled. When this option is chosen, " + - "spark.sql.hive.metastore.version must be either 0.13.1 or not defined. " + - "2. \"maven\" Use Hive jars of specified version downloaded from Maven repositories." + - "3. A classpath in the standard format for both Hive and Hadoop.") - + doc = s""" + | Location of the jars that should be used to instantiate the HiveMetastoreClient. + | This property can be one of three options: " + | 1. "builtin" + | Use Hive ${hiveExecutionVersion}, which is bundled with the Spark assembly jar when + | -Phive is enabled. When this option is chosen, + | spark.sql.hive.metastore.version must be either + | ${hiveExecutionVersion} or not defined. + | 2. "maven" + | Use Hive jars of specified version downloaded from Maven repositories. + | 3. A classpath in the standard format for both Hive and Hadoop. + """.stripMargin) val CONVERT_METASTORE_PARQUET = booleanConf("spark.sql.hive.convertMetastoreParquet", defaultValue = Some(true), doc = "When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of " + @@ -566,17 +649,18 @@ private[hive] object HiveContext { /** Constructs a configuration for hive, where the metastore is located in a temp directory. */ def newTemporaryConfiguration(): Map[String, String] = { val tempDir = Utils.createTempDir() - val localMetastore = new File(tempDir, "metastore").getAbsolutePath + val localMetastore = new File(tempDir, "metastore") val propMap: HashMap[String, String] = HashMap() // We have to mask all properties in hive-site.xml that relates to metastore data source // as we used a local metastore here. HiveConf.ConfVars.values().foreach { confvar => if (confvar.varname.contains("datanucleus") || confvar.varname.contains("jdo")) { - propMap.put(confvar.varname, confvar.defaultVal) + propMap.put(confvar.varname, confvar.getDefaultExpr()) } } - propMap.put("javax.jdo.option.ConnectionURL", - s"jdbc:derby:;databaseName=$localMetastore;create=true") + propMap.put(HiveConf.ConfVars.METASTOREWAREHOUSE.varname, localMetastore.toURI.toString) + propMap.put(HiveConf.ConfVars.METASTORECONNECTURLKEY.varname, + s"jdbc:derby:;databaseName=${localMetastore.getAbsolutePath};create=true") propMap.put("datanucleus.rdbms.datastoreAdapterClassName", "org.datanucleus.store.rdbms.adapter.DerbyAdapter") propMap.toMap diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index a8c9b4fa71b99..16c186627f6cc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -649,11 +649,12 @@ private[hive] case class MetastoreRelation table.outputFormat.foreach(sd.setOutputFormat) val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo - sd.setSerdeInfo(serdeInfo) table.serde.foreach(serdeInfo.setSerializationLib) + sd.setSerdeInfo(serdeInfo) + val serdeParameters = new java.util.HashMap[String, String]() - serdeInfo.setParameters(serdeParameters) table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + serdeInfo.setParameters(serdeParameters) new Table(tTable) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index e6df64d2642bc..e2fdfc6163a00 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import java.sql.Date +import java.util.Locale import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.serde.serdeConstants @@ -80,6 +81,7 @@ private[hive] object HiveQl extends Logging { "TOK_ALTERDATABASE_PROPERTIES", "TOK_ALTERINDEX_PROPERTIES", "TOK_ALTERINDEX_REBUILD", + "TOK_ALTERTABLE", "TOK_ALTERTABLE_ADDCOLS", "TOK_ALTERTABLE_ADDPARTS", "TOK_ALTERTABLE_ALTERPARTS", @@ -94,6 +96,7 @@ private[hive] object HiveQl extends Logging { "TOK_ALTERTABLE_SKEWED", "TOK_ALTERTABLE_TOUCH", "TOK_ALTERTABLE_UNARCHIVE", + "TOK_ALTERVIEW", "TOK_ALTERVIEW_ADDPARTS", "TOK_ALTERVIEW_AS", "TOK_ALTERVIEW_DROPPARTS", @@ -248,7 +251,7 @@ private[hive] object HiveQl extends Logging { * Otherwise, there will be Null pointer exception, * when retrieving properties form HiveConf. */ - val hContext = new Context(hiveConf) + val hContext = new Context(SessionState.get().getConf()) val node = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql, hContext)) hContext.clear() node @@ -577,12 +580,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C "TOK_TABLESKEWED", // Skewed by "TOK_TABLEROWFORMAT", "TOK_TABLESERIALIZER", - "TOK_FILEFORMAT_GENERIC", // For file formats not natively supported by Hive. - "TOK_TBLSEQUENCEFILE", // Stored as SequenceFile - "TOK_TBLTEXTFILE", // Stored as TextFile - "TOK_TBLRCFILE", // Stored as RCFile - "TOK_TBLORCFILE", // Stored as ORC File - "TOK_TBLPARQUETFILE", // Stored as PARQUET + "TOK_FILEFORMAT_GENERIC", "TOK_TABLEFILEFORMAT", // User-provided InputFormat and OutputFormat "TOK_STORAGEHANDLER", // Storage handler "TOK_TABLELOCATION", @@ -706,36 +704,51 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C tableDesc = tableDesc.copy(serdeProperties = tableDesc.serdeProperties ++ serdeParams) } case Token("TOK_FILEFORMAT_GENERIC", child :: Nil) => - throw new SemanticException( - "Unrecognized file format in STORED AS clause:${child.getText}") + child.getText().toLowerCase(Locale.ENGLISH) match { + case "orc" => + tableDesc = tableDesc.copy( + inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + } - case Token("TOK_TBLRCFILE", Nil) => - tableDesc = tableDesc.copy( - inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) - } + case "parquet" => + tableDesc = tableDesc.copy( + inputFormat = + Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), + outputFormat = + Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) + } - case Token("TOK_TBLORCFILE", Nil) => - tableDesc = tableDesc.copy( - inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) - } + case "rcfile" => + tableDesc = tableDesc.copy( + inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) + } - case Token("TOK_TBLPARQUETFILE", Nil) => - tableDesc = tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) + case "textfile" => + tableDesc = tableDesc.copy( + inputFormat = + Option("org.apache.hadoop.mapred.TextInputFormat"), + outputFormat = + Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) + + case "sequencefile" => + tableDesc = tableDesc.copy( + inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), + outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")) + + case _ => + throw new SemanticException( + s"Unrecognized file format in STORED AS clause: ${child.getText}") } case Token("TOK_TABLESERIALIZER", @@ -751,7 +764,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_TABLEPROPERTIES", list :: Nil) => tableDesc = tableDesc.copy(properties = tableDesc.properties ++ getProperties(list)) - case list @ Token("TOK_TABLEFILEFORMAT", _) => + case list @ Token("TOK_TABLEFILEFORMAT", children) => tableDesc = tableDesc.copy( inputFormat = Option(BaseSemanticAnalyzer.unescapeSQLString(list.getChild(0).getText)), @@ -889,7 +902,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C Token("TOK_TABLEPROPLIST", propsClause) :: Nil) :: Nil) :: Nil => val serdeProps = propsClause.map { case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) => - (name, value) + (BaseSemanticAnalyzer.unescapeSQLString(name), + BaseSemanticAnalyzer.unescapeSQLString(value)) } (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps) @@ -1037,10 +1051,11 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // return With plan if there is CTE cteRelations.map(With(query, _)).getOrElse(query) - case Token("TOK_UNION", left :: right :: Nil) => Union(nodeToPlan(left), nodeToPlan(right)) + // HIVE-9039 renamed TOK_UNION => TOK_UNIONALL while adding TOK_UNIONDISTINCT + case Token("TOK_UNIONALL", left :: right :: Nil) => Union(nodeToPlan(left), nodeToPlan(right)) case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + throw new NotImplementedError(s"No parse rules for $node:\n ${dumpTree(a).toString} ") } val allJoinTokens = "(TOK_.*JOIN)".r @@ -1251,7 +1266,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, true) case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + throw new NotImplementedError(s"No parse rules for ${a.getName}:" + + s"\n ${dumpTree(a).toString} ") } protected def selExprNodeToExpr(node: Node): Option[Expression] = node match { @@ -1274,7 +1290,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_HINTLIST", _) => None case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + throw new NotImplementedError(s"No parse rules for ${a.getName }:" + + s"\n ${dumpTree(a).toString } ") } protected val escapedIdentifier = "`([^`]+)`".r diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index a357bb39ca7fd..267074f3ad102 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.hive import java.io.{InputStream, OutputStream} import java.rmi.server.UID +import org.apache.avro.Schema + /* Implicit conversions */ import scala.collection.JavaConversions._ import scala.language.implicitConversions @@ -33,7 +35,7 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} import org.apache.hadoop.hive.serde2.ColumnProjectionUtils -import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable +import org.apache.hadoop.hive.serde2.avro.{AvroGenericRecordWritable, AvroSerdeUtils} import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector import org.apache.hadoop.io.Writable @@ -82,10 +84,19 @@ private[hive] object HiveShim { * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that * is needed to initialize before serialization. */ - def prepareWritable(w: Writable): Writable = { + def prepareWritable(w: Writable, serDeProps: Seq[(String, String)]): Writable = { w match { case w: AvroGenericRecordWritable => w.setRecordReaderID(new UID()) + // In Hive 1.1, the record's schema may need to be initialized manually or a NPE will + // be thrown. + if (w.getFileSchema() == null) { + serDeProps + .find(_._1 == AvroSerdeUtils.AvroTableProperties.SCHEMA_LITERAL.getPropName()) + .foreach { kv => + w.setFileSchema(new Schema.Parser().parse(kv._2)) + } + } case _ => } w diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index d834b4e83e043..a82e152dcda2c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -87,6 +87,10 @@ private[hive] case class HiveTable( * shared classes. */ private[hive] trait ClientInterface { + + /** Returns the configuration for the given key in the current session. */ + def getConf(key: String, defaultValue: String): String + /** * Runs a HiveQL command using Hive, returning the results as a list of strings. Each row will * result in one string. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 6e0912da5862d..dc372be0e5a37 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -38,7 +38,6 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.util.{CircularBuffer, Utils} - /** * A class that wraps the HiveClient and converts its responses to externally visible classes. * Note that this class is typically loaded with an internal classloader for each instantiation, @@ -115,6 +114,10 @@ private[hive] class ClientWrapper( /** Returns the configuration for the current session. */ def conf: HiveConf = SessionState.get().getConf + override def getConf(key: String, defaultValue: String): String = { + conf.get(key, defaultValue) + } + // TODO: should be a def?s // When we create this val client, the HiveConf of it (conf) is the one associated with state. @GuardedBy("this") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 956997e5f9dce..6e826ce552204 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -512,7 +512,7 @@ private[client] class Shim_v1_2 extends Shim_v1_1 { listBucketingEnabled: Boolean): Unit = { loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean, JBoolean.FALSE, - 0: JLong) + 0L: JLong) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 97fb98199991b..f58bc7d7a0af4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -55,7 +55,7 @@ private[hive] object IsolatedClientLoader { case "14" | "0.14" | "0.14.0" => hive.v14 case "1.0" | "1.0.0" => hive.v1_0 case "1.1" | "1.1.0" => hive.v1_1 - case "1.2" | "1.2.0" => hive.v1_2 + case "1.2" | "1.2.0" | "1.2.1" => hive.v1_2 } private def downloadVersion(version: HiveVersion, ivyPath: Option[String]): Seq[URL] = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index b48082fe4b363..0503691a44249 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -56,7 +56,7 @@ package object client { "net.hydromatic:linq4j", "net.hydromatic:quidem")) - case object v1_2 extends HiveVersion("1.2.0", + case object v1_2 extends HiveVersion("1.2.1", exclusions = Seq("eigenbase:eigenbase-properties", "org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 40a6a32156687..12c667e6e92da 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -129,7 +129,7 @@ case class InsertIntoHiveTable( // instances within the closure, since Serializer is not serializable while TableDesc is. val tableDesc = table.tableDesc val tableLocation = table.hiveQlTable.getDataLocation - val tmpLocation = hiveContext.getExternalTmpPath(tableLocation.toUri) + val tmpLocation = hiveContext.getExternalTmpPath(tableLocation) val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) val isCompressed = sc.hiveconf.getBoolean( ConfVars.COMPRESSRESULT.varname, ConfVars.COMPRESSRESULT.defaultBoolVal) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 7e3342cc84c0e..fbb86406f40cb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -247,7 +247,7 @@ private class ScriptTransformationWriterThread( } else { val writable = inputSerde.serialize( row.asInstanceOf[GenericInternalRow].values, inputSoi) - prepareWritable(writable).write(dataOutputStream) + prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream) } } outputStream.close() @@ -345,9 +345,7 @@ case class HiveScriptIOSchema ( val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",") - var propsMap = serdeProps.map(kv => { - (kv._1.split("'")(1), kv._2.split("'")(1)) - }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) + var propsMap = serdeProps.toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) val properties = new Properties() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index abe5c69003130..8a86a87368f29 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -249,7 +249,7 @@ private[spark] object ResolveHiveWindowFunction extends Rule[LogicalPlan] { // Get the class of this function. // In Hive 0.12, there is no windowFunctionInfo.getFunctionClass. So, we use // windowFunctionInfo.getfInfo().getFunctionClass for both Hive 0.13 and Hive 0.13.1. - val functionClass = windowFunctionInfo.getfInfo().getFunctionClass + val functionClass = windowFunctionInfo.getFunctionClass() val newChildren = // Rank(), DENSE_RANK(), CUME_DIST(), and PERCENT_RANK() do not take explicit // input parameters and requires implicit parameters, which diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 8850e060d2a73..684ea1d137b49 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -171,7 +171,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( import SparkHiveDynamicPartitionWriterContainer._ private val defaultPartName = jobConf.get( - ConfVars.DEFAULTPARTITIONNAME.varname, ConfVars.DEFAULTPARTITIONNAME.defaultVal) + ConfVars.DEFAULTPARTITIONNAME.varname, ConfVars.DEFAULTPARTITIONNAME.defaultStrVal) @transient private var writers: mutable.HashMap[String, FileSinkOperator.RecordWriter] = _ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index ddd5d24717add..86142e5d66f37 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.orc import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar} -import org.apache.hadoop.hive.ql.io.sarg.SearchArgument +import org.apache.hadoop.hive.ql.io.sarg.{SearchArgumentFactory, SearchArgument} import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder import org.apache.hadoop.hive.serde2.io.DateWritable @@ -33,13 +33,13 @@ import org.apache.spark.sql.sources._ private[orc] object OrcFilters extends Logging { def createFilter(expr: Array[Filter]): Option[SearchArgument] = { expr.reduceOption(And).flatMap { conjunction => - val builder = SearchArgument.FACTORY.newBuilder() + val builder = SearchArgumentFactory.newBuilder() buildSearchArgument(conjunction, builder).map(_.build()) } } private def buildSearchArgument(expression: Filter, builder: Builder): Option[Builder] = { - def newBuilder = SearchArgument.FACTORY.newBuilder() + def newBuilder = SearchArgumentFactory.newBuilder() def isSearchableLiteral(value: Any): Boolean = value match { // These are types recognized by the `SearchArgumentImpl.BuilderImpl.boxLiteral()` method. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 7bbdef90cd6b9..8d0bf46e8fad7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -20,29 +20,25 @@ package org.apache.spark.sql.hive.test import java.io.File import java.util.{Set => JavaSet} -import org.apache.hadoop.hive.conf.HiveConf +import scala.collection.mutable +import scala.language.implicitConversions + import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.FunctionRegistry import org.apache.hadoop.hive.ql.io.avro.{AvroContainerInputFormat, AvroContainerOutputFormat} -import org.apache.hadoop.hive.ql.metadata.Table -import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.hive.serde2.avro.AvroSerDe -import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.execution.HiveNativeCommand -import org.apache.spark.sql.SQLConf import org.apache.spark.util.Utils import org.apache.spark.{SparkConf, SparkContext} -import scala.collection.mutable -import scala.language.implicitConversions - /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -83,15 +79,25 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { hiveconf.set("hive.plan.serialization.format", "javaXML") - lazy val warehousePath = Utils.createTempDir() + lazy val warehousePath = Utils.createTempDir(namePrefix = "warehouse-") + + lazy val scratchDirPath = { + val dir = Utils.createTempDir(namePrefix = "scratch-") + dir.delete() + dir + } private lazy val temporaryConfig = newTemporaryConfiguration() /** Sets up the system initially or after a RESET command */ - protected override def configure(): Map[String, String] = - temporaryConfig ++ Map( - ConfVars.METASTOREWAREHOUSE.varname -> warehousePath.toString, - ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true") + protected override def configure(): Map[String, String] = { + super.configure() ++ temporaryConfig ++ Map( + ConfVars.METASTOREWAREHOUSE.varname -> warehousePath.toURI.toString, + ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true", + ConfVars.SCRATCHDIR.varname -> scratchDirPath.toURI.toString, + ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY.varname -> "1" + ) + } val testTempDir = Utils.createTempDir() @@ -244,7 +250,6 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { }), TestTable("src_thrift", () => { import org.apache.hadoop.hive.serde2.thrift.ThriftDeserializer - import org.apache.hadoop.hive.serde2.thrift.test.Complex import org.apache.hadoop.mapred.{SequenceFileInputFormat, SequenceFileOutputFormat} import org.apache.thrift.protocol.TBinaryProtocol @@ -253,7 +258,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { |CREATE TABLE src_thrift(fake INT) |ROW FORMAT SERDE '${classOf[ThriftDeserializer].getName}' |WITH SERDEPROPERTIES( - | 'serialization.class'='${classOf[Complex].getName}', + | 'serialization.class'='org.apache.spark.sql.hive.test.Complex', | 'serialization.format'='${classOf[TBinaryProtocol].getName}' |) |STORED AS @@ -437,6 +442,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { case (k, v) => metadataHive.runSqlHive(s"SET $k=$v") } + defaultOverides() runSqlHive("USE default") diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java new file mode 100644 index 0000000000000..e010112bb9327 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java @@ -0,0 +1,1139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hive.test; + +import org.apache.commons.lang.builder.HashCodeBuilder; +import org.apache.thrift.scheme.IScheme; +import org.apache.thrift.scheme.SchemeFactory; +import org.apache.thrift.scheme.StandardScheme; + +import org.apache.hadoop.hive.serde2.thrift.test.IntString; +import org.apache.thrift.scheme.TupleScheme; +import org.apache.thrift.protocol.TTupleProtocol; +import org.apache.thrift.EncodingUtils; +import java.util.List; +import java.util.ArrayList; +import java.util.Map; +import java.util.HashMap; +import java.util.EnumMap; +import java.util.EnumSet; +import java.util.Collections; +import java.util.BitSet; + +/** + * This is a fork of Hive 0.13's org/apache/hadoop/hive/serde2/thrift/test/Complex.java, which + * does not contain union fields that are not supported by Spark SQL. + */ + +@SuppressWarnings({"ALL", "unchecked"}) +public class Complex implements org.apache.thrift.TBase, java.io.Serializable, Cloneable { + private static final org.apache.thrift.protocol.TStruct STRUCT_DESC = new org.apache.thrift.protocol.TStruct("Complex"); + + private static final org.apache.thrift.protocol.TField AINT_FIELD_DESC = new org.apache.thrift.protocol.TField("aint", org.apache.thrift.protocol.TType.I32, (short)1); + private static final org.apache.thrift.protocol.TField A_STRING_FIELD_DESC = new org.apache.thrift.protocol.TField("aString", org.apache.thrift.protocol.TType.STRING, (short)2); + private static final org.apache.thrift.protocol.TField LINT_FIELD_DESC = new org.apache.thrift.protocol.TField("lint", org.apache.thrift.protocol.TType.LIST, (short)3); + private static final org.apache.thrift.protocol.TField L_STRING_FIELD_DESC = new org.apache.thrift.protocol.TField("lString", org.apache.thrift.protocol.TType.LIST, (short)4); + private static final org.apache.thrift.protocol.TField LINT_STRING_FIELD_DESC = new org.apache.thrift.protocol.TField("lintString", org.apache.thrift.protocol.TType.LIST, (short)5); + private static final org.apache.thrift.protocol.TField M_STRING_STRING_FIELD_DESC = new org.apache.thrift.protocol.TField("mStringString", org.apache.thrift.protocol.TType.MAP, (short)6); + + private static final Map, SchemeFactory> schemes = new HashMap, SchemeFactory>(); + static { + schemes.put(StandardScheme.class, new ComplexStandardSchemeFactory()); + schemes.put(TupleScheme.class, new ComplexTupleSchemeFactory()); + } + + private int aint; // required + private String aString; // required + private List lint; // required + private List lString; // required + private List lintString; // required + private Map mStringString; // required + + /** The set of fields this struct contains, along with convenience methods for finding and manipulating them. */ + public enum _Fields implements org.apache.thrift.TFieldIdEnum { + AINT((short)1, "aint"), + A_STRING((short)2, "aString"), + LINT((short)3, "lint"), + L_STRING((short)4, "lString"), + LINT_STRING((short)5, "lintString"), + M_STRING_STRING((short)6, "mStringString"); + + private static final Map byName = new HashMap(); + + static { + for (_Fields field : EnumSet.allOf(_Fields.class)) { + byName.put(field.getFieldName(), field); + } + } + + /** + * Find the _Fields constant that matches fieldId, or null if its not found. + */ + public static _Fields findByThriftId(int fieldId) { + switch(fieldId) { + case 1: // AINT + return AINT; + case 2: // A_STRING + return A_STRING; + case 3: // LINT + return LINT; + case 4: // L_STRING + return L_STRING; + case 5: // LINT_STRING + return LINT_STRING; + case 6: // M_STRING_STRING + return M_STRING_STRING; + default: + return null; + } + } + + /** + * Find the _Fields constant that matches fieldId, throwing an exception + * if it is not found. + */ + public static _Fields findByThriftIdOrThrow(int fieldId) { + _Fields fields = findByThriftId(fieldId); + if (fields == null) throw new IllegalArgumentException("Field " + fieldId + " doesn't exist!"); + return fields; + } + + /** + * Find the _Fields constant that matches name, or null if its not found. + */ + public static _Fields findByName(String name) { + return byName.get(name); + } + + private final short _thriftId; + private final String _fieldName; + + _Fields(short thriftId, String fieldName) { + _thriftId = thriftId; + _fieldName = fieldName; + } + + public short getThriftFieldId() { + return _thriftId; + } + + public String getFieldName() { + return _fieldName; + } + } + + // isset id assignments + private static final int __AINT_ISSET_ID = 0; + private byte __isset_bitfield = 0; + public static final Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> metaDataMap; + static { + Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> tmpMap = new EnumMap<_Fields, org.apache.thrift.meta_data.FieldMetaData>(_Fields.class); + tmpMap.put(_Fields.AINT, new org.apache.thrift.meta_data.FieldMetaData("aint", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32))); + tmpMap.put(_Fields.A_STRING, new org.apache.thrift.meta_data.FieldMetaData("aString", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING))); + tmpMap.put(_Fields.LINT, new org.apache.thrift.meta_data.FieldMetaData("lint", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32)))); + tmpMap.put(_Fields.L_STRING, new org.apache.thrift.meta_data.FieldMetaData("lString", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING)))); + tmpMap.put(_Fields.LINT_STRING, new org.apache.thrift.meta_data.FieldMetaData("lintString", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST, + new org.apache.thrift.meta_data.StructMetaData(org.apache.thrift.protocol.TType.STRUCT, IntString.class)))); + tmpMap.put(_Fields.M_STRING_STRING, new org.apache.thrift.meta_data.FieldMetaData("mStringString", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.MapMetaData(org.apache.thrift.protocol.TType.MAP, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING), + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING)))); + metaDataMap = Collections.unmodifiableMap(tmpMap); + org.apache.thrift.meta_data.FieldMetaData.addStructMetaDataMap(Complex.class, metaDataMap); + } + + public Complex() { + } + + public Complex( + int aint, + String aString, + List lint, + List lString, + List lintString, + Map mStringString) + { + this(); + this.aint = aint; + setAintIsSet(true); + this.aString = aString; + this.lint = lint; + this.lString = lString; + this.lintString = lintString; + this.mStringString = mStringString; + } + + /** + * Performs a deep copy on other. + */ + public Complex(Complex other) { + __isset_bitfield = other.__isset_bitfield; + this.aint = other.aint; + if (other.isSetAString()) { + this.aString = other.aString; + } + if (other.isSetLint()) { + List __this__lint = new ArrayList(); + for (Integer other_element : other.lint) { + __this__lint.add(other_element); + } + this.lint = __this__lint; + } + if (other.isSetLString()) { + List __this__lString = new ArrayList(); + for (String other_element : other.lString) { + __this__lString.add(other_element); + } + this.lString = __this__lString; + } + if (other.isSetLintString()) { + List __this__lintString = new ArrayList(); + for (IntString other_element : other.lintString) { + __this__lintString.add(new IntString(other_element)); + } + this.lintString = __this__lintString; + } + if (other.isSetMStringString()) { + Map __this__mStringString = new HashMap(); + for (Map.Entry other_element : other.mStringString.entrySet()) { + + String other_element_key = other_element.getKey(); + String other_element_value = other_element.getValue(); + + String __this__mStringString_copy_key = other_element_key; + + String __this__mStringString_copy_value = other_element_value; + + __this__mStringString.put(__this__mStringString_copy_key, __this__mStringString_copy_value); + } + this.mStringString = __this__mStringString; + } + } + + public Complex deepCopy() { + return new Complex(this); + } + + @Override + public void clear() { + setAintIsSet(false); + this.aint = 0; + this.aString = null; + this.lint = null; + this.lString = null; + this.lintString = null; + this.mStringString = null; + } + + public int getAint() { + return this.aint; + } + + public void setAint(int aint) { + this.aint = aint; + setAintIsSet(true); + } + + public void unsetAint() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __AINT_ISSET_ID); + } + + /** Returns true if field aint is set (has been assigned a value) and false otherwise */ + public boolean isSetAint() { + return EncodingUtils.testBit(__isset_bitfield, __AINT_ISSET_ID); + } + + public void setAintIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __AINT_ISSET_ID, value); + } + + public String getAString() { + return this.aString; + } + + public void setAString(String aString) { + this.aString = aString; + } + + public void unsetAString() { + this.aString = null; + } + + /** Returns true if field aString is set (has been assigned a value) and false otherwise */ + public boolean isSetAString() { + return this.aString != null; + } + + public void setAStringIsSet(boolean value) { + if (!value) { + this.aString = null; + } + } + + public int getLintSize() { + return (this.lint == null) ? 0 : this.lint.size(); + } + + public java.util.Iterator getLintIterator() { + return (this.lint == null) ? null : this.lint.iterator(); + } + + public void addToLint(int elem) { + if (this.lint == null) { + this.lint = new ArrayList<>(); + } + this.lint.add(elem); + } + + public List getLint() { + return this.lint; + } + + public void setLint(List lint) { + this.lint = lint; + } + + public void unsetLint() { + this.lint = null; + } + + /** Returns true if field lint is set (has been assigned a value) and false otherwise */ + public boolean isSetLint() { + return this.lint != null; + } + + public void setLintIsSet(boolean value) { + if (!value) { + this.lint = null; + } + } + + public int getLStringSize() { + return (this.lString == null) ? 0 : this.lString.size(); + } + + public java.util.Iterator getLStringIterator() { + return (this.lString == null) ? null : this.lString.iterator(); + } + + public void addToLString(String elem) { + if (this.lString == null) { + this.lString = new ArrayList(); + } + this.lString.add(elem); + } + + public List getLString() { + return this.lString; + } + + public void setLString(List lString) { + this.lString = lString; + } + + public void unsetLString() { + this.lString = null; + } + + /** Returns true if field lString is set (has been assigned a value) and false otherwise */ + public boolean isSetLString() { + return this.lString != null; + } + + public void setLStringIsSet(boolean value) { + if (!value) { + this.lString = null; + } + } + + public int getLintStringSize() { + return (this.lintString == null) ? 0 : this.lintString.size(); + } + + public java.util.Iterator getLintStringIterator() { + return (this.lintString == null) ? null : this.lintString.iterator(); + } + + public void addToLintString(IntString elem) { + if (this.lintString == null) { + this.lintString = new ArrayList<>(); + } + this.lintString.add(elem); + } + + public List getLintString() { + return this.lintString; + } + + public void setLintString(List lintString) { + this.lintString = lintString; + } + + public void unsetLintString() { + this.lintString = null; + } + + /** Returns true if field lintString is set (has been assigned a value) and false otherwise */ + public boolean isSetLintString() { + return this.lintString != null; + } + + public void setLintStringIsSet(boolean value) { + if (!value) { + this.lintString = null; + } + } + + public int getMStringStringSize() { + return (this.mStringString == null) ? 0 : this.mStringString.size(); + } + + public void putToMStringString(String key, String val) { + if (this.mStringString == null) { + this.mStringString = new HashMap(); + } + this.mStringString.put(key, val); + } + + public Map getMStringString() { + return this.mStringString; + } + + public void setMStringString(Map mStringString) { + this.mStringString = mStringString; + } + + public void unsetMStringString() { + this.mStringString = null; + } + + /** Returns true if field mStringString is set (has been assigned a value) and false otherwise */ + public boolean isSetMStringString() { + return this.mStringString != null; + } + + public void setMStringStringIsSet(boolean value) { + if (!value) { + this.mStringString = null; + } + } + + public void setFieldValue(_Fields field, Object value) { + switch (field) { + case AINT: + if (value == null) { + unsetAint(); + } else { + setAint((Integer)value); + } + break; + + case A_STRING: + if (value == null) { + unsetAString(); + } else { + setAString((String)value); + } + break; + + case LINT: + if (value == null) { + unsetLint(); + } else { + setLint((List)value); + } + break; + + case L_STRING: + if (value == null) { + unsetLString(); + } else { + setLString((List)value); + } + break; + + case LINT_STRING: + if (value == null) { + unsetLintString(); + } else { + setLintString((List)value); + } + break; + + case M_STRING_STRING: + if (value == null) { + unsetMStringString(); + } else { + setMStringString((Map)value); + } + break; + + } + } + + public Object getFieldValue(_Fields field) { + switch (field) { + case AINT: + return Integer.valueOf(getAint()); + + case A_STRING: + return getAString(); + + case LINT: + return getLint(); + + case L_STRING: + return getLString(); + + case LINT_STRING: + return getLintString(); + + case M_STRING_STRING: + return getMStringString(); + + } + throw new IllegalStateException(); + } + + /** Returns true if field corresponding to fieldID is set (has been assigned a value) and false otherwise */ + public boolean isSet(_Fields field) { + if (field == null) { + throw new IllegalArgumentException(); + } + + switch (field) { + case AINT: + return isSetAint(); + case A_STRING: + return isSetAString(); + case LINT: + return isSetLint(); + case L_STRING: + return isSetLString(); + case LINT_STRING: + return isSetLintString(); + case M_STRING_STRING: + return isSetMStringString(); + } + throw new IllegalStateException(); + } + + @Override + public boolean equals(Object that) { + if (that == null) + return false; + if (that instanceof Complex) + return this.equals((Complex)that); + return false; + } + + public boolean equals(Complex that) { + if (that == null) + return false; + + boolean this_present_aint = true; + boolean that_present_aint = true; + if (this_present_aint || that_present_aint) { + if (!(this_present_aint && that_present_aint)) + return false; + if (this.aint != that.aint) + return false; + } + + boolean this_present_aString = true && this.isSetAString(); + boolean that_present_aString = true && that.isSetAString(); + if (this_present_aString || that_present_aString) { + if (!(this_present_aString && that_present_aString)) + return false; + if (!this.aString.equals(that.aString)) + return false; + } + + boolean this_present_lint = true && this.isSetLint(); + boolean that_present_lint = true && that.isSetLint(); + if (this_present_lint || that_present_lint) { + if (!(this_present_lint && that_present_lint)) + return false; + if (!this.lint.equals(that.lint)) + return false; + } + + boolean this_present_lString = true && this.isSetLString(); + boolean that_present_lString = true && that.isSetLString(); + if (this_present_lString || that_present_lString) { + if (!(this_present_lString && that_present_lString)) + return false; + if (!this.lString.equals(that.lString)) + return false; + } + + boolean this_present_lintString = true && this.isSetLintString(); + boolean that_present_lintString = true && that.isSetLintString(); + if (this_present_lintString || that_present_lintString) { + if (!(this_present_lintString && that_present_lintString)) + return false; + if (!this.lintString.equals(that.lintString)) + return false; + } + + boolean this_present_mStringString = true && this.isSetMStringString(); + boolean that_present_mStringString = true && that.isSetMStringString(); + if (this_present_mStringString || that_present_mStringString) { + if (!(this_present_mStringString && that_present_mStringString)) + return false; + if (!this.mStringString.equals(that.mStringString)) + return false; + } + + return true; + } + + @Override + public int hashCode() { + HashCodeBuilder builder = new HashCodeBuilder(); + + boolean present_aint = true; + builder.append(present_aint); + if (present_aint) + builder.append(aint); + + boolean present_aString = true && (isSetAString()); + builder.append(present_aString); + if (present_aString) + builder.append(aString); + + boolean present_lint = true && (isSetLint()); + builder.append(present_lint); + if (present_lint) + builder.append(lint); + + boolean present_lString = true && (isSetLString()); + builder.append(present_lString); + if (present_lString) + builder.append(lString); + + boolean present_lintString = true && (isSetLintString()); + builder.append(present_lintString); + if (present_lintString) + builder.append(lintString); + + boolean present_mStringString = true && (isSetMStringString()); + builder.append(present_mStringString); + if (present_mStringString) + builder.append(mStringString); + + return builder.toHashCode(); + } + + public int compareTo(Complex other) { + if (!getClass().equals(other.getClass())) { + return getClass().getName().compareTo(other.getClass().getName()); + } + + int lastComparison = 0; + Complex typedOther = (Complex)other; + + lastComparison = Boolean.valueOf(isSetAint()).compareTo(typedOther.isSetAint()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetAint()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.aint, typedOther.aint); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetAString()).compareTo(typedOther.isSetAString()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetAString()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.aString, typedOther.aString); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetLint()).compareTo(typedOther.isSetLint()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetLint()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.lint, typedOther.lint); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetLString()).compareTo(typedOther.isSetLString()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetLString()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.lString, typedOther.lString); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetLintString()).compareTo(typedOther.isSetLintString()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetLintString()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.lintString, typedOther.lintString); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetMStringString()).compareTo(typedOther.isSetMStringString()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetMStringString()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.mStringString, typedOther.mStringString); + if (lastComparison != 0) { + return lastComparison; + } + } + return 0; + } + + public _Fields fieldForId(int fieldId) { + return _Fields.findByThriftId(fieldId); + } + + public void read(org.apache.thrift.protocol.TProtocol iprot) throws org.apache.thrift.TException { + schemes.get(iprot.getScheme()).getScheme().read(iprot, this); + } + + public void write(org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException { + schemes.get(oprot.getScheme()).getScheme().write(oprot, this); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("Complex("); + boolean first = true; + + sb.append("aint:"); + sb.append(this.aint); + first = false; + if (!first) sb.append(", "); + sb.append("aString:"); + if (this.aString == null) { + sb.append("null"); + } else { + sb.append(this.aString); + } + first = false; + if (!first) sb.append(", "); + sb.append("lint:"); + if (this.lint == null) { + sb.append("null"); + } else { + sb.append(this.lint); + } + first = false; + if (!first) sb.append(", "); + sb.append("lString:"); + if (this.lString == null) { + sb.append("null"); + } else { + sb.append(this.lString); + } + first = false; + if (!first) sb.append(", "); + sb.append("lintString:"); + if (this.lintString == null) { + sb.append("null"); + } else { + sb.append(this.lintString); + } + first = false; + if (!first) sb.append(", "); + sb.append("mStringString:"); + if (this.mStringString == null) { + sb.append("null"); + } else { + sb.append(this.mStringString); + } + first = false; + sb.append(")"); + return sb.toString(); + } + + public void validate() throws org.apache.thrift.TException { + // check for required fields + // check for sub-struct validity + } + + private void writeObject(java.io.ObjectOutputStream out) throws java.io.IOException { + try { + write(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(out))); + } catch (org.apache.thrift.TException te) { + throw new java.io.IOException(te); + } + } + + private void readObject(java.io.ObjectInputStream in) throws java.io.IOException, ClassNotFoundException { + try { + // it doesn't seem like you should have to do this, but java serialization is wacky, and doesn't call the default constructor. + __isset_bitfield = 0; + read(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(in))); + } catch (org.apache.thrift.TException te) { + throw new java.io.IOException(te); + } + } + + private static class ComplexStandardSchemeFactory implements SchemeFactory { + public ComplexStandardScheme getScheme() { + return new ComplexStandardScheme(); + } + } + + private static class ComplexStandardScheme extends StandardScheme { + + public void read(org.apache.thrift.protocol.TProtocol iprot, Complex struct) throws org.apache.thrift.TException { + org.apache.thrift.protocol.TField schemeField; + iprot.readStructBegin(); + while (true) + { + schemeField = iprot.readFieldBegin(); + if (schemeField.type == org.apache.thrift.protocol.TType.STOP) { + break; + } + switch (schemeField.id) { + case 1: // AINT + if (schemeField.type == org.apache.thrift.protocol.TType.I32) { + struct.aint = iprot.readI32(); + struct.setAintIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 2: // A_STRING + if (schemeField.type == org.apache.thrift.protocol.TType.STRING) { + struct.aString = iprot.readString(); + struct.setAStringIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 3: // LINT + if (schemeField.type == org.apache.thrift.protocol.TType.LIST) { + { + org.apache.thrift.protocol.TList _list0 = iprot.readListBegin(); + struct.lint = new ArrayList(_list0.size); + for (int _i1 = 0; _i1 < _list0.size; ++_i1) + { + int _elem2; // required + _elem2 = iprot.readI32(); + struct.lint.add(_elem2); + } + iprot.readListEnd(); + } + struct.setLintIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 4: // L_STRING + if (schemeField.type == org.apache.thrift.protocol.TType.LIST) { + { + org.apache.thrift.protocol.TList _list3 = iprot.readListBegin(); + struct.lString = new ArrayList(_list3.size); + for (int _i4 = 0; _i4 < _list3.size; ++_i4) + { + String _elem5; // required + _elem5 = iprot.readString(); + struct.lString.add(_elem5); + } + iprot.readListEnd(); + } + struct.setLStringIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 5: // LINT_STRING + if (schemeField.type == org.apache.thrift.protocol.TType.LIST) { + { + org.apache.thrift.protocol.TList _list6 = iprot.readListBegin(); + struct.lintString = new ArrayList(_list6.size); + for (int _i7 = 0; _i7 < _list6.size; ++_i7) + { + IntString _elem8; // required + _elem8 = new IntString(); + _elem8.read(iprot); + struct.lintString.add(_elem8); + } + iprot.readListEnd(); + } + struct.setLintStringIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 6: // M_STRING_STRING + if (schemeField.type == org.apache.thrift.protocol.TType.MAP) { + { + org.apache.thrift.protocol.TMap _map9 = iprot.readMapBegin(); + struct.mStringString = new HashMap(2*_map9.size); + for (int _i10 = 0; _i10 < _map9.size; ++_i10) + { + String _key11; // required + String _val12; // required + _key11 = iprot.readString(); + _val12 = iprot.readString(); + struct.mStringString.put(_key11, _val12); + } + iprot.readMapEnd(); + } + struct.setMStringStringIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + default: + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + iprot.readFieldEnd(); + } + iprot.readStructEnd(); + struct.validate(); + } + + public void write(org.apache.thrift.protocol.TProtocol oprot, Complex struct) throws org.apache.thrift.TException { + struct.validate(); + + oprot.writeStructBegin(STRUCT_DESC); + oprot.writeFieldBegin(AINT_FIELD_DESC); + oprot.writeI32(struct.aint); + oprot.writeFieldEnd(); + if (struct.aString != null) { + oprot.writeFieldBegin(A_STRING_FIELD_DESC); + oprot.writeString(struct.aString); + oprot.writeFieldEnd(); + } + if (struct.lint != null) { + oprot.writeFieldBegin(LINT_FIELD_DESC); + { + oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.I32, struct.lint.size())); + for (int _iter13 : struct.lint) + { + oprot.writeI32(_iter13); + } + oprot.writeListEnd(); + } + oprot.writeFieldEnd(); + } + if (struct.lString != null) { + oprot.writeFieldBegin(L_STRING_FIELD_DESC); + { + oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRING, struct.lString.size())); + for (String _iter14 : struct.lString) + { + oprot.writeString(_iter14); + } + oprot.writeListEnd(); + } + oprot.writeFieldEnd(); + } + if (struct.lintString != null) { + oprot.writeFieldBegin(LINT_STRING_FIELD_DESC); + { + oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRUCT, struct.lintString.size())); + for (IntString _iter15 : struct.lintString) + { + _iter15.write(oprot); + } + oprot.writeListEnd(); + } + oprot.writeFieldEnd(); + } + if (struct.mStringString != null) { + oprot.writeFieldBegin(M_STRING_STRING_FIELD_DESC); + { + oprot.writeMapBegin(new org.apache.thrift.protocol.TMap(org.apache.thrift.protocol.TType.STRING, org.apache.thrift.protocol.TType.STRING, struct.mStringString.size())); + for (Map.Entry _iter16 : struct.mStringString.entrySet()) + { + oprot.writeString(_iter16.getKey()); + oprot.writeString(_iter16.getValue()); + } + oprot.writeMapEnd(); + } + oprot.writeFieldEnd(); + } + oprot.writeFieldStop(); + oprot.writeStructEnd(); + } + + } + + private static class ComplexTupleSchemeFactory implements SchemeFactory { + public ComplexTupleScheme getScheme() { + return new ComplexTupleScheme(); + } + } + + private static class ComplexTupleScheme extends TupleScheme { + + @Override + public void write(org.apache.thrift.protocol.TProtocol prot, Complex struct) throws org.apache.thrift.TException { + TTupleProtocol oprot = (TTupleProtocol) prot; + BitSet optionals = new BitSet(); + if (struct.isSetAint()) { + optionals.set(0); + } + if (struct.isSetAString()) { + optionals.set(1); + } + if (struct.isSetLint()) { + optionals.set(2); + } + if (struct.isSetLString()) { + optionals.set(3); + } + if (struct.isSetLintString()) { + optionals.set(4); + } + if (struct.isSetMStringString()) { + optionals.set(5); + } + oprot.writeBitSet(optionals, 6); + if (struct.isSetAint()) { + oprot.writeI32(struct.aint); + } + if (struct.isSetAString()) { + oprot.writeString(struct.aString); + } + if (struct.isSetLint()) { + { + oprot.writeI32(struct.lint.size()); + for (int _iter17 : struct.lint) + { + oprot.writeI32(_iter17); + } + } + } + if (struct.isSetLString()) { + { + oprot.writeI32(struct.lString.size()); + for (String _iter18 : struct.lString) + { + oprot.writeString(_iter18); + } + } + } + if (struct.isSetLintString()) { + { + oprot.writeI32(struct.lintString.size()); + for (IntString _iter19 : struct.lintString) + { + _iter19.write(oprot); + } + } + } + if (struct.isSetMStringString()) { + { + oprot.writeI32(struct.mStringString.size()); + for (Map.Entry _iter20 : struct.mStringString.entrySet()) + { + oprot.writeString(_iter20.getKey()); + oprot.writeString(_iter20.getValue()); + } + } + } + } + + @Override + public void read(org.apache.thrift.protocol.TProtocol prot, Complex struct) throws org.apache.thrift.TException { + TTupleProtocol iprot = (TTupleProtocol) prot; + BitSet incoming = iprot.readBitSet(6); + if (incoming.get(0)) { + struct.aint = iprot.readI32(); + struct.setAintIsSet(true); + } + if (incoming.get(1)) { + struct.aString = iprot.readString(); + struct.setAStringIsSet(true); + } + if (incoming.get(2)) { + { + org.apache.thrift.protocol.TList _list21 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.I32, iprot.readI32()); + struct.lint = new ArrayList(_list21.size); + for (int _i22 = 0; _i22 < _list21.size; ++_i22) + { + int _elem23; // required + _elem23 = iprot.readI32(); + struct.lint.add(_elem23); + } + } + struct.setLintIsSet(true); + } + if (incoming.get(3)) { + { + org.apache.thrift.protocol.TList _list24 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRING, iprot.readI32()); + struct.lString = new ArrayList(_list24.size); + for (int _i25 = 0; _i25 < _list24.size; ++_i25) + { + String _elem26; // required + _elem26 = iprot.readString(); + struct.lString.add(_elem26); + } + } + struct.setLStringIsSet(true); + } + if (incoming.get(4)) { + { + org.apache.thrift.protocol.TList _list27 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRUCT, iprot.readI32()); + struct.lintString = new ArrayList(_list27.size); + for (int _i28 = 0; _i28 < _list27.size; ++_i28) + { + IntString _elem29; // required + _elem29 = new IntString(); + _elem29.read(iprot); + struct.lintString.add(_elem29); + } + } + struct.setLintStringIsSet(true); + } + if (incoming.get(5)) { + { + org.apache.thrift.protocol.TMap _map30 = new org.apache.thrift.protocol.TMap(org.apache.thrift.protocol.TType.STRING, org.apache.thrift.protocol.TType.STRING, iprot.readI32()); + struct.mStringString = new HashMap(2*_map30.size); + for (int _i31 = 0; _i31 < _map30.size; ++_i31) + { + String _key32; // required + String _val33; // required + _key32 = iprot.readString(); + _val33 = iprot.readString(); + struct.mStringString.put(_key32, _val33); + } + } + struct.setMStringStringIsSet(true); + } + } + } + +} + diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 64d1ce92931eb..15c2c3deb0d83 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -90,8 +90,10 @@ public void setUp() throws IOException { @After public void tearDown() throws IOException { // Clean up tables. - sqlContext.sql("DROP TABLE IF EXISTS javaSavedTable"); - sqlContext.sql("DROP TABLE IF EXISTS externalTable"); + if (sqlContext != null) { + sqlContext.sql("DROP TABLE IF EXISTS javaSavedTable"); + sqlContext.sql("DROP TABLE IF EXISTS externalTable"); + } } @Test diff --git a/sql/hive/src/test/resources/golden/! operator-0-ee7f6a60a9792041b85b18cda56429bf b/sql/hive/src/test/resources/golden/! operator-0-ee7f6a60a9792041b85b18cda56429bf new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/! operator-0-ee7f6a60a9792041b85b18cda56429bf @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/convert_enum_to_string-1-db089ff46f9826c7883198adacdfad59 b/sql/hive/src/test/resources/golden/convert_enum_to_string-1-db089ff46f9826c7883198adacdfad59 index d35bf9093ca9c..2383bef940973 100644 --- a/sql/hive/src/test/resources/golden/convert_enum_to_string-1-db089ff46f9826c7883198adacdfad59 +++ b/sql/hive/src/test/resources/golden/convert_enum_to_string-1-db089ff46f9826c7883198adacdfad59 @@ -15,9 +15,9 @@ my_enum_structlist_map map from deserializer my_structlist array>> from deserializer my_enumlist array from deserializer -my_stringset struct<> from deserializer -my_enumset struct<> from deserializer -my_structset struct<> from deserializer +my_stringset array from deserializer +my_enumset array from deserializer +my_structset array>> from deserializer optionals struct<> from deserializer b string diff --git a/sql/hive/src/test/resources/golden/parenthesis_star_by-5-6888c7f7894910538d82eefa23443189 b/sql/hive/src/test/resources/golden/parenthesis_star_by-5-41d474f5e6d7c61c36f74b4bec4e9e44 similarity index 100% rename from sql/hive/src/test/resources/golden/parenthesis_star_by-5-6888c7f7894910538d82eefa23443189 rename to sql/hive/src/test/resources/golden/parenthesis_star_by-5-41d474f5e6d7c61c36f74b4bec4e9e44 diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 index 501bb6ab32f25..7bb2c0ab43984 100644 --- a/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 @@ -1,4 +1,4 @@ -CREATE TABLE `tmp_showcrt1`( +CREATE TABLE `tmp_showcrt1`( `key` smallint, `value` float) COMMENT 'temporary table' diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 b/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 index 90f8415a1c6be..3cc1a57ee3a47 100644 --- a/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 +++ b/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 @@ -1,4 +1,4 @@ -CREATE TABLE `tmp_feng.tmp_showcrt`( +CREATE TABLE `tmp_feng.tmp_showcrt`( `key` string, `value` int) ROW FORMAT SERDE diff --git a/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 index 4ee22e5230316..b51c71a71f91c 100644 --- a/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 +++ b/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 @@ -1,4 +1,4 @@ -CREATE TABLE `tmp_showcrt1`( +CREATE TABLE `tmp_showcrt1`( `key` int, `value` string, `newvalue` bigint) diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 index 6fda2570b53f1..29189e1d860a4 100644 --- a/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 +++ b/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 @@ -1,4 +1,4 @@ -CREATE TABLE `tmp_showcrt1`( +CREATE TABLE `tmp_showcrt1`( `key` int, `value` string, `newvalue` bigint) diff --git a/sql/hive/src/test/resources/golden/show_functions-0-45a7762c39f1b0f26f076220e2764043 b/sql/hive/src/test/resources/golden/show_functions-0-45a7762c39f1b0f26f076220e2764043 index 3049cd6243ad8..1b283db3e7744 100644 --- a/sql/hive/src/test/resources/golden/show_functions-0-45a7762c39f1b0f26f076220e2764043 +++ b/sql/hive/src/test/resources/golden/show_functions-0-45a7762c39f1b0f26f076220e2764043 @@ -17,6 +17,7 @@ ^ abs acos +add_months and array array_contains @@ -29,6 +30,7 @@ base64 between bin case +cbrt ceil ceiling coalesce @@ -47,7 +49,11 @@ covar_samp create_union cume_dist current_database +current_date +current_timestamp +current_user date_add +date_format date_sub datediff day @@ -65,6 +71,7 @@ ewah_bitmap_empty ewah_bitmap_or exp explode +factorial field find_in_set first_value @@ -73,6 +80,7 @@ format_number from_unixtime from_utc_timestamp get_json_object +greatest hash hex histogram_numeric @@ -81,6 +89,7 @@ if in in_file index +initcap inline instr isnotnull @@ -88,10 +97,13 @@ isnull java_method json_tuple lag +last_day last_value lcase lead +least length +levenshtein like ln locate @@ -109,11 +121,15 @@ max min minute month +months_between named_struct negative +next_day ngrams noop +noopstreaming noopwithmap +noopwithmapstreaming not ntile nvl @@ -147,10 +163,14 @@ rpad rtrim second sentences +shiftleft +shiftright +shiftrightunsigned sign sin size sort_array +soundex space split sqrt @@ -170,6 +190,7 @@ to_unix_timestamp to_utc_timestamp translate trim +trunc ucase unbase64 unhex diff --git a/sql/hive/src/test/resources/golden/show_tblproperties-1-be4adb893c7f946ebd76a648ce3cc1ae b/sql/hive/src/test/resources/golden/show_tblproperties-1-be4adb893c7f946ebd76a648ce3cc1ae index 0f6cc6f44f1f7..fdf701f962800 100644 --- a/sql/hive/src/test/resources/golden/show_tblproperties-1-be4adb893c7f946ebd76a648ce3cc1ae +++ b/sql/hive/src/test/resources/golden/show_tblproperties-1-be4adb893c7f946ebd76a648ce3cc1ae @@ -1 +1 @@ -Table tmpfoo does not have property: bar +Table default.tmpfoo does not have property: bar diff --git a/sql/hive/src/test/resources/golden/udf_date_add-1-efb60fcbd6d78ad35257fb1ec39ace2 b/sql/hive/src/test/resources/golden/udf_date_add-1-efb60fcbd6d78ad35257fb1ec39ace2 index 3c91e138d7bd5..d8ec084f0b2b0 100644 --- a/sql/hive/src/test/resources/golden/udf_date_add-1-efb60fcbd6d78ad35257fb1ec39ace2 +++ b/sql/hive/src/test/resources/golden/udf_date_add-1-efb60fcbd6d78ad35257fb1ec39ace2 @@ -1,5 +1,5 @@ date_add(start_date, num_days) - Returns the date that is num_days after start_date. start_date is a string in the format 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. num_days is a number. The time part of start_date is ignored. Example: - > SELECT date_add('2009-30-07', 1) FROM src LIMIT 1; - '2009-31-07' + > SELECT date_add('2009-07-30', 1) FROM src LIMIT 1; + '2009-07-31' diff --git a/sql/hive/src/test/resources/golden/udf_date_sub-1-7efeb74367835ade71e5e42b22f8ced4 b/sql/hive/src/test/resources/golden/udf_date_sub-1-7efeb74367835ade71e5e42b22f8ced4 index 29d663f35c586..169c500036255 100644 --- a/sql/hive/src/test/resources/golden/udf_date_sub-1-7efeb74367835ade71e5e42b22f8ced4 +++ b/sql/hive/src/test/resources/golden/udf_date_sub-1-7efeb74367835ade71e5e42b22f8ced4 @@ -1,5 +1,5 @@ date_sub(start_date, num_days) - Returns the date that is num_days before start_date. start_date is a string in the format 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. num_days is a number. The time part of start_date is ignored. Example: - > SELECT date_sub('2009-30-07', 1) FROM src LIMIT 1; - '2009-29-07' + > SELECT date_sub('2009-07-30', 1) FROM src LIMIT 1; + '2009-07-29' diff --git a/sql/hive/src/test/resources/golden/udf_datediff-1-34ae7a68b13c2bc9a89f61acf2edd4c5 b/sql/hive/src/test/resources/golden/udf_datediff-1-34ae7a68b13c2bc9a89f61acf2edd4c5 index 7ccaee7ad3bd4..42197f7ad3e51 100644 --- a/sql/hive/src/test/resources/golden/udf_datediff-1-34ae7a68b13c2bc9a89f61acf2edd4c5 +++ b/sql/hive/src/test/resources/golden/udf_datediff-1-34ae7a68b13c2bc9a89f61acf2edd4c5 @@ -1,5 +1,5 @@ datediff(date1, date2) - Returns the number of days between date1 and date2 date1 and date2 are strings in the format 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. The time parts are ignored.If date1 is earlier than date2, the result is negative. Example: - > SELECT datediff('2009-30-07', '2009-31-07') FROM src LIMIT 1; + > SELECT datediff('2009-07-30', '2009-07-31') FROM src LIMIT 1; 1 diff --git a/sql/hive/src/test/resources/golden/udf_day-0-c4c503756384ff1220222d84fd25e756 b/sql/hive/src/test/resources/golden/udf_day-0-c4c503756384ff1220222d84fd25e756 index d4017178b4e6b..09703d10eab7a 100644 --- a/sql/hive/src/test/resources/golden/udf_day-0-c4c503756384ff1220222d84fd25e756 +++ b/sql/hive/src/test/resources/golden/udf_day-0-c4c503756384ff1220222d84fd25e756 @@ -1 +1 @@ -day(date) - Returns the date of the month of date +day(param) - Returns the day of the month of date/timestamp, or day component of interval diff --git a/sql/hive/src/test/resources/golden/udf_day-1-87168babe1110fe4c38269843414ca4 b/sql/hive/src/test/resources/golden/udf_day-1-87168babe1110fe4c38269843414ca4 index 6135aafa50860..7c0ec1dc3be59 100644 --- a/sql/hive/src/test/resources/golden/udf_day-1-87168babe1110fe4c38269843414ca4 +++ b/sql/hive/src/test/resources/golden/udf_day-1-87168babe1110fe4c38269843414ca4 @@ -1,6 +1,9 @@ -day(date) - Returns the date of the month of date +day(param) - Returns the day of the month of date/timestamp, or day component of interval Synonyms: dayofmonth -date is a string in the format of 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. -Example: - > SELECT day('2009-30-07', 1) FROM src LIMIT 1; +param can be one of: +1. A string in the format of 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. +2. A date value +3. A timestamp value +4. A day-time interval valueExample: + > SELECT day('2009-07-30') FROM src LIMIT 1; 30 diff --git a/sql/hive/src/test/resources/golden/udf_dayofmonth-0-7b2caf942528656555cf19c261a18502 b/sql/hive/src/test/resources/golden/udf_dayofmonth-0-7b2caf942528656555cf19c261a18502 index 47a7018d9d5ac..c37eb0ec2e969 100644 --- a/sql/hive/src/test/resources/golden/udf_dayofmonth-0-7b2caf942528656555cf19c261a18502 +++ b/sql/hive/src/test/resources/golden/udf_dayofmonth-0-7b2caf942528656555cf19c261a18502 @@ -1 +1 @@ -dayofmonth(date) - Returns the date of the month of date +dayofmonth(param) - Returns the day of the month of date/timestamp, or day component of interval diff --git a/sql/hive/src/test/resources/golden/udf_dayofmonth-1-ca24d07102ad264d79ff30c64a73a7e8 b/sql/hive/src/test/resources/golden/udf_dayofmonth-1-ca24d07102ad264d79ff30c64a73a7e8 index d9490e20a3b6d..9e931f649914b 100644 --- a/sql/hive/src/test/resources/golden/udf_dayofmonth-1-ca24d07102ad264d79ff30c64a73a7e8 +++ b/sql/hive/src/test/resources/golden/udf_dayofmonth-1-ca24d07102ad264d79ff30c64a73a7e8 @@ -1,6 +1,9 @@ -dayofmonth(date) - Returns the date of the month of date +dayofmonth(param) - Returns the day of the month of date/timestamp, or day component of interval Synonyms: day -date is a string in the format of 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. -Example: - > SELECT dayofmonth('2009-30-07', 1) FROM src LIMIT 1; +param can be one of: +1. A string in the format of 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. +2. A date value +3. A timestamp value +4. A day-time interval valueExample: + > SELECT dayofmonth('2009-07-30') FROM src LIMIT 1; 30 diff --git a/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c b/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c index 2cf0d9d61882e..ce583fe81ff68 100644 --- a/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c +++ b/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c @@ -1 +1 @@ -There is no documentation for function 'if' +IF(expr1,expr2,expr3) - If expr1 is TRUE (expr1 <> 0 and expr1 <> NULL) then IF() returns expr2; otherwise it returns expr3. IF() returns a numeric or string value, depending on the context in which it is used. diff --git a/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a b/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a index 2cf0d9d61882e..ce583fe81ff68 100644 --- a/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a +++ b/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a @@ -1 +1 @@ -There is no documentation for function 'if' +IF(expr1,expr2,expr3) - If expr1 is TRUE (expr1 <> 0 and expr1 <> NULL) then IF() returns expr2; otherwise it returns expr3. IF() returns a numeric or string value, depending on the context in which it is used. diff --git a/sql/hive/src/test/resources/golden/udf_if-1-b7ffa85b5785cccef2af1b285348cc2c b/sql/hive/src/test/resources/golden/udf_if-1-b7ffa85b5785cccef2af1b285348cc2c index 2cf0d9d61882e..ce583fe81ff68 100644 --- a/sql/hive/src/test/resources/golden/udf_if-1-b7ffa85b5785cccef2af1b285348cc2c +++ b/sql/hive/src/test/resources/golden/udf_if-1-b7ffa85b5785cccef2af1b285348cc2c @@ -1 +1 @@ -There is no documentation for function 'if' +IF(expr1,expr2,expr3) - If expr1 is TRUE (expr1 <> 0 and expr1 <> NULL) then IF() returns expr2; otherwise it returns expr3. IF() returns a numeric or string value, depending on the context in which it is used. diff --git a/sql/hive/src/test/resources/golden/udf_if-2-30cf7f51f92b5684e556deff3032d49a b/sql/hive/src/test/resources/golden/udf_if-2-30cf7f51f92b5684e556deff3032d49a index 2cf0d9d61882e..ce583fe81ff68 100644 --- a/sql/hive/src/test/resources/golden/udf_if-2-30cf7f51f92b5684e556deff3032d49a +++ b/sql/hive/src/test/resources/golden/udf_if-2-30cf7f51f92b5684e556deff3032d49a @@ -1 +1 @@ -There is no documentation for function 'if' +IF(expr1,expr2,expr3) - If expr1 is TRUE (expr1 <> 0 and expr1 <> NULL) then IF() returns expr2; otherwise it returns expr3. IF() returns a numeric or string value, depending on the context in which it is used. diff --git a/sql/hive/src/test/resources/golden/udf_minute-0-9a38997c1f41f4afe00faa0abc471aee b/sql/hive/src/test/resources/golden/udf_minute-0-9a38997c1f41f4afe00faa0abc471aee index 231e4f382566d..06650592f8d3c 100644 --- a/sql/hive/src/test/resources/golden/udf_minute-0-9a38997c1f41f4afe00faa0abc471aee +++ b/sql/hive/src/test/resources/golden/udf_minute-0-9a38997c1f41f4afe00faa0abc471aee @@ -1 +1 @@ -minute(date) - Returns the minute of date +minute(param) - Returns the minute component of the string/timestamp/interval diff --git a/sql/hive/src/test/resources/golden/udf_minute-1-16995573ac4f4a1b047ad6ee88699e48 b/sql/hive/src/test/resources/golden/udf_minute-1-16995573ac4f4a1b047ad6ee88699e48 index ea842ea174ae4..08ddc19b84d82 100644 --- a/sql/hive/src/test/resources/golden/udf_minute-1-16995573ac4f4a1b047ad6ee88699e48 +++ b/sql/hive/src/test/resources/golden/udf_minute-1-16995573ac4f4a1b047ad6ee88699e48 @@ -1,6 +1,8 @@ -minute(date) - Returns the minute of date -date is a string in the format of 'yyyy-MM-dd HH:mm:ss' or 'HH:mm:ss'. -Example: +minute(param) - Returns the minute component of the string/timestamp/interval +param can be one of: +1. A string in the format of 'yyyy-MM-dd HH:mm:ss' or 'HH:mm:ss'. +2. A timestamp value +3. A day-time interval valueExample: > SELECT minute('2009-07-30 12:58:59') FROM src LIMIT 1; 58 > SELECT minute('12:58:59') FROM src LIMIT 1; diff --git a/sql/hive/src/test/resources/golden/udf_month-0-9a38997c1f41f4afe00faa0abc471aee b/sql/hive/src/test/resources/golden/udf_month-0-9a38997c1f41f4afe00faa0abc471aee index 231e4f382566d..06650592f8d3c 100644 --- a/sql/hive/src/test/resources/golden/udf_month-0-9a38997c1f41f4afe00faa0abc471aee +++ b/sql/hive/src/test/resources/golden/udf_month-0-9a38997c1f41f4afe00faa0abc471aee @@ -1 +1 @@ -minute(date) - Returns the minute of date +minute(param) - Returns the minute component of the string/timestamp/interval diff --git a/sql/hive/src/test/resources/golden/udf_month-1-16995573ac4f4a1b047ad6ee88699e48 b/sql/hive/src/test/resources/golden/udf_month-1-16995573ac4f4a1b047ad6ee88699e48 index ea842ea174ae4..08ddc19b84d82 100644 --- a/sql/hive/src/test/resources/golden/udf_month-1-16995573ac4f4a1b047ad6ee88699e48 +++ b/sql/hive/src/test/resources/golden/udf_month-1-16995573ac4f4a1b047ad6ee88699e48 @@ -1,6 +1,8 @@ -minute(date) - Returns the minute of date -date is a string in the format of 'yyyy-MM-dd HH:mm:ss' or 'HH:mm:ss'. -Example: +minute(param) - Returns the minute component of the string/timestamp/interval +param can be one of: +1. A string in the format of 'yyyy-MM-dd HH:mm:ss' or 'HH:mm:ss'. +2. A timestamp value +3. A day-time interval valueExample: > SELECT minute('2009-07-30 12:58:59') FROM src LIMIT 1; 58 > SELECT minute('12:58:59') FROM src LIMIT 1; diff --git a/sql/hive/src/test/resources/golden/udf_std-1-6759bde0e50a3607b7c3fd5a93cbd027 b/sql/hive/src/test/resources/golden/udf_std-1-6759bde0e50a3607b7c3fd5a93cbd027 index d54ebfbd6fb1a..a529b107ff216 100644 --- a/sql/hive/src/test/resources/golden/udf_std-1-6759bde0e50a3607b7c3fd5a93cbd027 +++ b/sql/hive/src/test/resources/golden/udf_std-1-6759bde0e50a3607b7c3fd5a93cbd027 @@ -1,2 +1,2 @@ std(x) - Returns the standard deviation of a set of numbers -Synonyms: stddev_pop, stddev +Synonyms: stddev, stddev_pop diff --git a/sql/hive/src/test/resources/golden/udf_stddev-1-18e1d598820013453fad45852e1a303d b/sql/hive/src/test/resources/golden/udf_stddev-1-18e1d598820013453fad45852e1a303d index 5f674788180e8..ac3176a382547 100644 --- a/sql/hive/src/test/resources/golden/udf_stddev-1-18e1d598820013453fad45852e1a303d +++ b/sql/hive/src/test/resources/golden/udf_stddev-1-18e1d598820013453fad45852e1a303d @@ -1,2 +1,2 @@ stddev(x) - Returns the standard deviation of a set of numbers -Synonyms: stddev_pop, std +Synonyms: std, stddev_pop diff --git a/sql/hive/src/test/resources/golden/union3-0-6a8a35102de1b0b88c6721a704eb174d b/sql/hive/src/test/resources/golden/union3-0-99620f72f0282904846a596ca5b3e46c similarity index 100% rename from sql/hive/src/test/resources/golden/union3-0-6a8a35102de1b0b88c6721a704eb174d rename to sql/hive/src/test/resources/golden/union3-0-99620f72f0282904846a596ca5b3e46c diff --git a/sql/hive/src/test/resources/golden/union3-2-2a1dcd937f117f1955a169592b96d5f9 b/sql/hive/src/test/resources/golden/union3-2-90ca96ea59fd45cf0af8c020ae77c908 similarity index 100% rename from sql/hive/src/test/resources/golden/union3-2-2a1dcd937f117f1955a169592b96d5f9 rename to sql/hive/src/test/resources/golden/union3-2-90ca96ea59fd45cf0af8c020ae77c908 diff --git a/sql/hive/src/test/resources/golden/union3-3-8fc63f8edb2969a63cd4485f1867ba97 b/sql/hive/src/test/resources/golden/union3-3-72b149ccaef751bcfe55d5ca37cb5fd7 similarity index 100% rename from sql/hive/src/test/resources/golden/union3-3-8fc63f8edb2969a63cd4485f1867ba97 rename to sql/hive/src/test/resources/golden/union3-3-72b149ccaef751bcfe55d5ca37cb5fd7 diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parenthesis_star_by.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parenthesis_star_by.q index 9e036c1a91d3b..e911fbf2d2c5c 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parenthesis_star_by.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parenthesis_star_by.q @@ -5,6 +5,6 @@ SELECT * FROM (SELECT key, value FROM src DISTRIBUTE BY key, value)t ORDER BY ke SELECT key, value FROM src CLUSTER BY (key, value); -SELECT key, value FROM src ORDER BY (key ASC, value ASC); +SELECT key, value FROM src ORDER BY key ASC, value ASC; SELECT key, value FROM src SORT BY (key, value); SELECT * FROM (SELECT key, value FROM src DISTRIBUTE BY (key, value))t ORDER BY key, value; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/union3.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/union3.q index b26a2e2799f7a..a989800cbf851 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/union3.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/union3.q @@ -1,42 +1,41 @@ +-- SORT_QUERY_RESULTS explain SELECT * FROM ( SELECT 1 AS id FROM (SELECT * FROM src LIMIT 1) s1 - CLUSTER BY id UNION ALL SELECT 2 AS id FROM (SELECT * FROM src LIMIT 1) s1 - CLUSTER BY id UNION ALL SELECT 3 AS id FROM (SELECT * FROM src LIMIT 1) s2 UNION ALL SELECT 4 AS id FROM (SELECT * FROM src LIMIT 1) s2 + CLUSTER BY id ) a; CREATE TABLE union_out (id int); -insert overwrite table union_out +insert overwrite table union_out SELECT * FROM ( SELECT 1 AS id FROM (SELECT * FROM src LIMIT 1) s1 - CLUSTER BY id UNION ALL SELECT 2 AS id FROM (SELECT * FROM src LIMIT 1) s1 - CLUSTER BY id UNION ALL SELECT 3 AS id FROM (SELECT * FROM src LIMIT 1) s2 UNION ALL SELECT 4 AS id FROM (SELECT * FROM src LIMIT 1) s2 + CLUSTER BY id ) a; -select * from union_out cluster by id; +select * from union_out; diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala new file mode 100644 index 0000000000000..34b2edb44b033 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.net.URL + +import org.apache.spark.SparkFunSuite + +/** + * Verify that some classes load and that others are not found on the classpath. + * + * + * This is used to detect classpath and shading conflict, especially between + * Spark's required Kryo version and that which can be found in some Hive versions. + */ +class ClasspathDependenciesSuite extends SparkFunSuite { + private val classloader = this.getClass.getClassLoader + + private def assertLoads(classname: String): Unit = { + val resourceURL: URL = Option(findResource(classname)).getOrElse { + fail(s"Class $classname not found as ${resourceName(classname)}") + } + + logInfo(s"Class $classname at $resourceURL") + classloader.loadClass(classname) + } + + private def assertLoads(classes: String*): Unit = { + classes.foreach(assertLoads) + } + + private def findResource(classname: String): URL = { + val resource = resourceName(classname) + classloader.getResource(resource) + } + + private def resourceName(classname: String): String = { + classname.replace(".", "/") + ".class" + } + + private def assertClassNotFound(classname: String): Unit = { + Option(findResource(classname)).foreach { resourceURL => + fail(s"Class $classname found at $resourceURL") + } + + intercept[ClassNotFoundException] { + classloader.loadClass(classname) + } + } + + private def assertClassNotFound(classes: String*): Unit = { + classes.foreach(assertClassNotFound) + } + + private val KRYO = "com.esotericsoftware.kryo.Kryo" + + private val SPARK_HIVE = "org.apache.hive." + private val SPARK_SHADED = "org.spark-project.hive.shaded." + + test("shaded Protobuf") { + assertLoads(SPARK_SHADED + "com.google.protobuf.ServiceException") + } + + test("hive-common") { + assertLoads("org.apache.hadoop.hive.conf.HiveConf") + } + + test("hive-exec") { + assertLoads("org.apache.hadoop.hive.ql.CommandNeedRetryException") + } + + private val STD_INSTANTIATOR = "org.objenesis.strategy.StdInstantiatorStrategy" + + test("unshaded kryo") { + assertLoads(KRYO, STD_INSTANTIATOR) + } + + test("Forbidden Dependencies") { + assertClassNotFound( + SPARK_HIVE + KRYO, + SPARK_SHADED + KRYO, + "org.apache.hive." + KRYO, + "com.esotericsoftware.shaded." + STD_INSTANTIATOR, + SPARK_HIVE + "com.esotericsoftware.shaded." + STD_INSTANTIATOR, + "org.apache.hive.com.esotericsoftware.shaded." + STD_INSTANTIATOR + ) + } + + test("parquet-hadoop-bundle") { + assertLoads( + "parquet.hadoop.ParquetOutputFormat", + "parquet.hadoop.ParquetInputFormat" + ) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 72b35959a491b..b8d41065d3f02 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -19,8 +19,11 @@ package org.apache.spark.sql.hive import java.io.File +import scala.collection.mutable.ArrayBuffer import scala.sys.process.{ProcessLogger, Process} +import org.scalatest.exceptions.TestFailedDueToTimeoutException + import org.apache.spark._ import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} import org.apache.spark.util.{ResetSystemProperties, Utils} @@ -84,23 +87,39 @@ class HiveSparkSubmitSuite // This is copied from org.apache.spark.deploy.SparkSubmitSuite private def runSparkSubmit(args: Seq[String]): Unit = { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) + val history = ArrayBuffer.empty[String] + val commands = Seq("./bin/spark-submit") ++ args + val commandLine = commands.mkString("'", "' '", "'") val process = Process( - Seq("./bin/spark-submit") ++ args, + commands, new File(sparkHome), "SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome ).run(ProcessLogger( // scalastyle:off println - (line: String) => { println(s"out> $line") }, - (line: String) => { println(s"err> $line") } + (line: String) => { println(s"stdout> $line"); history += s"out> $line"}, + (line: String) => { println(s"stderr> $line"); history += s"err> $line" } // scalastyle:on println )) try { - val exitCode = failAfter(180 seconds) { process.exitValue() } + val exitCode = failAfter(180.seconds) { process.exitValue() } if (exitCode != 0) { - fail(s"Process returned with exit code $exitCode. See the log4j logs for more detail.") + // include logs in output. Note that logging is async and may not have completed + // at the time this exception is raised + Thread.sleep(1000) + val historyLog = history.mkString("\n") + fail(s"$commandLine returned with exit code $exitCode." + + s" See the log4j logs for more detail." + + s"\n$historyLog") } + } catch { + case to: TestFailedDueToTimeoutException => + val historyLog = history.mkString("\n") + fail(s"Timeout of $commandLine" + + s" See the log4j logs for more detail." + + s"\n$historyLog", to) + case t: Throwable => throw t } finally { // Ensure we still kill the process in case it timed out process.destroy() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 508695919e9a7..d33e81227db88 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive import java.io.File +import org.apache.hadoop.hive.conf.HiveConf import org.scalatest.BeforeAndAfter import org.apache.spark.sql.execution.QueryExecutionException @@ -113,6 +114,8 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("SPARK-4203:random partition directory order") { sql("CREATE TABLE tmp_table (key int, value string)") val tmpDir = Utils.createTempDir() + val stagingDir = new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR) + sql( s""" |CREATE TABLE table_with_partition(c1 string) @@ -145,7 +148,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { """.stripMargin) def listFolders(path: File, acc: List[String]): List[List[String]] = { val dir = path.listFiles() - val folders = dir.filter(_.isDirectory).toList + val folders = dir.filter { e => e.isDirectory && !e.getName().startsWith(stagingDir) }.toList if (folders.isEmpty) { List(acc.reverse) } else { @@ -158,7 +161,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=1"::Nil , "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=4"::Nil ) - assert(listFolders(tmpDir, List()).sortBy(_.toString()) == expected.sortBy(_.toString)) + assert(listFolders(tmpDir, List()).sortBy(_.toString()) === expected.sortBy(_.toString)) sql("DROP TABLE table_with_partition") sql("DROP TABLE tmp_table") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index bb5f1febe9ad4..f00d3754c364a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import org.apache.hadoop.hive.conf.HiveConf + import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.parquet.ParquetCompatibilityTest import org.apache.spark.sql.{Row, SQLConf, SQLContext} @@ -26,6 +28,13 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest { override val sqlContext: SQLContext = TestHive + /** + * Set the staging directory (and hence path to ignore Parquet files under) + * to that set by [[HiveConf.ConfVars.STAGINGDIR]]. + */ + override val stagingDir: Option[String] = + Some(new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR)) + override protected def beforeAll(): Unit = { super.beforeAll() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index bc72b0172a467..e4fec7e2c8a2a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -54,6 +54,9 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { } } + // Ensure session state is initialized. + ctx.parseSql("use default") + assertAnalyzeCommand( "ANALYZE TABLE Table1 COMPUTE STATISTICS", classOf[HiveNativeCommand]) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 3eb127e23d486..f0bb77092c0cf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.client import java.io.File +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo} import org.apache.spark.sql.catalyst.util.quietly @@ -48,7 +49,9 @@ class VersionsSuite extends SparkFunSuite with Logging { } test("success sanity check") { - val badClient = IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).client + val badClient = IsolatedClientLoader.forVersion(HiveContext.hiveExecutionVersion, + buildConf(), + ivyPath).client val db = new HiveDatabase("default", "") badClient.createDatabase(db) } @@ -91,6 +94,7 @@ class VersionsSuite extends SparkFunSuite with Logging { versions.foreach { version => test(s"$version: create client") { client = null + System.gc() // Hack to avoid SEGV on some JVM versions. client = IsolatedClientLoader.forVersion(version, buildConf(), ivyPath).client } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 11a843becce69..a7cfac51cc097 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -52,14 +52,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting Locale.setDefault(Locale.US) - sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}") - // The function source code can be found at: - // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF - sql( - """ - |CREATE TEMPORARY FUNCTION udtf_count2 - |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' - """.stripMargin) } override def afterAll() { @@ -69,15 +61,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql("DROP TEMPORARY FUNCTION udtf_count2") } - createQueryTest("Test UDTF.close in Lateral Views", - """ - |SELECT key, cc - |FROM src LATERAL VIEW udtf_count2(value) dd AS cc - """.stripMargin, false) // false mean we have to keep the temp function in registry - - createQueryTest("Test UDTF.close in SELECT", - "SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) table", false) - test("SPARK-4908: concurrent hive native commands") { (1 to 100).par.map { _ => sql("USE default") @@ -176,8 +159,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("! operator", """ |SELECT a FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 2 AS a FROM src LIMIT 1) table + | SELECT 1 AS a UNION ALL SELECT 2 AS a) t |WHERE !(a>1) """.stripMargin) @@ -229,71 +211,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |FROM src LIMIT 1; """.stripMargin) - createQueryTest("count distinct 0 values", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 'a' AS a FROM src LIMIT 0) table - """.stripMargin) - - createQueryTest("count distinct 1 value strings", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 'a' AS a FROM src LIMIT 1 UNION ALL - | SELECT 'b' AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 1 value", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 1 AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 2 values", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 2 AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 2 values including null", - """ - |SELECT COUNT(DISTINCT a, 1) FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT null AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 1 value + null", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT null AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 1 value long", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1L AS a FROM src LIMIT 1 UNION ALL - | SELECT 1L AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 2 values long", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1L AS a FROM src LIMIT 1 UNION ALL - | SELECT 2L AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 1 value + null long", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1L AS a FROM src LIMIT 1 UNION ALL - | SELECT 1L AS a FROM src LIMIT 1 UNION ALL - | SELECT null AS a FROM src LIMIT 1) table - """.stripMargin) - createQueryTest("null case", "SELECT case when(true) then 1 else null end FROM src LIMIT 1") @@ -674,7 +591,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql( """ |SELECT a FROM ( - | SELECT 1 AS a FROM src LIMIT 1 ) table + | SELECT 1 AS a FROM src LIMIT 1 ) t |WHERE abs(20141202) is not null """.stripMargin).collect() } @@ -987,7 +904,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { .zip(parts) .map { case (k, v) => if (v == "NULL") { - s"$k=${ConfVars.DEFAULTPARTITIONNAME.defaultVal}" + s"$k=${ConfVars.DEFAULTPARTITIONNAME.defaultStrVal}" } else { s"$k=$v" } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index e83a7dc77e329..3bf8f3ac20480 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -82,16 +82,16 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { Seq.empty) createPruningTest("Column pruning - non-trivial top project with aliases", - "SELECT c1 * 2 AS double FROM (SELECT key AS c1 FROM src WHERE key > 10) t1 LIMIT 3", - Seq("double"), + "SELECT c1 * 2 AS dbl FROM (SELECT key AS c1 FROM src WHERE key > 10) t1 LIMIT 3", + Seq("dbl"), Seq("key"), Seq.empty) // Partition pruning tests createPruningTest("Partition pruning - non-partitioned, non-trivial project", - "SELECT key * 2 AS double FROM src WHERE value IS NOT NULL", - Seq("double"), + "SELECT key * 2 AS dbl FROM src WHERE value IS NOT NULL", + Seq("dbl"), Seq("key", "value"), Seq.empty) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index c4923d83e48f3..95c1da6e9796c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -67,6 +67,25 @@ class MyDialect extends DefaultParserDialect class SQLQuerySuite extends QueryTest with SQLTestUtils { override def sqlContext: SQLContext = TestHive + test("UDTF") { + sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}") + // The function source code can be found at: + // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF + sql( + """ + |CREATE TEMPORARY FUNCTION udtf_count2 + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin) + + checkAnswer( + sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"), + Row(97, 500) :: Row(97, 500) :: Nil) + + checkAnswer( + sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), + Row(3) :: Row(3) :: Nil) + } + test("SPARK-6835: udtf in lateral view") { val df = Seq((1, 1)).toDF("c1", "c2") df.registerTempTable("table1") @@ -264,47 +283,51 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { setConf(HiveContext.CONVERT_CTAS, true) - sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - var message = intercept[AnalysisException] { + try { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - }.getMessage - assert(message.contains("ctas1 already exists")) - checkRelation("ctas1", true) - sql("DROP TABLE ctas1") - - // Specifying database name for query can be converted to data source write path - // is not allowed right now. - message = intercept[AnalysisException] { - sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - }.getMessage - assert( - message.contains("Cannot specify database name in a CTAS statement"), - "When spark.sql.hive.convertCTAS is true, we should not allow " + - "database name specified.") - - sql("CREATE TABLE ctas1 stored as textfile AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true) - sql("DROP TABLE ctas1") - - sql( - "CREATE TABLE ctas1 stored as sequencefile AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true) - sql("DROP TABLE ctas1") - - sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false) - sql("DROP TABLE ctas1") - - sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false) - sql("DROP TABLE ctas1") - - sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false) - sql("DROP TABLE ctas1") - - setConf(HiveContext.CONVERT_CTAS, originalConf) + sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + var message = intercept[AnalysisException] { + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + }.getMessage + assert(message.contains("ctas1 already exists")) + checkRelation("ctas1", true) + sql("DROP TABLE ctas1") + + // Specifying database name for query can be converted to data source write path + // is not allowed right now. + message = intercept[AnalysisException] { + sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + }.getMessage + assert( + message.contains("Cannot specify database name in a CTAS statement"), + "When spark.sql.hive.convertCTAS is true, we should not allow " + + "database name specified.") + + sql("CREATE TABLE ctas1 stored as textfile" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", true) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as sequencefile" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", true) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", false) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", false) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", false) + sql("DROP TABLE ctas1") + } finally { + setConf(HiveContext.CONVERT_CTAS, originalConf) + sql("DROP TABLE IF EXISTS ctas1") + } } test("SQL Dialect Switching") { @@ -670,22 +693,25 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { val originalConf = convertCTAS setConf(HiveContext.CONVERT_CTAS, false) - sql("CREATE TABLE explodeTest (key bigInt)") - table("explodeTest").queryExecution.analyzed match { - case metastoreRelation: MetastoreRelation => // OK - case _ => - fail("To correctly test the fix of SPARK-5875, explodeTest should be a MetastoreRelation") - } + try { + sql("CREATE TABLE explodeTest (key bigInt)") + table("explodeTest").queryExecution.analyzed match { + case metastoreRelation: MetastoreRelation => // OK + case _ => + fail("To correctly test the fix of SPARK-5875, explodeTest should be a MetastoreRelation") + } - sql(s"INSERT OVERWRITE TABLE explodeTest SELECT explode(a) AS val FROM data") - checkAnswer( - sql("SELECT key from explodeTest"), - (1 to 5).flatMap(i => Row(i) :: Row(i + 1) :: Nil) - ) + sql(s"INSERT OVERWRITE TABLE explodeTest SELECT explode(a) AS val FROM data") + checkAnswer( + sql("SELECT key from explodeTest"), + (1 to 5).flatMap(i => Row(i) :: Row(i + 1) :: Nil) + ) - sql("DROP TABLE explodeTest") - dropTempTable("data") - setConf(HiveContext.CONVERT_CTAS, originalConf) + sql("DROP TABLE explodeTest") + dropTempTable("data") + } finally { + setConf(HiveContext.CONVERT_CTAS, originalConf) + } } test("sanity test for SPARK-6618") { @@ -1058,12 +1084,12 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("SPARK-8588 HiveTypeCoercion.inConversion fires too early") { val df = TestHive.createDataFrame(Seq((1, "2014-01-01"), (2, "2015-01-01"), (3, "2016-01-01"))) - df.toDF("id", "date").registerTempTable("test_SPARK8588") + df.toDF("id", "datef").registerTempTable("test_SPARK8588") checkAnswer( TestHive.sql( """ - |select id, concat(year(date)) - |from test_SPARK8588 where concat(year(date), ' year') in ('2015 year', '2014 year') + |select id, concat(year(datef)) + |from test_SPARK8588 where concat(year(datef), ' year') in ('2015 year', '2014 year') """.stripMargin), Row(1, "2014") :: Row(2, "2015") :: Nil ) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index af3f468aaa5e9..deec0048d24b8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -48,11 +48,9 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchemaWithPartition.json))) + read.options(Map( + "path" -> file.getCanonicalPath, + "dataSchema" -> dataSchemaWithPartition.json)).format(dataSourceName).load()) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index d463e8fd626f9..a46ca9a2c9706 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -31,7 +31,6 @@ import org.scalatest.BeforeAndAfterAll import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag - // The data where the partitioning key exists only in the directory structure. case class OrcParData(intField: Int, stringField: String) @@ -40,7 +39,7 @@ case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: St // TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { - val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultVal + val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultStrVal def withTempDir(f: File => Unit): Unit = { val dir = Utils.createTempDir().getCanonicalFile diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index f56fb96c52d37..c4bc60086f6e1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -60,7 +60,14 @@ case class ParquetDataWithKeyAndComplexTypes( class ParquetMetastoreSuite extends ParquetPartitioningTest { override def beforeAll(): Unit = { super.beforeAll() - + dropTables("partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes", + "normal_parquet", + "jt", + "jt_array", + "test_parquet") sql(s""" create external table partitioned_parquet ( @@ -172,14 +179,14 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } override def afterAll(): Unit = { - sql("DROP TABLE partitioned_parquet") - sql("DROP TABLE partitioned_parquet_with_key") - sql("DROP TABLE partitioned_parquet_with_complextypes") - sql("DROP TABLE partitioned_parquet_with_key_and_complextypes") - sql("DROP TABLE normal_parquet") - sql("DROP TABLE IF EXISTS jt") - sql("DROP TABLE IF EXISTS jt_array") - sql("DROP TABLE IF EXISTS test_parquet") + dropTables("partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes", + "normal_parquet", + "jt", + "jt_array", + "test_parquet") setConf(HiveContext.CONVERT_METASTORE_PARQUET, false) } @@ -203,6 +210,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } test("insert into an empty parquet table") { + dropTables("test_insert_parquet") sql( """ |create table test_insert_parquet @@ -228,7 +236,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField > 2"), Row(3, "str3") :: Row(4, "str4") :: Nil ) - sql("DROP TABLE IF EXISTS test_insert_parquet") + dropTables("test_insert_parquet") // Create it again. sql( @@ -255,118 +263,118 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { sql(s"SELECT intField, stringField FROM test_insert_parquet"), (1 to 10).map(i => Row(i, s"str$i")) ++ (1 to 4).map(i => Row(i, s"str$i")) ) - sql("DROP TABLE IF EXISTS test_insert_parquet") + dropTables("test_insert_parquet") } test("scan a parquet table created through a CTAS statement") { - sql( - """ - |create table test_parquet_ctas ROW FORMAT - |SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - |AS select * from jt - """.stripMargin) + withTable("test_parquet_ctas") { + sql( + """ + |create table test_parquet_ctas ROW FORMAT + |SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + |AS select * from jt + """.stripMargin) - checkAnswer( - sql(s"SELECT a, b FROM test_parquet_ctas WHERE a = 1"), - Seq(Row(1, "str1")) - ) + checkAnswer( + sql(s"SELECT a, b FROM test_parquet_ctas WHERE a = 1"), + Seq(Row(1, "str1")) + ) - table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(_: ParquetRelation) => // OK - case _ => fail( - "test_parquet_ctas should be converted to " + - s"${classOf[ParquetRelation].getCanonicalName}") + table("test_parquet_ctas").queryExecution.optimizedPlan match { + case LogicalRelation(_: ParquetRelation) => // OK + case _ => fail( + "test_parquet_ctas should be converted to " + + s"${classOf[ParquetRelation].getCanonicalName }") + } } - - sql("DROP TABLE IF EXISTS test_parquet_ctas") } test("MetastoreRelation in InsertIntoTable will be converted") { - sql( - """ - |create table test_insert_parquet - |( - | intField INT - |) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - """.stripMargin) + withTable("test_insert_parquet") { + sql( + """ + |create table test_insert_parquet + |( + | intField INT + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") + df.queryExecution.executedPlan match { + case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation, _, _)) => // OK + case o => fail("test_insert_parquet should be converted to a " + + s"${classOf[ParquetRelation].getCanonicalName} and " + + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan. " + + s"However, found a ${o.toString} ") + } - val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") - df.queryExecution.executedPlan match { - case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation, _, _)) => // OK - case o => fail("test_insert_parquet should be converted to a " + - s"${classOf[ParquetRelation].getCanonicalName} and " + - s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan. " + - s"However, found a ${o.toString} ") + checkAnswer( + sql("SELECT intField FROM test_insert_parquet WHERE test_insert_parquet.intField > 5"), + sql("SELECT a FROM jt WHERE jt.a > 5").collect() + ) } - - checkAnswer( - sql("SELECT intField FROM test_insert_parquet WHERE test_insert_parquet.intField > 5"), - sql("SELECT a FROM jt WHERE jt.a > 5").collect() - ) - - sql("DROP TABLE IF EXISTS test_insert_parquet") } test("MetastoreRelation in InsertIntoHiveTable will be converted") { - sql( - """ - |create table test_insert_parquet - |( - | int_array array - |) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - """.stripMargin) + withTable("test_insert_parquet") { + sql( + """ + |create table test_insert_parquet + |( + | int_array array + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") + df.queryExecution.executedPlan match { + case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation, _, _)) => // OK + case o => fail("test_insert_parquet should be converted to a " + + s"${classOf[ParquetRelation].getCanonicalName} and " + + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + + s"However, found a ${o.toString} ") + } - val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") - df.queryExecution.executedPlan match { - case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation, _, _)) => // OK - case o => fail("test_insert_parquet should be converted to a " + - s"${classOf[ParquetRelation].getCanonicalName} and " + - s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + - s"However, found a ${o.toString} ") + checkAnswer( + sql("SELECT int_array FROM test_insert_parquet"), + sql("SELECT a FROM jt_array").collect() + ) } - - checkAnswer( - sql("SELECT int_array FROM test_insert_parquet"), - sql("SELECT a FROM jt_array").collect() - ) - - sql("DROP TABLE IF EXISTS test_insert_parquet") } test("SPARK-6450 regression test") { - sql( - """CREATE TABLE IF NOT EXISTS ms_convert (key INT) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - """.stripMargin) + withTable("ms_convert") { + sql( + """CREATE TABLE IF NOT EXISTS ms_convert (key INT) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + // This shouldn't throw AnalysisException + val analyzed = sql( + """SELECT key FROM ms_convert + |UNION ALL + |SELECT key FROM ms_convert + """.stripMargin).queryExecution.analyzed - // This shouldn't throw AnalysisException - val analyzed = sql( - """SELECT key FROM ms_convert - |UNION ALL - |SELECT key FROM ms_convert - """.stripMargin).queryExecution.analyzed - - assertResult(2) { - analyzed.collect { - case r @ LogicalRelation(_: ParquetRelation) => r - }.size + assertResult(2) { + analyzed.collect { + case r@LogicalRelation(_: ParquetRelation) => r + }.size + } } - - sql("DROP TABLE ms_convert") } def collectParquetRelation(df: DataFrame): ParquetRelation = { @@ -379,42 +387,42 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } test("SPARK-7749: non-partitioned metastore Parquet table lookup should use cached relation") { - sql( - s"""CREATE TABLE nonPartitioned ( - | key INT, - | value STRING - |) - |STORED AS PARQUET - """.stripMargin) - - // First lookup fills the cache - val r1 = collectParquetRelation(table("nonPartitioned")) - // Second lookup should reuse the cache - val r2 = collectParquetRelation(table("nonPartitioned")) - // They should be the same instance - assert(r1 eq r2) - - sql("DROP TABLE nonPartitioned") + withTable("nonPartitioned") { + sql( + s"""CREATE TABLE nonPartitioned ( + | key INT, + | value STRING + |) + |STORED AS PARQUET + """.stripMargin) + + // First lookup fills the cache + val r1 = collectParquetRelation(table("nonPartitioned")) + // Second lookup should reuse the cache + val r2 = collectParquetRelation(table("nonPartitioned")) + // They should be the same instance + assert(r1 eq r2) + } } test("SPARK-7749: partitioned metastore Parquet table lookup should use cached relation") { - sql( - s"""CREATE TABLE partitioned ( - | key INT, - | value STRING - |) - |PARTITIONED BY (part INT) - |STORED AS PARQUET + withTable("partitioned") { + sql( + s"""CREATE TABLE partitioned ( + | key INT, + | value STRING + |) + |PARTITIONED BY (part INT) + |STORED AS PARQUET """.stripMargin) - // First lookup fills the cache - val r1 = collectParquetRelation(table("partitioned")) - // Second lookup should reuse the cache - val r2 = collectParquetRelation(table("partitioned")) - // They should be the same instance - assert(r1 eq r2) - - sql("DROP TABLE partitioned") + // First lookup fills the cache + val r1 = collectParquetRelation(table("partitioned")) + // Second lookup should reuse the cache + val r2 = collectParquetRelation(table("partitioned")) + // They should be the same instance + assert(r1 eq r2) + } } test("Caching converted data source Parquet Relations") { @@ -430,8 +438,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } } - sql("DROP TABLE IF EXISTS test_insert_parquet") - sql("DROP TABLE IF EXISTS test_parquet_partitioned_cache_test") + dropTables("test_insert_parquet", "test_parquet_partitioned_cache_test") sql( """ @@ -479,7 +486,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { | intField INT, | stringField STRING |) - |PARTITIONED BY (date string) + |PARTITIONED BY (`date` string) |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' |STORED AS | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' @@ -491,7 +498,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test - |PARTITION (date='2015-04-01') + |PARTITION (`date`='2015-04-01') |select a, b from jt """.stripMargin) // Right now, insert into a partitioned Parquet is not supported in data source Parquet. @@ -500,7 +507,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test - |PARTITION (date='2015-04-02') + |PARTITION (`date`='2015-04-02') |select a, b from jt """.stripMargin) assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) @@ -510,7 +517,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { checkCached(tableIdentifier) // Make sure we can read the data. checkAnswer( - sql("select STRINGField, date, intField from test_parquet_partitioned_cache_test"), + sql("select STRINGField, `date`, intField from test_parquet_partitioned_cache_test"), sql( """ |select b, '2015-04-01', a FROM jt @@ -521,8 +528,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { invalidateTable("test_parquet_partitioned_cache_test") assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) - sql("DROP TABLE test_insert_parquet") - sql("DROP TABLE test_parquet_partitioned_cache_test") + dropTables("test_insert_parquet", "test_parquet_partitioned_cache_test") } } @@ -532,6 +538,11 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { class ParquetSourceSuite extends ParquetPartitioningTest { override def beforeAll(): Unit = { super.beforeAll() + dropTables("partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes", + "normal_parquet") sql( s""" create temporary table partitioned_parquet @@ -635,22 +646,22 @@ class ParquetSourceSuite extends ParquetPartitioningTest { StructField("a", arrayType1, nullable = true) :: Nil) assert(df.schema === expectedSchema1) - df.write.format("parquet").saveAsTable("alwaysNullable") + withTable("alwaysNullable") { + df.write.format("parquet").saveAsTable("alwaysNullable") - val mapType2 = MapType(IntegerType, IntegerType, valueContainsNull = true) - val arrayType2 = ArrayType(IntegerType, containsNull = true) - val expectedSchema2 = - StructType( - StructField("m", mapType2, nullable = true) :: - StructField("a", arrayType2, nullable = true) :: Nil) + val mapType2 = MapType(IntegerType, IntegerType, valueContainsNull = true) + val arrayType2 = ArrayType(IntegerType, containsNull = true) + val expectedSchema2 = + StructType( + StructField("m", mapType2, nullable = true) :: + StructField("a", arrayType2, nullable = true) :: Nil) - assert(table("alwaysNullable").schema === expectedSchema2) - - checkAnswer( - sql("SELECT m, a FROM alwaysNullable"), - Row(Map(2 -> 3), Seq(4, 5, 6))) + assert(table("alwaysNullable").schema === expectedSchema2) - sql("DROP TABLE alwaysNullable") + checkAnswer( + sql("SELECT m, a FROM alwaysNullable"), + Row(Map(2 -> 3), Seq(4, 5, 6))) + } } test("Aggregation attribute names can't contain special chars \" ,;{}()\\n\\t=\"") { @@ -738,6 +749,16 @@ abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with partitionedTableDirWithKeyAndComplexTypes.delete() } + /** + * Drop named tables if they exist + * @param tableNames tables to drop + */ + def dropTables(tableNames: String*): Unit = { + tableNames.foreach { name => + sql(s"DROP TABLE IF EXISTS $name") + } + } + Seq( "partitioned_parquet", "partitioned_parquet_with_key", diff --git a/yarn/pom.xml b/yarn/pom.xml index 2aeed98285aa8..49360c48256ea 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -30,7 +30,6 @@ Spark Project YARN yarn - 1.9 @@ -125,25 +124,16 @@ com.sun.jersey jersey-core - ${jersey.version} test com.sun.jersey jersey-json - ${jersey.version} test - - - stax - stax-api - - com.sun.jersey jersey-server - ${jersey.version} test diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 547863d9a0739..eb6e1fd370620 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -384,19 +384,29 @@ private object YarnClusterDriver extends Logging with Matchers { } -private object YarnClasspathTest { +private object YarnClasspathTest extends Logging { + + var exitCode = 0 + + def error(m: String, ex: Throwable = null): Unit = { + logError(m, ex) + // scalastyle:off println + System.out.println(m) + if (ex != null) { + ex.printStackTrace(System.out) + } + // scalastyle:on println + } def main(args: Array[String]): Unit = { if (args.length != 2) { - // scalastyle:off println - System.err.println( + error( s""" |Invalid command line: ${args.mkString(" ")} | |Usage: YarnClasspathTest [driver result file] [executor result file] """.stripMargin) // scalastyle:on println - System.exit(1) } readResource(args(0)) @@ -406,6 +416,7 @@ private object YarnClasspathTest { } finally { sc.stop() } + System.exit(exitCode) } private def readResource(resultPath: String): Unit = { @@ -415,6 +426,11 @@ private object YarnClasspathTest { val resource = ccl.getResourceAsStream("test.resource") val bytes = ByteStreams.toByteArray(resource) result = new String(bytes, 0, bytes.length, UTF_8) + } catch { + case t: Throwable => + error(s"loading test.resource to $resultPath", t) + // set the exit code if not yet set + exitCode = 2 } finally { Files.write(result, new File(resultPath), UTF_8) } From 13675c742a71cbdc8324701c3694775ce1dd5c62 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Mon, 3 Aug 2015 16:44:25 -0700 Subject: [PATCH 111/340] [SPARK-8874] [ML] Add missing methods in Word2Vec Add missing methods 1. getVectors 2. findSynonyms to W2Vec scala and python API mengxr Author: MechCoder Closes #7263 from MechCoder/missing_methods_w2vec and squashes the following commits: 149d5ca [MechCoder] minor doc 69d91b7 [MechCoder] [SPARK-8874] [ML] Add missing methods in Word2Vec --- .../apache/spark/ml/feature/Word2Vec.scala | 38 +++++++++++- .../spark/ml/feature/Word2VecSuite.scala | 62 +++++++++++++++++++ 2 files changed, 99 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 6ea6590956300..b4f46cef798dd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -18,15 +18,17 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.Experimental +import org.apache.spark.SparkContext import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{VectorUDT, Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types._ /** @@ -146,6 +148,40 @@ class Word2VecModel private[ml] ( wordVectors: feature.Word2VecModel) extends Model[Word2VecModel] with Word2VecBase { + + /** + * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and + * and the vector the DenseVector that it is mapped to. + */ + val getVectors: DataFrame = { + val sc = SparkContext.getOrCreate() + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + val wordVec = wordVectors.getVectors.mapValues(vec => Vectors.dense(vec.map(_.toDouble))) + sc.parallelize(wordVec.toSeq).toDF("word", "vector") + } + + /** + * Find "num" number of words closest in similarity to the given word. + * Returns a dataframe with the words and the cosine similarities between the + * synonyms and the given word. + */ + def findSynonyms(word: String, num: Int): DataFrame = { + findSynonyms(wordVectors.transform(word), num) + } + + /** + * Find "num" number of words closest to similarity to the given vector representation + * of the word. Returns a dataframe with the words and the cosine similarities between the + * synonyms and the given word vector. + */ + def findSynonyms(word: Vector, num: Int): DataFrame = { + val sc = SparkContext.getOrCreate() + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + sc.parallelize(wordVectors.findSynonyms(word, num)).toDF("word", "similarity") + } + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index aa6ce533fd885..adcda0e623b25 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -67,5 +67,67 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.") } } + + test("getVectors") { + + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val sentence = "a b " * 100 + "a c " * 10 + val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) + + val codes = Map( + "a" -> Array(-0.2811822295188904, -0.6356269121170044, -0.3020961284637451), + "b" -> Array(1.0309048891067505, -1.29472815990448, 0.22276712954044342), + "c" -> Array(-0.08456747233867645, 0.5137411952018738, 0.11731560528278351) + ) + val expectedVectors = codes.toSeq.sortBy(_._1).map { case (w, v) => Vectors.dense(v) } + + val docDF = doc.zip(doc).toDF("text", "alsotext") + + val model = new Word2Vec() + .setVectorSize(3) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .fit(docDF) + + val realVectors = model.getVectors.sort("word").select("vector").map { + case Row(v: Vector) => v + }.collect() + + realVectors.zip(expectedVectors).foreach { + case (real, expected) => + assert(real ~== expected absTol 1E-5, "Actual vector is different from expected.") + } + } + + test("findSynonyms") { + + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val sentence = "a b " * 100 + "a c " * 10 + val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) + val docDF = doc.zip(doc).toDF("text", "alsotext") + + val model = new Word2Vec() + .setVectorSize(3) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .fit(docDF) + + val expectedSimilarity = Array(0.2789285076917586, -0.6336972059851644) + val (synonyms, similarity) = model.findSynonyms("a", 2).map { + case Row(w: String, sim: Double) => (w, sim) + }.collect().unzip + + assert(synonyms.toArray === Array("b", "c")) + expectedSimilarity.zip(similarity).map { + case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5) + } + + } } From 7abaaad5b169520fbf7299808b2bafde089a16a2 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 3 Aug 2015 17:00:59 -0700 Subject: [PATCH 112/340] Add a prerequisites section for building docs This puts all the install commands that need to be run in one section instead of being spread over many paragraphs cc rxin Author: Shivaram Venkataraman Closes #7912 from shivaram/docs-setup-readme and squashes the following commits: cf7a204 [Shivaram Venkataraman] Add a prerequisites section for building docs --- docs/README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/README.md b/docs/README.md index d7652e921f7df..50209896f986c 100644 --- a/docs/README.md +++ b/docs/README.md @@ -8,6 +8,16 @@ Read on to learn more about viewing documentation in plain text (i.e., markdown) documentation yourself. Why build it yourself? So that you have the docs that corresponds to whichever version of Spark you currently have checked out of revision control. +## Prerequisites +The Spark documenation build uses a number of tools to build HTML docs and API docs in Scala, Python +and R. To get started you can run the following commands + + $ sudo gem install jekyll + $ sudo gem install jekyll-redirect-from + $ sudo pip install Pygments + $ Rscript -e 'install.packages(c("knitr", "devtools"), repos="http://cran.stat.ucla.edu/")' + + ## Generating the Documentation HTML We include the Spark documentation as part of the source (as opposed to using a hosted wiki, such as From b79b4f5f2251ed7efeec1f4b26e45a8ea6b85a6a Mon Sep 17 00:00:00 2001 From: Matthew Brandyberry Date: Mon, 3 Aug 2015 17:36:56 -0700 Subject: [PATCH 113/340] [SPARK-9483] Fix UTF8String.getPrefix for big-endian. Previous code assumed little-endian. Author: Matthew Brandyberry Closes #7902 from mtbrandy/SPARK-9483 and squashes the following commits: ec31df8 [Matthew Brandyberry] [SPARK-9483] Changes from review comments. 17d54c6 [Matthew Brandyberry] [SPARK-9483] Fix UTF8String.getPrefix for big-endian. --- .../apache/spark/unsafe/types/UTF8String.java | 40 ++++++++++++++----- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index f6c9b87778f8f..d80bd57bd2048 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -20,6 +20,7 @@ import javax.annotation.Nonnull; import java.io.Serializable; import java.io.UnsupportedEncodingException; +import java.nio.ByteOrder; import java.util.Arrays; import org.apache.spark.unsafe.PlatformDependent; @@ -53,6 +54,8 @@ public final class UTF8String implements Comparable, Serializable { 5, 5, 5, 5, 6, 6}; + private static ByteOrder byteOrder = ByteOrder.nativeOrder(); + public static final UTF8String EMPTY_UTF8 = UTF8String.fromString(""); /** @@ -175,18 +178,35 @@ public long getPrefix() { // If size is greater than 4, assume we have at least 8 bytes of data to fetch. // After getting the data, we use a mask to mask out data that is not part of the string. long p; - if (numBytes >= 8) { - p = PlatformDependent.UNSAFE.getLong(base, offset); - } else if (numBytes > 4) { - p = PlatformDependent.UNSAFE.getLong(base, offset); - p = p & ((1L << numBytes * 8) - 1); - } else if (numBytes > 0) { - p = (long) PlatformDependent.UNSAFE.getInt(base, offset); - p = p & ((1L << numBytes * 8) - 1); + long mask = 0; + if (byteOrder == ByteOrder.LITTLE_ENDIAN) { + if (numBytes >= 8) { + p = PlatformDependent.UNSAFE.getLong(base, offset); + } else if (numBytes > 4) { + p = PlatformDependent.UNSAFE.getLong(base, offset); + mask = (1L << (8 - numBytes) * 8) - 1; + } else if (numBytes > 0) { + p = (long) PlatformDependent.UNSAFE.getInt(base, offset); + mask = (1L << (8 - numBytes) * 8) - 1; + } else { + p = 0; + } + p = java.lang.Long.reverseBytes(p); } else { - p = 0; + // byteOrder == ByteOrder.BIG_ENDIAN + if (numBytes >= 8) { + p = PlatformDependent.UNSAFE.getLong(base, offset); + } else if (numBytes > 4) { + p = PlatformDependent.UNSAFE.getLong(base, offset); + mask = (1L << (8 - numBytes) * 8) - 1; + } else if (numBytes > 0) { + p = ((long) PlatformDependent.UNSAFE.getInt(base, offset)) << 32; + mask = (1L << (8 - numBytes) * 8) - 1; + } else { + p = 0; + } } - p = java.lang.Long.reverseBytes(p); + p &= ~mask; return p; } From 1633d0a2612d94151f620c919425026150e69ae1 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 3 Aug 2015 17:42:03 -0700 Subject: [PATCH 114/340] [SPARK-9263] Added flags to exclude dependencies when using --packages While the functionality is there to exclude packages, there are no flags that allow users to exclude dependencies, in case of dependency conflicts. We should provide users with a flag to add dependency exclusions in case the packages are not resolved properly (or not available due to licensing). The flag I added was --packages-exclude, but I'm open on renaming it. I also added property flags in case people would like to use a conf file to provide dependencies, which is possible if there is a long list of dependencies or exclusions. cc andrewor14 vanzin pwendell Author: Burak Yavuz Closes #7599 from brkyvz/packages-exclusions and squashes the following commits: 636f410 [Burak Yavuz] addressed nits 6e54ede [Burak Yavuz] is this the culprit b5e508e [Burak Yavuz] Merge branch 'master' of github.com:apache/spark into packages-exclusions 154f5db [Burak Yavuz] addressed initial comments 1536d7a [Burak Yavuz] Added flags to exclude packages using --packages-exclude --- .../org/apache/spark/deploy/SparkSubmit.scala | 29 +++++++++--------- .../spark/deploy/SparkSubmitArguments.scala | 11 +++++++ .../spark/deploy/SparkSubmitUtilsSuite.scala | 30 +++++++++++++++++++ .../launcher/SparkSubmitOptionParser.java | 2 ++ 4 files changed, 57 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 0b39ee8fe3ba0..31185c8e77def 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -24,6 +24,7 @@ import java.security.PrivilegedExceptionAction import scala.collection.mutable.{ArrayBuffer, HashMap, Map} +import org.apache.commons.lang3.StringUtils import org.apache.hadoop.fs.Path import org.apache.hadoop.security.UserGroupInformation import org.apache.ivy.Ivy @@ -37,6 +38,7 @@ import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.matcher.GlobPatternMatcher import org.apache.ivy.plugins.repository.file.FileRepository import org.apache.ivy.plugins.resolver.{FileSystemResolver, ChainResolver, IBiblioResolver} + import org.apache.spark.api.r.RUtils import org.apache.spark.SPARK_VERSION import org.apache.spark.deploy.rest._ @@ -275,21 +277,18 @@ object SparkSubmit { // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files // too for packages that include Python code - val resolvedMavenCoordinates = - SparkSubmitUtils.resolveMavenCoordinates( - args.packages, Option(args.repositories), Option(args.ivyRepoPath)) - if (!resolvedMavenCoordinates.trim.isEmpty) { - if (args.jars == null || args.jars.trim.isEmpty) { - args.jars = resolvedMavenCoordinates + val exclusions: Seq[String] = + if (!StringUtils.isBlank(args.packagesExclusions)) { + args.packagesExclusions.split(",") } else { - args.jars += s",$resolvedMavenCoordinates" + Nil } + val resolvedMavenCoordinates = SparkSubmitUtils.resolveMavenCoordinates(args.packages, + Some(args.repositories), Some(args.ivyRepoPath), exclusions = exclusions) + if (!StringUtils.isBlank(resolvedMavenCoordinates)) { + args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates) if (args.isPython) { - if (args.pyFiles == null || args.pyFiles.trim.isEmpty) { - args.pyFiles = resolvedMavenCoordinates - } else { - args.pyFiles += s",$resolvedMavenCoordinates" - } + args.pyFiles = mergeFileLists(args.pyFiles, resolvedMavenCoordinates) } } @@ -736,7 +735,7 @@ object SparkSubmit { * no files, into a single comma-separated string. */ private def mergeFileLists(lists: String*): String = { - val merged = lists.filter(_ != null) + val merged = lists.filterNot(StringUtils.isBlank) .flatMap(_.split(",")) .mkString(",") if (merged == "") null else merged @@ -938,7 +937,7 @@ private[spark] object SparkSubmitUtils { // are supplied to spark-submit val alternateIvyCache = ivyPath.getOrElse("") val packagesDirectory: File = - if (alternateIvyCache.trim.isEmpty) { + if (alternateIvyCache == null || alternateIvyCache.trim.isEmpty) { new File(ivySettings.getDefaultIvyUserDir, "jars") } else { ivySettings.setDefaultIvyUserDir(new File(alternateIvyCache)) @@ -1010,7 +1009,7 @@ private[spark] object SparkSubmitUtils { } } - private def createExclusion( + private[deploy] def createExclusion( coords: String, ivySettings: IvySettings, ivyConfName: String): ExcludeRule = { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index b3710073e330c..44852ce4e84ac 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -59,6 +59,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var packages: String = null var repositories: String = null var ivyRepoPath: String = null + var packagesExclusions: String = null var verbose: Boolean = false var isPython: Boolean = false var pyFiles: String = null @@ -172,6 +173,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull ivyRepoPath = sparkProperties.get("spark.jars.ivy").orNull + packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull + packagesExclusions = Option(packagesExclusions) + .orElse(sparkProperties.get("spark.jars.excludes")).orNull deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull numExecutors = Option(numExecutors) .getOrElse(sparkProperties.get("spark.executor.instances").orNull) @@ -299,6 +303,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | childArgs [${childArgs.mkString(" ")}] | jars $jars | packages $packages + | packagesExclusions $packagesExclusions | repositories $repositories | verbose $verbose | @@ -391,6 +396,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case PACKAGES => packages = value + case PACKAGES_EXCLUDE => + packagesExclusions = value + case REPOSITORIES => repositories = value @@ -482,6 +490,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | maven repo, then maven central and any additional remote | repositories given by --repositories. The format for the | coordinates should be groupId:artifactId:version. + | --exclude-packages Comma-separated list of groupId:artifactId, to exclude while + | resolving the dependencies provided in --packages to avoid + | dependency conflicts. | --repositories Comma-separated list of additional remote repositories to | search for the maven coordinates given with --packages. | --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to place diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 01ece1a10f46d..63c346c1b8908 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -95,6 +95,25 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(md.getDependencies.length === 2) } + test("excludes works correctly") { + val md = SparkSubmitUtils.getModuleDescriptor + val excludes = Seq("a:b", "c:d") + excludes.foreach { e => + md.addExcludeRule(SparkSubmitUtils.createExclusion(e + ":*", new IvySettings, "default")) + } + val rules = md.getAllExcludeRules + assert(rules.length === 2) + val rule1 = rules(0).getId.getModuleId + assert(rule1.getOrganisation === "a") + assert(rule1.getName === "b") + val rule2 = rules(1).getId.getModuleId + assert(rule2.getOrganisation === "c") + assert(rule2.getName === "d") + intercept[IllegalArgumentException] { + SparkSubmitUtils.createExclusion("e:f:g:h", new IvySettings, "default") + } + } + test("ivy path works correctly") { val md = SparkSubmitUtils.getModuleDescriptor val artifacts = for (i <- 0 until 3) yield new MDArtifact(md, s"jar-$i", "jar", "jar") @@ -168,4 +187,15 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(files.indexOf(main.artifactId) >= 0, "Did not return artifact") } } + + test("exclude dependencies end to end") { + val main = new MavenCoordinate("my.great.lib", "mylib", "0.1") + val dep = "my.great.dep:mydep:0.5" + IvyTestUtils.withRepository(main, Some(dep), None) { repo => + val files = SparkSubmitUtils.resolveMavenCoordinates(main.toString, + Some(repo), None, Seq("my.great.dep:mydep"), isTest = true) + assert(files.indexOf(main.artifactId) >= 0, "Did not return artifact") + assert(files.indexOf("my.great.dep") < 0, "Returned excluded artifact") + } + } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java index b88bba883ac65..5779eb3fc0f78 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java @@ -51,6 +51,7 @@ class SparkSubmitOptionParser { protected final String MASTER = "--master"; protected final String NAME = "--name"; protected final String PACKAGES = "--packages"; + protected final String PACKAGES_EXCLUDE = "--exclude-packages"; protected final String PROPERTIES_FILE = "--properties-file"; protected final String PROXY_USER = "--proxy-user"; protected final String PY_FILES = "--py-files"; @@ -105,6 +106,7 @@ class SparkSubmitOptionParser { { NAME }, { NUM_EXECUTORS }, { PACKAGES }, + { PACKAGES_EXCLUDE }, { PRINCIPAL }, { PROPERTIES_FILE }, { PROXY_USER }, From 3b0e44490aebfba30afc147e4a34a63439d985c6 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Mon, 3 Aug 2015 18:20:40 -0700 Subject: [PATCH 115/340] [SPARK-8416] highlight and topping the executor threads in thread dumping page https://issues.apache.org/jira/browse/SPARK-8416 To facilitate debugging, I made this patch with three changes: * render the executor-thread and non executor-thread entries with different background colors * put the executor threads on the top of the list * sort the threads alphabetically Author: CodingCat Closes #7808 from CodingCat/SPARK-8416 and squashes the following commits: 34fc708 [CodingCat] fix className d7b79dd [CodingCat] lowercase threadName d032882 [CodingCat] sort alphabetically and change the css class name f0513b1 [CodingCat] change the color & group threads by name 2da6e06 [CodingCat] small fix 3fc9f36 [CodingCat] define classes in webui.css 8ee125e [CodingCat] highlight and put on top the executor threads in thread dumping page --- .../org/apache/spark/ui/static/webui.css | 8 +++++++ .../ui/exec/ExecutorThreadDumpPage.scala | 24 ++++++++++++++++--- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 648cd1b104802..04f3070d25b4a 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -224,3 +224,11 @@ span.additional-metric-title { a.expandbutton { cursor: pointer; } + +.executor-thread { + background: #E6E6E6; +} + +.non-executor-thread { + background: #FAFAFA; +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index f0ae95bb8c812..b0a2cb4aa4d4b 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -49,11 +49,29 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage val maybeThreadDump = sc.get.getExecutorThreadDump(executorId) val content = maybeThreadDump.map { threadDump => - val dumpRows = threadDump.map { thread => + val dumpRows = threadDump.sortWith { + case (threadTrace1, threadTrace2) => { + val v1 = if (threadTrace1.threadName.contains("Executor task launch")) 1 else 0 + val v2 = if (threadTrace2.threadName.contains("Executor task launch")) 1 else 0 + if (v1 == v2) { + threadTrace1.threadName.toLowerCase < threadTrace2.threadName.toLowerCase + } else { + v1 > v2 + } + } + }.map { thread => + val threadName = thread.threadName + val className = "accordion-heading " + { + if (threadName.contains("Executor task launch")) { + "executor-thread" + } else { + "non-executor-thread" + } + }
    -
    + @@ -594,6 +597,9 @@ rowMat = mat.toRowMatrix() # Convert to an IndexedRowMatrix. indexedRowMat = mat.toIndexedRowMatrix() + +# Convert to a BlockMatrix. +blockMat = mat.toBlockMatrix() {% endhighlight %}
    @@ -661,4 +667,39 @@ matA.validate(); BlockMatrix ata = matA.transpose().multiply(matA); {% endhighlight %}
    + +
    + +A [`BlockMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.BlockMatrix) +can be created from an `RDD` of sub-matrix blocks, where a sub-matrix block is a +`((blockRowIndex, blockColIndex), sub-matrix)` tuple. + +{% highlight python %} +from pyspark.mllib.linalg import Matrices +from pyspark.mllib.linalg.distributed import BlockMatrix + +# Create an RDD of sub-matrix blocks. +blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + +# Create a BlockMatrix from an RDD of sub-matrix blocks. +mat = BlockMatrix(blocks, 3, 2) + +# Get its size. +m = mat.numRows() # 6 +n = mat.numCols() # 2 + +# Get the blocks as an RDD of sub-matrix blocks. +blocksRDD = mat.blocks + +# Convert to a LocalMatrix. +localMat = mat.toLocalMatrix() + +# Convert to an IndexedRowMatrix. +indexedRowMat = mat.toIndexedRowMatrix() + +# Convert to a CoordinateMatrix. +coordinateMat = mat.toCoordinateMatrix() +{% endhighlight %} +
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index d2b3fae381acb..f585aacd452e0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -1128,6 +1128,21 @@ private[python] class PythonMLLibAPI extends Serializable { new CoordinateMatrix(entries, numRows, numCols) } + /** + * Wrapper around BlockMatrix constructor. + */ + def createBlockMatrix(blocks: DataFrame, rowsPerBlock: Int, colsPerBlock: Int, + numRows: Long, numCols: Long): BlockMatrix = { + // We use DataFrames for serialization of sub-matrix blocks from + // Python, so map each Row in the DataFrame back to a + // ((blockRowIndex, blockColIndex), sub-matrix) tuple. + val blockTuples = blocks.map { + case Row(Row(blockRowIndex: Long, blockColIndex: Long), subMatrix: Matrix) => + ((blockRowIndex.toInt, blockColIndex.toInt), subMatrix) + } + new BlockMatrix(blockTuples, rowsPerBlock, colsPerBlock, numRows, numCols) + } + /** * Return the rows of an IndexedRowMatrix. */ @@ -1147,6 +1162,16 @@ private[python] class PythonMLLibAPI extends Serializable { val sqlContext = new SQLContext(coordinateMatrix.entries.sparkContext) sqlContext.createDataFrame(coordinateMatrix.entries) } + + /** + * Return the sub-matrix blocks of a BlockMatrix. + */ + def getMatrixBlocks(blockMatrix: BlockMatrix): DataFrame = { + // We use DataFrames for serialization of sub-matrix blocks to + // Python, so return a DataFrame. + val sqlContext = new SQLContext(blockMatrix.blocks.sparkContext) + sqlContext.createDataFrame(blockMatrix.blocks) + } } /** diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index 666d833019562..aec407de90aa3 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -28,11 +28,12 @@ from pyspark import RDD from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import _convert_to_vector +from pyspark.mllib.linalg import _convert_to_vector, Matrix __all__ = ['DistributedMatrix', 'RowMatrix', 'IndexedRow', - 'IndexedRowMatrix', 'MatrixEntry', 'CoordinateMatrix'] + 'IndexedRowMatrix', 'MatrixEntry', 'CoordinateMatrix', + 'BlockMatrix'] class DistributedMatrix(object): @@ -322,6 +323,35 @@ def toCoordinateMatrix(self): java_coordinate_matrix = self._java_matrix_wrapper.call("toCoordinateMatrix") return CoordinateMatrix(java_coordinate_matrix) + def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024): + """ + Convert this matrix to a BlockMatrix. + + :param rowsPerBlock: Number of rows that make up each block. + The blocks forming the final rows are not + required to have the given number of rows. + :param colsPerBlock: Number of columns that make up each block. + The blocks forming the final columns are not + required to have the given number of columns. + + >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]), + ... IndexedRow(6, [4, 5, 6])]) + >>> mat = IndexedRowMatrix(rows).toBlockMatrix() + + >>> # This IndexedRowMatrix will have 7 effective rows, due to + >>> # the highest row index being 6, and the ensuing + >>> # BlockMatrix will have 7 rows as well. + >>> print(mat.numRows()) + 7 + + >>> print(mat.numCols()) + 3 + """ + java_block_matrix = self._java_matrix_wrapper.call("toBlockMatrix", + rowsPerBlock, + colsPerBlock) + return BlockMatrix(java_block_matrix, rowsPerBlock, colsPerBlock) + class MatrixEntry(object): """ @@ -476,19 +506,18 @@ def toRowMatrix(self): >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), ... MatrixEntry(6, 4, 2.1)]) + >>> mat = CoordinateMatrix(entries).toRowMatrix() >>> # This CoordinateMatrix will have 7 effective rows, due to >>> # the highest row index being 6, but the ensuing RowMatrix >>> # will only have 2 rows since there are only entries on 2 >>> # unique rows. - >>> mat = CoordinateMatrix(entries).toRowMatrix() >>> print(mat.numRows()) 2 >>> # This CoordinateMatrix will have 5 columns, due to the >>> # highest column index being 4, and the ensuing RowMatrix >>> # will have 5 columns as well. - >>> mat = CoordinateMatrix(entries).toRowMatrix() >>> print(mat.numCols()) 5 """ @@ -501,33 +530,320 @@ def toIndexedRowMatrix(self): >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), ... MatrixEntry(6, 4, 2.1)]) + >>> mat = CoordinateMatrix(entries).toIndexedRowMatrix() >>> # This CoordinateMatrix will have 7 effective rows, due to >>> # the highest row index being 6, and the ensuing >>> # IndexedRowMatrix will have 7 rows as well. - >>> mat = CoordinateMatrix(entries).toIndexedRowMatrix() >>> print(mat.numRows()) 7 >>> # This CoordinateMatrix will have 5 columns, due to the >>> # highest column index being 4, and the ensuing >>> # IndexedRowMatrix will have 5 columns as well. - >>> mat = CoordinateMatrix(entries).toIndexedRowMatrix() >>> print(mat.numCols()) 5 """ java_indexed_row_matrix = self._java_matrix_wrapper.call("toIndexedRowMatrix") return IndexedRowMatrix(java_indexed_row_matrix) + def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024): + """ + Convert this matrix to a BlockMatrix. + + :param rowsPerBlock: Number of rows that make up each block. + The blocks forming the final rows are not + required to have the given number of rows. + :param colsPerBlock: Number of columns that make up each block. + The blocks forming the final columns are not + required to have the given number of columns. + + >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), + ... MatrixEntry(6, 4, 2.1)]) + >>> mat = CoordinateMatrix(entries).toBlockMatrix() + + >>> # This CoordinateMatrix will have 7 effective rows, due to + >>> # the highest row index being 6, and the ensuing + >>> # BlockMatrix will have 7 rows as well. + >>> print(mat.numRows()) + 7 + + >>> # This CoordinateMatrix will have 5 columns, due to the + >>> # highest column index being 4, and the ensuing + >>> # BlockMatrix will have 5 columns as well. + >>> print(mat.numCols()) + 5 + """ + java_block_matrix = self._java_matrix_wrapper.call("toBlockMatrix", + rowsPerBlock, + colsPerBlock) + return BlockMatrix(java_block_matrix, rowsPerBlock, colsPerBlock) + + +def _convert_to_matrix_block_tuple(block): + if (isinstance(block, tuple) and len(block) == 2 + and isinstance(block[0], tuple) and len(block[0]) == 2 + and isinstance(block[1], Matrix)): + blockRowIndex = int(block[0][0]) + blockColIndex = int(block[0][1]) + subMatrix = block[1] + return ((blockRowIndex, blockColIndex), subMatrix) + else: + raise TypeError("Cannot convert type %s into a sub-matrix block tuple" % type(block)) + + +class BlockMatrix(DistributedMatrix): + """ + .. note:: Experimental + + Represents a distributed matrix in blocks of local matrices. + + :param blocks: An RDD of sub-matrix blocks + ((blockRowIndex, blockColIndex), sub-matrix) that + form this distributed matrix. If multiple blocks + with the same index exist, the results for + operations like add and multiply will be + unpredictable. + :param rowsPerBlock: Number of rows that make up each block. + The blocks forming the final rows are not + required to have the given number of rows. + :param colsPerBlock: Number of columns that make up each block. + The blocks forming the final columns are not + required to have the given number of columns. + :param numRows: Number of rows of this matrix. If the supplied + value is less than or equal to zero, the number + of rows will be calculated when `numRows` is + invoked. + :param numCols: Number of columns of this matrix. If the supplied + value is less than or equal to zero, the number + of columns will be calculated when `numCols` is + invoked. + """ + def __init__(self, blocks, rowsPerBlock, colsPerBlock, numRows=0, numCols=0): + """ + Note: This docstring is not shown publicly. + + Create a wrapper over a Java BlockMatrix. + + Publicly, we require that `blocks` be an RDD. However, for + internal usage, `blocks` can also be a Java BlockMatrix + object, in which case we can wrap it directly. This + assists in clean matrix conversions. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2) + + >>> mat_diff = BlockMatrix(blocks, 3, 2) + >>> (mat_diff._java_matrix_wrapper._java_model == + ... mat._java_matrix_wrapper._java_model) + False + + >>> mat_same = BlockMatrix(mat._java_matrix_wrapper._java_model, 3, 2) + >>> (mat_same._java_matrix_wrapper._java_model == + ... mat._java_matrix_wrapper._java_model) + True + """ + if isinstance(blocks, RDD): + blocks = blocks.map(_convert_to_matrix_block_tuple) + # We use DataFrames for serialization of sub-matrix blocks + # from Python, so first convert the RDD to a DataFrame on + # this side. This will convert each sub-matrix block + # tuple to a Row containing the 'blockRowIndex', + # 'blockColIndex', and 'subMatrix' values, which can + # each be easily serialized. We will convert back to + # ((blockRowIndex, blockColIndex), sub-matrix) tuples on + # the Scala side. + java_matrix = callMLlibFunc("createBlockMatrix", blocks.toDF(), + int(rowsPerBlock), int(colsPerBlock), + long(numRows), long(numCols)) + elif (isinstance(blocks, JavaObject) + and blocks.getClass().getSimpleName() == "BlockMatrix"): + java_matrix = blocks + else: + raise TypeError("blocks should be an RDD of sub-matrix blocks as " + "((int, int), matrix) tuples, got %s" % type(blocks)) + + self._java_matrix_wrapper = JavaModelWrapper(java_matrix) + + @property + def blocks(self): + """ + The RDD of sub-matrix blocks + ((blockRowIndex, blockColIndex), sub-matrix) that form this + distributed matrix. + + >>> mat = BlockMatrix( + ... sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]), 3, 2) + >>> blocks = mat.blocks + >>> blocks.first() + ((0, 0), DenseMatrix(3, 2, [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 0)) + + """ + # We use DataFrames for serialization of sub-matrix blocks + # from Java, so we first convert the RDD of blocks to a + # DataFrame on the Scala/Java side. Then we map each Row in + # the DataFrame back to a sub-matrix block on this side. + blocks_df = callMLlibFunc("getMatrixBlocks", self._java_matrix_wrapper._java_model) + blocks = blocks_df.map(lambda row: ((row[0][0], row[0][1]), row[1])) + return blocks + + @property + def rowsPerBlock(self): + """ + Number of rows that make up each block. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2) + >>> mat.rowsPerBlock + 3 + """ + return self._java_matrix_wrapper.call("rowsPerBlock") + + @property + def colsPerBlock(self): + """ + Number of columns that make up each block. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2) + >>> mat.colsPerBlock + 2 + """ + return self._java_matrix_wrapper.call("colsPerBlock") + + @property + def numRowBlocks(self): + """ + Number of rows of blocks in the BlockMatrix. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2) + >>> mat.numRowBlocks + 2 + """ + return self._java_matrix_wrapper.call("numRowBlocks") + + @property + def numColBlocks(self): + """ + Number of columns of blocks in the BlockMatrix. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2) + >>> mat.numColBlocks + 1 + """ + return self._java_matrix_wrapper.call("numColBlocks") + + def numRows(self): + """ + Get or compute the number of rows. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + + >>> mat = BlockMatrix(blocks, 3, 2) + >>> print(mat.numRows()) + 6 + + >>> mat = BlockMatrix(blocks, 3, 2, 7, 6) + >>> print(mat.numRows()) + 7 + """ + return self._java_matrix_wrapper.call("numRows") + + def numCols(self): + """ + Get or compute the number of cols. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + + >>> mat = BlockMatrix(blocks, 3, 2) + >>> print(mat.numCols()) + 2 + + >>> mat = BlockMatrix(blocks, 3, 2, 7, 6) + >>> print(mat.numCols()) + 6 + """ + return self._java_matrix_wrapper.call("numCols") + + def toLocalMatrix(self): + """ + Collect the distributed matrix on the driver as a DenseMatrix. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2).toLocalMatrix() + + >>> # This BlockMatrix will have 6 effective rows, due to + >>> # having two sub-matrix blocks stacked, each with 3 rows. + >>> # The ensuing DenseMatrix will also have 6 rows. + >>> print(mat.numRows) + 6 + + >>> # This BlockMatrix will have 2 effective columns, due to + >>> # having two sub-matrix blocks stacked, each with 2 + >>> # columns. The ensuing DenseMatrix will also have 2 columns. + >>> print(mat.numCols) + 2 + """ + return self._java_matrix_wrapper.call("toLocalMatrix") + + def toIndexedRowMatrix(self): + """ + Convert this matrix to an IndexedRowMatrix. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2).toIndexedRowMatrix() + + >>> # This BlockMatrix will have 6 effective rows, due to + >>> # having two sub-matrix blocks stacked, each with 3 rows. + >>> # The ensuing IndexedRowMatrix will also have 6 rows. + >>> print(mat.numRows()) + 6 + + >>> # This BlockMatrix will have 2 effective columns, due to + >>> # having two sub-matrix blocks stacked, each with 2 columns. + >>> # The ensuing IndexedRowMatrix will also have 2 columns. + >>> print(mat.numCols()) + 2 + """ + java_indexed_row_matrix = self._java_matrix_wrapper.call("toIndexedRowMatrix") + return IndexedRowMatrix(java_indexed_row_matrix) + + def toCoordinateMatrix(self): + """ + Convert this matrix to a CoordinateMatrix. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(1, 2, [1, 2])), + ... ((1, 0), Matrices.dense(1, 2, [7, 8]))]) + >>> mat = BlockMatrix(blocks, 1, 2).toCoordinateMatrix() + >>> mat.entries.take(3) + [MatrixEntry(0, 0, 1.0), MatrixEntry(0, 1, 2.0), MatrixEntry(1, 0, 7.0)] + """ + java_coordinate_matrix = self._java_matrix_wrapper.call("toCoordinateMatrix") + return CoordinateMatrix(java_coordinate_matrix) + def _test(): import doctest from pyspark import SparkContext from pyspark.sql import SQLContext + from pyspark.mllib.linalg import Matrices import pyspark.mllib.linalg.distributed globs = pyspark.mllib.linalg.distributed.__dict__.copy() globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) globs['sqlContext'] = SQLContext(globs['sc']) + globs['Matrices'] = Matrices (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: From 23d982204bb9ef74d3b788a32ce6608116968719 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 5 Aug 2015 09:01:45 -0700 Subject: [PATCH 163/340] [SPARK-9141] [SQL] Remove project collapsing from DataFrame API Currently we collapse successive projections that are added by `withColumn`. However, this optimization violates the constraint that adding nodes to a plan will never change its analyzed form and thus breaks caching. Instead of doing early optimization, in this PR I just fix some low-hanging slowness in the analyzer. In particular, I add a mechanism for skipping already analyzed subplans, `resolveOperators` and `resolveExpression`. Since trees are generally immutable after construction, it's safe to annotate a plan as already analyzed as any transformation will create a new tree with this bit no longer set. Together these result in a faster analyzer than before, even with added timing instrumentation. ``` Original Code [info] 3430ms [info] 2205ms [info] 1973ms [info] 1982ms [info] 1916ms Without Project Collapsing in DataFrame [info] 44610ms [info] 45977ms [info] 46423ms [info] 46306ms [info] 54723ms With analyzer optimizations [info] 6394ms [info] 4630ms [info] 4388ms [info] 4093ms [info] 4113ms With resolveOperators [info] 2495ms [info] 1380ms [info] 1685ms [info] 1414ms [info] 1240ms ``` Author: Michael Armbrust Closes #7920 from marmbrus/withColumnCache and squashes the following commits: 2145031 [Michael Armbrust] fix hive udfs tests 5a5a525 [Michael Armbrust] remove wrong comment 7a507d5 [Michael Armbrust] style b59d710 [Michael Armbrust] revert small change 1fa5949 [Michael Armbrust] move logic into LogicalPlan, add tests 0e2cb43 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into withColumnCache c926e24 [Michael Armbrust] naming e593a2d [Michael Armbrust] style f5a929e [Michael Armbrust] [SPARK-9141][SQL] Remove project collapsing from DataFrame API 38b1c83 [Michael Armbrust] WIP --- .../sql/catalyst/analysis/Analyzer.scala | 28 ++++---- .../sql/catalyst/analysis/CheckAnalysis.scala | 3 + .../catalyst/analysis/HiveTypeCoercion.scala | 30 ++++---- .../spark/sql/catalyst/plans/QueryPlan.scala | 5 +- .../catalyst/plans/logical/LogicalPlan.scala | 51 ++++++++++++- .../sql/catalyst/rules/RuleExecutor.scala | 22 ++++++ .../spark/sql/catalyst/trees/TreeNode.scala | 64 ++++------------- .../sql/catalyst/plans/LogicalPlanSuite.scala | 72 +++++++++++++++++++ .../org/apache/spark/sql/DataFrame.scala | 6 +- .../spark/sql/execution/SparkPlan.scala | 2 +- .../spark/sql/execution/pythonUDFs.scala | 2 +- .../apache/spark/sql/CachedTableSuite.scala | 20 ++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 12 ---- .../execution/HiveCompatibilitySuite.scala | 5 ++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 5 +- .../org/apache/spark/sql/hive/hiveUDFs.scala | 3 +- .../sql/hive/execution/HiveUDFSuite.scala | 8 +-- .../sql/hive/execution/SQLQuerySuite.scala | 2 + 18 files changed, 234 insertions(+), 106 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ca17f3e3d06ff..6de31f42dd30c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -90,7 +90,7 @@ class Analyzer( */ object CTESubstitution extends Rule[LogicalPlan] { // TODO allow subquery to define CTE - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case With(child, relations) => substituteCTE(child, relations) case other => other } @@ -116,7 +116,7 @@ class Analyzer( * Substitute child plan with WindowSpecDefinitions. */ object WindowsSubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => child.transform { @@ -140,7 +140,7 @@ class Analyzer( object ResolveAliases extends Rule[LogicalPlan] { private def assignAliases(exprs: Seq[NamedExpression]) = { // The `UnresolvedAlias`s will appear only at root of a expression tree, we don't need - // to transform down the whole tree. + // to resolveOperator down the whole tree. exprs.zipWithIndex.map { case (u @ UnresolvedAlias(child), i) => child match { @@ -156,7 +156,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case Aggregate(groups, aggs, child) if child.resolved && aggs.exists(_.isInstanceOf[UnresolvedAlias]) => Aggregate(groups, assignAliases(aggs), child) @@ -198,7 +198,7 @@ class Analyzer( Seq.tabulate(1 << c.groupByExprs.length)(i => i) } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case a if !a.childrenResolved => a // be sure all of the children are resolved. case a: Cube => GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) @@ -261,7 +261,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _, _) => i.copy(table = EliminateSubQueries(getTable(u))) case u: UnresolvedRelation => @@ -274,7 +274,7 @@ class Analyzer( * a logical plan node's children. */ object ResolveReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. @@ -444,7 +444,7 @@ class Analyzer( * remove these attributes after sorting. */ object ResolveSortReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case s @ Sort(ordering, global, p @ Project(projectList, child)) if !s.resolved && p.resolved => val (newOrdering, missing) = resolveAndFindMissing(ordering, p, child) @@ -519,7 +519,7 @@ class Analyzer( * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case q: LogicalPlan => q transformExpressions { case u @ UnresolvedFunction(name, children, isDistinct) => @@ -551,7 +551,7 @@ class Analyzer( * Turns projections that contain aggregate expressions into aggregations. */ object GlobalAggregates extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case Project(projectList, child) if containsAggregates(projectList) => Aggregate(Nil, projectList, child) } @@ -571,7 +571,7 @@ class Analyzer( * aggregates and then projects them away above the filter. */ object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _)) if aggregate.resolved && containsAggregate(havingCondition) => @@ -601,7 +601,7 @@ class Analyzer( * [[AnalysisException]] is throw. */ object ResolveGenerate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: Generate if !p.child.resolved || !p.generator.resolved => p case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) @@ -872,6 +872,8 @@ class Analyzer( // We have to use transformDown at here to make sure the rule of // "Aggregate with Having clause" will be triggered. def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + + // Aggregate with Having clause. This rule works with an unresolved Aggregate because // a resolved Aggregate will not have Window Functions. case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) @@ -927,7 +929,7 @@ class Analyzer( * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: Project => p case f: Filter => f diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 187b238045f85..39f554c137c98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -47,6 +47,7 @@ trait CheckAnalysis { // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. plan.foreachUp { + case p if p.analyzed => // Skip already analyzed sub-plans case operator: LogicalPlan => operator transformExpressionsUp { @@ -179,5 +180,7 @@ trait CheckAnalysis { } } extendedCheckRules.foreach(_(plan)) + + plan.foreach(_.setAnalyzed()) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 490f3dc07b6ed..970f3c8282c81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -144,7 +144,8 @@ object HiveTypeCoercion { * instances higher in the query tree. */ object PropagateTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + // No propagation required for leaf nodes. case q: LogicalPlan if q.children.isEmpty => q @@ -225,7 +226,9 @@ object HiveTypeCoercion { } } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if p.analyzed => p + case u @ Union(left, right) if u.childrenResolved && !u.resolved => val (newLeft, newRight) = widenOutputTypes(u.nodeName, left, right) Union(newLeft, newRight) @@ -242,7 +245,7 @@ object HiveTypeCoercion { * Promotes strings that appear in arithmetic expressions. */ object PromoteStrings extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -305,7 +308,7 @@ object HiveTypeCoercion { * Convert all expressions in in() list to the left operator type */ object InConversion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -372,7 +375,8 @@ object HiveTypeCoercion { ChangeDecimalPrecision(Cast(e, dataType)) } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + // fix decimal precision for expressions case q => q.transformExpressions { // Skip nodes whose children have not been resolved yet @@ -466,7 +470,7 @@ object HiveTypeCoercion { )) } - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -508,7 +512,7 @@ object HiveTypeCoercion { * truncated version of this number. */ object StringToIntegralCasts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -521,7 +525,7 @@ object HiveTypeCoercion { * This ensure that the types for various functions are as expected. */ object FunctionArgumentConversion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -575,7 +579,7 @@ object HiveTypeCoercion { * converted to fractional types. */ object Division extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who has not been resolved yet, // as this is an extra rule which should be applied at last. case e if !e.resolved => e @@ -592,7 +596,7 @@ object HiveTypeCoercion { * Coerces the type of different branches of a CASE WHEN statement to a common type. */ object CaseWhenCoercion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual => logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}") val maybeCommonType = findTightestCommonTypeAndPromoteToString(c.valueTypes) @@ -628,7 +632,7 @@ object HiveTypeCoercion { * Coerces the type of different branches of If statement to a common type. */ object IfCoercion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if left.dataType != right.dataType => findTightestCommonTypeToString(left.dataType, right.dataType).map { widestType => @@ -652,7 +656,7 @@ object HiveTypeCoercion { private val acceptedTypes = Seq(DateType, TimestampType, StringType) - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -669,7 +673,7 @@ object HiveTypeCoercion { * Casts types according to the expected input types for [[Expression]]s. */ object ImplicitTypeCasts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index c610f70d38437..55286f9f2fc5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, VirtualColumn} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{DataType, StructType} @@ -92,7 +93,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy val newArgs = productIterator.map(recursiveTransform).toArray - if (changed) makeCopy(newArgs) else this + if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this } /** @@ -124,7 +125,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy val newArgs = productIterator.map(recursiveTransform).toArray - if (changed) makeCopy(newArgs) else this + if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this } /** Returns the result of running [[transformExpressions]] on this node diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index bedeaf06adf12..9b52f020093f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -22,11 +22,60 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode} abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { + private var _analyzed: Boolean = false + + /** + * Marks this plan as already analyzed. This should only be called by CheckAnalysis. + */ + private[catalyst] def setAnalyzed(): Unit = { _analyzed = true } + + /** + * Returns true if this node and its children have already been gone through analysis and + * verification. Note that this is only an optimization used to avoid analyzing trees that + * have already been analyzed, and can be reset by transformations. + */ + def analyzed: Boolean = _analyzed + + /** + * Returns a copy of this node where `rule` has been recursively applied first to all of its + * children and then itself (post-order). When `rule` does not apply to a given node, it is left + * unchanged. This function is similar to `transformUp`, but skips sub-trees that have already + * been marked as analyzed. + * + * @param rule the function use to transform this nodes children + */ + def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + if (!analyzed) { + val afterRuleOnChildren = transformChildren(rule, (t, r) => t.resolveOperators(r)) + if (this fastEquals afterRuleOnChildren) { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(this, identity[LogicalPlan]) + } + } else { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(afterRuleOnChildren, identity[LogicalPlan]) + } + } + } else { + this + } + } + + /** + * Recursively transforms the expressions of a tree, skipping nodes that have already + * been analyzed. + */ + def resolveExpressions(r: PartialFunction[Expression, Expression]): LogicalPlan = { + this resolveOperators { + case p => p.transformExpressions(r) + } + } + /** * Computes [[Statistics]] for this plan. The default implementation assumes the output * cardinality is the product of of all child plan's cardinality, i.e. applies in the case diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 3f9858b0c4a43..8b824511a79da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -21,6 +21,23 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide +import scala.collection.mutable + +object RuleExecutor { + protected val timeMap = new mutable.HashMap[String, Long].withDefault(_ => 0) + + /** Resets statistics about time spent running specific rules */ + def resetTime(): Unit = timeMap.clear() + + /** Dump statistics about time spent running specific rules. */ + def dumpTimeSpent(): String = { + val maxSize = timeMap.keys.map(_.toString.length).max + timeMap.toSeq.sortBy(_._2).reverseMap { case (k, v) => + s"${k.padTo(maxSize, " ").mkString} $v" + }.mkString("\n") + } +} + abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { /** @@ -41,6 +58,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { /** Defines a sequence of rule batches, to be overridden by the implementation. */ protected val batches: Seq[Batch] + /** * Executes the batches of rules defined by the subclass. The batches are executed serially * using the defined execution strategy. Within each batch, rules are also executed serially. @@ -58,7 +76,11 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { while (continue) { curPlan = batch.rules.foldLeft(curPlan) { case (plan, rule) => + val startTime = System.nanoTime() val result = rule(plan) + val runTime = System.nanoTime() - startTime + RuleExecutor.timeMap(rule.ruleName) = RuleExecutor.timeMap(rule.ruleName) + runTime + if (!result.fastEquals(plan)) { logTrace( s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 122e9fc5ed77f..7971e25188e8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -149,7 +149,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** * Returns a copy of this node where `f` has been applied to all the nodes children. */ - def mapChildren(f: BaseType => BaseType): this.type = { + def mapChildren(f: BaseType => BaseType): BaseType = { var changed = false val newArgs = productIterator.map { case arg: TreeNode[_] if containsChild(arg) => @@ -170,7 +170,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * Returns a copy of this node with the children replaced. * TODO: Validate somewhere (in debug mode?) that children are ordered correctly. */ - def withNewChildren(newChildren: Seq[BaseType]): this.type = { + def withNewChildren(newChildren: Seq[BaseType]): BaseType = { assert(newChildren.size == children.size, "Incorrect number of children") var changed = false val remainingNewChildren = newChildren.toBuffer @@ -229,9 +229,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { // Check if unchanged and then possibly return old copy to avoid gc churn. if (this fastEquals afterRule) { - transformChildrenDown(rule) + transformChildren(rule, (t, r) => t.transformDown(r)) } else { - afterRule.transformChildrenDown(rule) + afterRule.transformChildren(rule, (t, r) => t.transformDown(r)) } } @@ -240,11 +240,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * this node. When `rule` does not apply to a given node it is left unchanged. * @param rule the function used to transform this nodes children */ - def transformChildrenDown(rule: PartialFunction[BaseType, BaseType]): this.type = { + protected def transformChildren( + rule: PartialFunction[BaseType, BaseType], + nextOperation: (BaseType, PartialFunction[BaseType, BaseType]) => BaseType): BaseType = { var changed = false val newArgs = productIterator.map { case arg: TreeNode[_] if containsChild(arg) => - val newChild = arg.asInstanceOf[BaseType].transformDown(rule) + val newChild = nextOperation(arg.asInstanceOf[BaseType], rule) if (!(newChild fastEquals arg)) { changed = true newChild @@ -252,7 +254,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { arg } case Some(arg: TreeNode[_]) if containsChild(arg) => - val newChild = arg.asInstanceOf[BaseType].transformDown(rule) + val newChild = nextOperation(arg.asInstanceOf[BaseType], rule) if (!(newChild fastEquals arg)) { changed = true Some(newChild) @@ -263,7 +265,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if containsChild(arg) => - val newChild = arg.asInstanceOf[BaseType].transformDown(rule) + val newChild = nextOperation(arg.asInstanceOf[BaseType], rule) if (!(newChild fastEquals arg)) { changed = true newChild @@ -285,7 +287,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * @param rule the function use to transform this nodes children */ def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = { - val afterRuleOnChildren = transformChildrenUp(rule) + val afterRuleOnChildren = transformChildren(rule, (t, r) => t.transformUp(r)) if (this fastEquals afterRuleOnChildren) { CurrentOrigin.withOrigin(origin) { rule.applyOrElse(this, identity[BaseType]) @@ -297,44 +299,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } } - def transformChildrenUp(rule: PartialFunction[BaseType, BaseType]): this.type = { - var changed = false - val newArgs = productIterator.map { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = arg.asInstanceOf[BaseType].transformUp(rule) - if (!(newChild fastEquals arg)) { - changed = true - newChild - } else { - arg - } - case Some(arg: TreeNode[_]) if containsChild(arg) => - val newChild = arg.asInstanceOf[BaseType].transformUp(rule) - if (!(newChild fastEquals arg)) { - changed = true - Some(newChild) - } else { - Some(arg) - } - case m: Map[_, _] => m - case d: DataType => d // Avoid unpacking Structs - case args: Traversable[_] => args.map { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = arg.asInstanceOf[BaseType].transformUp(rule) - if (!(newChild fastEquals arg)) { - changed = true - newChild - } else { - arg - } - case other => other - } - case nonChild: AnyRef => nonChild - case null => null - }.toArray - if (changed) makeCopy(newArgs) else this - } - /** * Args to the constructor that should be copied, but not transformed. * These are appended to the transformed args automatically by makeCopy @@ -348,7 +312,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * that are not present in the productIterator. * @param newArgs the new product arguments. */ - def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") { + def makeCopy(newArgs: Array[AnyRef]): BaseType = attachTree(this, "makeCopy") { val ctors = getClass.getConstructors.filter(_.getParameterTypes.size != 0) if (ctors.isEmpty) { sys.error(s"No valid constructor for $nodeName") @@ -359,9 +323,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { CurrentOrigin.withOrigin(origin) { // Skip no-arg constructors that are just there for kryo. if (otherCopyArgs.isEmpty) { - defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type] + defaultCtor.newInstance(newArgs: _*).asInstanceOf[BaseType] } else { - defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[this.type] + defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[BaseType] } } } catch { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala new file mode 100644 index 0000000000000..797b29f23cbb9 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util._ + +/** + * Provides helper methods for comparing plans. + */ +class LogicalPlanSuite extends SparkFunSuite { + private var invocationCount = 0 + private val function: PartialFunction[LogicalPlan, LogicalPlan] = { + case p: Project => + invocationCount += 1 + p + } + + private val testRelation = LocalRelation() + + test("resolveOperator runs on operators") { + invocationCount = 0 + val plan = Project(Nil, testRelation) + plan resolveOperators function + + assert(invocationCount === 1) + } + + test("resolveOperator runs on operators recursively") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, testRelation)) + plan resolveOperators function + + assert(invocationCount === 2) + } + + test("resolveOperator skips all ready resolved plans") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, testRelation)) + plan.foreach(_.setAnalyzed()) + plan resolveOperators function + + assert(invocationCount === 0) + } + + test("resolveOperator skips partially resolved plans") { + invocationCount = 0 + val plan1 = Project(Nil, testRelation) + val plan2 = Project(Nil, plan1) + plan1.foreach(_.setAnalyzed()) + plan2 resolveOperators function + + assert(invocationCount === 1) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index db15711202b77..e57acec59d327 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.io.CharArrayWriter import java.util.Properties +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.unsafe.types.UTF8String import scala.language.implicitConversions @@ -54,7 +55,6 @@ private[sql] object DataFrame { } } - /** * :: Experimental :: * A distributed collection of data organized into named columns. @@ -690,9 +690,7 @@ class DataFrame private[sql]( case Column(explode: Explode) => MultiAlias(explode, Nil) case Column(expr: Expression) => Alias(expr, expr.prettyString)() } - // When user continuously call `select`, speed up analysis by collapsing `Project` - import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing - Project(namedExpressions.toSeq, ProjectCollapsing(logicalPlan)) + Project(namedExpressions.toSeq, logicalPlan) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 73b237fffece8..dbc0cefbe2e10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -67,7 +67,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ private val prepareCalled = new AtomicBoolean(false) /** Overridden make copy also propogates sqlContext to copied plan. */ - override def makeCopy(newArgs: Array[AnyRef]): this.type = { + override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = { SparkPlan.currentContext.set(sqlContext) super.makeCopy(newArgs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index dedc7c4dfb4d1..59f8b079ab333 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -65,7 +65,7 @@ private[spark] case class PythonUDF( * multiple child operators. */ private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // Skip EvaluatePython nodes. case plan: EvaluatePython => plan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index e9dd7ef226e42..a88df91b1001c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.functions._ import org.apache.spark.storage.{StorageLevel, RDDBlockId} case class BigData(s: String) @@ -50,6 +51,25 @@ class CachedTableSuite extends QueryTest { ctx.sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty } + test("withColumn doesn't invalidate cached dataframe") { + var evalCount = 0 + val myUDF = udf((x: String) => { evalCount += 1; "result" }) + val df = Seq(("test", 1)).toDF("s", "i").select(myUDF($"s")) + df.cache() + + df.collect() + assert(evalCount === 1) + + df.collect() + assert(evalCount === 1) + + val df2 = df.withColumn("newColumn", lit(1)) + df2.collect() + + // We should not reevaluate the cached dataframe + assert(evalCount === 1) + } + test("cache temp table") { testData.select('key).registerTempTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b8f10b00f5690..f9cc6d1f3c250 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -686,18 +686,6 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { Seq(Row(2, 1, 2), Row(1, 1, 1))) } - test("SPARK-7276: Project collapse for continuous select") { - var df = testData - for (i <- 1 to 5) { - df = df.select($"*") - } - - import org.apache.spark.sql.catalyst.plans.logical.Project - // make sure df have at most two Projects - val p = df.logicalPlan.asInstanceOf[Project].child.asInstanceOf[Project] - assert(!p.child.isInstanceOf[Project]) - } - test("SPARK-7150 range api") { // numSlice is greater than length val res1 = sqlContext.range(0, 10, 1, 15).select("id") diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index d4fc6c2b6ebc0..ab309e0a1d36b 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.util.{Locale, TimeZone} +import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SQLConf @@ -50,6 +51,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, 5) // Enable in-memory partition pruning for testing purposes TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) + RuleExecutor.resetTime() } override def afterAll() { @@ -58,6 +60,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { Locale.setDefault(originalLocale) TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + + // For debugging dump some statistics about how much time was spent in various optimizer rules. + logWarning(RuleExecutor.dumpTimeSpent()) } /** A list of tests deemed out of scope currently and thus completely disregarded. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 16c186627f6cc..6b37af99f4677 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -391,7 +391,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive */ object ParquetConversions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { - if (!plan.resolved) { + if (!plan.resolved || plan.analyzed) { return plan } @@ -418,8 +418,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive (relation, parquetRelation, attributedRewrites) // Read path - case p @ PhysicalOperation(_, _, relation: MetastoreRelation) - if hive.convertMetastoreParquet && + case relation: MetastoreRelation if hive.convertMetastoreParquet && relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 8a86a87368f29..7182246e466a4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -133,8 +133,7 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre @transient private lazy val conversionHelper = new ConversionHelper(method, arguments) - @transient - lazy val dataType = javaClassToDataType(method.getReturnType) + val dataType = javaClassToDataType(method.getReturnType) @transient lazy val returnInspector = ObjectInspectorFactory.getReflectionObjectInspector( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 7069afc9f7da2..10f2902e5eef0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -182,7 +182,7 @@ class HiveUDFSuite extends QueryTest { val errMsg = intercept[AnalysisException] { sql("SELECT testUDFToListString(s) FROM inputTable") } - assert(errMsg.getMessage === "List type in java is unsupported because " + + assert(errMsg.getMessage contains "List type in java is unsupported because " + "JVM type erasure makes spark fail to catch a component type in List<>;") sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListString") @@ -197,7 +197,7 @@ class HiveUDFSuite extends QueryTest { val errMsg = intercept[AnalysisException] { sql("SELECT testUDFToListInt(s) FROM inputTable") } - assert(errMsg.getMessage === "List type in java is unsupported because " + + assert(errMsg.getMessage contains "List type in java is unsupported because " + "JVM type erasure makes spark fail to catch a component type in List<>;") sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListInt") @@ -213,7 +213,7 @@ class HiveUDFSuite extends QueryTest { val errMsg = intercept[AnalysisException] { sql("SELECT testUDFToStringIntMap(s) FROM inputTable") } - assert(errMsg.getMessage === "Map type in java is unsupported because " + + assert(errMsg.getMessage contains "Map type in java is unsupported because " + "JVM type erasure makes spark fail to catch key and value types in Map<>;") sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToStringIntMap") @@ -229,7 +229,7 @@ class HiveUDFSuite extends QueryTest { val errMsg = intercept[AnalysisException] { sql("SELECT testUDFToIntIntMap(s) FROM inputTable") } - assert(errMsg.getMessage === "Map type in java is unsupported because " + + assert(errMsg.getMessage contains "Map type in java is unsupported because " + "JVM type erasure makes spark fail to catch key and value types in Map<>;") sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToIntIntMap") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index ff9a3694d612e..1dff07a6de8ad 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -948,6 +948,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { } test("SPARK-7595: Window will cause resolve failed with self join") { + sql("SELECT * FROM src") // Force loading of src table. + checkAnswer(sql( """ |with From 7a969a6967c4ecc0f004b73bff27a75257a94e86 Mon Sep 17 00:00:00 2001 From: linweizhong Date: Wed, 5 Aug 2015 10:16:12 -0700 Subject: [PATCH 164/340] [SPARK-9519] [YARN] Confirm stop sc successfully when application was killed Currently, when we kill application on Yarn, then will call sc.stop() at Yarn application state monitor thread, then in YarnClientSchedulerBackend.stop() will call interrupt this will cause SparkContext not stop fully as we will wait executor to exit. Author: linweizhong Closes #7846 from Sephiroth-Lin/SPARK-9519 and squashes the following commits: 1ae736d [linweizhong] Update comments 2e8e365 [linweizhong] Add comment explaining the code ad0e23b [linweizhong] Update 243d2c7 [linweizhong] Confirm stop sc successfully when application was killed --- .../cluster/YarnClientSchedulerBackend.scala | 47 +++++++++++++------ 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index d97fa2e2151bc..d225061fcd1b4 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -33,7 +33,7 @@ private[spark] class YarnClientSchedulerBackend( private var client: Client = null private var appId: ApplicationId = null - private var monitorThread: Thread = null + private var monitorThread: MonitorThread = null /** * Create a Yarn client to submit an application to the ResourceManager. @@ -131,24 +131,42 @@ private[spark] class YarnClientSchedulerBackend( } } + /** + * We create this class for SPARK-9519. Basically when we interrupt the monitor thread it's + * because the SparkContext is being shut down(sc.stop() called by user code), but if + * monitorApplication return, it means the Yarn application finished before sc.stop() was called, + * which means we should call sc.stop() here, and we don't allow the monitor to be interrupted + * before SparkContext stops successfully. + */ + private class MonitorThread extends Thread { + private var allowInterrupt = true + + override def run() { + try { + val (state, _) = client.monitorApplication(appId, logApplicationReport = false) + logError(s"Yarn application has already exited with state $state!") + allowInterrupt = false + sc.stop() + } catch { + case e: InterruptedException => logInfo("Interrupting monitor thread") + } + } + + def stopMonitor(): Unit = { + if (allowInterrupt) { + this.interrupt() + } + } + } + /** * Monitor the application state in a separate thread. * If the application has exited for any reason, stop the SparkContext. * This assumes both `client` and `appId` have already been set. */ - private def asyncMonitorApplication(): Thread = { + private def asyncMonitorApplication(): MonitorThread = { assert(client != null && appId != null, "Application has not been submitted yet!") - val t = new Thread { - override def run() { - try { - val (state, _) = client.monitorApplication(appId, logApplicationReport = false) - logError(s"Yarn application has already exited with state $state!") - sc.stop() - } catch { - case e: InterruptedException => logInfo("Interrupting monitor thread") - } - } - } + val t = new MonitorThread t.setName("Yarn application state monitor") t.setDaemon(true) t @@ -160,7 +178,7 @@ private[spark] class YarnClientSchedulerBackend( override def stop() { assert(client != null, "Attempted to stop this scheduler before starting it!") if (monitorThread != null) { - monitorThread.interrupt() + monitorThread.stopMonitor() } super.stop() client.stop() @@ -174,5 +192,4 @@ private[spark] class YarnClientSchedulerBackend( super.applicationId } } - } From 1f8c364b9c6636f06986f5f80d5a49b7a7772ac3 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 5 Aug 2015 11:03:02 -0700 Subject: [PATCH 165/340] [SPARK-9141] [SQL] [MINOR] Fix comments of PR #7920 This is a follow-up of https://github.com/apache/spark/pull/7920 to fix comments. Author: Yin Huai Closes #7964 from yhuai/SPARK-9141-follow-up and squashes the following commits: 4d0ee80 [Yin Huai] Fix comments. --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 3 +-- .../org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6de31f42dd30c..82158e61e3fb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -140,7 +140,7 @@ class Analyzer( object ResolveAliases extends Rule[LogicalPlan] { private def assignAliases(exprs: Seq[NamedExpression]) = { // The `UnresolvedAlias`s will appear only at root of a expression tree, we don't need - // to resolveOperator down the whole tree. + // to traverse the whole tree. exprs.zipWithIndex.map { case (u @ UnresolvedAlias(child), i) => child match { @@ -873,7 +873,6 @@ class Analyzer( // "Aggregate with Having clause" will be triggered. def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { - // Aggregate with Having clause. This rule works with an unresolved Aggregate because // a resolved Aggregate will not have Window Functions. case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index 797b29f23cbb9..455a3810c719e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -23,7 +23,8 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ /** - * Provides helper methods for comparing plans. + * This suite is used to test [[LogicalPlan]]'s `resolveOperators` and make sure it can correctly + * skips sub-trees that have already been marked as analyzed. */ class LogicalPlanSuite extends SparkFunSuite { private var invocationCount = 0 From e1e05873fc75781b6dd3f7fadbfb57824f83054e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 5 Aug 2015 11:38:56 -0700 Subject: [PATCH 166/340] [SPARK-9403] [SQL] Add codegen support in In and InSet This continues tarekauel's work in #7778. Author: Liang-Chi Hsieh Author: Tarek Auel Closes #7893 from viirya/codegen_in and squashes the following commits: 81ff97b [Liang-Chi Hsieh] For comments. 47761c6 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into codegen_in cf4bf41 [Liang-Chi Hsieh] For comments. f532b3c [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into codegen_in 446bbcd [Liang-Chi Hsieh] Fix bug. b3d0ab4 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into codegen_in 4610eff [Liang-Chi Hsieh] Relax the types of references and update optimizer test. 224f18e [Liang-Chi Hsieh] Beef up the test cases for In and InSet to include all primitive data types. 86dc8aa [Liang-Chi Hsieh] Only convert In to InSet when the number of items in set is more than the threshold. b7ded7e [Tarek Auel] [SPARK-9403][SQL] codeGen in / inSet --- .../sql/catalyst/expressions/predicates.scala | 63 +++++++++++++++++-- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../catalyst/expressions/PredicateSuite.scala | 37 ++++++++++- .../catalyst/optimizer/OptimizeInSuite.scala | 14 ++++- .../datasources/DataSourceStrategy.scala | 7 +++ .../spark/sql/ColumnExpressionSuite.scala | 6 ++ 6 files changed, 119 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index b69bbabee7e81..68c832d7194d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -97,32 +100,80 @@ case class Not(child: Expression) /** * Evaluates to `true` if `list` contains `value`. */ -case class In(value: Expression, list: Seq[Expression]) extends Predicate with CodegenFallback { +case class In(value: Expression, list: Seq[Expression]) extends Predicate + with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = value.dataType +: list.map(_.dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + if (list.exists(l => l.dataType != value.dataType)) { + TypeCheckResult.TypeCheckFailure( + "Arguments must be same type") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + override def children: Seq[Expression] = value +: list - override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. + override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN. override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}" override def eval(input: InternalRow): Any = { val evaluatedValue = value.eval(input) list.exists(e => e.eval(input) == evaluatedValue) } -} + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val valueGen = value.gen(ctx) + val listGen = list.map(_.gen(ctx)) + val listCode = listGen.map(x => + s""" + if (!${ev.primitive}) { + ${x.code} + if (${ctx.genEqual(value.dataType, valueGen.primitive, x.primitive)}) { + ${ev.primitive} = true; + } + } + """).mkString("\n") + s""" + ${valueGen.code} + boolean ${ev.primitive} = false; + boolean ${ev.isNull} = false; + $listCode + """ + } +} /** * Optimized version of In clause, when all filter values of In clause are * static. */ -case class InSet(child: Expression, hset: Set[Any]) - extends UnaryExpression with Predicate with CodegenFallback { +case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with Predicate { - override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. + override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN. override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}" override def eval(input: InternalRow): Any = { hset.contains(child.eval(input)) } + + def getHSet(): Set[Any] = hset + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val setName = classOf[Set[Any]].getName + val InSetName = classOf[InSet].getName + val childGen = child.gen(ctx) + ctx.references += this + val hsetTerm = ctx.freshName("hset") + ctx.addMutableState(setName, hsetTerm, + s"$hsetTerm = (($InSetName)expressions[${ctx.references.size - 1}]).getHSet();") + s""" + ${childGen.code} + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = $hsetTerm.contains(${childGen.primitive}); + """ + } } case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 29d706dcb39a7..4ab5ac2c61e3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -393,7 +393,7 @@ object ConstantFolding extends Rule[LogicalPlan] { object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) => + case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) && list.size > 10 => val hSet = list.map(e => e.eval(EmptyRow)) InSet(v, HashSet() ++ hSet) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index d7eb13c50b134..7beef71845e43 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -21,7 +21,8 @@ import scala.collection.immutable.HashSet import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType, BooleanType} +import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.types._ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -118,6 +119,23 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true) checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true) checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false) + + val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType, + LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType) + primitiveTypes.map { t => + val dataGen = RandomDataGenerator.forType(t, nullable = false).get + val inputData = Seq.fill(10) { + val value = dataGen.apply() + value match { + case d: Double if d.isNaN => 0.0d + case f: Float if f.isNaN => 0.0f + case _ => value + } + } + val input = inputData.map(Literal(_)) + checkEvaluation(In(input(0), input.slice(1, 10)), + inputData.slice(1, 10).contains(inputData(0))) + } } test("INSET") { @@ -134,6 +152,23 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(InSet(three, hS), false) checkEvaluation(InSet(three, nS), false) checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true) + + val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType, + LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType) + primitiveTypes.map { t => + val dataGen = RandomDataGenerator.forType(t, nullable = false).get + val inputData = Seq.fill(10) { + val value = dataGen.apply() + value match { + case d: Double if d.isNaN => 0.0d + case f: Float if f.isNaN => 0.0f + case _ => value + } + } + val input = inputData.map(Literal(_)) + checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet), + inputData.slice(1, 10).contains(inputData(0))) + } } private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false).map(Literal(_)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 1d433275fed2e..6f7b5b9572e22 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -43,16 +43,26 @@ class OptimizeInSuite extends PlanTest { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - test("OptimizedIn test: In clause optimized to InSet") { + test("OptimizedIn test: In clause not optimized to InSet when less than 10 items") { val originalQuery = testRelation .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2)))) .analyze + val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, originalQuery) + } + + test("OptimizedIn test: In clause optimized to InSet when more than 10 items") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), (1 to 11).map(Literal(_)))) + .analyze + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(InSet(UnresolvedAttribute("a"), HashSet[Any]() + 1 + 2)) + .where(InSet(UnresolvedAttribute("a"), (1 to 11).toSet)) .analyze comparePlans(optimized, correctAnswer) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index a43bccbe6927c..e5dc676b87841 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -366,6 +366,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { case expressions.InSet(a: Attribute, set) => Some(sources.In(a.name, set.toArray)) + // Because we only convert In to InSet in Optimizer when there are more than certain + // items. So it is possible we still get an In expression here that needs to be pushed + // down. + case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) => + val hSet = list.map(e => e.eval(EmptyRow)) + Some(sources.In(a.name, hSet.toArray)) + case expressions.IsNull(a: Attribute) => Some(sources.IsNull(a.name)) case expressions.IsNotNull(a: Attribute) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 35ca0b4c7cc21..b351380373259 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -357,6 +357,12 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x")) checkAnswer(df.filter($"b".in("z", "y")), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) + + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") + + intercept[AnalysisException] { + df2.filter($"a".in($"b")) + } } val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize( From eb5b8f4a603e0f289bdaa0a2164cde2cfe4feecb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 5 Aug 2015 12:51:12 -0700 Subject: [PATCH 167/340] Closes #7778 since it is done as #7893. From 5f0fb6466f5e3607f7fca9b2371a73b3deef3fdf Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 5 Aug 2015 14:12:22 -0700 Subject: [PATCH 168/340] [SPARK-9649] Fix flaky test MasterSuite - randomize ports ``` Error Message Failed to bind to: /127.0.0.1:7093: Service 'sparkMaster' failed after 16 retries! Stacktrace java.net.BindException: Failed to bind to: /127.0.0.1:7093: Service 'sparkMaster' failed after 16 retries! at org.jboss.netty.bootstrap.ServerBootstrap.bind(ServerBootstrap.java:272) at akka.remote.transport.netty.NettyTransport$$anonfun$listen$1.apply(NettyTransport.scala:393) at akka.remote.transport.netty.NettyTransport$$anonfun$listen$1.apply(NettyTransport.scala:389) at scala.util.Success$$anonfun$map$1.apply(Try.scala:206) at scala.util.Try$.apply(Try.scala:161) ``` Author: Andrew Or Closes #7968 from andrewor14/fix-master-flaky-test and squashes the following commits: fcc42ef [Andrew Or] Randomize port --- .../org/apache/spark/deploy/master/MasterSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 30780a0da7f8d..ae0e037d822ea 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -93,8 +93,8 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva publicAddress = "" ) - val (rpcEnv, uiPort, restPort) = - Master.startRpcEnvAndEndpoint("127.0.0.1", 7077, 8080, conf) + val (rpcEnv, _, _) = + Master.startRpcEnvAndEndpoint("127.0.0.1", 0, 0, conf) try { rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, rpcEnv.address, Master.ENDPOINT_NAME) @@ -343,8 +343,8 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva private def makeMaster(conf: SparkConf = new SparkConf): Master = { val securityMgr = new SecurityManager(conf) - val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 7077, conf, securityMgr) - val master = new Master(rpcEnv, rpcEnv.address, 8080, securityMgr, conf) + val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityMgr) + val master = new Master(rpcEnv, rpcEnv.address, 0, securityMgr, conf) master } From f9c2a2af1e883b36c5e51b87ef660a1b9ad0f586 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 5 Aug 2015 14:15:57 -0700 Subject: [PATCH 169/340] Closes #7474 since it's marked as won't fix. From dac090d1e9be7dec6c5ebdb2a81105b87e853193 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Wed, 5 Aug 2015 15:42:18 -0700 Subject: [PATCH 170/340] [SPARK-9657] Fix return type of getMaxPatternLength mengxr Author: Feynman Liang Closes #7974 from feynmanliang/SPARK-9657 and squashes the following commits: 7ca533f [Feynman Liang] Fix return type of getMaxPatternLength --- .../src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index d5f0c926c69bb..ad6715b52f337 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -82,7 +82,7 @@ class PrefixSpan private ( /** * Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider. */ - def getMaxPatternLength: Double = maxPatternLength + def getMaxPatternLength: Int = maxPatternLength /** * Sets maximal pattern length (default: `10`). From 9c878923db6634effed98c99bf24dd263bb7c6ad Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 5 Aug 2015 16:33:42 -0700 Subject: [PATCH 171/340] [SPARK-9054] [SQL] Rename RowOrdering to InterpretedOrdering; use newOrdering in SMJ This patches renames `RowOrdering` to `InterpretedOrdering` and updates SortMergeJoin to use the `SparkPlan` methods for constructing its ordering so that it may benefit from codegen. This is an updated version of #7408. Author: Josh Rosen Closes #7973 from JoshRosen/SPARK-9054 and squashes the following commits: e610655 [Josh Rosen] Add comment RE: Ascending ordering 34b8e0c [Josh Rosen] Import ordering be19a0f [Josh Rosen] [SPARK-9054] [SQL] Rename RowOrdering to InterpretedOrdering; use newOrdering in more places. --- .../sql/catalyst/expressions/arithmetic.scala | 4 +-- .../catalyst/expressions/conditionals.scala | 4 +-- .../{RowOrdering.scala => ordering.scala} | 27 ++++++++++--------- .../sql/catalyst/expressions/predicates.scala | 8 +++--- .../spark/sql/catalyst/util/TypeUtils.scala | 4 +-- .../apache/spark/sql/types/StructType.scala | 4 +-- .../expressions/CodeGenerationSuite.scala | 2 +- .../apache/spark/sql/execution/Exchange.scala | 5 +++- .../spark/sql/execution/SparkPlan.scala | 14 ++++++++-- .../spark/sql/execution/basicOperators.scala | 4 ++- .../sql/execution/joins/SortMergeJoin.scala | 9 ++++--- .../UnsafeKVExternalSorterSuite.scala | 6 ++--- 12 files changed, 55 insertions(+), 36 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{RowOrdering.scala => ordering.scala} (85%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 5808e3f66de3c..98464edf4d390 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -320,7 +320,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = left.nullable && right.nullable - private lazy val ordering = TypeUtils.getOrdering(dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def eval(input: InternalRow): Any = { val input1 = left.eval(input) @@ -374,7 +374,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = left.nullable && right.nullable - private lazy val ordering = TypeUtils.getOrdering(dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def eval(input: InternalRow): Any = { val input1 = left.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala index 961b1d8616801..d51f3d3cef588 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -319,7 +319,7 @@ case class Least(children: Seq[Expression]) extends Expression { override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) - private lazy val ordering = TypeUtils.getOrdering(dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { if (children.length <= 1) { @@ -374,7 +374,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) - private lazy val ordering = TypeUtils.getOrdering(dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { if (children.length <= 1) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/RowOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala similarity index 85% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/RowOrdering.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala index 873f5324c573e..6407c73bc97d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/RowOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types._ /** * An interpreted row ordering comparator. */ -class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { +class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = this(ordering.map(BindReferences.bindReference(_, inputSchema))) @@ -49,9 +49,9 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { case dt: AtomicType if order.direction == Descending => dt.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) case s: StructType if order.direction == Ascending => - s.ordering.asInstanceOf[Ordering[Any]].compare(left, right) + s.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right) case s: StructType if order.direction == Descending => - s.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) + s.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) case other => throw new IllegalArgumentException(s"Type $other does not support ordered operations") } @@ -65,6 +65,18 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { } } +object InterpretedOrdering { + + /** + * Creates a [[InterpretedOrdering]] for the given schema, in natural ascending order. + */ + def forSchema(dataTypes: Seq[DataType]): InterpretedOrdering = { + new InterpretedOrdering(dataTypes.zipWithIndex.map { + case (dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) + }) + } +} + object RowOrdering { /** @@ -81,13 +93,4 @@ object RowOrdering { * Returns true iff outputs from the expressions can be ordered. */ def isOrderable(exprs: Seq[Expression]): Boolean = exprs.forall(e => isOrderable(e.dataType)) - - /** - * Creates a [[RowOrdering]] for the given schema, in natural ascending order. - */ - def forSchema(dataTypes: Seq[DataType]): RowOrdering = { - new RowOrdering(dataTypes.zipWithIndex.map { - case (dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) - }) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 68c832d7194d4..fe7dffb815987 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -376,7 +376,7 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso override def symbol: String = "<" - private lazy val ordering = TypeUtils.getOrdering(left.dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2) } @@ -388,7 +388,7 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo override def symbol: String = "<=" - private lazy val ordering = TypeUtils.getOrdering(left.dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2) } @@ -400,7 +400,7 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar override def symbol: String = ">" - private lazy val ordering = TypeUtils.getOrdering(left.dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2) } @@ -412,7 +412,7 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar override def symbol: String = ">=" - private lazy val ordering = TypeUtils.getOrdering(left.dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 0b41f92c6193c..bcf4d78fb9371 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -54,10 +54,10 @@ object TypeUtils { def getNumeric(t: DataType): Numeric[Any] = t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] - def getOrdering(t: DataType): Ordering[Any] = { + def getInterpretedOrdering(t: DataType): Ordering[Any] = { t match { case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case s: StructType => s.ordering.asInstanceOf[Ordering[Any]] + case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 6928707f7bf6e..9cbc207538d4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -24,7 +24,7 @@ import org.json4s.JsonDSL._ import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, AttributeReference, Attribute, InterpretedOrdering$} /** @@ -301,7 +301,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru StructType(newFields) } - private[sql] val ordering = RowOrdering.forSchema(this.fields.map(_.dataType)) + private[sql] val interpretedOrdering = InterpretedOrdering.forSchema(this.fields.map(_.dataType)) } object StructType extends AbstractDataType { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index cc82f7c3f5a73..e310aee221666 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -54,7 +54,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { // GenerateOrdering agrees with RowOrdering. (DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType => test(s"GenerateOrdering with $dataType") { - val rowOrdering = RowOrdering.forSchema(Seq(dataType, dataType)) + val rowOrdering = InterpretedOrdering.forSchema(Seq(dataType, dataType)) val genOrdering = GenerateOrdering.generate( BoundReference(0, dataType, nullable = true).asc :: BoundReference(1, dataType, nullable = true).asc :: Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 05b009d1935bb..6ea5eeedf1bbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -156,7 +156,10 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una val mutablePair = new MutablePair[InternalRow, Null]() iter.map(row => mutablePair.update(row.copy(), null)) } - implicit val ordering = new RowOrdering(sortingExpressions, child.output) + // We need to use an interpreted ordering here because generated orderings cannot be + // serialized and this ordering needs to be created on the driver in order to be passed into + // Spark core code. + implicit val ordering = new InterpretedOrdering(sortingExpressions, child.output) new RangePartitioner(numPartitions, rddForSampling, ascending = true) case SinglePartition => new Partitioner { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index dbc0cefbe2e10..2f29067f5646a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.types.DataType object SparkPlan { protected[sql] val currentContext = new ThreadLocal[SQLContext]() @@ -309,13 +310,22 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ throw e } else { log.error("Failed to generate ordering, fallback to interpreted", e) - new RowOrdering(order, inputSchema) + new InterpretedOrdering(order, inputSchema) } } } else { - new RowOrdering(order, inputSchema) + new InterpretedOrdering(order, inputSchema) } } + /** + * Creates a row ordering for the given schema, in natural ascending order. + */ + protected def newNaturalAscendingOrdering(dataTypes: Seq[DataType]): Ordering[InternalRow] = { + val order: Seq[SortOrder] = dataTypes.zipWithIndex.map { + case (dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) + } + newOrdering(order, Seq.empty) + } } private[sql] trait LeafNode extends SparkPlan { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 477170297c2ac..f4677b4ee86bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -212,7 +212,9 @@ case class TakeOrderedAndProject( override def outputPartitioning: Partitioning = SinglePartition - private val ord: RowOrdering = new RowOrdering(sortOrder, child.output) + // We need to use an interpreted ordering here because generated orderings cannot be serialized + // and this ordering needs to be created on the driver in order to be passed into Spark core code. + private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output) // TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable. @transient private val projection = projectList.map(new InterpretedProjection(_, child.output)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index eb595490fbf28..4ae23c186cf7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -48,9 +48,6 @@ case class SortMergeJoin( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - // this is to manually construct an ordering that can be used to compare keys from both sides - private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType)) - override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys) override def requiredChildOrdering: Seq[Seq[SortOrder]] = @@ -59,8 +56,10 @@ case class SortMergeJoin( @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) - private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { + // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. keys.map(SortOrder(_, Ascending)) + } protected override def doExecute(): RDD[InternalRow] = { val leftResults = left.execute().map(_.copy()) @@ -68,6 +67,8 @@ case class SortMergeJoin( leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => new Iterator[InternalRow] { + // An ordering that can be used to compare keys from both sides. + private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) // Mutable per row objects. private[this] val joinRow = new JoinedRow private[this] var leftElement: InternalRow = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 08156f0e39ce8..a9515a03acf2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -22,7 +22,7 @@ import scala.util.Random import org.apache.spark._ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, RowOrdering, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} @@ -144,8 +144,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite { } sorter.cleanupResources() - val keyOrdering = RowOrdering.forSchema(keySchema.map(_.dataType)) - val valueOrdering = RowOrdering.forSchema(valueSchema.map(_.dataType)) + val keyOrdering = InterpretedOrdering.forSchema(keySchema.map(_.dataType)) + val valueOrdering = InterpretedOrdering.forSchema(valueSchema.map(_.dataType)) val kvOrdering = new Ordering[(InternalRow, InternalRow)] { override def compare(x: (InternalRow, InternalRow), y: (InternalRow, InternalRow)): Int = { keyOrdering.compare(x._1, y._1) match { From a018b85716fd510ae95a3c66d676bbdb90f8d4e7 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 5 Aug 2015 17:07:55 -0700 Subject: [PATCH 172/340] [SPARK-5895] [ML] Add VectorSlicer - updated Add VectorSlicer transformer to spark.ml, with features specified as either indices or names. Transfers feature attributes for selected features. Updated version of [https://github.com/apache/spark/pull/5731] CC: yinxusen This updates your PR. You'll still be the primary author of this PR. CC: mengxr Author: Xusen Yin Author: Joseph K. Bradley Closes #7972 from jkbradley/yinxusen-SPARK-5895 and squashes the following commits: b16e86e [Joseph K. Bradley] fixed scala style 71c65d2 [Joseph K. Bradley] fix import order 86e9739 [Joseph K. Bradley] cleanups per code review 9d8d6f1 [Joseph K. Bradley] style fix 83bc2e9 [Joseph K. Bradley] Updated VectorSlicer 98c6939 [Xusen Yin] fix style error ecbf2d3 [Xusen Yin] change interfaces and params f6be302 [Xusen Yin] Merge branch 'master' into SPARK-5895 e4781f2 [Xusen Yin] fix commit error fd154d7 [Xusen Yin] add test suite of vector slicer 17171f8 [Xusen Yin] fix slicer 9ab9747 [Xusen Yin] add vector slicer aa5a0bf [Xusen Yin] add vector slicer --- .../spark/ml/feature/VectorSlicer.scala | 170 ++++++++++++++++++ .../apache/spark/ml/util/MetadataUtils.scala | 17 ++ .../apache/spark/mllib/linalg/Vectors.scala | 24 +++ .../spark/ml/feature/VectorSlicerSuite.scala | 109 +++++++++++ .../spark/mllib/linalg/VectorsSuite.scala | 7 + 5 files changed, 327 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala new file mode 100644 index 0000000000000..772bebeff214b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam} +import org.apache.spark.ml.util.{Identifiable, MetadataUtils, SchemaUtils} +import org.apache.spark.mllib.linalg._ +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.StructType + +/** + * :: Experimental :: + * This class takes a feature vector and outputs a new feature vector with a subarray of the + * original features. + * + * The subset of features can be specified with either indices ([[setIndices()]]) + * or names ([[setNames()]]). At least one feature must be selected. Duplicate features + * are not allowed, so there can be no overlap between selected indices and names. + * + * The output vector will order features with the selected indices first (in the order given), + * followed by the selected names (in the order given). + */ +@Experimental +final class VectorSlicer(override val uid: String) + extends Transformer with HasInputCol with HasOutputCol { + + def this() = this(Identifiable.randomUID("vectorSlicer")) + + /** + * An array of indices to select features from a vector column. + * There can be no overlap with [[names]]. + * @group param + */ + val indices = new IntArrayParam(this, "indices", + "An array of indices to select features from a vector column." + + " There can be no overlap with names.", VectorSlicer.validIndices) + + setDefault(indices -> Array.empty[Int]) + + /** @group getParam */ + def getIndices: Array[Int] = $(indices) + + /** @group setParam */ + def setIndices(value: Array[Int]): this.type = set(indices, value) + + /** + * An array of feature names to select features from a vector column. + * These names must be specified by ML [[org.apache.spark.ml.attribute.Attribute]]s. + * There can be no overlap with [[indices]]. + * @group param + */ + val names = new StringArrayParam(this, "names", + "An array of feature names to select features from a vector column." + + " There can be no overlap with indices.", VectorSlicer.validNames) + + setDefault(names -> Array.empty[String]) + + /** @group getParam */ + def getNames: Array[String] = $(names) + + /** @group setParam */ + def setNames(value: Array[String]): this.type = set(names, value) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def validateParams(): Unit = { + require($(indices).length > 0 || $(names).length > 0, + s"VectorSlicer requires that at least one feature be selected.") + } + + override def transform(dataset: DataFrame): DataFrame = { + // Validity checks + transformSchema(dataset.schema) + val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol))) + inputAttr.numAttributes.foreach { numFeatures => + val maxIndex = $(indices).max + require(maxIndex < numFeatures, + s"Selected feature index $maxIndex invalid for only $numFeatures input features.") + } + + // Prepare output attributes + val inds = getSelectedFeatureIndices(dataset.schema) + val selectedAttrs: Option[Array[Attribute]] = inputAttr.attributes.map { attrs => + inds.map(index => attrs(index)) + } + val outputAttr = selectedAttrs match { + case Some(attrs) => new AttributeGroup($(outputCol), attrs) + case None => new AttributeGroup($(outputCol), inds.length) + } + + // Select features + val slicer = udf { vec: Vector => + vec match { + case features: DenseVector => Vectors.dense(inds.map(features.apply)) + case features: SparseVector => features.slice(inds) + } + } + dataset.withColumn($(outputCol), + slicer(dataset($(inputCol))).as($(outputCol), outputAttr.toMetadata())) + } + + /** Get the feature indices in order: indices, names */ + private def getSelectedFeatureIndices(schema: StructType): Array[Int] = { + val nameFeatures = MetadataUtils.getFeatureIndicesFromNames(schema($(inputCol)), $(names)) + val indFeatures = $(indices) + val numDistinctFeatures = (nameFeatures ++ indFeatures).distinct.length + lazy val errMsg = "VectorSlicer requires indices and names to be disjoint" + + s" sets of features, but they overlap." + + s" indices: ${indFeatures.mkString("[", ",", "]")}." + + s" names: " + + nameFeatures.zip($(names)).map { case (i, n) => s"$i:$n" }.mkString("[", ",", "]") + require(nameFeatures.length + indFeatures.length == numDistinctFeatures, errMsg) + indFeatures ++ nameFeatures + } + + override def transformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) + + if (schema.fieldNames.contains($(outputCol))) { + throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.") + } + val numFeaturesSelected = $(indices).length + $(names).length + val outputAttr = new AttributeGroup($(outputCol), numFeaturesSelected) + val outputFields = schema.fields :+ outputAttr.toStructField() + StructType(outputFields) + } + + override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra) +} + +private[feature] object VectorSlicer { + + /** Return true if given feature indices are valid */ + def validIndices(indices: Array[Int]): Boolean = { + if (indices.isEmpty) { + true + } else { + indices.length == indices.distinct.length && indices.forall(_ >= 0) + } + } + + /** Return true if given feature names are valid */ + def validNames(names: Array[String]): Boolean = { + names.forall(_.nonEmpty) && names.length == names.distinct.length + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala index 2a1db90f2ca2b..fcb517b5f735e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.util import scala.collection.immutable.HashMap import org.apache.spark.ml.attribute._ +import org.apache.spark.mllib.linalg.VectorUDT import org.apache.spark.sql.types.StructField @@ -74,4 +75,20 @@ private[spark] object MetadataUtils { } } + /** + * Takes a Vector column and a list of feature names, and returns the corresponding list of + * feature indices in the column, in order. + * @param col Vector column which must have feature names specified via attributes + * @param names List of feature names + */ + def getFeatureIndicesFromNames(col: StructField, names: Array[String]): Array[Int] = { + require(col.dataType.isInstanceOf[VectorUDT], s"getFeatureIndicesFromNames expected column $col" + + s" to be Vector type, but it was type ${col.dataType} instead.") + val inputAttr = AttributeGroup.fromStructField(col) + names.map { name => + require(inputAttr.hasAttr(name), + s"getFeatureIndicesFromNames found no feature with name $name in column $col.") + inputAttr.getAttr(name).index.get + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 96d1f48ba2ba3..86c461fa91633 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -766,6 +766,30 @@ class SparseVector( maxIdx } } + + /** + * Create a slice of this vector based on the given indices. + * @param selectedIndices Unsorted list of indices into the vector. + * This does NOT do bound checking. + * @return New SparseVector with values in the order specified by the given indices. + * + * NOTE: The API needs to be discussed before making this public. + * Also, if we have a version assuming indices are sorted, we should optimize it. + */ + private[spark] def slice(selectedIndices: Array[Int]): SparseVector = { + var currentIdx = 0 + val (sliceInds, sliceVals) = selectedIndices.flatMap { origIdx => + val iIdx = java.util.Arrays.binarySearch(this.indices, origIdx) + val i_v = if (iIdx >= 0) { + Iterator((currentIdx, this.values(iIdx))) + } else { + Iterator() + } + currentIdx += 1 + i_v + }.unzip + new SparseVector(selectedIndices.length, sliceInds.toArray, sliceVals.toArray) + } } object SparseVector { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala new file mode 100644 index 0000000000000..a6c2fba8360dd --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + val slicer = new VectorSlicer + ParamsSuite.checkParams(slicer) + assert(slicer.getIndices.length === 0) + assert(slicer.getNames.length === 0) + withClue("VectorSlicer should not have any features selected by default") { + intercept[IllegalArgumentException] { + slicer.validateParams() + } + } + } + + test("feature validity checks") { + import VectorSlicer._ + assert(validIndices(Array(0, 1, 8, 2))) + assert(validIndices(Array.empty[Int])) + assert(!validIndices(Array(-1))) + assert(!validIndices(Array(1, 2, 1))) + + assert(validNames(Array("a", "b"))) + assert(validNames(Array.empty[String])) + assert(!validNames(Array("", "b"))) + assert(!validNames(Array("a", "b", "a"))) + } + + test("Test vector slicer") { + val sqlContext = new SQLContext(sc) + + val data = Array( + Vectors.sparse(5, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(-2.0, 2.3, 0.0, 0.0, 1.0), + Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0), + Vectors.dense(0.6, -1.1, -3.0, 4.5, 3.3), + Vectors.sparse(5, Seq()) + ) + + // Expected after selecting indices 1, 4 + val expected = Array( + Vectors.sparse(2, Seq((0, 2.3))), + Vectors.dense(2.3, 1.0), + Vectors.dense(0.0, 0.0), + Vectors.dense(-1.1, 3.3), + Vectors.sparse(2, Seq()) + ) + + val defaultAttr = NumericAttribute.defaultAttr + val attrs = Array("f0", "f1", "f2", "f3", "f4").map(defaultAttr.withName) + val attrGroup = new AttributeGroup("features", attrs.asInstanceOf[Array[Attribute]]) + + val resultAttrs = Array("f1", "f4").map(defaultAttr.withName) + val resultAttrGroup = new AttributeGroup("expected", resultAttrs.asInstanceOf[Array[Attribute]]) + + val rdd = sc.parallelize(data.zip(expected)).map { case (a, b) => Row(a, b) } + val df = sqlContext.createDataFrame(rdd, + StructType(Array(attrGroup.toStructField(), resultAttrGroup.toStructField()))) + + val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result") + + def validateResults(df: DataFrame): Unit = { + df.select("result", "expected").collect().foreach { case Row(vec1: Vector, vec2: Vector) => + assert(vec1 === vec2) + } + val resultMetadata = AttributeGroup.fromStructField(df.schema("result")) + val expectedMetadata = AttributeGroup.fromStructField(df.schema("expected")) + assert(resultMetadata.numAttributes === expectedMetadata.numAttributes) + resultMetadata.attributes.get.zip(expectedMetadata.attributes.get).foreach { case (a, b) => + assert(a === b) + } + } + + vectorSlicer.setIndices(Array(1, 4)).setNames(Array.empty) + validateResults(vectorSlicer.transform(df)) + + vectorSlicer.setIndices(Array(1)).setNames(Array("f4")) + validateResults(vectorSlicer.transform(df)) + + vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4")) + validateResults(vectorSlicer.transform(df)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 1c37ea5123e82..6508ddeba4206 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -367,4 +367,11 @@ class VectorsSuite extends SparkFunSuite with Logging { val sv1c = sv1.compressed.asInstanceOf[DenseVector] assert(sv1 === sv1c) } + + test("SparseVector.slice") { + val v = new SparseVector(5, Array(1, 2, 4), Array(1.1, 2.2, 4.4)) + assert(v.slice(Array(0, 2)) === new SparseVector(2, Array(1), Array(2.2))) + assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2))) + assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), Array(2.2, 4.4))) + } } From 8c320e45b5c9ffd7f0e35c1c7e6b5fc355377ea6 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 5 Aug 2015 17:28:23 -0700 Subject: [PATCH 173/340] [SPARK-6591] [SQL] Python data source load options should auto convert common types into strings JIRA: https://issues.apache.org/jira/browse/SPARK-6591 Author: Yijie Shen Closes #7926 from yjshen/py_dsload_opt and squashes the following commits: b207832 [Yijie Shen] fix style efdf834 [Yijie Shen] resolve comment 7a8f6a2 [Yijie Shen] lowercase 822e769 [Yijie Shen] convert load opts to string --- python/pyspark/sql/readwriter.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index dea8bad79e187..bf6ac084bbbf8 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -24,6 +24,16 @@ __all__ = ["DataFrameReader", "DataFrameWriter"] +def to_str(value): + """ + A wrapper over str(), but convert bool values to lower case string + """ + if isinstance(value, bool): + return str(value).lower() + else: + return str(value) + + class DataFrameReader(object): """ Interface used to load a :class:`DataFrame` from external storage systems @@ -77,7 +87,7 @@ def schema(self, schema): def option(self, key, value): """Adds an input option for the underlying data source. """ - self._jreader = self._jreader.option(key, value) + self._jreader = self._jreader.option(key, to_str(value)) return self @since(1.4) @@ -85,7 +95,7 @@ def options(self, **options): """Adds input options for the underlying data source. """ for k in options: - self._jreader = self._jreader.option(k, options[k]) + self._jreader = self._jreader.option(k, to_str(options[k])) return self @since(1.4) @@ -97,7 +107,8 @@ def load(self, path=None, format=None, schema=None, **options): :param schema: optional :class:`StructType` for the input schema. :param options: all other string options - >>> df = sqlContext.read.load('python/test_support/sql/parquet_partitioned') + >>> df = sqlContext.read.load('python/test_support/sql/parquet_partitioned', opt1=True, + ... opt2=1, opt3='str') >>> df.dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] """ From 4399b7b0903d830313ab7e69731c11d587ae567c Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 5 Aug 2015 17:58:36 -0700 Subject: [PATCH 174/340] [SPARK-9651] Fix UnsafeExternalSorterSuite. First, it's probably a bad idea to call generated Scala methods from Java. In this case, the method being called wasn't actually "Utils.createTempDir()", but actually the method that returns the first default argument to the actual createTempDir method, which is just the location of java.io.tmpdir; meaning that all tests in the class were using the same temp dir, and thus affecting each other. Second, spillingOccursInResponseToMemoryPressure was not writing enough records to actually cause a spill. Author: Marcelo Vanzin Closes #7970 from vanzin/SPARK-9651 and squashes the following commits: 74d357f [Marcelo Vanzin] Clean up temp dir on test tear down. a64f36a [Marcelo Vanzin] [SPARK-9651] Fix UnsafeExternalSorterSuite. --- .../sort/UnsafeExternalSorterSuite.java | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 968185bde78ab..117745f9a9c00 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -101,7 +101,7 @@ public OutputStream apply(OutputStream stream) { public void setUp() { MockitoAnnotations.initMocks(this); sparkConf = new SparkConf(); - tempDir = new File(Utils.createTempDir$default$1()); + tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test"); shuffleMemoryManager = new ShuffleMemoryManager(Long.MAX_VALUE); spillFilesCreated.clear(); taskContext = mock(TaskContext.class); @@ -143,13 +143,18 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th @After public void tearDown() { - long leakedUnsafeMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); - if (shuffleMemoryManager != null) { - long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask(); - shuffleMemoryManager = null; - assertEquals(0L, leakedShuffleMemory); + try { + long leakedUnsafeMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); + if (shuffleMemoryManager != null) { + long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask(); + shuffleMemoryManager = null; + assertEquals(0L, leakedShuffleMemory); + } + assertEquals(0, leakedUnsafeMemory); + } finally { + Utils.deleteRecursively(tempDir); + tempDir = null; } - assertEquals(0, leakedUnsafeMemory); } private void assertSpillFilesWereCleanedUp() { @@ -234,7 +239,7 @@ public void testSortingEmptyArrays() throws Exception { public void spillingOccursInResponseToMemoryPressure() throws Exception { shuffleMemoryManager = new ShuffleMemoryManager(pageSizeBytes * 2); final UnsafeExternalSorter sorter = newSorter(); - final int numRecords = 100000; + final int numRecords = (int) pageSizeBytes / 4; for (int i = 0; i <= numRecords; i++) { insertNumber(sorter, numRecords - i); } From 4581badbc8aa7e5a37ba7f7f83cc3860240f5dd3 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 5 Aug 2015 19:19:09 -0700 Subject: [PATCH 175/340] [SPARK-9611] [SQL] Fixes a few corner cases when we spill a UnsafeFixedWidthAggregationMap This PR has the following three small fixes. 1. UnsafeKVExternalSorter does not use 0 as the initialSize to create an UnsafeInMemorySorter if its BytesToBytesMap is empty. 2. We will not not spill a InMemorySorter if it is empty. 3. We will not add a SpillReader to a SpillMerger if this SpillReader is empty. JIRA: https://issues.apache.org/jira/browse/SPARK-9611 Author: Yin Huai Closes #7948 from yhuai/unsafeEmptyMap and squashes the following commits: 9727abe [Yin Huai] Address Josh's comments. 34b6f76 [Yin Huai] 1. UnsafeKVExternalSorter does not use 0 as the initialSize to create an UnsafeInMemorySorter if its BytesToBytesMap is empty. 2. Do not spill a InMemorySorter if it is empty. 3. Do not add spill to SpillMerger if this spill is empty. --- .../unsafe/sort/UnsafeExternalSorter.java | 36 +++--- .../unsafe/sort/UnsafeSorterSpillMerger.java | 12 +- .../sql/execution/UnsafeKVExternalSorter.java | 6 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 108 +++++++++++++++++- 4 files changed, 141 insertions(+), 21 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index e6ddd08e5fa99..8f78fc5a41629 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -191,24 +191,29 @@ public void spill() throws IOException { spillWriters.size(), spillWriters.size() > 1 ? " times" : " time"); - final UnsafeSorterSpillWriter spillWriter = - new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, - inMemSorter.numRecords()); - spillWriters.add(spillWriter); - final UnsafeSorterIterator sortedRecords = inMemSorter.getSortedIterator(); - while (sortedRecords.hasNext()) { - sortedRecords.loadNext(); - final Object baseObject = sortedRecords.getBaseObject(); - final long baseOffset = sortedRecords.getBaseOffset(); - final int recordLength = sortedRecords.getRecordLength(); - spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix()); + // We only write out contents of the inMemSorter if it is not empty. + if (inMemSorter.numRecords() > 0) { + final UnsafeSorterSpillWriter spillWriter = + new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, + inMemSorter.numRecords()); + spillWriters.add(spillWriter); + final UnsafeSorterIterator sortedRecords = inMemSorter.getSortedIterator(); + while (sortedRecords.hasNext()) { + sortedRecords.loadNext(); + final Object baseObject = sortedRecords.getBaseObject(); + final long baseOffset = sortedRecords.getBaseOffset(); + final int recordLength = sortedRecords.getRecordLength(); + spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix()); + } + spillWriter.close(); } - spillWriter.close(); + final long spillSize = freeMemory(); // Note that this is more-or-less going to be a multiple of the page size, so wasted space in // pages will currently be counted as memory spilled even though that space isn't actually // written to disk. This also counts the space needed to store the sorter's pointer array. taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + initializeForWriting(); } @@ -505,12 +510,11 @@ public UnsafeSorterIterator getSortedIterator() throws IOException { final UnsafeSorterSpillMerger spillMerger = new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge); for (UnsafeSorterSpillWriter spillWriter : spillWriters) { - spillMerger.addSpill(spillWriter.getReader(blockManager)); + spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager)); } spillWriters.clear(); - if (inMemoryIterator.hasNext()) { - spillMerger.addSpill(inMemoryIterator); - } + spillMerger.addSpillIfNotEmpty(inMemoryIterator); + return spillMerger.getSortedIterator(); } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java index 8272c2a5be0d1..3874a9f9cbdb6 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -47,11 +47,19 @@ public int compare(UnsafeSorterIterator left, UnsafeSorterIterator right) { priorityQueue = new PriorityQueue(numSpills, comparator); } - public void addSpill(UnsafeSorterIterator spillReader) throws IOException { + /** + * Add an UnsafeSorterIterator to this merger + */ + public void addSpillIfNotEmpty(UnsafeSorterIterator spillReader) throws IOException { if (spillReader.hasNext()) { + // We only add the spillReader to the priorityQueue if it is not empty. We do this to + // make sure the hasNext method of UnsafeSorterIterator returned by getSortedIterator + // does not return wrong result because hasNext will returns true + // at least priorityQueue.size() times. If we allow n spillReaders in the + // priorityQueue, we will have n extra empty records in the result of the UnsafeSorterIterator. spillReader.loadNext(); + priorityQueue.add(spillReader); } - priorityQueue.add(spillReader); } public UnsafeSorterIterator getSortedIterator() throws IOException { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 86a563df992d0..6c1cf136d9b81 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -82,8 +82,11 @@ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, pageSizeBytes); } else { // Insert the records into the in-memory sorter. + // We will use the number of elements in the map as the initialSize of the + // UnsafeInMemorySorter. Because UnsafeInMemorySorter does not accept 0 as the initialSize, + // we will use 1 as its initial size if the map is empty. final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( - taskMemoryManager, recordComparator, prefixComparator, map.numElements()); + taskMemoryManager, recordComparator, prefixComparator, Math.max(1, map.numElements())); final int numKeyFields = keySchema.size(); BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator(); @@ -214,7 +217,6 @@ public boolean next() throws IOException { // Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself) int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset); int valueLen = recordLen - keyLen - 4; - key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen); value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index ef827b0fe9b5b..b513c970ccfe2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -23,7 +23,7 @@ import scala.util.{Try, Random} import org.scalatest.Matchers -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.test.TestSQLContext @@ -231,4 +231,110 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { map.free() } + + testWithMemoryLeakDetection("test external sorting with an empty map") { + // Calling this make sure we have block manager and everything else setup. + TestSQLContext + + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + taskMemoryManager, + shuffleMemoryManager, + 128, // initial capacity + PAGE_SIZE_BYTES, + false // disable perf metrics + ) + + // Convert the map into a sorter + val sorter = map.destructAndCreateExternalSorter() + + // Add more keys to the sorter and make sure the results come out sorted. + val additionalKeys = randomStrings(1024) + val keyConverter = UnsafeProjection.create(groupKeySchema) + val valueConverter = UnsafeProjection.create(aggBufferSchema) + + additionalKeys.zipWithIndex.foreach { case (str, i) => + val k = InternalRow(UTF8String.fromString(str)) + val v = InternalRow(str.length) + sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) + + if ((i % 100) == 0) { + shuffleMemoryManager.markAsOutOfMemory() + sorter.closeCurrentPage() + } + } + + val out = new scala.collection.mutable.ArrayBuffer[String] + val iter = sorter.sortedIterator() + while (iter.next()) { + // At here, we also test if copy is correct. + val key = iter.getKey.copy() + val value = iter.getValue.copy() + assert(key.getString(0).length === value.getInt(0)) + out += key.getString(0) + } + + assert(out === (additionalKeys).sorted) + + map.free() + } + + testWithMemoryLeakDetection("test external sorting with empty records") { + // Calling this make sure we have block manager and everything else setup. + TestSQLContext + + // Memory consumption in the beginning of the task. + val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask() + + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + StructType(Nil), + StructType(Nil), + taskMemoryManager, + shuffleMemoryManager, + 128, // initial capacity + PAGE_SIZE_BYTES, + false // disable perf metrics + ) + + (1 to 10).foreach { i => + val buf = map.getAggregationBuffer(UnsafeRow.createFromByteArray(0, 0)) + assert(buf != null) + } + + // Convert the map into a sorter. Right now, it contains one record. + val sorter = map.destructAndCreateExternalSorter() + + withClue(s"destructAndCreateExternalSorter should release memory used by the map") { + // 4096 * 16 is the initial size allocated for the pointer/prefix array in the in-mem sorter. + assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() === + initialMemoryConsumption + 4096 * 16) + } + + // Add more keys to the sorter and make sure the results come out sorted. + (1 to 4096).foreach { i => + sorter.insertKV(UnsafeRow.createFromByteArray(0, 0), UnsafeRow.createFromByteArray(0, 0)) + + if ((i % 100) == 0) { + shuffleMemoryManager.markAsOutOfMemory() + sorter.closeCurrentPage() + } + } + + var count = 0 + val iter = sorter.sortedIterator() + while (iter.next()) { + // At here, we also test if copy is correct. + iter.getKey.copy() + iter.getValue.copy() + count += 1; + } + + // 1 record was from the map and 4096 records were explicitly inserted. + assert(count === 4097) + + map.free() + } } From 119b59053870df7be899bf5c1c0d321406af96f9 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 6 Aug 2015 11:13:44 +0800 Subject: [PATCH 176/340] [SPARK-6923] [SPARK-7550] [SQL] Persists data source relations in Hive compatible format when possible This PR is a fork of PR #5733 authored by chenghao-intel. For committers who's going to merge this PR, please set the author to "Cheng Hao ". ---- When a data source relation meets the following requirements, we persist it in Hive compatible format, so that other systems like Hive can access it: 1. It's a `HadoopFsRelation` 2. It has only one input path 3. It's non-partitioned 4. It's data source provider can be naturally mapped to a Hive builtin SerDe (e.g. ORC and Parquet) Author: Cheng Lian Author: Cheng Hao Closes #7967 from liancheng/spark-6923/refactoring-pr-5733 and squashes the following commits: 5175ee6 [Cheng Lian] Fixes an oudated comment 3870166 [Cheng Lian] Fixes build error and comments 864acee [Cheng Lian] Refactors PR #5733 3490cdc [Cheng Hao] update the scaladoc 6f57669 [Cheng Hao] write schema info to hivemetastore for data source --- .../org/apache/spark/sql/DataFrame.scala | 53 +++++-- .../apache/spark/sql/DataFrameWriter.scala | 7 + .../spark/sql/hive/HiveMetastoreCatalog.scala | 146 +++++++++++++++++- .../org/apache/spark/sql/hive/HiveQl.scala | 49 ++---- .../spark/sql/hive/orc/OrcRelation.scala | 6 +- .../sql/hive/HiveMetastoreCatalogSuite.scala | 133 ++++++++++++++-- 6 files changed, 324 insertions(+), 70 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index e57acec59d327..405b5a4a9a7f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -20,9 +20,6 @@ package org.apache.spark.sql import java.io.CharArrayWriter import java.util.Properties -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.unsafe.types.UTF8String - import scala.language.implicitConversions import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -42,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} -import org.apache.spark.sql.json.{JacksonGenerator, JSONRelation} +import org.apache.spark.sql.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -1650,8 +1647,12 @@ class DataFrame private[sql]( * an RDD out to a parquet file, and then register that file as a table. This "table" can then * be the target of an `insertInto`. * - * Also note that while this function can persist the table metadata into Hive's metastore, - * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. + * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input + * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC + * and Parquet), the table is persisted in a Hive compatible format, which means other systems + * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL + * specific format. + * * @group output * @deprecated As of 1.4.0, replaced by `write().saveAsTable(tableName)`. */ @@ -1669,8 +1670,12 @@ class DataFrame private[sql]( * an RDD out to a parquet file, and then register that file as a table. This "table" can then * be the target of an `insertInto`. * - * Also note that while this function can persist the table metadata into Hive's metastore, - * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. + * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input + * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC + * and Parquet), the table is persisted in a Hive compatible format, which means other systems + * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL + * specific format. + * * @group output * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. */ @@ -1689,8 +1694,12 @@ class DataFrame private[sql]( * an RDD out to a parquet file, and then register that file as a table. This "table" can then * be the target of an `insertInto`. * - * Also note that while this function can persist the table metadata into Hive's metastore, - * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. + * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input + * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC + * and Parquet), the table is persisted in a Hive compatible format, which means other systems + * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL + * specific format. + * * @group output * @deprecated As of 1.4.0, replaced by `write().format(source).saveAsTable(tableName)`. */ @@ -1709,8 +1718,12 @@ class DataFrame private[sql]( * an RDD out to a parquet file, and then register that file as a table. This "table" can then * be the target of an `insertInto`. * - * Also note that while this function can persist the table metadata into Hive's metastore, - * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. + * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input + * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC + * and Parquet), the table is persisted in a Hive compatible format, which means other systems + * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL + * specific format. + * * @group output * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. */ @@ -1728,8 +1741,12 @@ class DataFrame private[sql]( * an RDD out to a parquet file, and then register that file as a table. This "table" can then * be the target of an `insertInto`. * - * Also note that while this function can persist the table metadata into Hive's metastore, - * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. + * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input + * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC + * and Parquet), the table is persisted in a Hive compatible format, which means other systems + * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL + * specific format. + * * @group output * @deprecated As of 1.4.0, replaced by * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. @@ -1754,8 +1771,12 @@ class DataFrame private[sql]( * an RDD out to a parquet file, and then register that file as a table. This "table" can then * be the target of an `insertInto`. * - * Also note that while this function can persist the table metadata into Hive's metastore, - * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. + * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input + * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC + * and Parquet), the table is persisted in a Hive compatible format, which means other systems + * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL + * specific format. + * * @group output * @deprecated As of 1.4.0, replaced by * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 7e3318cefe62c..2a4992db09bc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource} import org.apache.spark.sql.jdbc.{JDBCWriteDetails, JdbcUtils} +import org.apache.spark.sql.sources.HadoopFsRelation /** @@ -185,6 +186,12 @@ final class DataFrameWriter private[sql](df: DataFrame) { * When `mode` is `Append`, the schema of the [[DataFrame]] need to be * the same as that of the existing table, and format or options will be ignored. * + * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input + * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC + * and Parquet), the table is persisted in a Hive compatible format, which means other systems + * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL + * specific format. + * * @since 1.4.0 */ def saveAsTable(tableName: String): Unit = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 6b37af99f4677..1523ebe9d5493 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -18,11 +18,13 @@ package org.apache.spark.sql.hive import scala.collection.JavaConversions._ +import scala.collection.mutable import com.google.common.base.Objects import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.common.StatsSetupConst +import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.Warehouse import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.ql.metadata._ @@ -40,9 +42,59 @@ import org.apache.spark.sql.execution.datasources import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} +private[hive] case class HiveSerDe( + inputFormat: Option[String] = None, + outputFormat: Option[String] = None, + serde: Option[String] = None) + +private[hive] object HiveSerDe { + /** + * Get the Hive SerDe information from the data source abbreviation string or classname. + * + * @param source Currently the source abbreviation can be one of the following: + * SequenceFile, RCFile, ORC, PARQUET, and case insensitive. + * @param hiveConf Hive Conf + * @return HiveSerDe associated with the specified source + */ + def sourceToSerDe(source: String, hiveConf: HiveConf): Option[HiveSerDe] = { + val serdeMap = Map( + "sequencefile" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), + outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")), + + "rcfile" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat"), + serde = Option(hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTRCFILESERDE))), + + "orc" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat"), + serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")), + + "parquet" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"), + serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"))) + + val key = source.toLowerCase match { + case _ if source.startsWith("org.apache.spark.sql.parquet") => "parquet" + case _ if source.startsWith("org.apache.spark.sql.orc") => "orc" + case _ => source.toLowerCase + } + + serdeMap.get(key) + } +} + private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: HiveContext) extends Catalog with Logging { @@ -164,15 +216,15 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive processDatabaseAndTableName(database, tableIdent.table) } - val tableProperties = new scala.collection.mutable.HashMap[String, String] + val tableProperties = new mutable.HashMap[String, String] tableProperties.put("spark.sql.sources.provider", provider) // Saves optional user specified schema. Serialized JSON schema string may be too long to be // stored into a single metastore SerDe property. In this case, we split the JSON string and // store each part as a separate SerDe property. - if (userSpecifiedSchema.isDefined) { + userSpecifiedSchema.foreach { schema => val threshold = conf.schemaStringLengthThreshold - val schemaJsonString = userSpecifiedSchema.get.json + val schemaJsonString = schema.json // Split the JSON string. val parts = schemaJsonString.grouped(threshold).toSeq tableProperties.put("spark.sql.sources.schema.numParts", parts.size.toString) @@ -194,7 +246,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // The table does not have a specified schema, which means that the schema will be inferred // when we load the table. So, we are not expecting partition columns and we will discover // partitions when we load the table. However, if there are specified partition columns, - // we simplily ignore them and provide a warning message.. + // we simply ignore them and provide a warning message. logWarning( s"The schema and partitions of table $tableIdent will be inferred when it is loaded. " + s"Specified partition columns (${partitionColumns.mkString(",")}) will be ignored.") @@ -210,7 +262,11 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive ManagedTable } - client.createTable( + val maybeSerDe = HiveSerDe.sourceToSerDe(provider, hive.hiveconf) + val dataSource = ResolvedDataSource( + hive, userSpecifiedSchema, partitionColumns, provider, options) + + def newSparkSQLSpecificMetastoreTable(): HiveTable = { HiveTable( specifiedDatabase = Option(dbName), name = tblName, @@ -218,7 +274,83 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive partitionColumns = metastorePartitionColumns, tableType = tableType, properties = tableProperties.toMap, - serdeProperties = options)) + serdeProperties = options) + } + + def newHiveCompatibleMetastoreTable(relation: HadoopFsRelation, serde: HiveSerDe): HiveTable = { + def schemaToHiveColumn(schema: StructType): Seq[HiveColumn] = { + schema.map { field => + HiveColumn( + name = field.name, + hiveType = HiveMetastoreTypes.toMetastoreType(field.dataType), + comment = "") + } + } + + val partitionColumns = schemaToHiveColumn(relation.partitionColumns) + val dataColumns = schemaToHiveColumn(relation.schema).filterNot(partitionColumns.contains) + + HiveTable( + specifiedDatabase = Option(dbName), + name = tblName, + schema = dataColumns, + partitionColumns = partitionColumns, + tableType = tableType, + properties = tableProperties.toMap, + serdeProperties = options, + location = Some(relation.paths.head), + viewText = None, // TODO We need to place the SQL string here. + inputFormat = serde.inputFormat, + outputFormat = serde.outputFormat, + serde = serde.serde) + } + + // TODO: Support persisting partitioned data source relations in Hive compatible format + val hiveTable = (maybeSerDe, dataSource.relation) match { + case (Some(serde), relation: HadoopFsRelation) + if relation.paths.length == 1 && relation.partitionColumns.isEmpty => + logInfo { + "Persisting data source relation with a single input path into Hive metastore in Hive " + + s"compatible format. Input path: ${relation.paths.head}" + } + newHiveCompatibleMetastoreTable(relation, serde) + + case (Some(serde), relation: HadoopFsRelation) if relation.partitionColumns.nonEmpty => + logWarning { + val paths = relation.paths.mkString(", ") + "Persisting partitioned data source relation into Hive metastore in " + + s"Spark SQL specific format, which is NOT compatible with Hive. Input path(s): " + + paths.mkString("\n", "\n", "") + } + newSparkSQLSpecificMetastoreTable() + + case (Some(serde), relation: HadoopFsRelation) => + logWarning { + val paths = relation.paths.mkString(", ") + "Persisting data source relation with multiple input paths into Hive metastore in " + + s"Spark SQL specific format, which is NOT compatible with Hive. Input paths: " + + paths.mkString("\n", "\n", "") + } + newSparkSQLSpecificMetastoreTable() + + case (Some(serde), _) => + logWarning { + s"Data source relation is not a ${classOf[HadoopFsRelation].getSimpleName}. " + + "Persisting it into Hive metastore in Spark SQL specific format, " + + "which is NOT compatible with Hive." + } + newSparkSQLSpecificMetastoreTable() + + case _ => + logWarning { + s"Couldn't find corresponding Hive SerDe for data source provider $provider. " + + "Persisting data source relation into Hive metastore in Spark SQL specific format, " + + "which is NOT compatible with Hive." + } + newSparkSQLSpecificMetastoreTable() + } + + client.createTable(hiveTable) } def hiveDefaultTableFilePath(tableName: String): String = { @@ -463,7 +595,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive case p: LogicalPlan if !p.childrenResolved => p case p: LogicalPlan if p.resolved => p case p @ CreateTableAsSelect(table, child, allowExisting) => - val schema = if (table.schema.size > 0) { + val schema = if (table.schema.nonEmpty) { table.schema } else { child.output.map { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index f43e403ce9a9d..7d7b4b9167306 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -32,6 +32,7 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -261,8 +262,8 @@ private[hive] object HiveQl extends Logging { /** * Returns the HiveConf */ - private[this] def hiveConf(): HiveConf = { - val ss = SessionState.get() // SessionState is lazy initializaion, it can be null here + private[this] def hiveConf: HiveConf = { + val ss = SessionState.get() // SessionState is lazy initialization, it can be null here if (ss == null) { new HiveConf() } else { @@ -604,38 +605,18 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C serde = None, viewText = None) - // default storage type abbriviation (e.g. RCFile, ORC, PARQUET etc.) + // default storage type abbreviation (e.g. RCFile, ORC, PARQUET etc.) val defaultStorageType = hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTFILEFORMAT) - // handle the default format for the storage type abbriviation - tableDesc = if ("SequenceFile".equalsIgnoreCase(defaultStorageType)) { - tableDesc.copy( - inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), - outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")) - } else if ("RCFile".equalsIgnoreCase(defaultStorageType)) { - tableDesc.copy( - inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat"), - serde = Option(hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTRCFILESERDE))) - } else if ("ORC".equalsIgnoreCase(defaultStorageType)) { - tableDesc.copy( - inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat"), - serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) - } else if ("PARQUET".equalsIgnoreCase(defaultStorageType)) { - tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"), - serde = - Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) - } else { - tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - } + // handle the default format for the storage type abbreviation + val hiveSerDe = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf).getOrElse { + HiveSerDe( + inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) + } + + hiveSerDe.inputFormat.foreach(f => tableDesc = tableDesc.copy(inputFormat = Some(f))) + hiveSerDe.outputFormat.foreach(f => tableDesc = tableDesc.copy(outputFormat = Some(f))) + hiveSerDe.serde.foreach(f => tableDesc = tableDesc.copy(serde = Some(f))) children.collect { case list @ Token("TOK_TABCOLLIST", _) => @@ -908,7 +889,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps) - case Nil => (Nil, Option(hiveConf().getVar(ConfVars.HIVESCRIPTSERDE)), Nil) + case Nil => (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), Nil) } val (inRowFormat, inSerdeClass, inSerdeProps) = matchSerDe(inputSerdeClause) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 6fa599734892b..4a310ff4e9016 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -291,9 +291,11 @@ private[orc] case class OrcTableScan( // Sets requested columns addColumnIds(attributes, relation, conf) - if (inputPaths.nonEmpty) { - FileInputFormat.setInputPaths(job, inputPaths.map(_.getPath): _*) + if (inputPaths.isEmpty) { + // the input path probably be pruned, return an empty RDD. + return sqlContext.sparkContext.emptyRDD[InternalRow] } + FileInputFormat.setInputPaths(job, inputPaths.map(_.getPath): _*) val inputFormatClass = classOf[OrcInputFormat] diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 983c013bcf86a..332c3ec0c28b8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -17,31 +17,142 @@ package org.apache.spark.sql.hive -import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.sql.hive.test.TestHive +import java.io.File -import org.apache.spark.sql.test.ExamplePointUDT +import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, ManagedTable} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.sources.DataSourceTest +import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.{Logging, SparkFunSuite} + class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging { test("struct field should accept underscore in sub-column name") { - val metastr = "struct" - - val datatype = HiveMetastoreTypes.toDataType(metastr) - assert(datatype.isInstanceOf[StructType]) + val hiveTypeStr = "struct" + val dateType = HiveMetastoreTypes.toDataType(hiveTypeStr) + assert(dateType.isInstanceOf[StructType]) } test("udt to metastore type conversion") { val udt = new ExamplePointUDT - assert(HiveMetastoreTypes.toMetastoreType(udt) === - HiveMetastoreTypes.toMetastoreType(udt.sqlType)) + assertResult(HiveMetastoreTypes.toMetastoreType(udt.sqlType)) { + HiveMetastoreTypes.toMetastoreType(udt) + } } test("duplicated metastore relations") { - import TestHive.implicits._ - val df = TestHive.sql("SELECT * FROM src") + val df = sql("SELECT * FROM src") logInfo(df.queryExecution.toString) df.as('a).join(df.as('b), $"a.key" === $"b.key") } } + +class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTestUtils { + override val sqlContext = TestHive + + private val testDF = (1 to 2).map(i => (i, s"val_$i")).toDF("d1", "d2").coalesce(1) + + Seq( + "parquet" -> ( + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + ), + + "orc" -> ( + "org.apache.hadoop.hive.ql.io.orc.OrcInputFormat", + "org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat", + "org.apache.hadoop.hive.ql.io.orc.OrcSerde" + ) + ).foreach { case (provider, (inputFormat, outputFormat, serde)) => + test(s"Persist non-partitioned $provider relation into metastore as managed table") { + withTable("t") { + testDF + .write + .mode(SaveMode.Overwrite) + .format(provider) + .saveAsTable("t") + + val hiveTable = catalog.client.getTable("default", "t") + assert(hiveTable.inputFormat === Some(inputFormat)) + assert(hiveTable.outputFormat === Some(outputFormat)) + assert(hiveTable.serde === Some(serde)) + + assert(!hiveTable.isPartitioned) + assert(hiveTable.tableType === ManagedTable) + + val columns = hiveTable.schema + assert(columns.map(_.name) === Seq("d1", "d2")) + assert(columns.map(_.hiveType) === Seq("int", "string")) + + checkAnswer(table("t"), testDF) + assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1", "2\tval_2")) + } + } + + test(s"Persist non-partitioned $provider relation into metastore as external table") { + withTempPath { dir => + withTable("t") { + val path = dir.getCanonicalFile + + testDF + .write + .mode(SaveMode.Overwrite) + .format(provider) + .option("path", path.toString) + .saveAsTable("t") + + val hiveTable = catalog.client.getTable("default", "t") + assert(hiveTable.inputFormat === Some(inputFormat)) + assert(hiveTable.outputFormat === Some(outputFormat)) + assert(hiveTable.serde === Some(serde)) + + assert(hiveTable.tableType === ExternalTable) + assert(hiveTable.location.get === path.toURI.toString.stripSuffix(File.separator)) + + val columns = hiveTable.schema + assert(columns.map(_.name) === Seq("d1", "d2")) + assert(columns.map(_.hiveType) === Seq("int", "string")) + + checkAnswer(table("t"), testDF) + assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1", "2\tval_2")) + } + } + } + + test(s"Persist non-partitioned $provider relation into metastore as managed table using CTAS") { + withTempPath { dir => + withTable("t") { + val path = dir.getCanonicalPath + + sql( + s"""CREATE TABLE t USING $provider + |OPTIONS (path '$path') + |AS SELECT 1 AS d1, "val_1" AS d2 + """.stripMargin) + + val hiveTable = catalog.client.getTable("default", "t") + assert(hiveTable.inputFormat === Some(inputFormat)) + assert(hiveTable.outputFormat === Some(outputFormat)) + assert(hiveTable.serde === Some(serde)) + + assert(hiveTable.isPartitioned === false) + assert(hiveTable.tableType === ExternalTable) + assert(hiveTable.partitionColumns.length === 0) + + val columns = hiveTable.schema + assert(columns.map(_.name) === Seq("d1", "d2")) + assert(columns.map(_.hiveType) === Seq("int", "string")) + + checkAnswer(table("t"), Row(1, "val_1")) + assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) + } + } + } + } +} From 9270bd06fd0b16892e3f37213b5bc7813ea11fdd Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 5 Aug 2015 21:50:14 -0700 Subject: [PATCH 177/340] [SPARK-9674][SQL] Remove GeneratedAggregate. The new aggregate replaces the old GeneratedAggregate. Author: Reynold Xin Closes #7983 from rxin/remove-generated-agg and squashes the following commits: 8334aae [Reynold Xin] [SPARK-9674][SQL] Remove GeneratedAggregate. --- .../sql/execution/GeneratedAggregate.scala | 352 ------------------ .../spark/sql/execution/SparkStrategies.scala | 34 -- .../org/apache/spark/sql/SQLQuerySuite.scala | 5 +- .../spark/sql/execution/AggregateSuite.scala | 48 --- 4 files changed, 2 insertions(+), 437 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala deleted file mode 100644 index bf4905dc1eef9..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ /dev/null @@ -1,352 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.io.IOException - -import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext} -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.trees._ -import org.apache.spark.sql.types._ - -case class AggregateEvaluation( - schema: Seq[Attribute], - initialValues: Seq[Expression], - update: Seq[Expression], - result: Expression) - -/** - * :: DeveloperApi :: - * Alternate version of aggregation that leverages projection and thus code generation. - * Aggregations are converted into a set of projections from a aggregation buffer tuple back onto - * itself. Currently only used for simple aggregations like SUM, COUNT, or AVERAGE are supported. - * - * @param partial if true then aggregation is done partially on local data without shuffling to - * ensure all values where `groupingExpressions` are equal are present. - * @param groupingExpressions expressions that are evaluated to determine grouping. - * @param aggregateExpressions expressions that are computed for each group. - * @param unsafeEnabled whether to allow Unsafe-based aggregation buffers to be used. - * @param child the input data source. - */ -@DeveloperApi -case class GeneratedAggregate( - partial: Boolean, - groupingExpressions: Seq[Expression], - aggregateExpressions: Seq[NamedExpression], - unsafeEnabled: Boolean, - child: SparkPlan) - extends UnaryNode { - - override def requiredChildDistribution: Seq[Distribution] = - if (partial) { - UnspecifiedDistribution :: Nil - } else { - if (groupingExpressions == Nil) { - AllTuples :: Nil - } else { - ClusteredDistribution(groupingExpressions) :: Nil - } - } - - override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) - - protected override def doExecute(): RDD[InternalRow] = { - val aggregatesToCompute = aggregateExpressions.flatMap { a => - a.collect { case agg: AggregateExpression1 => agg} - } - - // If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite - // (in test "aggregation with codegen"). - val computeFunctions = aggregatesToCompute.map { - case c @ Count(expr) => - // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its - // UnscaledValue will be null if and only if x is null; helps with Average on decimals - val toCount = expr match { - case UnscaledValue(e) => e - case _ => expr - } - val currentCount = AttributeReference("currentCount", LongType, nullable = false)() - val initialValue = Literal(0L) - val updateFunction = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount) - val result = currentCount - - AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - - case s @ Sum(expr) => - val calcType = - expr.dataType match { - case DecimalType.Fixed(p, s) => - DecimalType.bounded(p + 10, s) - case _ => - expr.dataType - } - - val currentSum = AttributeReference("currentSum", calcType, nullable = true)() - val initialValue = Literal.create(null, calcType) - - // Coalesce avoids double calculation... - // but really, common sub expression elimination would be better.... - val zero = Cast(Literal(0), calcType) - val updateFunction = Coalesce( - Add( - Coalesce(currentSum :: zero :: Nil), - Cast(expr, calcType) - ) :: currentSum :: Nil) - val result = - expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(currentSum, s.dataType) - case _ => currentSum - } - - AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - - case m @ Max(expr) => - val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() - val initialValue = Literal.create(null, expr.dataType) - val updateMax = MaxOf(currentMax, expr) - - AggregateEvaluation( - currentMax :: Nil, - initialValue :: Nil, - updateMax :: Nil, - currentMax) - - case m @ Min(expr) => - val currentMin = AttributeReference("currentMin", expr.dataType, nullable = true)() - val initialValue = Literal.create(null, expr.dataType) - val updateMin = MinOf(currentMin, expr) - - AggregateEvaluation( - currentMin :: Nil, - initialValue :: Nil, - updateMin :: Nil, - currentMin) - - case CollectHashSet(Seq(expr)) => - val set = - AttributeReference("hashSet", new OpenHashSetUDT(expr.dataType), nullable = false)() - val initialValue = NewSet(expr.dataType) - val addToSet = AddItemToSet(expr, set) - - AggregateEvaluation( - set :: Nil, - initialValue :: Nil, - addToSet :: Nil, - set) - - case CombineSetsAndCount(inputSet) => - val elementType = inputSet.dataType.asInstanceOf[OpenHashSetUDT].elementType - val set = - AttributeReference("hashSet", new OpenHashSetUDT(elementType), nullable = false)() - val initialValue = NewSet(elementType) - val collectSets = CombineSets(set, inputSet) - - AggregateEvaluation( - set :: Nil, - initialValue :: Nil, - collectSets :: Nil, - CountSet(set)) - - case o => sys.error(s"$o can't be codegened.") - } - - val computationSchema = computeFunctions.flatMap(_.schema) - - val resultMap: Map[TreeNodeRef, Expression] = - aggregatesToCompute.zip(computeFunctions).map { - case (agg, func) => new TreeNodeRef(agg) -> func.result - }.toMap - - val namedGroups = groupingExpressions.zipWithIndex.map { - case (ne: NamedExpression, _) => (ne, ne.toAttribute) - case (e, i) => (e, Alias(e, s"GroupingExpr$i")().toAttribute) - } - - // The set of expressions that produce the final output given the aggregation buffer and the - // grouping expressions. - val resultExpressions = aggregateExpressions.map(_.transform { - case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e)) - case e: Expression => - namedGroups.collectFirst { - case (expr, attr) if expr semanticEquals e => attr - }.getOrElse(e) - }) - - val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema) - - val groupKeySchema: StructType = { - val fields = groupingExpressions.zipWithIndex.map { case (expr, idx) => - // This is a dummy field name - StructField(idx.toString, expr.dataType, expr.nullable) - } - StructType(fields) - } - - val schemaSupportsUnsafe: Boolean = { - UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && - UnsafeProjection.canSupport(groupKeySchema) - } - - child.execute().mapPartitions { iter => - // Builds a new custom class for holding the results of aggregation for a group. - val initialValues = computeFunctions.flatMap(_.initialValues) - val newAggregationBuffer = newProjection(initialValues, child.output) - log.info(s"Initial values: ${initialValues.mkString(",")}") - - // A projection that computes the group given an input tuple. - val groupProjection = newProjection(groupingExpressions, child.output) - log.info(s"Grouping Projection: ${groupingExpressions.mkString(",")}") - - // A projection that is used to update the aggregate values for a group given a new tuple. - // This projection should be targeted at the current values for the group and then applied - // to a joined row of the current values with the new input row. - val updateExpressions = computeFunctions.flatMap(_.update) - val updateSchema = computeFunctions.flatMap(_.schema) ++ child.output - val updateProjection = newMutableProjection(updateExpressions, updateSchema)() - log.info(s"Update Expressions: ${updateExpressions.mkString(",")}") - - // A projection that produces the final result, given a computation. - val resultProjectionBuilder = - newMutableProjection( - resultExpressions, - namedGroups.map(_._2) ++ computationSchema) - log.info(s"Result Projection: ${resultExpressions.mkString(",")}") - - val joinedRow = new JoinedRow - - if (!iter.hasNext) { - // This is an empty input, so return early so that we do not allocate data structures - // that won't be cleaned up (see SPARK-8357). - if (groupingExpressions.isEmpty) { - // This is a global aggregate, so return an empty aggregation buffer. - val resultProjection = resultProjectionBuilder() - Iterator(resultProjection(newAggregationBuffer(EmptyRow))) - } else { - // This is a grouped aggregate, so return an empty iterator. - Iterator[InternalRow]() - } - } else if (groupingExpressions.isEmpty) { - // TODO: Codegening anything other than the updateProjection is probably over kill. - val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] - var currentRow: InternalRow = null - updateProjection.target(buffer) - - while (iter.hasNext) { - currentRow = iter.next() - updateProjection(joinedRow(buffer, currentRow)) - } - - val resultProjection = resultProjectionBuilder() - Iterator(resultProjection(buffer)) - - } else if (unsafeEnabled && schemaSupportsUnsafe) { - assert(iter.hasNext, "There should be at least one row for this path") - log.info("Using Unsafe-based aggregator") - val pageSizeBytes = SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m") - val taskContext = TaskContext.get() - val aggregationMap = new UnsafeFixedWidthAggregationMap( - newAggregationBuffer(EmptyRow), - aggregationBufferSchema, - groupKeySchema, - taskContext.taskMemoryManager(), - SparkEnv.get.shuffleMemoryManager, - 1024 * 16, // initial capacity - pageSizeBytes, - false // disable tracking of performance metrics - ) - - while (iter.hasNext) { - val currentRow: InternalRow = iter.next() - val groupKey: InternalRow = groupProjection(currentRow) - val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey) - if (aggregationBuffer == null) { - throw new IOException("Could not allocate memory to grow aggregation buffer") - } - updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow)) - } - - // Record memory used in the process - taskContext.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(aggregationMap.getMemoryUsage) - - new Iterator[InternalRow] { - private[this] val mapIterator = aggregationMap.iterator() - private[this] val resultProjection = resultProjectionBuilder() - private[this] var _hasNext = mapIterator.next() - - def hasNext: Boolean = _hasNext - - def next(): InternalRow = { - if (_hasNext) { - val result = resultProjection(joinedRow(mapIterator.getKey, mapIterator.getValue)) - _hasNext = mapIterator.next() - if (_hasNext) { - result - } else { - // This is the last element in the iterator, so let's free the buffer. Before we do, - // though, we need to make a defensive copy of the result so that we don't return an - // object that might contain dangling pointers to the freed memory. - val resultCopy = result.copy() - aggregationMap.free() - resultCopy - } - } else { - throw new java.util.NoSuchElementException - } - } - } - } else { - if (unsafeEnabled) { - log.info("Not using Unsafe-based aggregator because it is not supported for this schema") - } - val buffers = new java.util.HashMap[InternalRow, MutableRow]() - - var currentRow: InternalRow = null - while (iter.hasNext) { - currentRow = iter.next() - val currentGroup = groupProjection(currentRow) - var currentBuffer = buffers.get(currentGroup) - if (currentBuffer == null) { - currentBuffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] - buffers.put(currentGroup, currentBuffer) - } - // Target the projection at the current aggregation buffer and then project the updated - // values. - updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow)) - } - - new Iterator[InternalRow] { - private[this] val resultIterator = buffers.entrySet.iterator() - private[this] val resultProjection = resultProjectionBuilder() - - def hasNext: Boolean = resultIterator.hasNext - - def next(): InternalRow = { - val currentGroup = resultIterator.next() - resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue)) - } - } - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 952ba7d45c13e..a730ffbb217c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -136,32 +136,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object HashAggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // Aggregations that can be performed in two phases, before and after the shuffle. - - // Cases where all aggregates can be codegened. - case PartialAggregation( - namedGroupingAttributes, - rewrittenAggregateExpressions, - groupingExpressions, - partialComputation, - child) - if canBeCodeGened( - allAggregates(partialComputation) ++ - allAggregates(rewrittenAggregateExpressions)) && - codegenEnabled && - !canBeConvertedToNewAggregation(plan) => - execution.GeneratedAggregate( - partial = false, - namedGroupingAttributes, - rewrittenAggregateExpressions, - unsafeEnabled, - execution.GeneratedAggregate( - partial = true, - groupingExpressions, - partialComputation, - unsafeEnabled, - planLater(child))) :: Nil - - // Cases where some aggregate can not be codegened case PartialAggregation( namedGroupingAttributes, rewrittenAggregateExpressions, @@ -192,14 +166,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => false } - def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = aggs.forall { - case _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => true - // The generated set implementation is pretty limited ATM. - case CollectHashSet(exprs) if exprs.size == 1 && - Seq(IntegerType, LongType).contains(exprs.head.dataType) => true - case _ => false - } - def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] = exprs.flatMap(_.collect { case a: AggregateExpression1 => a }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 29dfcf2575227..cef40dd324d9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.test.SQLTestUtils @@ -263,7 +262,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { val df = sql(sqlText) // First, check if we have GeneratedAggregate. val hasGeneratedAgg = df.queryExecution.executedPlan - .collect { case _: GeneratedAggregate | _: aggregate.Aggregate => true } + .collect { case _: aggregate.Aggregate => true } .nonEmpty if (!hasGeneratedAgg) { fail( @@ -1603,7 +1602,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123)))) } - test("aggregation with codegen updates peak execution memory") { + ignore("aggregation with codegen updates peak execution memory") { withSQLConf( (SQLConf.CODEGEN_ENABLED.key, "true"), (SQLConf.USE_SQL_AGGREGATE2.key, "false")) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala deleted file mode 100644 index 20def6bef0c17..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.test.TestSQLContext - -class AggregateSuite extends SparkPlanTest { - - test("SPARK-8357 unsafe aggregation path should not leak memory with empty input") { - val codegenDefault = TestSQLContext.getConf(SQLConf.CODEGEN_ENABLED) - val unsafeDefault = TestSQLContext.getConf(SQLConf.UNSAFE_ENABLED) - try { - TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, true) - TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, true) - val df = Seq.empty[(Int, Int)].toDF("a", "b") - checkAnswer( - df, - GeneratedAggregate( - partial = true, - Seq(df.col("b").expr), - Seq(Alias(Count(df.col("a").expr), "cnt")()), - unsafeEnabled = true, - _: SparkPlan), - Seq.empty - ) - } finally { - TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) - TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, unsafeDefault) - } - } -} From d5a9af3230925c347d0904fe7f2402e468e80bc8 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 5 Aug 2015 21:50:35 -0700 Subject: [PATCH 178/340] [SPARK-9664] [SQL] Remove UDAFRegistration and add apply to UserDefinedAggregateFunction. https://issues.apache.org/jira/browse/SPARK-9664 Author: Yin Huai Closes #7982 from yhuai/udafRegister and squashes the following commits: 0cc2287 [Yin Huai] Remove UDAFRegistration and add apply to UserDefinedAggregateFunction. --- .../org/apache/spark/sql/SQLContext.scala | 3 -- .../apache/spark/sql/UDAFRegistration.scala | 36 ------------------- .../apache/spark/sql/UDFRegistration.scala | 16 +++++++++ .../spark/sql/execution/aggregate/udaf.scala | 8 ++--- .../apache/spark/sql/expressions/udaf.scala | 32 ++++++++++++++++- .../org/apache/spark/sql/functions.scala | 1 + .../spark/sql/hive/JavaDataFrameSuite.java | 26 ++++++++++++++ .../execution/AggregationQuerySuite.scala | 4 +-- 8 files changed, 80 insertions(+), 46 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index ffc2baf7a8826..6f8ffb54402a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -291,9 +291,6 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient val udf: UDFRegistration = new UDFRegistration(this) - @transient - val udaf: UDAFRegistration = new UDAFRegistration(this) - /** * Returns true if the table is currently cached in-memory. * @group cachemgmt diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala deleted file mode 100644 index 0d4e30f29255e..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.{Expression} -import org.apache.spark.sql.execution.aggregate.ScalaUDAF -import org.apache.spark.sql.expressions.UserDefinedAggregateFunction - -class UDAFRegistration private[sql] (sqlContext: SQLContext) extends Logging { - - private val functionRegistry = sqlContext.functionRegistry - - def register( - name: String, - func: UserDefinedAggregateFunction): UserDefinedAggregateFunction = { - def builder(children: Seq[Expression]) = ScalaUDAF(children, func) - functionRegistry.registerFunction(name, builder) - func - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 7cd7421a518c9..1f270560d7bc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -26,6 +26,8 @@ import org.apache.spark.Logging import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} +import org.apache.spark.sql.execution.aggregate.ScalaUDAF +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction import org.apache.spark.sql.types.DataType /** @@ -52,6 +54,20 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { functionRegistry.registerFunction(name, udf.builder) } + /** + * Register a user-defined aggregate function (UDAF). + * @param name the name of the UDAF. + * @param udaf the UDAF needs to be registered. + * @return the registered UDAF. + */ + def register( + name: String, + udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = { + def builder(children: Seq[Expression]) = ScalaUDAF(children, udaf) + functionRegistry.registerFunction(name, builder) + udaf + } + // scalastyle:off /* register 0-22 were generated by this script diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 5fafc916bfa0b..7619f3ec9f0a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -316,7 +316,7 @@ private[sql] case class ScalaUDAF( override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance()) - private[this] val childrenSchema: StructType = { + private[this] lazy val childrenSchema: StructType = { val inputFields = children.zipWithIndex.map { case (child, index) => StructField(s"input$index", child.dataType, child.nullable, Metadata.empty) @@ -337,16 +337,16 @@ private[sql] case class ScalaUDAF( } } - private[this] val inputToScalaConverters: Any => Any = + private[this] lazy val inputToScalaConverters: Any => Any = CatalystTypeConverters.createToScalaConverter(childrenSchema) - private[this] val bufferValuesToCatalystConverters: Array[Any => Any] = { + private[this] lazy val bufferValuesToCatalystConverters: Array[Any => Any] = { bufferSchema.fields.map { field => CatalystTypeConverters.createToCatalystConverter(field.dataType) } } - private[this] val bufferValuesToScalaConverters: Array[Any => Any] = { + private[this] lazy val bufferValuesToScalaConverters: Array[Any => Any] = { bufferSchema.fields.map { field => CatalystTypeConverters.createToScalaConverter(field.dataType) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index 278dd438fab4a..5180871585f25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.expressions -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.ScalaUDF +import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2} +import org.apache.spark.sql.execution.aggregate.ScalaUDAF +import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.types._ import org.apache.spark.annotation.Experimental @@ -87,6 +90,33 @@ abstract class UserDefinedAggregateFunction extends Serializable { * aggregation buffer. */ def evaluate(buffer: Row): Any + + /** + * Creates a [[Column]] for this UDAF with given [[Column]]s as arguments. + */ + @scala.annotation.varargs + def apply(exprs: Column*): Column = { + val aggregateExpression = + AggregateExpression2( + ScalaUDAF(exprs.map(_.expr), this), + Complete, + isDistinct = false) + Column(aggregateExpression) + } + + /** + * Creates a [[Column]] for this UDAF with given [[Column]]s as arguments. + * If `isDistinct` is true, this UDAF is working on distinct input values. + */ + @scala.annotation.varargs + def apply(isDistinct: Boolean, exprs: Column*): Column = { + val aggregateExpression = + AggregateExpression2( + ScalaUDAF(exprs.map(_.expr), this), + Complete, + isDistinct = isDistinct) + Column(aggregateExpression) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5a10c3891ad6c..39aa905c8532a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2500,6 +2500,7 @@ object functions { * @group udf_funcs * @since 1.5.0 */ + @scala.annotation.varargs def callUDF(udfName: String, cols: Column*): Column = { UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false) } diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java index 613b2bcc80e37..21b053f07a3ba 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -29,8 +29,12 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; import org.apache.spark.sql.expressions.Window; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; +import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.hive.HiveContext; import org.apache.spark.sql.hive.test.TestHive$; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; +import test.org.apache.spark.sql.hive.aggregate.MyDoubleSum; public class JavaDataFrameSuite { private transient JavaSparkContext sc; @@ -77,4 +81,26 @@ public void saveTableAndQueryIt() { " ROWS BETWEEN 1 preceding and 1 following) " + "FROM window_table").collectAsList()); } + + @Test + public void testUDAF() { + DataFrame df = hc.range(0, 100).unionAll(hc.range(0, 100)).select(col("id").as("value")); + UserDefinedAggregateFunction udaf = new MyDoubleSum(); + UserDefinedAggregateFunction registeredUDAF = hc.udf().register("mydoublesum", udaf); + // Create Columns for the UDAF. For now, callUDF does not take an argument to specific if + // we want to use distinct aggregation. + DataFrame aggregatedDF = + df.groupBy() + .agg( + udaf.apply(true, col("value")), + udaf.apply(col("value")), + registeredUDAF.apply(col("value")), + callUDF("mydoublesum", col("value"))); + + List expectedResult = new ArrayList(); + expectedResult.add(RowFactory.create(4950.0, 9900.0, 9900.0, 9900.0)); + checkAnswer( + aggregatedDF, + expectedResult); + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 6f0db27775e4d..4b35c8fd83533 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -73,8 +73,8 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be emptyDF.registerTempTable("emptyTable") // Register UDAFs - sqlContext.udaf.register("mydoublesum", new MyDoubleSum) - sqlContext.udaf.register("mydoubleavg", new MyDoubleAvg) + sqlContext.udf.register("mydoublesum", new MyDoubleSum) + sqlContext.udf.register("mydoubleavg", new MyDoubleAvg) } override def afterAll(): Unit = { From aead18ffca36830e854fba32a1cac11a0b2e31d5 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Thu, 6 Aug 2015 09:02:30 -0700 Subject: [PATCH 179/340] [SPARK-8266] [SQL] add function translate ![translate](http://www.w3resource.com/PostgreSQL/postgresql-translate-function.png) Author: zhichao.li Closes #7709 from zhichao-li/translate and squashes the following commits: 9418088 [zhichao.li] refine checking condition f2ab77a [zhichao.li] clone string 9d88f2d [zhichao.li] fix indent 6aa2962 [zhichao.li] style e575ead [zhichao.li] add python api 9d4bab0 [zhichao.li] add special case for fodable and refactor unittest eda7ad6 [zhichao.li] update to use TernaryExpression cdfd4be [zhichao.li] add function translate --- python/pyspark/sql/functions.py | 16 ++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/expressions/Expression.scala | 4 +- .../expressions/stringOperations.scala | 79 ++++++++++++++++++- .../expressions/StringExpressionsSuite.scala | 14 ++++ .../org/apache/spark/sql/functions.scala | 21 +++-- .../spark/sql/StringFunctionsSuite.scala | 6 ++ .../apache/spark/unsafe/types/UTF8String.java | 16 ++++ .../spark/unsafe/types/UTF8StringSuite.java | 31 ++++++++ 9 files changed, 180 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9f0d71d7960cf..b5c6a01f18858 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1290,6 +1290,22 @@ def length(col): return Column(sc._jvm.functions.length(_to_java_column(col))) +@ignore_unicode_prefix +@since(1.5) +def translate(srcCol, matching, replace): + """A function translate any character in the `srcCol` by a character in `matching`. + The characters in `replace` is corresponding to the characters in `matching`. + The translate will happen when any character in the string matching with the character + in the `matching`. + + >>> sqlContext.createDataFrame([('translate',)], ['a']).select(translate('a', "rnlt", "123")\ + .alias('r')).collect() + [Row(r=u'1a2s3ae')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.translate(_to_java_column(srcCol), matching, replace)) + + # ---------------------- Collection functions ------------------------------ @since(1.4) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 94c355f838fa0..cd5a90d788151 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -203,6 +203,7 @@ object FunctionRegistry { expression[Substring]("substr"), expression[Substring]("substring"), expression[SubstringIndex]("substring_index"), + expression[StringTranslate]("translate"), expression[StringTrim]("trim"), expression[UnBase64]("unbase64"), expression[Upper]("ucase"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index ef2fc2e8c29d4..0b98f555a1d60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -444,7 +444,7 @@ abstract class TernaryExpression extends Expression { override def nullable: Boolean = children.exists(_.nullable) /** - * Default behavior of evaluation according to the default nullability of BinaryExpression. + * Default behavior of evaluation according to the default nullability of TernaryExpression. * If subclass of BinaryExpression override nullable, probably should also override this. */ override def eval(input: InternalRow): Any = { @@ -463,7 +463,7 @@ abstract class TernaryExpression extends Expression { } /** - * Called by default [[eval]] implementation. If subclass of BinaryExpression keep the default + * Called by default [[eval]] implementation. If subclass of TernaryExpression keep the default * nullability, they can override this method to save null-check code. If we need full control * of evaluation process, we should override [[eval]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 0cc785d9f3a49..76666bd6b3d27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import java.text.DecimalFormat -import java.util.{Arrays, Locale} +import java.util.Arrays +import java.util.{Map => JMap, HashMap} +import java.util.Locale import java.util.regex.{MatchResult, Pattern} import org.apache.commons.lang3.StringEscapeUtils @@ -349,6 +351,81 @@ case class EndsWith(left: Expression, right: Expression) } } +object StringTranslate { + + def buildDict(matchingString: UTF8String, replaceString: UTF8String) + : JMap[Character, Character] = { + val matching = matchingString.toString() + val replace = replaceString.toString() + val dict = new HashMap[Character, Character]() + var i = 0 + while (i < matching.length()) { + val rep = if (i < replace.length()) replace.charAt(i) else '\0' + if (null == dict.get(matching.charAt(i))) { + dict.put(matching.charAt(i), rep) + } + i += 1 + } + dict + } +} + +/** + * A function translate any character in the `srcExpr` by a character in `replaceExpr`. + * The characters in `replaceExpr` is corresponding to the characters in `matchingExpr`. + * The translate will happen when any character in the string matching with the character + * in the `matchingExpr`. + */ +case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replaceExpr: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + @transient private var lastMatching: UTF8String = _ + @transient private var lastReplace: UTF8String = _ + @transient private var dict: JMap[Character, Character] = _ + + override def nullSafeEval(srcEval: Any, matchingEval: Any, replaceEval: Any): Any = { + if (matchingEval != lastMatching || replaceEval != lastReplace) { + lastMatching = matchingEval.asInstanceOf[UTF8String].clone() + lastReplace = replaceEval.asInstanceOf[UTF8String].clone() + dict = StringTranslate.buildDict(lastMatching, lastReplace) + } + srcEval.asInstanceOf[UTF8String].translate(dict) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val termLastMatching = ctx.freshName("lastMatching") + val termLastReplace = ctx.freshName("lastReplace") + val termDict = ctx.freshName("dict") + val classNameDict = classOf[JMap[Character, Character]].getCanonicalName + + ctx.addMutableState("UTF8String", termLastMatching, s"${termLastMatching} = null;") + ctx.addMutableState("UTF8String", termLastReplace, s"${termLastReplace} = null;") + ctx.addMutableState(classNameDict, termDict, s"${termDict} = null;") + + nullSafeCodeGen(ctx, ev, (src, matching, replace) => { + val check = if (matchingExpr.foldable && replaceExpr.foldable) { + s"${termDict} == null" + } else { + s"!${matching}.equals(${termLastMatching}) || !${replace}.equals(${termLastReplace})" + } + s"""if ($check) { + // Not all of them is literal or matching or replace value changed + ${termLastMatching} = ${matching}.clone(); + ${termLastReplace} = ${replace}.clone(); + ${termDict} = org.apache.spark.sql.catalyst.expressions.StringTranslate + .buildDict(${termLastMatching}, ${termLastReplace}); + } + ${ev.primitive} = ${src}.translate(${termDict}); + """ + }) + } + + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) + override def children: Seq[Expression] = srcExpr :: matchingExpr :: replaceExpr :: Nil + override def prettyName: String = "translate" +} + /** * A function that returns the index (1-based) of the given string (left) in the comma- * delimited list (right). Returns 0, if the string wasn't found or if the given diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 23f36ca43d663..426dc272471ae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -431,6 +431,20 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(SoundEx(Literal("!!")), "!!") } + test("translate") { + checkEvaluation( + StringTranslate(Literal("translate"), Literal("rnlt"), Literal("123")), "1a2s3ae") + checkEvaluation(StringTranslate(Literal("translate"), Literal(""), Literal("123")), "translate") + checkEvaluation(StringTranslate(Literal("translate"), Literal("rnlt"), Literal("")), "asae") + // test for multiple mapping + checkEvaluation(StringTranslate(Literal("abcd"), Literal("aba"), Literal("123")), "12cd") + checkEvaluation(StringTranslate(Literal("abcd"), Literal("aba"), Literal("12")), "12cd") + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(StringTranslate(Literal("花花世界"), Literal("花界"), Literal("ab")), "aa世b") + // scalastyle:on + } + test("TRIM/LTRIM/RTRIM") { val s = 'a.string.at(0) checkEvaluation(StringTrim(Literal(" aa ")), "aa", create_row(" abdef ")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 39aa905c8532a..79c5f596661d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1100,11 +1100,11 @@ object functions { } /** - * Computes hex value of the given column. - * - * @group math_funcs - * @since 1.5.0 - */ + * Computes hex value of the given column. + * + * @group math_funcs + * @since 1.5.0 + */ def hex(column: Column): Column = Hex(column.expr) /** @@ -1863,6 +1863,17 @@ object functions { def substring_index(str: Column, delim: String, count: Int): Column = SubstringIndex(str.expr, lit(delim).expr, lit(count).expr) + /* Translate any character in the src by a character in replaceString. + * The characters in replaceString is corresponding to the characters in matchingString. + * The translate will happen when any character in the string matching with the character + * in the matchingString. + * + * @group string_funcs + * @since 1.5.0 + */ + def translate(src: Column, matchingString: String, replaceString: String): Column = + StringTranslate(src.expr, lit(matchingString).expr, lit(replaceString).expr) + /** * Trim the spaces from both ends for the specified string column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index ab5da6ee79f1b..ca298b2434410 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -128,6 +128,12 @@ class StringFunctionsSuite extends QueryTest { // scalastyle:on } + test("string translate") { + val df = Seq(("translate", "")).toDF("a", "b") + checkAnswer(df.select(translate($"a", "rnlt", "123")), Row("1a2s3ae")) + checkAnswer(df.selectExpr("""translate(a, "rnlt", "")"""), Row("asae")) + } + test("string trim functions") { val df = Seq((" example ", "")).toDF("a", "b") diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index febbe3d4e54d1..d1014426c0f49 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -22,6 +22,7 @@ import java.io.UnsupportedEncodingException; import java.nio.ByteOrder; import java.util.Arrays; +import java.util.Map; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; @@ -795,6 +796,21 @@ public UTF8String[] split(UTF8String pattern, int limit) { return res; } + // TODO: Need to use `Code Point` here instead of Char in case the character longer than 2 bytes + public UTF8String translate(Map dict) { + String srcStr = this.toString(); + + StringBuilder sb = new StringBuilder(); + for(int k = 0; k< srcStr.length(); k++) { + if (null == dict.get(srcStr.charAt(k))) { + sb.append(srcStr.charAt(k)); + } else if ('\0' != dict.get(srcStr.charAt(k))){ + sb.append(dict.get(srcStr.charAt(k))); + } + } + return fromString(sb.toString()); + } + @Override public String toString() { try { diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index b30c94c1c1f80..98aa8a2469a75 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -19,7 +19,9 @@ import java.io.UnsupportedEncodingException; import java.util.Arrays; +import java.util.HashMap; +import com.google.common.collect.ImmutableMap; import org.junit.Test; import static junit.framework.Assert.*; @@ -391,6 +393,35 @@ public void levenshteinDistance() { assertEquals(fromString("世界千世").levenshteinDistance(fromString("千a世b")),4); } + @Test + public void translate() { + assertEquals( + fromString("1a2s3ae"), + fromString("translate").translate(ImmutableMap.of( + 'r', '1', + 'n', '2', + 'l', '3', + 't', '\0' + ))); + assertEquals( + fromString("translate"), + fromString("translate").translate(new HashMap())); + assertEquals( + fromString("asae"), + fromString("translate").translate(ImmutableMap.of( + 'r', '\0', + 'n', '\0', + 'l', '\0', + 't', '\0' + ))); + assertEquals( + fromString("aa世b"), + fromString("花花世界").translate(ImmutableMap.of( + '花', 'a', + '界', 'b' + ))); + } + @Test public void createBlankString() { assertEquals(fromString(" "), blankString(1)); From 5b965d64ee1687145ba793da749659c8f67384e8 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 6 Aug 2015 09:10:57 -0700 Subject: [PATCH 180/340] [SPARK-9644] [SQL] Support update DecimalType with precision > 18 in UnsafeRow In order to support update a varlength (actually fixed length) object, the space should be preserved even it's null. And, we can't call setNullAt(i) for it anymore, we because setNullAt(i) will remove the offset of the preserved space, should call setDecimal(i, null, precision) instead. After this, we can do hash based aggregation on DecimalType with precision > 18. In a tests, this could decrease the end-to-end run time of aggregation query from 37 seconds (sort based) to 24 seconds (hash based). cc rxin Author: Davies Liu Closes #7978 from davies/update_decimal and squashes the following commits: bed8100 [Davies Liu] isSettable -> isMutable 923c9eb [Davies Liu] address comments and fix bug 385891d [Davies Liu] Merge branch 'master' of github.com:apache/spark into update_decimal 36a1872 [Davies Liu] fix tests cd6c524 [Davies Liu] support set decimal with precision > 18 --- .../sql/catalyst/expressions/UnsafeRow.java | 74 +++++++++++++++---- .../expressions/UnsafeRowWriters.java | 41 ++++++---- .../codegen/GenerateMutableProjection.scala | 15 +++- .../codegen/GenerateUnsafeProjection.scala | 53 +++++++------ .../spark/sql/catalyst/expressions/rows.scala | 8 +- .../expressions/UnsafeRowConverterSuite.scala | 17 ++++- .../UnsafeFixedWidthAggregationMap.java | 4 +- .../SortBasedAggregationIterator.scala | 4 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 2 +- .../spark/unsafe/PlatformDependent.java | 26 +++++++ 10 files changed, 183 insertions(+), 61 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index e3e1622de08ba..e829acb6285f1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -65,11 +65,11 @@ public static int calculateBitSetWidthInBytes(int numFields) { /** * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types) */ - public static final Set settableFieldTypes; + public static final Set mutableFieldTypes; - // DecimalType(precision <= 18) is settable + // DecimalType is also mutable static { - settableFieldTypes = Collections.unmodifiableSet( + mutableFieldTypes = Collections.unmodifiableSet( new HashSet<>( Arrays.asList(new DataType[] { NullType, @@ -87,12 +87,16 @@ public static int calculateBitSetWidthInBytes(int numFields) { public static boolean isFixedLength(DataType dt) { if (dt instanceof DecimalType) { - return ((DecimalType) dt).precision() < Decimal.MAX_LONG_DIGITS(); + return ((DecimalType) dt).precision() <= Decimal.MAX_LONG_DIGITS(); } else { - return settableFieldTypes.contains(dt); + return mutableFieldTypes.contains(dt); } } + public static boolean isMutable(DataType dt) { + return mutableFieldTypes.contains(dt) || dt instanceof DecimalType; + } + ////////////////////////////////////////////////////////////////////////////// // Private fields and methods ////////////////////////////////////////////////////////////////////////////// @@ -238,17 +242,45 @@ public void setFloat(int ordinal, float value) { PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); } + /** + * Updates the decimal column. + * + * Note: In order to support update a decimal with precision > 18, CAN NOT call + * setNullAt() for this column. + */ @Override public void setDecimal(int ordinal, Decimal value, int precision) { assertIndexIsValid(ordinal); - if (value == null) { - setNullAt(ordinal); - } else { - if (precision <= Decimal.MAX_LONG_DIGITS()) { + if (precision <= Decimal.MAX_LONG_DIGITS()) { + // compact format + if (value == null) { + setNullAt(ordinal); + } else { setLong(ordinal, value.toUnscaledLong()); + } + } else { + // fixed length + long cursor = getLong(ordinal) >>> 32; + assert cursor > 0 : "invalid cursor " + cursor; + // zero-out the bytes + PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + cursor, 0L); + PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + cursor + 8, 0L); + + if (value == null) { + setNullAt(ordinal); + // keep the offset for future update + PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(ordinal), cursor << 32); } else { - // TODO(davies): support update decimal (hold a bounded space even it's null) - throw new UnsupportedOperationException(); + + final BigInteger integer = value.toJavaBigDecimal().unscaledValue(); + final int[] mag = (int[]) PlatformDependent.UNSAFE.getObjectVolatile(integer, + PlatformDependent.BIG_INTEGER_MAG_OFFSET); + assert(mag.length <= 4); + + // Write the bytes to the variable length portion. + PlatformDependent.copyMemory(mag, PlatformDependent.INT_ARRAY_OFFSET, + baseObject, baseOffset + cursor, mag.length * 4); + setLong(ordinal, (cursor << 32) | ((long) (((integer.signum() + 1) << 8) + mag.length))); } } } @@ -343,6 +375,8 @@ public double getDouble(int ordinal) { return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal)); } + private static byte[] EMPTY = new byte[0]; + @Override public Decimal getDecimal(int ordinal, int precision, int scale) { if (isNullAt(ordinal)) { @@ -351,10 +385,20 @@ public Decimal getDecimal(int ordinal, int precision, int scale) { if (precision <= Decimal.MAX_LONG_DIGITS()) { return Decimal.apply(getLong(ordinal), precision, scale); } else { - byte[] bytes = getBinary(ordinal); - BigInteger bigInteger = new BigInteger(bytes); - BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); - return Decimal.apply(new scala.math.BigDecimal(javaDecimal), precision, scale); + long offsetAndSize = getLong(ordinal); + long offset = offsetAndSize >>> 32; + int signum = ((int) (offsetAndSize & 0xfff) >> 8); + assert signum >=0 && signum <= 2 : "invalid signum " + signum; + int size = (int) (offsetAndSize & 0xff); + int[] mag = new int[size]; + PlatformDependent.copyMemory(baseObject, baseOffset + offset, + mag, PlatformDependent.INT_ARRAY_OFFSET, size * 4); + + // create a BigInteger using signum and mag + BigInteger v = new BigInteger(0, EMPTY); // create the initial object + PlatformDependent.UNSAFE.putInt(v, PlatformDependent.BIG_INTEGER_SIGNUM_OFFSET, signum - 1); + PlatformDependent.UNSAFE.putObjectVolatile(v, PlatformDependent.BIG_INTEGER_MAG_OFFSET, mag); + return Decimal.apply(new BigDecimal(v, scale), precision, scale); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java index 31928731545da..28e7ec0a0f120 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.expressions; +import java.math.BigInteger; + import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.Decimal; -import org.apache.spark.sql.types.MapData; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.ByteArray; @@ -47,29 +48,41 @@ public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input /** Writer for Decimal with precision larger than 18. */ public static class DecimalWriter { - + private static final int SIZE = 16; public static int getSize(Decimal input) { // bounded size - return 16; + return SIZE; } public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) { + final Object base = target.getBaseObject(); final long offset = target.getBaseOffset() + cursor; - final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); - final int numBytes = bytes.length; - assert(numBytes <= 16); - // zero-out the bytes - PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, 0L); - PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, 0L); + PlatformDependent.UNSAFE.putLong(base, offset, 0L); + PlatformDependent.UNSAFE.putLong(base, offset + 8, 0L); + + if (input == null) { + target.setNullAt(ordinal); + // keep the offset and length for update + int fieldOffset = UnsafeRow.calculateBitSetWidthInBytes(target.numFields()) + ordinal * 8; + PlatformDependent.UNSAFE.putLong(base, target.getBaseOffset() + fieldOffset, + ((long) cursor) << 32); + return SIZE; + } - // Write the bytes to the variable length portion. - PlatformDependent.copyMemory(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, - target.getBaseObject(), offset, numBytes); + final BigInteger integer = input.toJavaBigDecimal().unscaledValue(); + int signum = integer.signum() + 1; + final int[] mag = (int[]) PlatformDependent.UNSAFE.getObjectVolatile(integer, + PlatformDependent.BIG_INTEGER_MAG_OFFSET); + assert(mag.length <= 4); + // Write the bytes to the variable length portion. + PlatformDependent.copyMemory(mag, PlatformDependent.INT_ARRAY_OFFSET, + base, target.getBaseOffset() + cursor, mag.length * 4); // Set the fixed length portion. - target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); - return 16; + target.setLong(ordinal, (((long) cursor) << 32) | ((long) ((signum << 8) + mag.length))); + + return SIZE; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index e4a8fc24dac2f..ac58423cd884d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp +import org.apache.spark.sql.types.DecimalType // MutableProjection is not accessible in Java abstract class BaseMutableProjection extends MutableProjection @@ -43,14 +44,26 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu case (NoOp, _) => "" case (e, i) => val evaluationCode = e.gen(ctx) - evaluationCode.code + + if (e.dataType.isInstanceOf[DecimalType]) { + // Can't call setNullAt on DecimalType, because we need to keep the offset s""" + ${evaluationCode.code} + if (${evaluationCode.isNull}) { + ${ctx.setColumn("mutableRow", e.dataType, i, null)}; + } else { + ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; + } + """ + } else { + s""" + ${evaluationCode.code} if (${evaluationCode.isNull}) { mutableRow.setNullAt($i); } else { ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; } """ + } } // collect projections into blocks as function has 64kb codesize limit in JVM val projectionBlocks = new ArrayBuffer[String]() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 71f8ea09f0770..d8912df694a10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -45,10 +45,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = dataType match { + case NullType => true case t: AtomicType => true case _: CalendarIntervalType => true case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) - case NullType => true case t: ArrayType if canSupport(t.elementType) => true case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true case _ => false @@ -56,7 +56,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match { case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => - s" + (${ev.isNull} ? 0 : $DecimalWriter.getSize(${ev.primitive}))" + s" + $DecimalWriter.getSize(${ev.primitive})" case StringType => s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))" case BinaryType => @@ -76,41 +76,41 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodeGenContext, fieldType: DataType, ev: GeneratedExpressionCode, - primitive: String, + target: String, index: Int, cursor: String): String = fieldType match { case _ if ctx.isPrimitiveType(fieldType) => - s"${ctx.setColumn(primitive, fieldType, index, ev.primitive)}" + s"${ctx.setColumn(target, fieldType, index, ev.primitive)}" case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => s""" // make sure Decimal object has the same scale as DecimalType if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) { - $CompactDecimalWriter.write($primitive, $index, $cursor, ${ev.primitive}); + $CompactDecimalWriter.write($target, $index, $cursor, ${ev.primitive}); } else { - $primitive.setNullAt($index); + $target.setNullAt($index); } """ case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => s""" // make sure Decimal object has the same scale as DecimalType if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) { - $cursor += $DecimalWriter.write($primitive, $index, $cursor, ${ev.primitive}); + $cursor += $DecimalWriter.write($target, $index, $cursor, ${ev.primitive}); } else { - $primitive.setNullAt($index); + $cursor += $DecimalWriter.write($target, $index, $cursor, null); } """ case StringType => - s"$cursor += $StringWriter.write($primitive, $index, $cursor, ${ev.primitive})" + s"$cursor += $StringWriter.write($target, $index, $cursor, ${ev.primitive})" case BinaryType => - s"$cursor += $BinaryWriter.write($primitive, $index, $cursor, ${ev.primitive})" + s"$cursor += $BinaryWriter.write($target, $index, $cursor, ${ev.primitive})" case CalendarIntervalType => - s"$cursor += $IntervalWriter.write($primitive, $index, $cursor, ${ev.primitive})" + s"$cursor += $IntervalWriter.write($target, $index, $cursor, ${ev.primitive})" case _: StructType => - s"$cursor += $StructWriter.write($primitive, $index, $cursor, ${ev.primitive})" + s"$cursor += $StructWriter.write($target, $index, $cursor, ${ev.primitive})" case _: ArrayType => - s"$cursor += $ArrayWriter.write($primitive, $index, $cursor, ${ev.primitive})" + s"$cursor += $ArrayWriter.write($target, $index, $cursor, ${ev.primitive})" case _: MapType => - s"$cursor += $MapWriter.write($primitive, $index, $cursor, ${ev.primitive})" + s"$cursor += $MapWriter.write($target, $index, $cursor, ${ev.primitive})" case NullType => "" case _ => throw new UnsupportedOperationException(s"Not supported DataType: $fieldType") @@ -146,13 +146,24 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val fieldWriters = inputTypes.zip(convertedFields).zipWithIndex.map { case ((dt, ev), i) => val update = genFieldWriter(ctx, dt, ev, output, i, cursor) - s""" - if (${ev.isNull}) { - $output.setNullAt($i); - } else { - $update; - } - """ + if (dt.isInstanceOf[DecimalType]) { + // Can't call setNullAt() for DecimalType + s""" + if (${ev.isNull}) { + $cursor += $DecimalWriter.write($output, $i, $cursor, null); + } else { + $update; + } + """ + } else { + s""" + if (${ev.isNull}) { + $output.setNullAt($i); + } else { + $update; + } + """ + } }.mkString("\n") val code = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 5e5de1d1dc6a7..7657fb535dcf4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** * An extended interface to [[InternalRow]] that allows the values for each column to be updated. @@ -39,6 +38,13 @@ abstract class MutableRow extends InternalRow { def setLong(i: Int, value: Long): Unit = { update(i, value) } def setFloat(i: Int, value: Float): Unit = { update(i, value) } def setDouble(i: Int, value: Double): Unit = { update(i, value) } + + /** + * Update the decimal column at `i`. + * + * Note: In order to support update decimal with precision > 18 in UnsafeRow, + * CAN NOT call setNullAt() for decimal column on UnsafeRow, call setDecimal(i, null, precision). + */ def setDecimal(i: Int, value: Decimal, precision: Int) { update(i, value) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 59491c5ba160e..8c72203193630 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -123,7 +123,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { DoubleType, StringType, BinaryType, - DecimalType.USER_DEFAULT + DecimalType.USER_DEFAULT, + DecimalType.SYSTEM_DEFAULT // ArrayType(IntegerType) ) val converter = UnsafeProjection.create(fieldTypes) @@ -151,6 +152,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(createdFromNull.getUTF8String(8) === null) assert(createdFromNull.getBinary(9) === null) assert(createdFromNull.getDecimal(10, 10, 0) === null) + assert(createdFromNull.getDecimal(11, 38, 18) === null) // assert(createdFromNull.get(11) === null) // If we have an UnsafeRow with columns that are initially non-null and we null out those @@ -169,6 +171,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { r.update(8, UTF8String.fromString("hello")) r.update(9, "world".getBytes) r.setDecimal(10, Decimal(10), 10) + r.setDecimal(11, Decimal(10.00, 38, 18), 38) // r.update(11, Array(11)) r } @@ -187,10 +190,17 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(setToNullAfterCreation.getBinary(9) === rowWithNoNullColumns.getBinary(9)) assert(setToNullAfterCreation.getDecimal(10, 10, 0) === rowWithNoNullColumns.getDecimal(10, 10, 0)) + assert(setToNullAfterCreation.getDecimal(11, 38, 18) === + rowWithNoNullColumns.getDecimal(11, 38, 18)) // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) for (i <- fieldTypes.indices) { - setToNullAfterCreation.setNullAt(i) + // Cann't call setNullAt() on DecimalType + if (i == 11) { + setToNullAfterCreation.setDecimal(11, null, 38) + } else { + setToNullAfterCreation.setNullAt(i) + } } // There are some garbage left in the var-length area assert(Arrays.equals(createdFromNull.getBytes, setToNullAfterCreation.getBytes())) @@ -206,6 +216,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // setToNullAfterCreation.update(8, UTF8String.fromString("hello")) // setToNullAfterCreation.update(9, "world".getBytes) setToNullAfterCreation.setDecimal(10, Decimal(10), 10) + setToNullAfterCreation.setDecimal(11, Decimal(10.00, 38, 18), 38) // setToNullAfterCreation.update(11, Array(11)) assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) @@ -220,6 +231,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) assert(setToNullAfterCreation.getDecimal(10, 10, 0) === rowWithNoNullColumns.getDecimal(10, 10, 0)) + assert(setToNullAfterCreation.getDecimal(11, 38, 18) === + rowWithNoNullColumns.getDecimal(11, 38, 18)) // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 43d06ce9bdfa3..02458030b00e9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -72,7 +72,7 @@ public final class UnsafeFixedWidthAggregationMap { */ public static boolean supportsAggregationBufferSchema(StructType schema) { for (StructField field: schema.fields()) { - if (!UnsafeRow.isFixedLength(field.dataType())) { + if (!UnsafeRow.isMutable(field.dataType())) { return false; } } @@ -111,8 +111,6 @@ public UnsafeFixedWidthAggregationMap( // Initialize the buffer for aggregation value final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema); this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); - assert(this.emptyAggregationBuffer.length == aggregationBufferSchema.length() * 8 + - UnsafeRow.calculateBitSetWidthInBytes(aggregationBufferSchema.length())); } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 78bcee16c9d00..40f6bff53d2b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2} -import org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap -import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.KVIterator /** @@ -57,7 +55,7 @@ class SortBasedAggregationIterator( val bufferRowSize: Int = bufferSchema.length val genericMutableBuffer = new GenericMutableRow(bufferRowSize) - val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isFixedLength) + val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) val buffer = if (useUnsafeBuffer) { val unsafeProjection = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index b513c970ccfe2..e03473041c3e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -93,7 +93,7 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { testWithMemoryLeakDetection("supported schemas") { assert(supportsAggregationBufferSchema( StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil))) - assert(!supportsAggregationBufferSchema( + assert(supportsAggregationBufferSchema( StructType(StructField("x", DecimalType.SYSTEM_DEFAULT) :: Nil))) assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) assert( diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java index 192c6714b2406..b2de2a2590f05 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java @@ -18,6 +18,7 @@ package org.apache.spark.unsafe; import java.lang.reflect.Field; +import java.math.BigInteger; import sun.misc.Unsafe; @@ -87,6 +88,14 @@ public static void putDouble(Object object, long offset, double value) { _UNSAFE.putDouble(object, offset, value); } + public static Object getObjectVolatile(Object object, long offset) { + return _UNSAFE.getObjectVolatile(object, offset); + } + + public static void putObjectVolatile(Object object, long offset, Object value) { + _UNSAFE.putObjectVolatile(object, offset, value); + } + public static long allocateMemory(long size) { return _UNSAFE.allocateMemory(size); } @@ -107,6 +116,10 @@ public static void freeMemory(long address) { public static final int DOUBLE_ARRAY_OFFSET; + // Support for resetting final fields while deserializing + public static final long BIG_INTEGER_SIGNUM_OFFSET; + public static final long BIG_INTEGER_MAG_OFFSET; + /** * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to * allow safepoint polling during a large copy. @@ -129,11 +142,24 @@ public static void freeMemory(long address) { INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class); LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class); DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class); + + long signumOffset = 0; + long magOffset = 0; + try { + signumOffset = _UNSAFE.objectFieldOffset(BigInteger.class.getDeclaredField("signum")); + magOffset = _UNSAFE.objectFieldOffset(BigInteger.class.getDeclaredField("mag")); + } catch (Exception ex) { + // should not happen + } + BIG_INTEGER_SIGNUM_OFFSET = signumOffset; + BIG_INTEGER_MAG_OFFSET = magOffset; } else { BYTE_ARRAY_OFFSET = 0; INT_ARRAY_OFFSET = 0; LONG_ARRAY_OFFSET = 0; DOUBLE_ARRAY_OFFSET = 0; + BIG_INTEGER_SIGNUM_OFFSET = 0; + BIG_INTEGER_MAG_OFFSET = 0; } } From 93085c992e40dbc06714cb1a64c838e25e683a6f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 6 Aug 2015 09:12:41 -0700 Subject: [PATCH 181/340] [SPARK-9482] [SQL] Fix thread-safey issue of using UnsafeProjection in join This PR also change to use `def` instead of `lazy val` for UnsafeProjection, because it's not thread safe. TODO: cleanup the debug code once the flaky test passed 100 times. Author: Davies Liu Closes #7940 from davies/semijoin and squashes the following commits: 93baac7 [Davies Liu] fix outerjoin 5c40ded [Davies Liu] address comments aa3de46 [Davies Liu] Merge branch 'master' of github.com:apache/spark into semijoin 7590a25 [Davies Liu] Merge branch 'master' of github.com:apache/spark into semijoin 2d4085b [Davies Liu] use def for resultProjection 0833407 [Davies Liu] Merge branch 'semijoin' of github.com:davies/spark into semijoin e0d8c71 [Davies Liu] use lazy val 6a59e8f [Davies Liu] Update HashedRelation.scala 0fdacaf [Davies Liu] fix broadcast and thread-safety of UnsafeProjection 2fc3ef6 [Davies Liu] reproduce failure in semijoin --- .../execution/joins/BroadcastHashJoin.scala | 6 ++--- .../joins/BroadcastHashOuterJoin.scala | 20 ++++++---------- .../joins/BroadcastNestedLoopJoin.scala | 17 ++++++++------ .../spark/sql/execution/joins/HashJoin.scala | 4 ++-- .../sql/execution/joins/HashOuterJoin.scala | 23 ++++++++++--------- .../sql/execution/joins/HashSemiJoin.scala | 8 +++---- .../sql/execution/joins/HashedRelation.scala | 4 ++-- .../joins/ShuffledHashOuterJoin.scala | 6 +++-- 8 files changed, 44 insertions(+), 44 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index ec1a148342fc6..f7a68e4f5d445 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ -import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan} import org.apache.spark.util.ThreadUtils +import org.apache.spark.{InternalAccumulator, TaskContext} /** * :: DeveloperApi :: @@ -102,6 +102,6 @@ case class BroadcastHashJoin( object BroadcastHashJoin { - private val broadcastHashJoinExecutionContext = ExecutionContext.fromExecutorService( + private[joins] val broadcastHashJoinExecutionContext = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-join", 128)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index e342fd914d321..a3626de49aeab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -20,15 +20,14 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ -import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution} -import org.apache.spark.util.ThreadUtils +import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan} +import org.apache.spark.{InternalAccumulator, TaskContext} /** * :: DeveloperApi :: @@ -76,7 +75,7 @@ case class BroadcastHashOuterJoin( val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size) sparkContext.broadcast(hashed) } - }(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext) + }(BroadcastHashJoin.broadcastHashJoinExecutionContext) } protected override def doPrepare(): Unit = { @@ -98,19 +97,20 @@ case class BroadcastHashOuterJoin( case _ => } + val resultProj = resultProjection joinType match { case LeftOuter => streamedIter.flatMap(currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey)) + leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj) }) case RightOuter => streamedIter.flatMap(currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow) + rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj) }) case x => @@ -120,9 +120,3 @@ case class BroadcastHashOuterJoin( } } } - -object BroadcastHashOuterJoin { - - private val broadcastHashOuterJoinExecutionContext = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-outer-join", 128)) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 83b726a8e2897..23aebf4b068b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -47,7 +47,7 @@ case class BroadcastNestedLoopJoin( override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows override def canProcessUnsafeRows: Boolean = true - @transient private[this] lazy val resultProjection: InternalRow => InternalRow = { + private[this] def genResultProjection: InternalRow => InternalRow = { if (outputsUnsafeRows) { UnsafeProjection.create(schema) } else { @@ -88,6 +88,7 @@ case class BroadcastNestedLoopJoin( val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) + val resultProj = genResultProjection streamedIter.foreach { streamedRow => var i = 0 @@ -97,11 +98,11 @@ case class BroadcastNestedLoopJoin( val broadcastedRow = broadcastedRelation.value(i) buildSide match { case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => - matchedRows += resultProjection(joinedRow(streamedRow, broadcastedRow)).copy() + matchedRows += resultProj(joinedRow(streamedRow, broadcastedRow)).copy() streamRowMatched = true includedBroadcastTuples += i case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => - matchedRows += resultProjection(joinedRow(broadcastedRow, streamedRow)).copy() + matchedRows += resultProj(joinedRow(broadcastedRow, streamedRow)).copy() streamRowMatched = true includedBroadcastTuples += i case _ => @@ -111,9 +112,9 @@ case class BroadcastNestedLoopJoin( (streamRowMatched, joinType, buildSide) match { case (false, LeftOuter | FullOuter, BuildRight) => - matchedRows += resultProjection(joinedRow(streamedRow, rightNulls)).copy() + matchedRows += resultProj(joinedRow(streamedRow, rightNulls)).copy() case (false, RightOuter | FullOuter, BuildLeft) => - matchedRows += resultProjection(joinedRow(leftNulls, streamedRow)).copy() + matchedRows += resultProj(joinedRow(leftNulls, streamedRow)).copy() case _ => } } @@ -127,6 +128,8 @@ case class BroadcastNestedLoopJoin( val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) + val resultProj = genResultProjection + /** Rows from broadcasted joined with nulls. */ val broadcastRowsWithNulls: Seq[InternalRow] = { val buf: CompactBuffer[InternalRow] = new CompactBuffer() @@ -138,7 +141,7 @@ case class BroadcastNestedLoopJoin( joinedRow.withLeft(leftNulls) while (i < rel.length) { if (!allIncludedBroadcastTuples.contains(i)) { - buf += resultProjection(joinedRow.withRight(rel(i))).copy() + buf += resultProj(joinedRow.withRight(rel(i))).copy() } i += 1 } @@ -147,7 +150,7 @@ case class BroadcastNestedLoopJoin( joinedRow.withRight(rightNulls) while (i < rel.length) { if (!allIncludedBroadcastTuples.contains(i)) { - buf += resultProjection(joinedRow.withLeft(rel(i))).copy() + buf += resultProj(joinedRow.withLeft(rel(i))).copy() } i += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 6b3d1652923fd..5e9cd9fd2345a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -52,14 +52,14 @@ trait HashJoin { override def canProcessUnsafeRows: Boolean = isUnsafeMode override def canProcessSafeRows: Boolean = !isUnsafeMode - @transient protected lazy val buildSideKeyGenerator: Projection = + protected def buildSideKeyGenerator: Projection = if (isUnsafeMode) { UnsafeProjection.create(buildKeys, buildPlan.output) } else { newMutableProjection(buildKeys, buildPlan.output)() } - @transient protected lazy val streamSideKeyGenerator: Projection = + protected def streamSideKeyGenerator: Projection = if (isUnsafeMode) { UnsafeProjection.create(streamedKeys, streamedPlan.output) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index a323aea4ea2c4..346337e64245c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -76,14 +76,14 @@ trait HashOuterJoin { override def canProcessUnsafeRows: Boolean = isUnsafeMode override def canProcessSafeRows: Boolean = !isUnsafeMode - @transient protected lazy val buildKeyGenerator: Projection = + protected def buildKeyGenerator: Projection = if (isUnsafeMode) { UnsafeProjection.create(buildKeys, buildPlan.output) } else { newMutableProjection(buildKeys, buildPlan.output)() } - @transient protected[this] lazy val streamedKeyGenerator: Projection = { + protected[this] def streamedKeyGenerator: Projection = { if (isUnsafeMode) { UnsafeProjection.create(streamedKeys, streamedPlan.output) } else { @@ -91,7 +91,7 @@ trait HashOuterJoin { } } - @transient private[this] lazy val resultProjection: InternalRow => InternalRow = { + protected[this] def resultProjection: InternalRow => InternalRow = { if (isUnsafeMode) { UnsafeProjection.create(self.schema) } else { @@ -113,7 +113,8 @@ trait HashOuterJoin { protected[this] def leftOuterIterator( key: InternalRow, joinedRow: JoinedRow, - rightIter: Iterable[InternalRow]): Iterator[InternalRow] = { + rightIter: Iterable[InternalRow], + resultProjection: InternalRow => InternalRow): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { val temp = if (rightIter != null) { @@ -124,12 +125,12 @@ trait HashOuterJoin { List.empty } if (temp.isEmpty) { - resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil + resultProjection(joinedRow.withRight(rightNullRow)) :: Nil } else { temp } } else { - resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil + resultProjection(joinedRow.withRight(rightNullRow)) :: Nil } } ret.iterator @@ -138,24 +139,24 @@ trait HashOuterJoin { protected[this] def rightOuterIterator( key: InternalRow, leftIter: Iterable[InternalRow], - joinedRow: JoinedRow): Iterator[InternalRow] = { + joinedRow: JoinedRow, + resultProjection: InternalRow => InternalRow): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { val temp = if (leftIter != null) { leftIter.collect { - case l if boundCondition(joinedRow.withLeft(l)) => - resultProjection(joinedRow).copy() + case l if boundCondition(joinedRow.withLeft(l)) => resultProjection(joinedRow).copy() } } else { List.empty } if (temp.isEmpty) { - resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil + resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil } else { temp } } else { - resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil + resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil } } ret.iterator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 97fde8f975bfd..47a7d370f5415 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -43,14 +43,14 @@ trait HashSemiJoin { override def canProcessUnsafeRows: Boolean = supportUnsafe override def canProcessSafeRows: Boolean = !supportUnsafe - @transient protected lazy val leftKeyGenerator: Projection = + protected def leftKeyGenerator: Projection = if (supportUnsafe) { UnsafeProjection.create(leftKeys, left.output) } else { newMutableProjection(leftKeys, left.output)() } - @transient protected lazy val rightKeyGenerator: Projection = + protected def rightKeyGenerator: Projection = if (supportUnsafe) { UnsafeProjection.create(rightKeys, right.output) } else { @@ -62,12 +62,11 @@ trait HashSemiJoin { protected def buildKeyHashSet(buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = { val hashSet = new java.util.HashSet[InternalRow]() - var currentRow: InternalRow = null // Create a Hash set of buildKeys val rightKey = rightKeyGenerator while (buildIter.hasNext) { - currentRow = buildIter.next() + val currentRow = buildIter.next() val rowKey = rightKey(currentRow) if (!rowKey.anyNull) { val keyExists = hashSet.contains(rowKey) @@ -76,6 +75,7 @@ trait HashSemiJoin { } } } + hashSet } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 58b4236f7b5b5..3f257ecdd156c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.execution.joins -import java.io.{IOException, Externalizable, ObjectInput, ObjectOutput} +import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.nio.ByteOrder import java.util.{HashMap => JavaHashMap} import org.apache.spark.shuffle.ShuffleMemoryManager -import org.apache.spark.{SparkConf, SparkEnv, TaskContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -31,6 +30,7 @@ import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.util.Utils import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.{SparkConf, SparkEnv} /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index eee8ad800f98e..6a8c35efca8f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -60,19 +60,21 @@ case class ShuffledHashOuterJoin( case LeftOuter => val hashed = HashedRelation(rightIter, buildKeyGenerator) val keyGenerator = streamedKeyGenerator + val resultProj = resultProjection leftIter.flatMap( currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey)) + leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj) }) case RightOuter => val hashed = HashedRelation(leftIter, buildKeyGenerator) val keyGenerator = streamedKeyGenerator + val resultProj = resultProjection rightIter.flatMap ( currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow) + rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj) }) case FullOuter => From 9f94c85ff35df6289371f80edde51c2aa6c4bcdc Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 6 Aug 2015 09:53:53 -0700 Subject: [PATCH 182/340] [SPARK-9593] [SQL] [HOTFIX] Makes the Hadoop shims loading fix more robust This is a follow-up of #7929. We found that Jenkins SBT master build still fails because of the Hadoop shims loading issue. But the failure doesn't appear to be deterministic. My suspect is that Hadoop `VersionInfo` class may fail to inspect Hadoop version, and the shims loading branch is skipped. This PR tries to make the fix more robust: 1. When Hadoop version is available, we load `Hadoop20SShims` for versions <= 2.0.x as srowen suggested in PR #7929. 2. Otherwise, we use `Path.getPathWithoutSchemeAndAuthority` as a probe method, which doesn't exist in Hadoop 1.x or 2.0.x. If this method is not found, `Hadoop20SShims` is also loaded. Author: Cheng Lian Closes #7994 from liancheng/spark-9593/fix-hadoop-shims and squashes the following commits: e1d3d70 [Cheng Lian] Fixes typo in comments 8d971da [Cheng Lian] Makes the Hadoop shims loading fix more robust --- .../spark/sql/hive/client/ClientWrapper.scala | 88 ++++++++++++------- 1 file changed, 55 insertions(+), 33 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 211a3b879c1b3..3d05b583cf9e0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -68,45 +68,67 @@ private[hive] class ClientWrapper( // !! HACK ALERT !! // - // This method is a surgical fix for Hadoop version 2.0.0-mr1-cdh4.1.1, which is used by Spark EC2 - // scripts. We should remove this after upgrading Spark EC2 scripts to some more recent Hadoop - // version in the future. - // // Internally, Hive `ShimLoader` tries to load different versions of Hadoop shims by checking - // version information gathered from Hadoop jar files. If the major version number is 1, - // `Hadoop20SShims` will be loaded. Otherwise, if the major version number is 2, `Hadoop23Shims` - // will be chosen. + // major version number gathered from Hadoop jar files: + // + // - For major version number 1, load `Hadoop20SShims`, where "20S" stands for Hadoop 0.20 with + // security. + // - For major version number 2, load `Hadoop23Shims`, where "23" stands for Hadoop 0.23. // - // However, part of APIs in Hadoop 2.0.x and 2.1.x versions were in flux due to historical - // reasons. So 2.0.0-mr1-cdh4.1.1 is actually more Hadoop-1-like and should be used together with - // `Hadoop20SShims`, but `Hadoop20SShims` is chose because the major version number here is 2. + // However, APIs in Hadoop 2.0.x and 2.1.x versions were in flux due to historical reasons. It + // turns out that Hadoop 2.0.x versions should also be used together with `Hadoop20SShims`, but + // `Hadoop23Shims` is chosen because the major version number here is 2. // - // Here we check for this specific version and loads `Hadoop20SShims` via reflection. Note that - // we can't check for string literal "2.0.0-mr1-cdh4.1.1" because the obtained version string - // comes from Maven artifact org.apache.hadoop:hadoop-common:2.0.0-cdh4.1.1, which doesn't have - // the "mr1" tag in its version string. + // To fix this issue, we try to inspect Hadoop version via `org.apache.hadoop.utils.VersionInfo` + // and load `Hadoop20SShims` for Hadoop 1.x and 2.0.x versions. If Hadoop version information is + // not available, we decide whether to override the shims or not by checking for existence of a + // probe method which doesn't exist in Hadoop 1.x or 2.0.x versions. private def overrideHadoopShims(): Unit = { - val VersionPattern = """2\.0\.0.*cdh4.*""".r - - VersionInfo.getVersion match { - case VersionPattern() => - val shimClassName = "org.apache.hadoop.hive.shims.Hadoop20SShims" - logInfo(s"Loading Hadoop shims $shimClassName") - - try { - val shimsField = classOf[ShimLoader].getDeclaredField("hadoopShims") - // scalastyle:off classforname - val shimsClass = Class.forName(shimClassName) - // scalastyle:on classforname - val shims = classOf[HadoopShims].cast(shimsClass.newInstance()) - shimsField.setAccessible(true) - shimsField.set(null, shims) - } catch { case cause: Throwable => - logError(s"Failed to load $shimClassName") - // Falls back to normal Hive `ShimLoader` logic + val hadoopVersion = VersionInfo.getVersion + val VersionPattern = """(\d+)\.(\d+).*""".r + + hadoopVersion match { + case null => + logError("Failed to inspect Hadoop version") + + // Using "Path.getPathWithoutSchemeAndAuthority" as the probe method. + val probeMethod = "getPathWithoutSchemeAndAuthority" + if (!classOf[Path].getDeclaredMethods.exists(_.getName == probeMethod)) { + logInfo( + s"Method ${classOf[Path].getCanonicalName}.$probeMethod not found, " + + s"we are probably using Hadoop 1.x or 2.0.x") + loadHadoop20SShims() + } + + case VersionPattern(majorVersion, minorVersion) => + logInfo(s"Inspected Hadoop version: $hadoopVersion") + + // Loads Hadoop20SShims for 1.x and 2.0.x versions + val (major, minor) = (majorVersion.toInt, minorVersion.toInt) + if (major < 2 || (major == 2 && minor == 0)) { + loadHadoop20SShims() } + } + + // Logs the actual loaded Hadoop shims class + val loadedShimsClassName = ShimLoader.getHadoopShims.getClass.getCanonicalName + logInfo(s"Loaded $loadedShimsClassName for Hadoop version $hadoopVersion") + } - case _ => + private def loadHadoop20SShims(): Unit = { + val hadoop20SShimsClassName = "org.apache.hadoop.hive.shims.Hadoop20SShims" + logInfo(s"Loading Hadoop shims $hadoop20SShimsClassName") + + try { + val shimsField = classOf[ShimLoader].getDeclaredField("hadoopShims") + // scalastyle:off classforname + val shimsClass = Class.forName(hadoop20SShimsClassName) + // scalastyle:on classforname + val shims = classOf[HadoopShims].cast(shimsClass.newInstance()) + shimsField.setAccessible(true) + shimsField.set(null, shims) + } catch { case cause: Throwable => + throw new RuntimeException(s"Failed to load $hadoop20SShimsClassName", cause) } } From c5c6aded641048a3e66ac79d9e84d34e4b1abae7 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Thu, 6 Aug 2015 10:08:33 -0700 Subject: [PATCH 183/340] [SPARK-9112] [ML] Implement Stats for LogisticRegression I have added support for stats in LogisticRegression. The API is similar to that of LinearRegression with LogisticRegressionTrainingSummary and LogisticRegressionSummary I have some queries and asked them inline. Author: MechCoder Closes #7538 from MechCoder/log_reg_stats and squashes the following commits: 2e9f7c7 [MechCoder] Change defs into lazy vals d775371 [MechCoder] Clean up class inheritance 9586125 [MechCoder] Add abstraction to handle Multiclass Metrics 40ad8ef [MechCoder] minor 640376a [MechCoder] remove unnecessary dataframe stuff and add docs 80d9954 [MechCoder] Added tests fbed861 [MechCoder] DataFrame support for metrics 70a0fc4 [MechCoder] [SPARK-9112] [ML] Implement Stats for LogisticRegression --- .../classification/LogisticRegression.scala | 166 +++++++++++++++++- .../JavaLogisticRegressionSuite.java | 9 + .../LogisticRegressionSuite.scala | 37 +++- 3 files changed, 209 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 0d073839259c6..f55134d258857 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -30,10 +30,12 @@ import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.storage.StorageLevel /** @@ -284,7 +286,13 @@ class LogisticRegression(override val uid: String) if (handlePersistence) instances.unpersist() - copyValues(new LogisticRegressionModel(uid, weights, intercept)) + val model = copyValues(new LogisticRegressionModel(uid, weights, intercept)) + val logRegSummary = new BinaryLogisticRegressionTrainingSummary( + model.transform(dataset), + $(probabilityCol), + $(labelCol), + objectiveHistory) + model.setSummary(logRegSummary) } override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra) @@ -319,6 +327,38 @@ class LogisticRegressionModel private[ml] ( override val numClasses: Int = 2 + private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None + + /** + * Gets summary of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + def summary: LogisticRegressionTrainingSummary = trainingSummary match { + case Some(summ) => summ + case None => + throw new SparkException( + "No training summary available for this LogisticRegressionModel", + new NullPointerException()) + } + + private[classification] def setSummary( + summary: LogisticRegressionTrainingSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** Indicates whether a training summary exists for this model instance. */ + def hasSummary: Boolean = trainingSummary.isDefined + + /** + * Evaluates the model on a testset. + * @param dataset Test dataset to evaluate model on. + */ + // TODO: decide on a good name before exposing to public API + private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = { + new BinaryLogisticRegressionSummary(this.transform(dataset), $(probabilityCol), $(labelCol)) + } + /** * Predict label for the given feature vector. * The behavior of this can be adjusted using [[thresholds]]. @@ -440,6 +480,128 @@ private[classification] class MultiClassSummarizer extends Serializable { } } +/** + * Abstraction for multinomial Logistic Regression Training results. + */ +sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary { + + /** objective function (scaled loss + regularization) at each iteration. */ + def objectiveHistory: Array[Double] + + /** Number of training iterations until termination */ + def totalIterations: Int = objectiveHistory.length + +} + +/** + * Abstraction for Logistic Regression Results for a given model. + */ +sealed trait LogisticRegressionSummary extends Serializable { + + /** Dataframe outputted by the model's `transform` method. */ + def predictions: DataFrame + + /** Field in "predictions" which gives the calibrated probability of each sample as a vector. */ + def probabilityCol: String + + /** Field in "predictions" which gives the the true label of each sample. */ + def labelCol: String + +} + +/** + * :: Experimental :: + * Logistic regression training results. + * @param predictions dataframe outputted by the model's `transform` method. + * @param probabilityCol field in "predictions" which gives the calibrated probability of + * each sample as a vector. + * @param labelCol field in "predictions" which gives the true label of each sample. + * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. + */ +@Experimental +class BinaryLogisticRegressionTrainingSummary private[classification] ( + predictions: DataFrame, + probabilityCol: String, + labelCol: String, + val objectiveHistory: Array[Double]) + extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol) + with LogisticRegressionTrainingSummary { + +} + +/** + * :: Experimental :: + * Binary Logistic regression results for a given model. + * @param predictions dataframe outputted by the model's `transform` method. + * @param probabilityCol field in "predictions" which gives the calibrated probability of + * each sample. + * @param labelCol field in "predictions" which gives the true label of each sample. + */ +@Experimental +class BinaryLogisticRegressionSummary private[classification] ( + @transient override val predictions: DataFrame, + override val probabilityCol: String, + override val labelCol: String) extends LogisticRegressionSummary { + + private val sqlContext = predictions.sqlContext + import sqlContext.implicits._ + + /** + * Returns a BinaryClassificationMetrics object. + */ + // TODO: Allow the user to vary the number of bins using a setBins method in + // BinaryClassificationMetrics. For now the default is set to 100. + @transient private val binaryMetrics = new BinaryClassificationMetrics( + predictions.select(probabilityCol, labelCol).map { + case Row(score: Vector, label: Double) => (score(1), label) + }, 100 + ) + + /** + * Returns the receiver operating characteristic (ROC) curve, + * which is an Dataframe having two fields (FPR, TPR) + * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. + * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic + */ + @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR") + + /** + * Computes the area under the receiver operating characteristic (ROC) curve. + */ + lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC() + + /** + * Returns the precision-recall curve, which is an Dataframe containing + * two fields recall, precision with (0.0, 1.0) prepended to it. + */ + @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision") + + /** + * Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0. + */ + @transient lazy val fMeasureByThreshold: DataFrame = { + binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure") + } + + /** + * Returns a dataframe with two fields (threshold, precision) curve. + * Every possible probability obtained in transforming the dataset are used + * as thresholds used in calculating the precision. + */ + @transient lazy val precisionByThreshold: DataFrame = { + binaryMetrics.precisionByThreshold().toDF("threshold", "precision") + } + + /** + * Returns a dataframe with two fields (threshold, recall) curve. + * Every possible probability obtained in transforming the dataset are used + * as thresholds used in calculating the recall. + */ + @transient lazy val recallByThreshold: DataFrame = { + binaryMetrics.recallByThreshold().toDF("threshold", "recall") + } +} + /** * LogisticAggregator computes the gradient and loss for binary logistic loss function, as used * in binary classification for samples in sparse or dense vector in a online fashion. diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index fb1de51163f2e..7e9aa383728f0 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -152,4 +152,13 @@ public void logisticRegressionPredictorClassifierMethods() { } } } + + @Test + public void logisticRegressionTrainingSummary() { + LogisticRegression lr = new LogisticRegression(); + LogisticRegressionModel model = lr.fit(dataset); + + LogisticRegressionTrainingSummary summary = model.summary(); + assert(summary.totalIterations() == summary.objectiveHistory().length); + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index da13dcb42d1ca..8c3d4590f5ae9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -723,6 +723,41 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val weightsR = Vectors.dense(0.0, 0.0, 0.0, 0.0) assert(model1.intercept ~== interceptR relTol 1E-5) - assert(model1.weights ~= weightsR absTol 1E-6) + assert(model1.weights ~== weightsR absTol 1E-6) + } + + test("evaluate on test set") { + // Evaluate on test set should be same as that of the transformed training data. + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(1.0) + .setThreshold(0.6) + val model = lr.fit(dataset) + val summary = model.summary.asInstanceOf[BinaryLogisticRegressionSummary] + + val sameSummary = model.evaluate(dataset).asInstanceOf[BinaryLogisticRegressionSummary] + assert(summary.areaUnderROC === sameSummary.areaUnderROC) + assert(summary.roc.collect() === sameSummary.roc.collect()) + assert(summary.pr.collect === sameSummary.pr.collect()) + assert( + summary.fMeasureByThreshold.collect() === sameSummary.fMeasureByThreshold.collect()) + assert(summary.recallByThreshold.collect() === sameSummary.recallByThreshold.collect()) + assert( + summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect()) + } + + test("statistics on training data") { + // Test that loss is monotonically decreasing. + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(1.0) + .setThreshold(0.6) + val model = lr.fit(dataset) + assert( + model.summary + .objectiveHistory + .sliding(2) + .forall(x => x(0) >= x(1))) + } } From 076ec056818a65216eaf51aa5b3bd8f697c34748 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Thu, 6 Aug 2015 10:09:58 -0700 Subject: [PATCH 184/340] [SPARK-9533] [PYSPARK] [ML] Add missing methods in Word2Vec ML After https://github.com/apache/spark/pull/7263 it is pretty straightforward to Python wrappers. Author: MechCoder Closes #7930 from MechCoder/spark-9533 and squashes the following commits: 1bea394 [MechCoder] make getVectors a lazy val 5522756 [MechCoder] [SPARK-9533] [PySpark] [ML] Add missing methods in Word2Vec ML --- .../apache/spark/ml/feature/Word2Vec.scala | 2 +- python/pyspark/ml/feature.py | 40 +++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index b4f46cef798dd..29acc3eb5865f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -153,7 +153,7 @@ class Word2VecModel private[ml] ( * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and * and the vector the DenseVector that it is mapped to. */ - val getVectors: DataFrame = { + @transient lazy val getVectors: DataFrame = { val sc = SparkContext.getOrCreate() val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 3f04c41ac5ab6..cb4dfa21298ce 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -15,11 +15,16 @@ # limitations under the License. # +import sys +if sys.version > '3': + basestring = str + from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.param.shared import * from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer from pyspark.mllib.common import inherit_doc +from pyspark.mllib.linalg import _convert_to_vector __all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel', @@ -954,6 +959,23 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has >>> sent = ("a b " * 100 + "a c " * 10).split(" ") >>> doc = sqlContext.createDataFrame([(sent,), (sent,)], ["sentence"]) >>> model = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", outputCol="model").fit(doc) + >>> model.getVectors().show() + +----+--------------------+ + |word| vector| + +----+--------------------+ + | a|[-0.3511952459812...| + | b|[0.29077222943305...| + | c|[0.02315592765808...| + +----+--------------------+ + ... + >>> model.findSynonyms("a", 2).show() + +----+-------------------+ + |word| similarity| + +----+-------------------+ + | b|0.29255685145799626| + | c|-0.5414068302988307| + +----+-------------------+ + ... >>> model.transform(doc).head().model DenseVector([-0.0422, -0.5138, -0.2546, 0.6885, 0.276]) """ @@ -1047,6 +1069,24 @@ class Word2VecModel(JavaModel): Model fitted by Word2Vec. """ + def getVectors(self): + """ + Returns the vector representation of the words as a dataframe + with two fields, word and vector. + """ + return self._call_java("getVectors") + + def findSynonyms(self, word, num): + """ + Find "num" number of words closest in similarity to "word". + word can be a string or vector representation. + Returns a dataframe with two fields word and similarity (which + gives the cosine similarity). + """ + if not isinstance(word, basestring): + word = _convert_to_vector(word) + return self._call_java("findSynonyms", word, num) + @inherit_doc class PCA(JavaEstimator, HasInputCol, HasOutputCol): From 98e69467d4fda2c26a951409b5b7c6f1e9345ce4 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 6 Aug 2015 10:29:40 -0700 Subject: [PATCH 185/340] [SPARK-9615] [SPARK-9616] [SQL] [MLLIB] Bugs related to FrequentItems when merging and with Tungsten In short: 1- FrequentItems should not use the InternalRow representation, because the keys in the map get messed up. For example, every key in the Map correspond to the very last element observed in the partition, when the elements are strings. 2- Merging two partitions had a bug: **Existing behavior with size 3** Partition A -> Map(1 -> 3, 2 -> 3, 3 -> 4) Partition B -> Map(4 -> 25) Result -> Map() **Correct Behavior:** Partition A -> Map(1 -> 3, 2 -> 3, 3 -> 4) Partition B -> Map(4 -> 25) Result -> Map(3 -> 1, 4 -> 22) cc mengxr rxin JoshRosen Author: Burak Yavuz Closes #7945 from brkyvz/freq-fix and squashes the following commits: 07fa001 [Burak Yavuz] address 2 1dc61a8 [Burak Yavuz] address 1 506753e [Burak Yavuz] fixed and added reg test 47bfd50 [Burak Yavuz] pushing --- .../sql/execution/stat/FrequentItems.scala | 26 +++++++++++-------- .../apache/spark/sql/DataFrameStatSuite.scala | 24 ++++++++++++++--- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index 9329148aa233c..db463029aedf7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -20,17 +20,15 @@ package org.apache.spark.sql.execution.stat import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.{Row, Column, DataFrame} private[sql] object FrequentItems extends Logging { /** A helper class wrapping `MutableMap[Any, Long]` for simplicity. */ private class FreqItemCounter(size: Int) extends Serializable { val baseMap: MutableMap[Any, Long] = MutableMap.empty[Any, Long] - /** * Add a new example to the counts if it exists, otherwise deduct the count * from existing items. @@ -42,9 +40,15 @@ private[sql] object FrequentItems extends Logging { if (baseMap.size < size) { baseMap += key -> count } else { - // TODO: Make this more efficient... A flatMap? - baseMap.retain((k, v) => v > count) - baseMap.transform((k, v) => v - count) + val minCount = baseMap.values.min + val remainder = count - minCount + if (remainder >= 0) { + baseMap += key -> count // something will get kicked out, so we can add this + baseMap.retain((k, v) => v > minCount) + baseMap.transform((k, v) => v - minCount) + } else { + baseMap.transform((k, v) => v - count) + } } } this @@ -90,12 +94,12 @@ private[sql] object FrequentItems extends Logging { (name, originalSchema.fields(index).dataType) }.toArray - val freqItems = df.select(cols.map(Column(_)) : _*).queryExecution.toRdd.aggregate(countMaps)( + val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)( seqOp = (counts, row) => { var i = 0 while (i < numCols) { val thisMap = counts(i) - val key = row.get(i, colInfo(i)._2) + val key = row.get(i) thisMap.add(key, 1L) i += 1 } @@ -110,13 +114,13 @@ private[sql] object FrequentItems extends Logging { baseCounts } ) - val justItems = freqItems.map(m => m.baseMap.keys.toArray).map(new GenericArrayData(_)) - val resultRow = InternalRow(justItems : _*) + val justItems = freqItems.map(m => m.baseMap.keys.toArray) + val resultRow = Row(justItems : _*) // append frequent Items to the column name for easy debugging val outputCols = colInfo.map { v => StructField(v._1 + "_freqItems", ArrayType(v._2, false)) } val schema = StructType(outputCols).toAttributes - new DataFrame(df.sqlContext, LocalRelation(schema, Seq(resultRow))) + new DataFrame(df.sqlContext, LocalRelation.fromExternalRows(schema, Seq(resultRow))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 07a675e64f527..0e7659f443ecd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -123,12 +123,30 @@ class DataFrameStatSuite extends QueryTest { val results = df.stat.freqItems(Array("numbers", "letters"), 0.1) val items = results.collect().head - items.getSeq[Int](0) should contain (1) - items.getSeq[String](1) should contain (toLetter(1)) + assert(items.getSeq[Int](0).contains(1)) + assert(items.getSeq[String](1).contains(toLetter(1))) val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1) val items2 = singleColResults.collect().head - items2.getSeq[Double](0) should contain (-1.0) + assert(items2.getSeq[Double](0).contains(-1.0)) + } + + test("Frequent Items 2") { + val rows = sqlCtx.sparkContext.parallelize(Seq.empty[Int], 4) + // this is a regression test, where when merging partitions, we omitted values with higher + // counts than those that existed in the map when the map was full. This test should also fail + // if anything like SPARK-9614 is observed once again + val df = rows.mapPartitionsWithIndex { (idx, iter) => + if (idx == 3) { // must come from one of the later merges, therefore higher partition index + Iterator("3", "3", "3", "3", "3") + } else { + Iterator("0", "1", "2", "3", "4") + } + }.toDF("a") + val results = df.stat.freqItems(Array("a"), 0.25) + val items = results.collect().head.getSeq[String](0) + assert(items.contains("3")) + assert(items.length === 1) } test("sampleBy") { From 5e1b0ef07942a041195b3decd05d86c289bc8d2b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 6 Aug 2015 10:39:16 -0700 Subject: [PATCH 186/340] [SPARK-9659][SQL] Rename inSet to isin to match Pandas function. Inspiration drawn from this blog post: https://lab.getbase.com/pandarize-spark-dataframes/ Author: Reynold Xin Closes #7977 from rxin/isin and squashes the following commits: 9b1d3d6 [Reynold Xin] Added return. 2197d37 [Reynold Xin] Fixed test case. 7c1b6cf [Reynold Xin] Import warnings. 4f4a35d [Reynold Xin] [SPARK-9659][SQL] Rename inSet to isin to match Pandas function. --- python/pyspark/sql/column.py | 20 ++++++++++++++++++- .../scala/org/apache/spark/sql/Column.scala | 13 +++++++++++- .../spark/sql/ColumnExpressionSuite.scala | 14 ++++++------- 3 files changed, 38 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 0a85da7443d3d..8af8637cf948d 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -16,6 +16,7 @@ # import sys +import warnings if sys.version >= '3': basestring = str @@ -254,12 +255,29 @@ def inSet(self, *cols): [Row(age=5, name=u'Bob')] >>> df[df.age.inSet([1, 2, 3])].collect() [Row(age=2, name=u'Alice')] + + .. note:: Deprecated in 1.5, use :func:`Column.isin` instead. + """ + warnings.warn("inSet is deprecated. Use isin() instead.") + return self.isin(*cols) + + @ignore_unicode_prefix + @since(1.5) + def isin(self, *cols): + """ + A boolean expression that is evaluated to true if the value of this + expression is contained by the evaluated values of the arguments. + + >>> df[df.name.isin("Bob", "Mike")].collect() + [Row(age=5, name=u'Bob')] + >>> df[df.age.isin([1, 2, 3])].collect() + [Row(age=2, name=u'Alice')] """ if len(cols) == 1 and isinstance(cols[0], (list, set)): cols = cols[0] cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] sc = SparkContext._active_spark_context - jc = getattr(self._jc, "in")(_to_seq(sc, cols)) + jc = getattr(self._jc, "isin")(_to_seq(sc, cols)) return Column(jc) # order diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index b25dcbca82b9f..75365fbcec757 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -627,8 +627,19 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ + @deprecated("use isin", "1.5.0") @scala.annotation.varargs - def in(list: Any*): Column = In(expr, list.map(lit(_).expr)) + def in(list: Any*): Column = isin(list : _*) + + /** + * A boolean expression that is evaluated to true if the value of this expression is contained + * by the evaluated values of the arguments. + * + * @group expr_ops + * @since 1.5.0 + */ + @scala.annotation.varargs + def isin(list: Any*): Column = In(expr, list.map(lit(_).expr)) /** * SQL like expression. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index b351380373259..e1b3443d74993 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -345,23 +345,23 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { test("in") { val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") - checkAnswer(df.filter($"a".in(1, 2)), + checkAnswer(df.filter($"a".isin(1, 2)), df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".in(3, 2)), + checkAnswer(df.filter($"a".isin(3, 2)), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".in(3, 1)), + checkAnswer(df.filter($"a".isin(3, 1)), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - checkAnswer(df.filter($"b".in("y", "x")), + checkAnswer(df.filter($"b".isin("y", "x")), df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x")) - checkAnswer(df.filter($"b".in("z", "x")), + checkAnswer(df.filter($"b".isin("z", "x")), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x")) - checkAnswer(df.filter($"b".in("z", "y")), + checkAnswer(df.filter($"b".isin("z", "y")), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") intercept[AnalysisException] { - df2.filter($"a".in($"b")) + df2.filter($"a".isin($"b")) } } From 6e009cb9c4d7a395991e10dab427f37019283758 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 6 Aug 2015 10:40:54 -0700 Subject: [PATCH 187/340] [SPARK-9632][SQL] update InternalRow.toSeq to make it accept data type info Author: Wenchen Fan Closes #7955 from cloud-fan/toSeq and squashes the following commits: 21665e2 [Wenchen Fan] fix hive again... 4addf29 [Wenchen Fan] fix hive bc16c59 [Wenchen Fan] minor fix 33d802c [Wenchen Fan] pass data type info to InternalRow.toSeq 3dd033e [Wenchen Fan] move the default special getters implementation from InternalRow to BaseGenericInternalRow --- .../spark/sql/catalyst/InternalRow.scala | 132 ++---------------- .../sql/catalyst/expressions/Projection.scala | 12 +- .../expressions/SpecificMutableRow.scala | 5 +- .../codegen/GenerateProjection.scala | 8 +- .../spark/sql/catalyst/expressions/rows.scala | 132 +++++++++++++++++- .../expressions/CodeGenerationSuite.scala | 2 +- .../spark/sql/columnar/ColumnStats.scala | 51 +++---- .../columnar/InMemoryColumnarTableScan.scala | 11 +- .../spark/sql/execution/debug/package.scala | 4 +- .../apache/spark/sql/sources/interfaces.scala | 4 +- .../spark/sql/columnar/ColumnStatsSuite.scala | 54 +++---- .../spark/sql/hive/HiveInspectors.scala | 6 +- .../hive/execution/ScriptTransformation.scala | 21 ++- .../spark/sql/hive/hiveWriterContainers.scala | 24 ++-- .../spark/sql/hive/HiveInspectorSuite.scala | 10 +- 15 files changed, 259 insertions(+), 217 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 7d17cca808791..85b4bf3b6aef5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{DataType, MapData, ArrayData, Decimal} -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.sql.types.{DataType, StructType} /** * An abstract class for row used internal in Spark SQL, which only contain the columns as @@ -32,8 +31,6 @@ abstract class InternalRow extends SpecializedGetters with Serializable { // This is only use for test and will throw a null pointer exception if the position is null. def getString(ordinal: Int): String = getUTF8String(ordinal).toString - override def toString: String = mkString("[", ",", "]") - /** * Make a copy of the current [[InternalRow]] object. */ @@ -50,136 +47,25 @@ abstract class InternalRow extends SpecializedGetters with Serializable { false } - // Subclasses of InternalRow should implement all special getters and equals/hashCode, - // or implement this genericGet. - protected def genericGet(ordinal: Int): Any = throw new IllegalStateException( - "Concrete internal rows should implement genericGet, " + - "or implement all special getters and equals/hashCode") - - // default implementation (slow) - private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T] - override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null - override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal) - override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) - override def getByte(ordinal: Int): Byte = getAs(ordinal) - override def getShort(ordinal: Int): Short = getAs(ordinal) - override def getInt(ordinal: Int): Int = getAs(ordinal) - override def getLong(ordinal: Int): Long = getAs(ordinal) - override def getFloat(ordinal: Int): Float = getAs(ordinal) - override def getDouble(ordinal: Int): Double = getAs(ordinal) - override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) - override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) - override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) - override def getArray(ordinal: Int): ArrayData = getAs(ordinal) - override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) - override def getMap(ordinal: Int): MapData = getAs(ordinal) - override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) - - override def equals(o: Any): Boolean = { - if (!o.isInstanceOf[InternalRow]) { - return false - } - - val other = o.asInstanceOf[InternalRow] - if (other eq null) { - return false - } - - val len = numFields - if (len != other.numFields) { - return false - } - - var i = 0 - while (i < len) { - if (isNullAt(i) != other.isNullAt(i)) { - return false - } - if (!isNullAt(i)) { - val o1 = genericGet(i) - val o2 = other.genericGet(i) - o1 match { - case b1: Array[Byte] => - if (!o2.isInstanceOf[Array[Byte]] || - !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { - return false - } - case f1: Float if java.lang.Float.isNaN(f1) => - if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { - return false - } - case d1: Double if java.lang.Double.isNaN(d1) => - if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { - return false - } - case _ => if (o1 != o2) { - return false - } - } - } - i += 1 - } - true - } - - // Custom hashCode function that matches the efficient code generated version. - override def hashCode: Int = { - var result: Int = 37 - var i = 0 - val len = numFields - while (i < len) { - val update: Int = - if (isNullAt(i)) { - 0 - } else { - genericGet(i) match { - case b: Boolean => if (b) 0 else 1 - case b: Byte => b.toInt - case s: Short => s.toInt - case i: Int => i - case l: Long => (l ^ (l >>> 32)).toInt - case f: Float => java.lang.Float.floatToIntBits(f) - case d: Double => - val b = java.lang.Double.doubleToLongBits(d) - (b ^ (b >>> 32)).toInt - case a: Array[Byte] => java.util.Arrays.hashCode(a) - case other => other.hashCode() - } - } - result = 37 * result + update - i += 1 - } - result - } - /* ---------------------- utility methods for Scala ---------------------- */ /** * Return a Scala Seq representing the row. Elements are placed in the same order in the Seq. */ - // todo: remove this as it needs the generic getter - def toSeq: Seq[Any] = { - val n = numFields - val values = new Array[Any](n) + def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = { + val len = numFields + assert(len == fieldTypes.length) + + val values = new Array[Any](len) var i = 0 - while (i < n) { - values.update(i, genericGet(i)) + while (i < len) { + values(i) = get(i, fieldTypes(i)) i += 1 } values } - /** Displays all elements of this sequence in a string (without a separator). */ - def mkString: String = toSeq.mkString - - /** Displays all elements of this sequence in a string using a separator string. */ - def mkString(sep: String): String = toSeq.mkString(sep) - - /** - * Displays all elements of this traversable or iterator in a string using - * start, end, and separator strings. - */ - def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) + def toSeq(schema: StructType): Seq[Any] = toSeq(schema.map(_.dataType)) } object InternalRow { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 4296b4b123fc0..59ce7fc4f2c63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -203,7 +203,11 @@ class JoinedRow extends InternalRow { this } - override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = { + assert(fieldTypes.length == row1.numFields + row2.numFields) + val (left, right) = fieldTypes.splitAt(row1.numFields) + row1.toSeq(left) ++ row2.toSeq(right) + } override def numFields: Int = row1.numFields + row2.numFields @@ -276,11 +280,11 @@ class JoinedRow extends InternalRow { if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" } else if (row1 eq null) { - row2.mkString("[", ",", "]") + row2.toString } else if (row2 eq null) { - row1.mkString("[", ",", "]") + row1.toString } else { - mkString("[", ",", "]") + s"{${row1.toString} + ${row2.toString}}" } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index b94df6bd66e04..4f56f94bd4ca4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -192,7 +192,8 @@ final class MutableAny extends MutableValue { * based on the dataTypes of each column. The intent is to decrease garbage when modifying the * values of primitive columns. */ -final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableRow { +final class SpecificMutableRow(val values: Array[MutableValue]) + extends MutableRow with BaseGenericInternalRow { def this(dataTypes: Seq[DataType]) = this( @@ -213,8 +214,6 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def numFields: Int = values.length - override def toSeq: Seq[Any] = values.map(_.boxed) - override def setNullAt(i: Int): Unit = { values(i).isNull = true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index c04fe734d554e..c744e84d822e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -25,6 +26,8 @@ import org.apache.spark.sql.types._ */ abstract class BaseProjection extends Projection {} +abstract class CodeGenMutableRow extends MutableRow with BaseGenericInternalRow + /** * Generates bytecode that produces a new [[InternalRow]] object based on a fixed set of input * [[Expression Expressions]] and a given input [[InternalRow]]. The returned [[InternalRow]] @@ -171,7 +174,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { return new SpecificRow((InternalRow) r); } - final class SpecificRow extends ${classOf[MutableRow].getName} { + final class SpecificRow extends ${classOf[CodeGenMutableRow].getName} { $columns @@ -184,7 +187,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } - protected Object genericGet(int i) { + @Override + public Object genericGet(int i) { if (isNullAt(i)) return null; switch (i) { $getCases diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 7657fb535dcf4..207e667792660 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -21,6 +21,130 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ +/** + * An extended version of [[InternalRow]] that implements all special getters, toString + * and equals/hashCode by `genericGet`. + */ +trait BaseGenericInternalRow extends InternalRow { + + protected def genericGet(ordinal: Int): Any + + // default implementation (slow) + private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T] + override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null + override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal) + override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) + override def getByte(ordinal: Int): Byte = getAs(ordinal) + override def getShort(ordinal: Int): Short = getAs(ordinal) + override def getInt(ordinal: Int): Int = getAs(ordinal) + override def getLong(ordinal: Int): Long = getAs(ordinal) + override def getFloat(ordinal: Int): Float = getAs(ordinal) + override def getDouble(ordinal: Int): Double = getAs(ordinal) + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) + override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) + override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) + override def getArray(ordinal: Int): ArrayData = getAs(ordinal) + override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + override def getMap(ordinal: Int): MapData = getAs(ordinal) + override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) + + override def toString(): String = { + if (numFields == 0) { + "[empty row]" + } else { + val sb = new StringBuilder + sb.append("[") + sb.append(genericGet(0)) + val len = numFields + var i = 1 + while (i < len) { + sb.append(",") + sb.append(genericGet(i)) + i += 1 + } + sb.append("]") + sb.toString() + } + } + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[BaseGenericInternalRow]) { + return false + } + + val other = o.asInstanceOf[BaseGenericInternalRow] + if (other eq null) { + return false + } + + val len = numFields + if (len != other.numFields) { + return false + } + + var i = 0 + while (i < len) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = genericGet(i) + val o2 = other.genericGet(i) + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false + } + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + return false + } + case _ => if (o1 != o2) { + return false + } + } + } + i += 1 + } + true + } + + // Custom hashCode function that matches the efficient code generated version. + override def hashCode: Int = { + var result: Int = 37 + var i = 0 + val len = numFields + while (i < len) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + genericGet(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case a: Array[Byte] => java.util.Arrays.hashCode(a) + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } +} + /** * An extended interface to [[InternalRow]] that allows the values for each column to be updated. * Setting a value through a primitive function implicitly marks that column as not null. @@ -82,7 +206,7 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType) * Note that, while the array is not copied, and thus could technically be mutated after creation, * this is not allowed. */ -class GenericInternalRow(private[sql] val values: Array[Any]) extends InternalRow { +class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGenericInternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) @@ -90,7 +214,7 @@ class GenericInternalRow(private[sql] val values: Array[Any]) extends InternalRo override protected def genericGet(ordinal: Int) = values(ordinal) - override def toSeq: Seq[Any] = values + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values override def numFields: Int = values.length @@ -109,7 +233,7 @@ class GenericInternalRowWithSchema(values: Array[Any], val schema: StructType) def fieldIndex(name: String): Int = schema.fieldIndex(name) } -class GenericMutableRow(values: Array[Any]) extends MutableRow { +class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericInternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) @@ -117,7 +241,7 @@ class GenericMutableRow(values: Array[Any]) extends MutableRow { override protected def genericGet(ordinal: Int) = values(ordinal) - override def toSeq: Seq[Any] = values + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values override def numFields: Int = values.length diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index e310aee221666..e323467af5f4a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -87,7 +87,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val length = 5000 val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1))) val plan = GenerateMutableProjection.generate(expressions)() - val actual = plan(new GenericMutableRow(length)).toSeq + val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq.fill(length)(true) if (!checkResult(actual, expected)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index af1a8ecca9b57..5cbd52bc0590e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.columnar import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -66,7 +66,7 @@ private[sql] sealed trait ColumnStats extends Serializable { * Column statistics represented as a single row, currently including closed lower bound, closed * upper bound and null count. */ - def collectedStatistics: InternalRow + def collectedStatistics: GenericInternalRow } /** @@ -75,7 +75,8 @@ private[sql] sealed trait ColumnStats extends Serializable { private[sql] class NoopColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal) - override def collectedStatistics: InternalRow = InternalRow(null, null, nullCount, count, 0L) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](null, null, nullCount, count, 0L)) } private[sql] class BooleanColumnStats extends ColumnStats { @@ -92,8 +93,8 @@ private[sql] class BooleanColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class ByteColumnStats extends ColumnStats { @@ -110,8 +111,8 @@ private[sql] class ByteColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class ShortColumnStats extends ColumnStats { @@ -128,8 +129,8 @@ private[sql] class ShortColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class IntColumnStats extends ColumnStats { @@ -146,8 +147,8 @@ private[sql] class IntColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class LongColumnStats extends ColumnStats { @@ -164,8 +165,8 @@ private[sql] class LongColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class FloatColumnStats extends ColumnStats { @@ -182,8 +183,8 @@ private[sql] class FloatColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class DoubleColumnStats extends ColumnStats { @@ -200,8 +201,8 @@ private[sql] class DoubleColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class StringColumnStats extends ColumnStats { @@ -218,8 +219,8 @@ private[sql] class StringColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class BinaryColumnStats extends ColumnStats { @@ -230,8 +231,8 @@ private[sql] class BinaryColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(null, null, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) } private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { @@ -248,8 +249,8 @@ private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends C } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats { @@ -262,8 +263,8 @@ private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(null, null, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) } private[sql] class DateColumnStats extends IntColumnStats diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 5d5b0697d7016..d553bb6169ecc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -148,7 +148,7 @@ private[sql] case class InMemoryRelation( } val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics) - .flatMap(_.toSeq)) + .flatMap(_.values)) batchStats += stats CachedBatch(columnBuilders.map(_.build().array()), stats) @@ -330,10 +330,11 @@ private[sql] case class InMemoryColumnarTableScan( if (inMemoryPartitionPruningEnabled) { cachedBatchIterator.filter { cachedBatch => if (!partitionFilter(cachedBatch.stats)) { - def statsString: String = relation.partitionStatistics.schema - .zip(cachedBatch.stats.toSeq) - .map { case (a, s) => s"${a.name}: $s" } - .mkString(", ") + def statsString: String = relation.partitionStatistics.schema.zipWithIndex.map { + case (a, i) => + val value = cachedBatch.stats.get(i, a.dataType) + s"${a.name}: $value" + }.mkString(", ") logInfo(s"Skipping partition based on stats $statsString") false } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index c37007f1eece7..dd3858ea2b520 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -156,8 +156,8 @@ package object debug { def typeCheck(data: Any, schema: DataType): Unit = (data, schema) match { case (null, _) => - case (row: InternalRow, StructType(fields)) => - row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } + case (row: InternalRow, s: StructType) => + row.toSeq(s).zip(s.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } case (a: ArrayData, ArrayType(elemType, _)) => a.foreach(elemType, (_, e) => { typeCheck(e, elemType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 7126145ddc010..c04557e5a0818 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -461,8 +461,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio val spec = discoverPartitions() val partitionColumnTypes = spec.partitionColumns.map(_.dataType) val castedPartitions = spec.partitions.map { case p @ Partition(values, path) => - val literals = values.toSeq.zip(partitionColumnTypes).map { - case (value, dataType) => Literal.create(value, dataType) + val literals = partitionColumnTypes.zipWithIndex.map { case (dt, i) => + Literal.create(values.get(i, dt), dt) } val castedValues = partitionSchema.zip(literals).map { case (field, literal) => Cast(literal, field.dataType).eval() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 16e0187ed20a0..d0430d2a60e75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -19,33 +19,36 @@ package org.apache.spark.sql.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.types._ class ColumnStatsSuite extends SparkFunSuite { - testColumnStats(classOf[BooleanColumnStats], BOOLEAN, InternalRow(true, false, 0)) - testColumnStats(classOf[ByteColumnStats], BYTE, InternalRow(Byte.MaxValue, Byte.MinValue, 0)) - testColumnStats(classOf[ShortColumnStats], SHORT, InternalRow(Short.MaxValue, Short.MinValue, 0)) - testColumnStats(classOf[IntColumnStats], INT, InternalRow(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[DateColumnStats], DATE, InternalRow(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[LongColumnStats], LONG, InternalRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[BooleanColumnStats], BOOLEAN, createRow(true, false, 0)) + testColumnStats(classOf[ByteColumnStats], BYTE, createRow(Byte.MaxValue, Byte.MinValue, 0)) + testColumnStats(classOf[ShortColumnStats], SHORT, createRow(Short.MaxValue, Short.MinValue, 0)) + testColumnStats(classOf[IntColumnStats], INT, createRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[DateColumnStats], DATE, createRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[LongColumnStats], LONG, createRow(Long.MaxValue, Long.MinValue, 0)) testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, - InternalRow(Long.MaxValue, Long.MinValue, 0)) - testColumnStats(classOf[FloatColumnStats], FLOAT, InternalRow(Float.MaxValue, Float.MinValue, 0)) + createRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[FloatColumnStats], FLOAT, createRow(Float.MaxValue, Float.MinValue, 0)) testColumnStats(classOf[DoubleColumnStats], DOUBLE, - InternalRow(Double.MaxValue, Double.MinValue, 0)) - testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0)) - testDecimalColumnStats(InternalRow(null, null, 0)) + createRow(Double.MaxValue, Double.MinValue, 0)) + testColumnStats(classOf[StringColumnStats], STRING, createRow(null, null, 0)) + testDecimalColumnStats(createRow(null, null, 0)) + + def createRow(values: Any*): GenericInternalRow = new GenericInternalRow(values.toArray) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], columnType: NativeColumnType[T], - initialStatistics: InternalRow): Unit = { + initialStatistics: GenericInternalRow): Unit = { val columnStatsName = columnStatsClass.getSimpleName test(s"$columnStatsName: empty") { val columnStats = columnStatsClass.newInstance() - columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { + columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { case (actual, expected) => assert(actual === expected) } } @@ -61,11 +64,11 @@ class ColumnStatsSuite extends SparkFunSuite { val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0, null)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1, null)) - assertResult(10, "Wrong null count")(stats.get(2, null)) - assertResult(20, "Wrong row count")(stats.get(3, null)) - assertResult(stats.get(4, null), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) + assertResult(10, "Wrong null count")(stats.values(2)) + assertResult(20, "Wrong row count")(stats.values(3)) + assertResult(stats.values(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum @@ -73,14 +76,15 @@ class ColumnStatsSuite extends SparkFunSuite { } } - def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](initialStatistics: InternalRow) { + def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats]( + initialStatistics: GenericInternalRow): Unit = { val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName val columnType = FIXED_DECIMAL(15, 10) test(s"$columnStatsName: empty") { val columnStats = new FixedDecimalColumnStats(15, 10) - columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { + columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { case (actual, expected) => assert(actual === expected) } } @@ -96,11 +100,11 @@ class ColumnStatsSuite extends SparkFunSuite { val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0, null)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1, null)) - assertResult(10, "Wrong null count")(stats.get(2, null)) - assertResult(20, "Wrong row count")(stats.get(3, null)) - assertResult(stats.get(4, null), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) + assertResult(10, "Wrong null count")(stats.values(2)) + assertResult(20, "Wrong row count")(stats.values(3)) + assertResult(stats.values(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 39d798d072aeb..9824dad239596 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -390,8 +390,10 @@ private[hive] trait HiveInspectors { (o: Any) => { if (o != null) { val struct = soi.create() - (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[InternalRow].toSeq).zipped.foreach { - (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) + val row = o.asInstanceOf[InternalRow] + soi.getAllStructFieldRefs.zip(wrappers).zipWithIndex.foreach { + case ((field, wrapper), i) => + soi.setStructFieldData(struct, field, wrapper(row.get(i, schema(i).dataType))) } struct } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index a6a343d395995..ade27454b9d29 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -88,6 +88,7 @@ case class ScriptTransformation( // external process. That process's output will be read by this current thread. val writerThread = new ScriptTransformationWriterThread( inputIterator, + input.map(_.dataType), outputProjection, inputSerde, inputSoi, @@ -201,6 +202,7 @@ case class ScriptTransformation( private class ScriptTransformationWriterThread( iter: Iterator[InternalRow], + inputSchema: Seq[DataType], outputProjection: Projection, @Nullable inputSerde: AbstractSerDe, @Nullable inputSoi: ObjectInspector, @@ -226,12 +228,25 @@ private class ScriptTransformationWriterThread( // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so // let's use a variable to record whether the `finally` block was hit due to an exception var threwException: Boolean = true + val len = inputSchema.length try { iter.map(outputProjection).foreach { row => if (inputSerde == null) { - val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), - ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") - outputStream.write(data) + val data = if (len == 0) { + ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES") + } else { + val sb = new StringBuilder + sb.append(row.get(0, inputSchema(0))) + var i = 1 + while (i < len) { + sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + sb.append(row.get(i, inputSchema(i))) + i += 1 + } + sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")) + sb.toString() + } + outputStream.write(data.getBytes("utf-8")) } else { val writable = inputSerde.serialize( row.asInstanceOf[GenericInternalRow].values, inputSoi) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 684ea1d137b49..8dc796b056a72 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -211,18 +211,18 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( } } - val dynamicPartPath = dynamicPartColNames - .zip(row.toSeq.takeRight(dynamicPartColNames.length)) - .map { case (col, rawVal) => - val string = if (rawVal == null) null else convertToHiveRawString(col, rawVal) - val colString = - if (string == null || string.isEmpty) { - defaultPartName - } else { - FileUtils.escapePathName(string, defaultPartName) - } - s"/$col=$colString" - }.mkString + val nonDynamicPartLen = row.numFields - dynamicPartColNames.length + val dynamicPartPath = dynamicPartColNames.zipWithIndex.map { case (colName, i) => + val rawVal = row.get(nonDynamicPartLen + i, schema(colName).dataType) + val string = if (rawVal == null) null else convertToHiveRawString(colName, rawVal) + val colString = + if (string == null || string.isEmpty) { + defaultPartName + } else { + FileUtils.escapePathName(string, defaultPartName) + } + s"/$colName=$colString" + }.mkString def newWriter(): FileSinkOperator.RecordWriter = { val newFileSinkDesc = new FileSinkDesc( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 99e95fb921301..81a70b8d42267 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -133,8 +133,8 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { } } - def checkValues(row1: Seq[Any], row2: InternalRow): Unit = { - row1.zip(row2.toSeq).foreach { case (r1, r2) => + def checkValues(row1: Seq[Any], row2: InternalRow, row2Schema: StructType): Unit = { + row1.zip(row2.toSeq(row2Schema)).foreach { case (r1, r2) => checkValue(r1, r2) } } @@ -211,8 +211,10 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { case (t, idx) => StructField(s"c_$idx", t) }) val inspector = toInspector(dt) - checkValues(row, - unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), inspector).asInstanceOf[InternalRow]) + checkValues( + row, + unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), inspector).asInstanceOf[InternalRow], + dt) checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) } From 2eca46a17a3d46a605804ff89c010017da91e1bc Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 6 Aug 2015 11:15:37 -0700 Subject: [PATCH 188/340] Revert "[SPARK-9632][SQL] update InternalRow.toSeq to make it accept data type info" This reverts commit 6e009cb9c4d7a395991e10dab427f37019283758. --- .../spark/sql/catalyst/InternalRow.scala | 132 ++++++++++++++++-- .../sql/catalyst/expressions/Projection.scala | 12 +- .../expressions/SpecificMutableRow.scala | 5 +- .../codegen/GenerateProjection.scala | 8 +- .../spark/sql/catalyst/expressions/rows.scala | 132 +----------------- .../expressions/CodeGenerationSuite.scala | 2 +- .../spark/sql/columnar/ColumnStats.scala | 51 ++++--- .../columnar/InMemoryColumnarTableScan.scala | 11 +- .../spark/sql/execution/debug/package.scala | 4 +- .../apache/spark/sql/sources/interfaces.scala | 4 +- .../spark/sql/columnar/ColumnStatsSuite.scala | 54 ++++--- .../spark/sql/hive/HiveInspectors.scala | 6 +- .../hive/execution/ScriptTransformation.scala | 21 +-- .../spark/sql/hive/hiveWriterContainers.scala | 24 ++-- .../spark/sql/hive/HiveInspectorSuite.scala | 10 +- 15 files changed, 217 insertions(+), 259 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 85b4bf3b6aef5..7d17cca808791 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.{DataType, MapData, ArrayData, Decimal} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** * An abstract class for row used internal in Spark SQL, which only contain the columns as @@ -31,6 +32,8 @@ abstract class InternalRow extends SpecializedGetters with Serializable { // This is only use for test and will throw a null pointer exception if the position is null. def getString(ordinal: Int): String = getUTF8String(ordinal).toString + override def toString: String = mkString("[", ",", "]") + /** * Make a copy of the current [[InternalRow]] object. */ @@ -47,25 +50,136 @@ abstract class InternalRow extends SpecializedGetters with Serializable { false } + // Subclasses of InternalRow should implement all special getters and equals/hashCode, + // or implement this genericGet. + protected def genericGet(ordinal: Int): Any = throw new IllegalStateException( + "Concrete internal rows should implement genericGet, " + + "or implement all special getters and equals/hashCode") + + // default implementation (slow) + private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T] + override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null + override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal) + override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) + override def getByte(ordinal: Int): Byte = getAs(ordinal) + override def getShort(ordinal: Int): Short = getAs(ordinal) + override def getInt(ordinal: Int): Int = getAs(ordinal) + override def getLong(ordinal: Int): Long = getAs(ordinal) + override def getFloat(ordinal: Int): Float = getAs(ordinal) + override def getDouble(ordinal: Int): Double = getAs(ordinal) + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) + override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) + override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) + override def getArray(ordinal: Int): ArrayData = getAs(ordinal) + override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + override def getMap(ordinal: Int): MapData = getAs(ordinal) + override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[InternalRow]) { + return false + } + + val other = o.asInstanceOf[InternalRow] + if (other eq null) { + return false + } + + val len = numFields + if (len != other.numFields) { + return false + } + + var i = 0 + while (i < len) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = genericGet(i) + val o2 = other.genericGet(i) + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false + } + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + return false + } + case _ => if (o1 != o2) { + return false + } + } + } + i += 1 + } + true + } + + // Custom hashCode function that matches the efficient code generated version. + override def hashCode: Int = { + var result: Int = 37 + var i = 0 + val len = numFields + while (i < len) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + genericGet(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case a: Array[Byte] => java.util.Arrays.hashCode(a) + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } + /* ---------------------- utility methods for Scala ---------------------- */ /** * Return a Scala Seq representing the row. Elements are placed in the same order in the Seq. */ - def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = { - val len = numFields - assert(len == fieldTypes.length) - - val values = new Array[Any](len) + // todo: remove this as it needs the generic getter + def toSeq: Seq[Any] = { + val n = numFields + val values = new Array[Any](n) var i = 0 - while (i < len) { - values(i) = get(i, fieldTypes(i)) + while (i < n) { + values.update(i, genericGet(i)) i += 1 } values } - def toSeq(schema: StructType): Seq[Any] = toSeq(schema.map(_.dataType)) + /** Displays all elements of this sequence in a string (without a separator). */ + def mkString: String = toSeq.mkString + + /** Displays all elements of this sequence in a string using a separator string. */ + def mkString(sep: String): String = toSeq.mkString(sep) + + /** + * Displays all elements of this traversable or iterator in a string using + * start, end, and separator strings. + */ + def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) } object InternalRow { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 59ce7fc4f2c63..4296b4b123fc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -203,11 +203,7 @@ class JoinedRow extends InternalRow { this } - override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = { - assert(fieldTypes.length == row1.numFields + row2.numFields) - val (left, right) = fieldTypes.splitAt(row1.numFields) - row1.toSeq(left) ++ row2.toSeq(right) - } + override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq override def numFields: Int = row1.numFields + row2.numFields @@ -280,11 +276,11 @@ class JoinedRow extends InternalRow { if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" } else if (row1 eq null) { - row2.toString + row2.mkString("[", ",", "]") } else if (row2 eq null) { - row1.toString + row1.mkString("[", ",", "]") } else { - s"{${row1.toString} + ${row2.toString}}" + mkString("[", ",", "]") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 4f56f94bd4ca4..b94df6bd66e04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -192,8 +192,7 @@ final class MutableAny extends MutableValue { * based on the dataTypes of each column. The intent is to decrease garbage when modifying the * values of primitive columns. */ -final class SpecificMutableRow(val values: Array[MutableValue]) - extends MutableRow with BaseGenericInternalRow { +final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableRow { def this(dataTypes: Seq[DataType]) = this( @@ -214,6 +213,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) override def numFields: Int = values.length + override def toSeq: Seq[Any] = values.map(_.boxed) + override def setNullAt(i: Int): Unit = { values(i).isNull = true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index c744e84d822e8..c04fe734d554e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -26,8 +25,6 @@ import org.apache.spark.sql.types._ */ abstract class BaseProjection extends Projection {} -abstract class CodeGenMutableRow extends MutableRow with BaseGenericInternalRow - /** * Generates bytecode that produces a new [[InternalRow]] object based on a fixed set of input * [[Expression Expressions]] and a given input [[InternalRow]]. The returned [[InternalRow]] @@ -174,7 +171,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { return new SpecificRow((InternalRow) r); } - final class SpecificRow extends ${classOf[CodeGenMutableRow].getName} { + final class SpecificRow extends ${classOf[MutableRow].getName} { $columns @@ -187,8 +184,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } - @Override - public Object genericGet(int i) { + protected Object genericGet(int i) { if (isNullAt(i)) return null; switch (i) { $getCases diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 207e667792660..7657fb535dcf4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -21,130 +21,6 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -/** - * An extended version of [[InternalRow]] that implements all special getters, toString - * and equals/hashCode by `genericGet`. - */ -trait BaseGenericInternalRow extends InternalRow { - - protected def genericGet(ordinal: Int): Any - - // default implementation (slow) - private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T] - override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null - override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal) - override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) - override def getByte(ordinal: Int): Byte = getAs(ordinal) - override def getShort(ordinal: Int): Short = getAs(ordinal) - override def getInt(ordinal: Int): Int = getAs(ordinal) - override def getLong(ordinal: Int): Long = getAs(ordinal) - override def getFloat(ordinal: Int): Float = getAs(ordinal) - override def getDouble(ordinal: Int): Double = getAs(ordinal) - override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) - override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) - override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) - override def getArray(ordinal: Int): ArrayData = getAs(ordinal) - override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) - override def getMap(ordinal: Int): MapData = getAs(ordinal) - override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) - - override def toString(): String = { - if (numFields == 0) { - "[empty row]" - } else { - val sb = new StringBuilder - sb.append("[") - sb.append(genericGet(0)) - val len = numFields - var i = 1 - while (i < len) { - sb.append(",") - sb.append(genericGet(i)) - i += 1 - } - sb.append("]") - sb.toString() - } - } - - override def equals(o: Any): Boolean = { - if (!o.isInstanceOf[BaseGenericInternalRow]) { - return false - } - - val other = o.asInstanceOf[BaseGenericInternalRow] - if (other eq null) { - return false - } - - val len = numFields - if (len != other.numFields) { - return false - } - - var i = 0 - while (i < len) { - if (isNullAt(i) != other.isNullAt(i)) { - return false - } - if (!isNullAt(i)) { - val o1 = genericGet(i) - val o2 = other.genericGet(i) - o1 match { - case b1: Array[Byte] => - if (!o2.isInstanceOf[Array[Byte]] || - !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { - return false - } - case f1: Float if java.lang.Float.isNaN(f1) => - if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { - return false - } - case d1: Double if java.lang.Double.isNaN(d1) => - if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { - return false - } - case _ => if (o1 != o2) { - return false - } - } - } - i += 1 - } - true - } - - // Custom hashCode function that matches the efficient code generated version. - override def hashCode: Int = { - var result: Int = 37 - var i = 0 - val len = numFields - while (i < len) { - val update: Int = - if (isNullAt(i)) { - 0 - } else { - genericGet(i) match { - case b: Boolean => if (b) 0 else 1 - case b: Byte => b.toInt - case s: Short => s.toInt - case i: Int => i - case l: Long => (l ^ (l >>> 32)).toInt - case f: Float => java.lang.Float.floatToIntBits(f) - case d: Double => - val b = java.lang.Double.doubleToLongBits(d) - (b ^ (b >>> 32)).toInt - case a: Array[Byte] => java.util.Arrays.hashCode(a) - case other => other.hashCode() - } - } - result = 37 * result + update - i += 1 - } - result - } -} - /** * An extended interface to [[InternalRow]] that allows the values for each column to be updated. * Setting a value through a primitive function implicitly marks that column as not null. @@ -206,7 +82,7 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType) * Note that, while the array is not copied, and thus could technically be mutated after creation, * this is not allowed. */ -class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGenericInternalRow { +class GenericInternalRow(private[sql] val values: Array[Any]) extends InternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) @@ -214,7 +90,7 @@ class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGeneri override protected def genericGet(ordinal: Int) = values(ordinal) - override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values + override def toSeq: Seq[Any] = values override def numFields: Int = values.length @@ -233,7 +109,7 @@ class GenericInternalRowWithSchema(values: Array[Any], val schema: StructType) def fieldIndex(name: String): Int = schema.fieldIndex(name) } -class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericInternalRow { +class GenericMutableRow(values: Array[Any]) extends MutableRow { /** No-arg constructor for serialization. */ protected def this() = this(null) @@ -241,7 +117,7 @@ class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericI override protected def genericGet(ordinal: Int) = values(ordinal) - override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values + override def toSeq: Seq[Any] = values override def numFields: Int = values.length diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index e323467af5f4a..e310aee221666 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -87,7 +87,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val length = 5000 val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1))) val plan = GenerateMutableProjection.generate(expressions)() - val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) + val actual = plan(new GenericMutableRow(length)).toSeq val expected = Seq.fill(length)(true) if (!checkResult(actual, expected)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 5cbd52bc0590e..af1a8ecca9b57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.columnar import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Attribute, AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -66,7 +66,7 @@ private[sql] sealed trait ColumnStats extends Serializable { * Column statistics represented as a single row, currently including closed lower bound, closed * upper bound and null count. */ - def collectedStatistics: GenericInternalRow + def collectedStatistics: InternalRow } /** @@ -75,8 +75,7 @@ private[sql] sealed trait ColumnStats extends Serializable { private[sql] class NoopColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal) - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](null, null, nullCount, count, 0L)) + override def collectedStatistics: InternalRow = InternalRow(null, null, nullCount, count, 0L) } private[sql] class BooleanColumnStats extends ColumnStats { @@ -93,8 +92,8 @@ private[sql] class BooleanColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class ByteColumnStats extends ColumnStats { @@ -111,8 +110,8 @@ private[sql] class ByteColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class ShortColumnStats extends ColumnStats { @@ -129,8 +128,8 @@ private[sql] class ShortColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class IntColumnStats extends ColumnStats { @@ -147,8 +146,8 @@ private[sql] class IntColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class LongColumnStats extends ColumnStats { @@ -165,8 +164,8 @@ private[sql] class LongColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class FloatColumnStats extends ColumnStats { @@ -183,8 +182,8 @@ private[sql] class FloatColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class DoubleColumnStats extends ColumnStats { @@ -201,8 +200,8 @@ private[sql] class DoubleColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class StringColumnStats extends ColumnStats { @@ -219,8 +218,8 @@ private[sql] class StringColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class BinaryColumnStats extends ColumnStats { @@ -231,8 +230,8 @@ private[sql] class BinaryColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(null, null, nullCount, count, sizeInBytes) } private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { @@ -249,8 +248,8 @@ private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends C } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats { @@ -263,8 +262,8 @@ private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(null, null, nullCount, count, sizeInBytes) } private[sql] class DateColumnStats extends IntColumnStats diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index d553bb6169ecc..5d5b0697d7016 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -148,7 +148,7 @@ private[sql] case class InMemoryRelation( } val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics) - .flatMap(_.values)) + .flatMap(_.toSeq)) batchStats += stats CachedBatch(columnBuilders.map(_.build().array()), stats) @@ -330,11 +330,10 @@ private[sql] case class InMemoryColumnarTableScan( if (inMemoryPartitionPruningEnabled) { cachedBatchIterator.filter { cachedBatch => if (!partitionFilter(cachedBatch.stats)) { - def statsString: String = relation.partitionStatistics.schema.zipWithIndex.map { - case (a, i) => - val value = cachedBatch.stats.get(i, a.dataType) - s"${a.name}: $value" - }.mkString(", ") + def statsString: String = relation.partitionStatistics.schema + .zip(cachedBatch.stats.toSeq) + .map { case (a, s) => s"${a.name}: $s" } + .mkString(", ") logInfo(s"Skipping partition based on stats $statsString") false } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index dd3858ea2b520..c37007f1eece7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -156,8 +156,8 @@ package object debug { def typeCheck(data: Any, schema: DataType): Unit = (data, schema) match { case (null, _) => - case (row: InternalRow, s: StructType) => - row.toSeq(s).zip(s.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } + case (row: InternalRow, StructType(fields)) => + row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } case (a: ArrayData, ArrayType(elemType, _)) => a.foreach(elemType, (_, e) => { typeCheck(e, elemType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index c04557e5a0818..7126145ddc010 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -461,8 +461,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio val spec = discoverPartitions() val partitionColumnTypes = spec.partitionColumns.map(_.dataType) val castedPartitions = spec.partitions.map { case p @ Partition(values, path) => - val literals = partitionColumnTypes.zipWithIndex.map { case (dt, i) => - Literal.create(values.get(i, dt), dt) + val literals = values.toSeq.zip(partitionColumnTypes).map { + case (value, dataType) => Literal.create(value, dataType) } val castedValues = partitionSchema.zip(literals).map { case (field, literal) => Cast(literal, field.dataType).eval() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index d0430d2a60e75..16e0187ed20a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -19,36 +19,33 @@ package org.apache.spark.sql.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.types._ class ColumnStatsSuite extends SparkFunSuite { - testColumnStats(classOf[BooleanColumnStats], BOOLEAN, createRow(true, false, 0)) - testColumnStats(classOf[ByteColumnStats], BYTE, createRow(Byte.MaxValue, Byte.MinValue, 0)) - testColumnStats(classOf[ShortColumnStats], SHORT, createRow(Short.MaxValue, Short.MinValue, 0)) - testColumnStats(classOf[IntColumnStats], INT, createRow(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[DateColumnStats], DATE, createRow(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[LongColumnStats], LONG, createRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[BooleanColumnStats], BOOLEAN, InternalRow(true, false, 0)) + testColumnStats(classOf[ByteColumnStats], BYTE, InternalRow(Byte.MaxValue, Byte.MinValue, 0)) + testColumnStats(classOf[ShortColumnStats], SHORT, InternalRow(Short.MaxValue, Short.MinValue, 0)) + testColumnStats(classOf[IntColumnStats], INT, InternalRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[DateColumnStats], DATE, InternalRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[LongColumnStats], LONG, InternalRow(Long.MaxValue, Long.MinValue, 0)) testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, - createRow(Long.MaxValue, Long.MinValue, 0)) - testColumnStats(classOf[FloatColumnStats], FLOAT, createRow(Float.MaxValue, Float.MinValue, 0)) + InternalRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[FloatColumnStats], FLOAT, InternalRow(Float.MaxValue, Float.MinValue, 0)) testColumnStats(classOf[DoubleColumnStats], DOUBLE, - createRow(Double.MaxValue, Double.MinValue, 0)) - testColumnStats(classOf[StringColumnStats], STRING, createRow(null, null, 0)) - testDecimalColumnStats(createRow(null, null, 0)) - - def createRow(values: Any*): GenericInternalRow = new GenericInternalRow(values.toArray) + InternalRow(Double.MaxValue, Double.MinValue, 0)) + testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0)) + testDecimalColumnStats(InternalRow(null, null, 0)) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], columnType: NativeColumnType[T], - initialStatistics: GenericInternalRow): Unit = { + initialStatistics: InternalRow): Unit = { val columnStatsName = columnStatsClass.getSimpleName test(s"$columnStatsName: empty") { val columnStats = columnStatsClass.newInstance() - columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { + columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { case (actual, expected) => assert(actual === expected) } } @@ -64,11 +61,11 @@ class ColumnStatsSuite extends SparkFunSuite { val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) - assertResult(10, "Wrong null count")(stats.values(2)) - assertResult(20, "Wrong row count")(stats.values(3)) - assertResult(stats.values(4), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0, null)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1, null)) + assertResult(10, "Wrong null count")(stats.get(2, null)) + assertResult(20, "Wrong row count")(stats.get(3, null)) + assertResult(stats.get(4, null), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum @@ -76,15 +73,14 @@ class ColumnStatsSuite extends SparkFunSuite { } } - def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats]( - initialStatistics: GenericInternalRow): Unit = { + def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](initialStatistics: InternalRow) { val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName val columnType = FIXED_DECIMAL(15, 10) test(s"$columnStatsName: empty") { val columnStats = new FixedDecimalColumnStats(15, 10) - columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { + columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { case (actual, expected) => assert(actual === expected) } } @@ -100,11 +96,11 @@ class ColumnStatsSuite extends SparkFunSuite { val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) - assertResult(10, "Wrong null count")(stats.values(2)) - assertResult(20, "Wrong row count")(stats.values(3)) - assertResult(stats.values(4), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0, null)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1, null)) + assertResult(10, "Wrong null count")(stats.get(2, null)) + assertResult(20, "Wrong row count")(stats.get(3, null)) + assertResult(stats.get(4, null), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 9824dad239596..39d798d072aeb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -390,10 +390,8 @@ private[hive] trait HiveInspectors { (o: Any) => { if (o != null) { val struct = soi.create() - val row = o.asInstanceOf[InternalRow] - soi.getAllStructFieldRefs.zip(wrappers).zipWithIndex.foreach { - case ((field, wrapper), i) => - soi.setStructFieldData(struct, field, wrapper(row.get(i, schema(i).dataType))) + (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[InternalRow].toSeq).zipped.foreach { + (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) } struct } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index ade27454b9d29..a6a343d395995 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -88,7 +88,6 @@ case class ScriptTransformation( // external process. That process's output will be read by this current thread. val writerThread = new ScriptTransformationWriterThread( inputIterator, - input.map(_.dataType), outputProjection, inputSerde, inputSoi, @@ -202,7 +201,6 @@ case class ScriptTransformation( private class ScriptTransformationWriterThread( iter: Iterator[InternalRow], - inputSchema: Seq[DataType], outputProjection: Projection, @Nullable inputSerde: AbstractSerDe, @Nullable inputSoi: ObjectInspector, @@ -228,25 +226,12 @@ private class ScriptTransformationWriterThread( // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so // let's use a variable to record whether the `finally` block was hit due to an exception var threwException: Boolean = true - val len = inputSchema.length try { iter.map(outputProjection).foreach { row => if (inputSerde == null) { - val data = if (len == 0) { - ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES") - } else { - val sb = new StringBuilder - sb.append(row.get(0, inputSchema(0))) - var i = 1 - while (i < len) { - sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD")) - sb.append(row.get(i, inputSchema(i))) - i += 1 - } - sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")) - sb.toString() - } - outputStream.write(data.getBytes("utf-8")) + val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), + ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") + outputStream.write(data) } else { val writable = inputSerde.serialize( row.asInstanceOf[GenericInternalRow].values, inputSoi) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 8dc796b056a72..684ea1d137b49 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -211,18 +211,18 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( } } - val nonDynamicPartLen = row.numFields - dynamicPartColNames.length - val dynamicPartPath = dynamicPartColNames.zipWithIndex.map { case (colName, i) => - val rawVal = row.get(nonDynamicPartLen + i, schema(colName).dataType) - val string = if (rawVal == null) null else convertToHiveRawString(colName, rawVal) - val colString = - if (string == null || string.isEmpty) { - defaultPartName - } else { - FileUtils.escapePathName(string, defaultPartName) - } - s"/$colName=$colString" - }.mkString + val dynamicPartPath = dynamicPartColNames + .zip(row.toSeq.takeRight(dynamicPartColNames.length)) + .map { case (col, rawVal) => + val string = if (rawVal == null) null else convertToHiveRawString(col, rawVal) + val colString = + if (string == null || string.isEmpty) { + defaultPartName + } else { + FileUtils.escapePathName(string, defaultPartName) + } + s"/$col=$colString" + }.mkString def newWriter(): FileSinkOperator.RecordWriter = { val newFileSinkDesc = new FileSinkDesc( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 81a70b8d42267..99e95fb921301 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -133,8 +133,8 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { } } - def checkValues(row1: Seq[Any], row2: InternalRow, row2Schema: StructType): Unit = { - row1.zip(row2.toSeq(row2Schema)).foreach { case (r1, r2) => + def checkValues(row1: Seq[Any], row2: InternalRow): Unit = { + row1.zip(row2.toSeq).foreach { case (r1, r2) => checkValue(r1, r2) } } @@ -211,10 +211,8 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { case (t, idx) => StructField(s"c_$idx", t) }) val inspector = toInspector(dt) - checkValues( - row, - unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), inspector).asInstanceOf[InternalRow], - dt) + checkValues(row, + unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), inspector).asInstanceOf[InternalRow]) checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) } From cdd53b762bf358616b313e3334b5f6945caf9ab1 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 6 Aug 2015 11:15:54 -0700 Subject: [PATCH 189/340] [SPARK-9632] [SQL] [HOT-FIX] Fix build. seems https://github.com/apache/spark/pull/7955 breaks the build. Author: Yin Huai Closes #8001 from yhuai/SPARK-9632-fixBuild and squashes the following commits: 6c257dd [Yin Huai] Fix build. --- .../scala/org/apache/spark/sql/catalyst/expressions/rows.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 7657fb535dcf4..fd42fac3d2cd4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** * An extended interface to [[InternalRow]] that allows the values for each column to be updated. From 0d7aac99da660cc42eb5a9be8e262bd9bd8a770f Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 6 Aug 2015 19:29:42 +0100 Subject: [PATCH 190/340] [SPARK-9641] [DOCS] spark.shuffle.service.port is not documented Document spark.shuffle.service.{enabled,port} CC sryza tgravescs This is pretty minimal; is there more to say here about the service? Author: Sean Owen Closes #7991 from srowen/SPARK-9641 and squashes the following commits: 3bb946e [Sean Owen] Add link to docs for setup and config of external shuffle service 2302e01 [Sean Owen] Document spark.shuffle.service.{enabled,port} --- docs/configuration.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index 24b606356a149..c60dd16839c02 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -473,6 +473,25 @@ Apart from these, the following properties are also available, and may be useful spark.storage.memoryFraction. + + spark.shuffle.service.enabled + false + + Enables the external shuffle service. This service preserves the shuffle files written by + executors so the executors can be safely removed. This must be enabled if + spark.dynamicAllocation.enabled is "true". The external shuffle service + must be set up in order to enable it. See + dynamic allocation + configuration and setup documentation for more information. + + + + spark.shuffle.service.port + 7337 + + Port on which the external shuffle service will run. + + spark.shuffle.sort.bypassMergeThreshold 200 From a1bbf1bc5c51cd796015ac159799cf024de6fa07 Mon Sep 17 00:00:00 2001 From: Nilanjan Raychaudhuri Date: Thu, 6 Aug 2015 12:50:08 -0700 Subject: [PATCH 191/340] [SPARK-8978] [STREAMING] Implements the DirectKafkaRateController MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Author: Dean Wampler Author: Nilanjan Raychaudhuri Author: François Garillot Closes #7796 from dragos/topic/streaming-bp/kafka-direct and squashes the following commits: 50d1f21 [Nilanjan Raychaudhuri] Taking care of the remaining nits 648c8b1 [Dean Wampler] Refactored rate controller test to be more predictable and run faster. e43f678 [Nilanjan Raychaudhuri] fixing doc and nits ce19d2a [Dean Wampler] Removing an unreliable assertion. 9615320 [Dean Wampler] Give me a break... 6372478 [Dean Wampler] Found a few ways to make this test more robust... 9e69e37 [Dean Wampler] Attempt to fix flakey test that fails in CI, but not locally :( d3db1ea [Dean Wampler] Fixing stylecheck errors. d04a288 [Nilanjan Raychaudhuri] adding test to make sure rate controller is used to calculate maxMessagesPerPartition b6ecb67 [Nilanjan Raychaudhuri] Fixed styling issue 3110267 [Nilanjan Raychaudhuri] [SPARK-8978][Streaming] Implements the DirectKafkaRateController 393c580 [François Garillot] [SPARK-8978][Streaming] Implements the DirectKafkaRateController 51e78c6 [Nilanjan Raychaudhuri] Rename and fix build failure 2795509 [Nilanjan Raychaudhuri] Added missing RateController 19200f5 [Dean Wampler] Removed usage of infix notation. Changed a private variable name to be more consistent with usage. aa4a70b [François Garillot] [SPARK-8978][Streaming] Implements the DirectKafkaController --- .../kafka/DirectKafkaInputDStream.scala | 47 ++++++++-- .../kafka/DirectKafkaStreamSuite.scala | 89 +++++++++++++++++++ 2 files changed, 127 insertions(+), 9 deletions(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 48a1933d92f85..8a177077775c6 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -29,7 +29,8 @@ import org.apache.spark.{Logging, SparkException} import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset -import org.apache.spark.streaming.scheduler.StreamInputInfo +import org.apache.spark.streaming.scheduler.{RateController, StreamInputInfo} +import org.apache.spark.streaming.scheduler.rate.RateEstimator /** * A stream of {@link org.apache.spark.streaming.kafka.KafkaRDD} where @@ -61,7 +62,7 @@ class DirectKafkaInputDStream[ val kafkaParams: Map[String, String], val fromOffsets: Map[TopicAndPartition, Long], messageHandler: MessageAndMetadata[K, V] => R -) extends InputDStream[R](ssc_) with Logging { + ) extends InputDStream[R](ssc_) with Logging { val maxRetries = context.sparkContext.getConf.getInt( "spark.streaming.kafka.maxRetries", 1) @@ -71,14 +72,35 @@ class DirectKafkaInputDStream[ protected[streaming] override val checkpointData = new DirectKafkaInputDStreamCheckpointData + + /** + * Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker. + */ + override protected[streaming] val rateController: Option[RateController] = { + if (RateController.isBackPressureEnabled(ssc.conf)) { + Some(new DirectKafkaRateController(id, + RateEstimator.create(ssc.conf, ssc_.graph.batchDuration))) + } else { + None + } + } + protected val kc = new KafkaCluster(kafkaParams) - protected val maxMessagesPerPartition: Option[Long] = { - val ratePerSec = context.sparkContext.getConf.getInt( + private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt( "spark.streaming.kafka.maxRatePerPartition", 0) - if (ratePerSec > 0) { + protected def maxMessagesPerPartition: Option[Long] = { + val estimatedRateLimit = rateController.map(_.getLatestRate().toInt) + val numPartitions = currentOffsets.keys.size + + val effectiveRateLimitPerPartition = estimatedRateLimit + .filter(_ > 0) + .map(limit => Math.min(maxRateLimitPerPartition, (limit / numPartitions))) + .getOrElse(maxRateLimitPerPartition) + + if (effectiveRateLimitPerPartition > 0) { val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 - Some((secsPerBatch * ratePerSec).toLong) + Some((secsPerBatch * effectiveRateLimitPerPartition).toLong) } else { None } @@ -170,11 +192,18 @@ class DirectKafkaInputDStream[ val leaders = KafkaCluster.checkErrors(kc.findLeaders(topics)) batchForTime.toSeq.sortBy(_._1)(Time.ordering).foreach { case (t, b) => - logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}") - generatedRDDs += t -> new KafkaRDD[K, V, U, T, R]( - context.sparkContext, kafkaParams, b.map(OffsetRange(_)), leaders, messageHandler) + logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}") + generatedRDDs += t -> new KafkaRDD[K, V, U, T, R]( + context.sparkContext, kafkaParams, b.map(OffsetRange(_)), leaders, messageHandler) } } } + /** + * A RateController to retrieve the rate from RateEstimator. + */ + private[streaming] class DirectKafkaRateController(id: Int, estimator: RateEstimator) + extends RateController(id, estimator) { + override def publish(rate: Long): Unit = () + } } diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 5b3c79444aa68..02225d5aa7cc5 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -20,6 +20,9 @@ package org.apache.spark.streaming.kafka import java.io.File import java.util.concurrent.atomic.AtomicLong +import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset +import org.apache.spark.streaming.scheduler.rate.RateEstimator + import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ @@ -350,6 +353,77 @@ class DirectKafkaStreamSuite ssc.stop() } + test("using rate controller") { + val topic = "backpressure" + val topicPartition = TopicAndPartition(topic, 0) + kafkaTestUtils.createTopic(topic) + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "smallest" + ) + + val batchIntervalMilliseconds = 100 + val estimator = new ConstantEstimator(100) + val messageKeys = (1 to 200).map(_.toString) + val messages = messageKeys.map((_, 1)).toMap + + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.kafka.maxRatePerPartition", "100") + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds)) + + val kafkaStream = withClue("Error creating direct stream") { + val kc = new KafkaCluster(kafkaParams) + val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message) + val m = kc.getEarliestLeaderOffsets(Set(topicPartition)) + .fold(e => Map.empty[TopicAndPartition, Long], m => m.mapValues(lo => lo.offset)) + + new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)]( + ssc, kafkaParams, m, messageHandler) { + override protected[streaming] val rateController = + Some(new DirectKafkaRateController(id, estimator)) + } + } + + val collectedData = + new mutable.ArrayBuffer[Array[String]]() with mutable.SynchronizedBuffer[Array[String]] + + // Used for assertion failure messages. + def dataToString: String = + collectedData.map(_.mkString("[", ",", "]")).mkString("{", ", ", "}") + + // This is to collect the raw data received from Kafka + kafkaStream.foreachRDD { (rdd: RDD[(String, String)], time: Time) => + val data = rdd.map { _._2 }.collect() + collectedData += data + } + + ssc.start() + + // Try different rate limits. + // Send data to Kafka and wait for arrays of data to appear matching the rate. + Seq(100, 50, 20).foreach { rate => + collectedData.clear() // Empty this buffer on each pass. + estimator.updateRate(rate) // Set a new rate. + // Expect blocks of data equal to "rate", scaled by the interval length in secs. + val expectedSize = Math.round(rate * batchIntervalMilliseconds * 0.001) + kafkaTestUtils.sendMessages(topic, messages) + eventually(timeout(5.seconds), interval(batchIntervalMilliseconds.milliseconds)) { + // Assert that rate estimator values are used to determine maxMessagesPerPartition. + // Funky "-" in message makes the complete assertion message read better. + assert(collectedData.exists(_.size == expectedSize), + s" - No arrays of size $expectedSize for rate $rate found in $dataToString") + } + } + + ssc.stop() + } + /** Get the generated offset ranges from the DirectKafkaStream */ private def getOffsetRanges[K, V]( kafkaStream: DStream[(K, V)]): Seq[(Time, Array[OffsetRange])] = { @@ -381,3 +455,18 @@ object DirectKafkaStreamSuite { } } } + +private[streaming] class ConstantEstimator(@volatile private var rate: Long) + extends RateEstimator { + + def updateRate(newRate: Long): Unit = { + rate = newRate + } + + def compute( + time: Long, + elements: Long, + processingDelay: Long, + schedulingDelay: Long): Option[Double] = Some(rate) +} + From 1f62f104c7a2aeac625b17d9e5ac62f1f10a2b21 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 6 Aug 2015 13:11:59 -0700 Subject: [PATCH 192/340] [SPARK-9632][SQL] update InternalRow.toSeq to make it accept data type info This re-applies #7955, which was reverted due to a race condition to fix build breaking. Author: Wenchen Fan Author: Reynold Xin Closes #8002 from rxin/InternalRow-toSeq and squashes the following commits: 332416a [Reynold Xin] Merge pull request #7955 from cloud-fan/toSeq 21665e2 [Wenchen Fan] fix hive again... 4addf29 [Wenchen Fan] fix hive bc16c59 [Wenchen Fan] minor fix 33d802c [Wenchen Fan] pass data type info to InternalRow.toSeq 3dd033e [Wenchen Fan] move the default special getters implementation from InternalRow to BaseGenericInternalRow --- .../spark/sql/catalyst/InternalRow.scala | 132 ++---------------- .../sql/catalyst/expressions/Projection.scala | 12 +- .../expressions/SpecificMutableRow.scala | 5 +- .../codegen/GenerateProjection.scala | 8 +- .../spark/sql/catalyst/expressions/rows.scala | 132 +++++++++++++++++- .../expressions/CodeGenerationSuite.scala | 2 +- .../spark/sql/columnar/ColumnStats.scala | 51 +++---- .../columnar/InMemoryColumnarTableScan.scala | 11 +- .../spark/sql/execution/debug/package.scala | 4 +- .../apache/spark/sql/sources/interfaces.scala | 4 +- .../spark/sql/columnar/ColumnStatsSuite.scala | 54 +++---- .../spark/sql/hive/HiveInspectors.scala | 6 +- .../hive/execution/ScriptTransformation.scala | 21 ++- .../spark/sql/hive/hiveWriterContainers.scala | 24 ++-- .../spark/sql/hive/HiveInspectorSuite.scala | 10 +- 15 files changed, 259 insertions(+), 217 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 7d17cca808791..85b4bf3b6aef5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{DataType, MapData, ArrayData, Decimal} -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.sql.types.{DataType, StructType} /** * An abstract class for row used internal in Spark SQL, which only contain the columns as @@ -32,8 +31,6 @@ abstract class InternalRow extends SpecializedGetters with Serializable { // This is only use for test and will throw a null pointer exception if the position is null. def getString(ordinal: Int): String = getUTF8String(ordinal).toString - override def toString: String = mkString("[", ",", "]") - /** * Make a copy of the current [[InternalRow]] object. */ @@ -50,136 +47,25 @@ abstract class InternalRow extends SpecializedGetters with Serializable { false } - // Subclasses of InternalRow should implement all special getters and equals/hashCode, - // or implement this genericGet. - protected def genericGet(ordinal: Int): Any = throw new IllegalStateException( - "Concrete internal rows should implement genericGet, " + - "or implement all special getters and equals/hashCode") - - // default implementation (slow) - private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T] - override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null - override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal) - override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) - override def getByte(ordinal: Int): Byte = getAs(ordinal) - override def getShort(ordinal: Int): Short = getAs(ordinal) - override def getInt(ordinal: Int): Int = getAs(ordinal) - override def getLong(ordinal: Int): Long = getAs(ordinal) - override def getFloat(ordinal: Int): Float = getAs(ordinal) - override def getDouble(ordinal: Int): Double = getAs(ordinal) - override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) - override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) - override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) - override def getArray(ordinal: Int): ArrayData = getAs(ordinal) - override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) - override def getMap(ordinal: Int): MapData = getAs(ordinal) - override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) - - override def equals(o: Any): Boolean = { - if (!o.isInstanceOf[InternalRow]) { - return false - } - - val other = o.asInstanceOf[InternalRow] - if (other eq null) { - return false - } - - val len = numFields - if (len != other.numFields) { - return false - } - - var i = 0 - while (i < len) { - if (isNullAt(i) != other.isNullAt(i)) { - return false - } - if (!isNullAt(i)) { - val o1 = genericGet(i) - val o2 = other.genericGet(i) - o1 match { - case b1: Array[Byte] => - if (!o2.isInstanceOf[Array[Byte]] || - !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { - return false - } - case f1: Float if java.lang.Float.isNaN(f1) => - if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { - return false - } - case d1: Double if java.lang.Double.isNaN(d1) => - if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { - return false - } - case _ => if (o1 != o2) { - return false - } - } - } - i += 1 - } - true - } - - // Custom hashCode function that matches the efficient code generated version. - override def hashCode: Int = { - var result: Int = 37 - var i = 0 - val len = numFields - while (i < len) { - val update: Int = - if (isNullAt(i)) { - 0 - } else { - genericGet(i) match { - case b: Boolean => if (b) 0 else 1 - case b: Byte => b.toInt - case s: Short => s.toInt - case i: Int => i - case l: Long => (l ^ (l >>> 32)).toInt - case f: Float => java.lang.Float.floatToIntBits(f) - case d: Double => - val b = java.lang.Double.doubleToLongBits(d) - (b ^ (b >>> 32)).toInt - case a: Array[Byte] => java.util.Arrays.hashCode(a) - case other => other.hashCode() - } - } - result = 37 * result + update - i += 1 - } - result - } - /* ---------------------- utility methods for Scala ---------------------- */ /** * Return a Scala Seq representing the row. Elements are placed in the same order in the Seq. */ - // todo: remove this as it needs the generic getter - def toSeq: Seq[Any] = { - val n = numFields - val values = new Array[Any](n) + def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = { + val len = numFields + assert(len == fieldTypes.length) + + val values = new Array[Any](len) var i = 0 - while (i < n) { - values.update(i, genericGet(i)) + while (i < len) { + values(i) = get(i, fieldTypes(i)) i += 1 } values } - /** Displays all elements of this sequence in a string (without a separator). */ - def mkString: String = toSeq.mkString - - /** Displays all elements of this sequence in a string using a separator string. */ - def mkString(sep: String): String = toSeq.mkString(sep) - - /** - * Displays all elements of this traversable or iterator in a string using - * start, end, and separator strings. - */ - def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) + def toSeq(schema: StructType): Seq[Any] = toSeq(schema.map(_.dataType)) } object InternalRow { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 4296b4b123fc0..59ce7fc4f2c63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -203,7 +203,11 @@ class JoinedRow extends InternalRow { this } - override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = { + assert(fieldTypes.length == row1.numFields + row2.numFields) + val (left, right) = fieldTypes.splitAt(row1.numFields) + row1.toSeq(left) ++ row2.toSeq(right) + } override def numFields: Int = row1.numFields + row2.numFields @@ -276,11 +280,11 @@ class JoinedRow extends InternalRow { if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" } else if (row1 eq null) { - row2.mkString("[", ",", "]") + row2.toString } else if (row2 eq null) { - row1.mkString("[", ",", "]") + row1.toString } else { - mkString("[", ",", "]") + s"{${row1.toString} + ${row2.toString}}" } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index b94df6bd66e04..4f56f94bd4ca4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -192,7 +192,8 @@ final class MutableAny extends MutableValue { * based on the dataTypes of each column. The intent is to decrease garbage when modifying the * values of primitive columns. */ -final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableRow { +final class SpecificMutableRow(val values: Array[MutableValue]) + extends MutableRow with BaseGenericInternalRow { def this(dataTypes: Seq[DataType]) = this( @@ -213,8 +214,6 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def numFields: Int = values.length - override def toSeq: Seq[Any] = values.map(_.boxed) - override def setNullAt(i: Int): Unit = { values(i).isNull = true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index c04fe734d554e..c744e84d822e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -25,6 +26,8 @@ import org.apache.spark.sql.types._ */ abstract class BaseProjection extends Projection {} +abstract class CodeGenMutableRow extends MutableRow with BaseGenericInternalRow + /** * Generates bytecode that produces a new [[InternalRow]] object based on a fixed set of input * [[Expression Expressions]] and a given input [[InternalRow]]. The returned [[InternalRow]] @@ -171,7 +174,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { return new SpecificRow((InternalRow) r); } - final class SpecificRow extends ${classOf[MutableRow].getName} { + final class SpecificRow extends ${classOf[CodeGenMutableRow].getName} { $columns @@ -184,7 +187,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } - protected Object genericGet(int i) { + @Override + public Object genericGet(int i) { if (isNullAt(i)) return null; switch (i) { $getCases diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index fd42fac3d2cd4..11d10b2d8a48b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -22,6 +22,130 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +/** + * An extended version of [[InternalRow]] that implements all special getters, toString + * and equals/hashCode by `genericGet`. + */ +trait BaseGenericInternalRow extends InternalRow { + + protected def genericGet(ordinal: Int): Any + + // default implementation (slow) + private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T] + override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null + override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal) + override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) + override def getByte(ordinal: Int): Byte = getAs(ordinal) + override def getShort(ordinal: Int): Short = getAs(ordinal) + override def getInt(ordinal: Int): Int = getAs(ordinal) + override def getLong(ordinal: Int): Long = getAs(ordinal) + override def getFloat(ordinal: Int): Float = getAs(ordinal) + override def getDouble(ordinal: Int): Double = getAs(ordinal) + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) + override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) + override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) + override def getArray(ordinal: Int): ArrayData = getAs(ordinal) + override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + override def getMap(ordinal: Int): MapData = getAs(ordinal) + override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) + + override def toString(): String = { + if (numFields == 0) { + "[empty row]" + } else { + val sb = new StringBuilder + sb.append("[") + sb.append(genericGet(0)) + val len = numFields + var i = 1 + while (i < len) { + sb.append(",") + sb.append(genericGet(i)) + i += 1 + } + sb.append("]") + sb.toString() + } + } + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[BaseGenericInternalRow]) { + return false + } + + val other = o.asInstanceOf[BaseGenericInternalRow] + if (other eq null) { + return false + } + + val len = numFields + if (len != other.numFields) { + return false + } + + var i = 0 + while (i < len) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = genericGet(i) + val o2 = other.genericGet(i) + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false + } + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + return false + } + case _ => if (o1 != o2) { + return false + } + } + } + i += 1 + } + true + } + + // Custom hashCode function that matches the efficient code generated version. + override def hashCode: Int = { + var result: Int = 37 + var i = 0 + val len = numFields + while (i < len) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + genericGet(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case a: Array[Byte] => java.util.Arrays.hashCode(a) + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } +} + /** * An extended interface to [[InternalRow]] that allows the values for each column to be updated. * Setting a value through a primitive function implicitly marks that column as not null. @@ -83,7 +207,7 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType) * Note that, while the array is not copied, and thus could technically be mutated after creation, * this is not allowed. */ -class GenericInternalRow(private[sql] val values: Array[Any]) extends InternalRow { +class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGenericInternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) @@ -91,7 +215,7 @@ class GenericInternalRow(private[sql] val values: Array[Any]) extends InternalRo override protected def genericGet(ordinal: Int) = values(ordinal) - override def toSeq: Seq[Any] = values + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values override def numFields: Int = values.length @@ -110,7 +234,7 @@ class GenericInternalRowWithSchema(values: Array[Any], val schema: StructType) def fieldIndex(name: String): Int = schema.fieldIndex(name) } -class GenericMutableRow(values: Array[Any]) extends MutableRow { +class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericInternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) @@ -118,7 +242,7 @@ class GenericMutableRow(values: Array[Any]) extends MutableRow { override protected def genericGet(ordinal: Int) = values(ordinal) - override def toSeq: Seq[Any] = values + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values override def numFields: Int = values.length diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index e310aee221666..e323467af5f4a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -87,7 +87,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val length = 5000 val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1))) val plan = GenerateMutableProjection.generate(expressions)() - val actual = plan(new GenericMutableRow(length)).toSeq + val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq.fill(length)(true) if (!checkResult(actual, expected)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index af1a8ecca9b57..5cbd52bc0590e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.columnar import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -66,7 +66,7 @@ private[sql] sealed trait ColumnStats extends Serializable { * Column statistics represented as a single row, currently including closed lower bound, closed * upper bound and null count. */ - def collectedStatistics: InternalRow + def collectedStatistics: GenericInternalRow } /** @@ -75,7 +75,8 @@ private[sql] sealed trait ColumnStats extends Serializable { private[sql] class NoopColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal) - override def collectedStatistics: InternalRow = InternalRow(null, null, nullCount, count, 0L) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](null, null, nullCount, count, 0L)) } private[sql] class BooleanColumnStats extends ColumnStats { @@ -92,8 +93,8 @@ private[sql] class BooleanColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class ByteColumnStats extends ColumnStats { @@ -110,8 +111,8 @@ private[sql] class ByteColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class ShortColumnStats extends ColumnStats { @@ -128,8 +129,8 @@ private[sql] class ShortColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class IntColumnStats extends ColumnStats { @@ -146,8 +147,8 @@ private[sql] class IntColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class LongColumnStats extends ColumnStats { @@ -164,8 +165,8 @@ private[sql] class LongColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class FloatColumnStats extends ColumnStats { @@ -182,8 +183,8 @@ private[sql] class FloatColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class DoubleColumnStats extends ColumnStats { @@ -200,8 +201,8 @@ private[sql] class DoubleColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class StringColumnStats extends ColumnStats { @@ -218,8 +219,8 @@ private[sql] class StringColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class BinaryColumnStats extends ColumnStats { @@ -230,8 +231,8 @@ private[sql] class BinaryColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(null, null, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) } private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { @@ -248,8 +249,8 @@ private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends C } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats { @@ -262,8 +263,8 @@ private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(null, null, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) } private[sql] class DateColumnStats extends IntColumnStats diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 5d5b0697d7016..d553bb6169ecc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -148,7 +148,7 @@ private[sql] case class InMemoryRelation( } val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics) - .flatMap(_.toSeq)) + .flatMap(_.values)) batchStats += stats CachedBatch(columnBuilders.map(_.build().array()), stats) @@ -330,10 +330,11 @@ private[sql] case class InMemoryColumnarTableScan( if (inMemoryPartitionPruningEnabled) { cachedBatchIterator.filter { cachedBatch => if (!partitionFilter(cachedBatch.stats)) { - def statsString: String = relation.partitionStatistics.schema - .zip(cachedBatch.stats.toSeq) - .map { case (a, s) => s"${a.name}: $s" } - .mkString(", ") + def statsString: String = relation.partitionStatistics.schema.zipWithIndex.map { + case (a, i) => + val value = cachedBatch.stats.get(i, a.dataType) + s"${a.name}: $value" + }.mkString(", ") logInfo(s"Skipping partition based on stats $statsString") false } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index c37007f1eece7..dd3858ea2b520 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -156,8 +156,8 @@ package object debug { def typeCheck(data: Any, schema: DataType): Unit = (data, schema) match { case (null, _) => - case (row: InternalRow, StructType(fields)) => - row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } + case (row: InternalRow, s: StructType) => + row.toSeq(s).zip(s.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } case (a: ArrayData, ArrayType(elemType, _)) => a.foreach(elemType, (_, e) => { typeCheck(e, elemType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 7126145ddc010..c04557e5a0818 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -461,8 +461,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio val spec = discoverPartitions() val partitionColumnTypes = spec.partitionColumns.map(_.dataType) val castedPartitions = spec.partitions.map { case p @ Partition(values, path) => - val literals = values.toSeq.zip(partitionColumnTypes).map { - case (value, dataType) => Literal.create(value, dataType) + val literals = partitionColumnTypes.zipWithIndex.map { case (dt, i) => + Literal.create(values.get(i, dt), dt) } val castedValues = partitionSchema.zip(literals).map { case (field, literal) => Cast(literal, field.dataType).eval() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 16e0187ed20a0..d0430d2a60e75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -19,33 +19,36 @@ package org.apache.spark.sql.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.types._ class ColumnStatsSuite extends SparkFunSuite { - testColumnStats(classOf[BooleanColumnStats], BOOLEAN, InternalRow(true, false, 0)) - testColumnStats(classOf[ByteColumnStats], BYTE, InternalRow(Byte.MaxValue, Byte.MinValue, 0)) - testColumnStats(classOf[ShortColumnStats], SHORT, InternalRow(Short.MaxValue, Short.MinValue, 0)) - testColumnStats(classOf[IntColumnStats], INT, InternalRow(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[DateColumnStats], DATE, InternalRow(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[LongColumnStats], LONG, InternalRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[BooleanColumnStats], BOOLEAN, createRow(true, false, 0)) + testColumnStats(classOf[ByteColumnStats], BYTE, createRow(Byte.MaxValue, Byte.MinValue, 0)) + testColumnStats(classOf[ShortColumnStats], SHORT, createRow(Short.MaxValue, Short.MinValue, 0)) + testColumnStats(classOf[IntColumnStats], INT, createRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[DateColumnStats], DATE, createRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[LongColumnStats], LONG, createRow(Long.MaxValue, Long.MinValue, 0)) testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, - InternalRow(Long.MaxValue, Long.MinValue, 0)) - testColumnStats(classOf[FloatColumnStats], FLOAT, InternalRow(Float.MaxValue, Float.MinValue, 0)) + createRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[FloatColumnStats], FLOAT, createRow(Float.MaxValue, Float.MinValue, 0)) testColumnStats(classOf[DoubleColumnStats], DOUBLE, - InternalRow(Double.MaxValue, Double.MinValue, 0)) - testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0)) - testDecimalColumnStats(InternalRow(null, null, 0)) + createRow(Double.MaxValue, Double.MinValue, 0)) + testColumnStats(classOf[StringColumnStats], STRING, createRow(null, null, 0)) + testDecimalColumnStats(createRow(null, null, 0)) + + def createRow(values: Any*): GenericInternalRow = new GenericInternalRow(values.toArray) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], columnType: NativeColumnType[T], - initialStatistics: InternalRow): Unit = { + initialStatistics: GenericInternalRow): Unit = { val columnStatsName = columnStatsClass.getSimpleName test(s"$columnStatsName: empty") { val columnStats = columnStatsClass.newInstance() - columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { + columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { case (actual, expected) => assert(actual === expected) } } @@ -61,11 +64,11 @@ class ColumnStatsSuite extends SparkFunSuite { val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0, null)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1, null)) - assertResult(10, "Wrong null count")(stats.get(2, null)) - assertResult(20, "Wrong row count")(stats.get(3, null)) - assertResult(stats.get(4, null), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) + assertResult(10, "Wrong null count")(stats.values(2)) + assertResult(20, "Wrong row count")(stats.values(3)) + assertResult(stats.values(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum @@ -73,14 +76,15 @@ class ColumnStatsSuite extends SparkFunSuite { } } - def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](initialStatistics: InternalRow) { + def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats]( + initialStatistics: GenericInternalRow): Unit = { val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName val columnType = FIXED_DECIMAL(15, 10) test(s"$columnStatsName: empty") { val columnStats = new FixedDecimalColumnStats(15, 10) - columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { + columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { case (actual, expected) => assert(actual === expected) } } @@ -96,11 +100,11 @@ class ColumnStatsSuite extends SparkFunSuite { val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0, null)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1, null)) - assertResult(10, "Wrong null count")(stats.get(2, null)) - assertResult(20, "Wrong row count")(stats.get(3, null)) - assertResult(stats.get(4, null), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) + assertResult(10, "Wrong null count")(stats.values(2)) + assertResult(20, "Wrong row count")(stats.values(3)) + assertResult(stats.values(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 39d798d072aeb..9824dad239596 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -390,8 +390,10 @@ private[hive] trait HiveInspectors { (o: Any) => { if (o != null) { val struct = soi.create() - (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[InternalRow].toSeq).zipped.foreach { - (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) + val row = o.asInstanceOf[InternalRow] + soi.getAllStructFieldRefs.zip(wrappers).zipWithIndex.foreach { + case ((field, wrapper), i) => + soi.setStructFieldData(struct, field, wrapper(row.get(i, schema(i).dataType))) } struct } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index a6a343d395995..ade27454b9d29 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -88,6 +88,7 @@ case class ScriptTransformation( // external process. That process's output will be read by this current thread. val writerThread = new ScriptTransformationWriterThread( inputIterator, + input.map(_.dataType), outputProjection, inputSerde, inputSoi, @@ -201,6 +202,7 @@ case class ScriptTransformation( private class ScriptTransformationWriterThread( iter: Iterator[InternalRow], + inputSchema: Seq[DataType], outputProjection: Projection, @Nullable inputSerde: AbstractSerDe, @Nullable inputSoi: ObjectInspector, @@ -226,12 +228,25 @@ private class ScriptTransformationWriterThread( // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so // let's use a variable to record whether the `finally` block was hit due to an exception var threwException: Boolean = true + val len = inputSchema.length try { iter.map(outputProjection).foreach { row => if (inputSerde == null) { - val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), - ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") - outputStream.write(data) + val data = if (len == 0) { + ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES") + } else { + val sb = new StringBuilder + sb.append(row.get(0, inputSchema(0))) + var i = 1 + while (i < len) { + sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + sb.append(row.get(i, inputSchema(i))) + i += 1 + } + sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")) + sb.toString() + } + outputStream.write(data.getBytes("utf-8")) } else { val writable = inputSerde.serialize( row.asInstanceOf[GenericInternalRow].values, inputSoi) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 684ea1d137b49..8dc796b056a72 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -211,18 +211,18 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( } } - val dynamicPartPath = dynamicPartColNames - .zip(row.toSeq.takeRight(dynamicPartColNames.length)) - .map { case (col, rawVal) => - val string = if (rawVal == null) null else convertToHiveRawString(col, rawVal) - val colString = - if (string == null || string.isEmpty) { - defaultPartName - } else { - FileUtils.escapePathName(string, defaultPartName) - } - s"/$col=$colString" - }.mkString + val nonDynamicPartLen = row.numFields - dynamicPartColNames.length + val dynamicPartPath = dynamicPartColNames.zipWithIndex.map { case (colName, i) => + val rawVal = row.get(nonDynamicPartLen + i, schema(colName).dataType) + val string = if (rawVal == null) null else convertToHiveRawString(colName, rawVal) + val colString = + if (string == null || string.isEmpty) { + defaultPartName + } else { + FileUtils.escapePathName(string, defaultPartName) + } + s"/$colName=$colString" + }.mkString def newWriter(): FileSinkOperator.RecordWriter = { val newFileSinkDesc = new FileSinkDesc( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 99e95fb921301..81a70b8d42267 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -133,8 +133,8 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { } } - def checkValues(row1: Seq[Any], row2: InternalRow): Unit = { - row1.zip(row2.toSeq).foreach { case (r1, r2) => + def checkValues(row1: Seq[Any], row2: InternalRow, row2Schema: StructType): Unit = { + row1.zip(row2.toSeq(row2Schema)).foreach { case (r1, r2) => checkValue(r1, r2) } } @@ -211,8 +211,10 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { case (t, idx) => StructField(s"c_$idx", t) }) val inspector = toInspector(dt) - checkValues(row, - unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), inspector).asInstanceOf[InternalRow]) + checkValues( + row, + unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), inspector).asInstanceOf[InternalRow], + dt) checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) } From 54c0789a05a783ce90e0e9848079be442a82966b Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 6 Aug 2015 13:29:31 -0700 Subject: [PATCH 193/340] [SPARK-9493] [ML] add featureIndex to handle vector features in IsotonicRegression This PR contains the following changes: * add `featureIndex` to handle vector features (in order to chain isotonic regression easily with output from logistic regression * make getter/setter names consistent with params * remove inheritance from Regressor because it is tricky to handle both `DoubleType` and `VectorType` * simplify test data generation jkbradley zapletal-martin Author: Xiangrui Meng Closes #7952 from mengxr/SPARK-9493 and squashes the following commits: 8818ac3 [Xiangrui Meng] address comments 05e2216 [Xiangrui Meng] address comments 8d08090 [Xiangrui Meng] add featureIndex to handle vector features make getter/setter names consistent with params remove inheritance from Regressor --- .../ml/regression/IsotonicRegression.scala | 202 +++++++++++++----- .../regression/IsotonicRegressionSuite.scala | 82 ++++--- 2 files changed, 194 insertions(+), 90 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 4ece8cf8cf0b6..f570590960a62 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -17,44 +17,113 @@ package org.apache.spark.ml.regression +import org.apache.spark.Logging import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.PredictorParams -import org.apache.spark.ml.param.{Param, ParamMap, BooleanParam} -import org.apache.spark.ml.util.{SchemaUtils, Identifiable} -import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} -import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression, IsotonicRegressionModel => MLlibIsotonicRegressionModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types.{DoubleType, DataType} -import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.{col, lit, udf} +import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.storage.StorageLevel /** * Params for isotonic regression. */ -private[regression] trait IsotonicRegressionParams extends PredictorParams { +private[regression] trait IsotonicRegressionBase extends Params with HasFeaturesCol + with HasLabelCol with HasPredictionCol with Logging { /** - * Param for weight column name. - * TODO: Move weightCol to sharedParams. - * + * Param for weight column name (default: none). * @group param */ + // TODO: Move weightCol to sharedParams. final val weightCol: Param[String] = - new Param[String](this, "weightCol", "weight column name") + new Param[String](this, "weightCol", + "weight column name. If this is not set or empty, we treat all instance weights as 1.0.") /** @group getParam */ final def getWeightCol: String = $(weightCol) /** - * Param for isotonic parameter. - * Isotonic (increasing) or antitonic (decreasing) sequence. + * Param for whether the output sequence should be isotonic/increasing (true) or + * antitonic/decreasing (false). * @group param */ final val isotonic: BooleanParam = - new BooleanParam(this, "isotonic", "isotonic (increasing) or antitonic (decreasing) sequence") + new BooleanParam(this, "isotonic", + "whether the output sequence should be isotonic/increasing (true) or" + + "antitonic/decreasing (false)") /** @group getParam */ - final def getIsotonicParam: Boolean = $(isotonic) + final def getIsotonic: Boolean = $(isotonic) + + /** + * Param for the index of the feature if [[featuresCol]] is a vector column (default: `0`), no + * effect otherwise. + * @group param + */ + final val featureIndex: IntParam = new IntParam(this, "featureIndex", + "The index of the feature if featuresCol is a vector column, no effect otherwise.") + + /** @group getParam */ + final def getFeatureIndex: Int = $(featureIndex) + + setDefault(isotonic -> true, featureIndex -> 0) + + /** Checks whether the input has weight column. */ + protected[ml] def hasWeightCol: Boolean = { + isDefined(weightCol) && $(weightCol) != "" + } + + /** + * Extracts (label, feature, weight) from input dataset. + */ + protected[ml] def extractWeightedLabeledPoints( + dataset: DataFrame): RDD[(Double, Double, Double)] = { + val f = if (dataset.schema($(featuresCol)).dataType.isInstanceOf[VectorUDT]) { + val idx = $(featureIndex) + val extract = udf { v: Vector => v(idx) } + extract(col($(featuresCol))) + } else { + col($(featuresCol)) + } + val w = if (hasWeightCol) { + col($(weightCol)) + } else { + lit(1.0) + } + dataset.select(col($(labelCol)), f, w) + .map { case Row(label: Double, feature: Double, weights: Double) => + (label, feature, weights) + } + } + + /** + * Validates and transforms input schema. + * @param schema input schema + * @param fitting whether this is in fitting or prediction + * @return output schema + */ + protected[ml] def validateAndTransformSchema( + schema: StructType, + fitting: Boolean): StructType = { + if (fitting) { + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + if (hasWeightCol) { + SchemaUtils.checkColumnType(schema, $(weightCol), DoubleType) + } else { + logInfo("The weight column is not defined. Treat all instance weights as 1.0.") + } + } + val featuresType = schema($(featuresCol)).dataType + require(featuresType == DoubleType || featuresType.isInstanceOf[VectorUDT]) + SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) + } } /** @@ -67,52 +136,46 @@ private[regression] trait IsotonicRegressionParams extends PredictorParams { * Uses [[org.apache.spark.mllib.regression.IsotonicRegression]]. */ @Experimental -class IsotonicRegression(override val uid: String) - extends Regressor[Double, IsotonicRegression, IsotonicRegressionModel] - with IsotonicRegressionParams { +class IsotonicRegression(override val uid: String) extends Estimator[IsotonicRegressionModel] + with IsotonicRegressionBase { def this() = this(Identifiable.randomUID("isoReg")) - /** - * Set the isotonic parameter. - * Default is true. - * @group setParam - */ - def setIsotonicParam(value: Boolean): this.type = set(isotonic, value) - setDefault(isotonic -> true) + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) - /** - * Set weight column param. - * Default is weight. - * @group setParam - */ - def setWeightParam(value: String): this.type = set(weightCol, value) - setDefault(weightCol -> "weight") + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) - override private[ml] def featuresDataType: DataType = DoubleType + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) - override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) + /** @group setParam */ + def setIsotonic(value: Boolean): this.type = set(isotonic, value) - private[this] def extractWeightedLabeledPoints( - dataset: DataFrame): RDD[(Double, Double, Double)] = { + /** @group setParam */ + def setWeightCol(value: String): this.type = set(weightCol, value) - dataset.select($(labelCol), $(featuresCol), $(weightCol)) - .map { case Row(label: Double, features: Double, weights: Double) => - (label, features, weights) - } - } + /** @group setParam */ + def setFeatureIndex(value: Int): this.type = set(featureIndex, value) - override protected def train(dataset: DataFrame): IsotonicRegressionModel = { - SchemaUtils.checkColumnType(dataset.schema, $(weightCol), DoubleType) + override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) + + override def fit(dataset: DataFrame): IsotonicRegressionModel = { + validateAndTransformSchema(dataset.schema, fitting = true) // Extract columns from data. If dataset is persisted, do not persist oldDataset. val instances = extractWeightedLabeledPoints(dataset) val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) val isotonicRegression = new MLlibIsotonicRegression().setIsotonic($(isotonic)) - val parentModel = isotonicRegression.run(instances) + val oldModel = isotonicRegression.run(instances) - new IsotonicRegressionModel(uid, parentModel) + copyValues(new IsotonicRegressionModel(uid, oldModel).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = true) } } @@ -123,22 +186,49 @@ class IsotonicRegression(override val uid: String) * * For detailed rules see [[org.apache.spark.mllib.regression.IsotonicRegressionModel.predict()]]. * - * @param parentModel A [[org.apache.spark.mllib.regression.IsotonicRegressionModel]] - * model trained by [[org.apache.spark.mllib.regression.IsotonicRegression]]. + * @param oldModel A [[org.apache.spark.mllib.regression.IsotonicRegressionModel]] + * model trained by [[org.apache.spark.mllib.regression.IsotonicRegression]]. */ +@Experimental class IsotonicRegressionModel private[ml] ( override val uid: String, - private[ml] val parentModel: MLlibIsotonicRegressionModel) - extends RegressionModel[Double, IsotonicRegressionModel] - with IsotonicRegressionParams { + private val oldModel: MLlibIsotonicRegressionModel) + extends Model[IsotonicRegressionModel] with IsotonicRegressionBase { - override def featuresDataType: DataType = DoubleType + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) - override protected def predict(features: Double): Double = { - parentModel.predict(features) - } + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + def setFeatureIndex(value: Int): this.type = set(featureIndex, value) + + /** Boundaries in increasing order for which predictions are known. */ + def boundaries: Vector = Vectors.dense(oldModel.boundaries) + + /** + * Predictions associated with the boundaries at the same index, monotone because of isotonic + * regression. + */ + def predictions: Vector = Vectors.dense(oldModel.predictions) override def copy(extra: ParamMap): IsotonicRegressionModel = { - copyValues(new IsotonicRegressionModel(uid, parentModel), extra) + copyValues(new IsotonicRegressionModel(uid, oldModel), extra) + } + + override def transform(dataset: DataFrame): DataFrame = { + val predict = dataset.schema($(featuresCol)).dataType match { + case DoubleType => + udf { feature: Double => oldModel.predict(feature) } + case _: VectorUDT => + val idx = $(featureIndex) + udf { features: Vector => oldModel.predict(features(idx)) } + } + dataset.withColumn($(predictionCol), predict(col($(featuresCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = false) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index 66e4b170bae80..c0ab00b68a2f3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -19,57 +19,46 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.types.{DoubleType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Row} class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { - private val schema = StructType( - Array( - StructField("label", DoubleType), - StructField("features", DoubleType), - StructField("weight", DoubleType))) - - private val predictionSchema = StructType(Array(StructField("features", DoubleType))) - private def generateIsotonicInput(labels: Seq[Double]): DataFrame = { - val data = Seq.tabulate(labels.size)(i => Row(labels(i), i.toDouble, 1d)) - val parallelData = sc.parallelize(data) - - sqlContext.createDataFrame(parallelData, schema) + sqlContext.createDataFrame( + labels.zipWithIndex.map { case (label, i) => (label, i.toDouble, 1.0) } + ).toDF("label", "features", "weight") } private def generatePredictionInput(features: Seq[Double]): DataFrame = { - val data = Seq.tabulate(features.size)(i => Row(features(i))) - - val parallelData = sc.parallelize(data) - sqlContext.createDataFrame(parallelData, predictionSchema) + sqlContext.createDataFrame(features.map(Tuple1.apply)) + .toDF("features") } test("isotonic regression predictions") { val dataset = generateIsotonicInput(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18)) - val trainer = new IsotonicRegression().setIsotonicParam(true) + val ir = new IsotonicRegression().setIsotonic(true) - val model = trainer.fit(dataset) + val model = ir.fit(dataset) val predictions = model .transform(dataset) - .select("prediction").map { - case Row(pred) => pred + .select("prediction").map { case Row(pred) => + pred }.collect() assert(predictions === Array(1, 2, 2, 2, 6, 16.5, 16.5, 17, 18)) - assert(model.parentModel.boundaries === Array(0, 1, 3, 4, 5, 6, 7, 8)) - assert(model.parentModel.predictions === Array(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0)) - assert(model.parentModel.isotonic) + assert(model.boundaries === Vectors.dense(0, 1, 3, 4, 5, 6, 7, 8)) + assert(model.predictions === Vectors.dense(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0)) + assert(model.getIsotonic) } test("antitonic regression predictions") { val dataset = generateIsotonicInput(Seq(7, 5, 3, 5, 1)) - val trainer = new IsotonicRegression().setIsotonicParam(false) + val ir = new IsotonicRegression().setIsotonic(false) - val model = trainer.fit(dataset) + val model = ir.fit(dataset) val features = generatePredictionInput(Seq(-2.0, -1.0, 0.5, 0.75, 1.0, 2.0, 9.0)) val predictions = model @@ -94,9 +83,10 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val ir = new IsotonicRegression() assert(ir.getLabelCol === "label") assert(ir.getFeaturesCol === "features") - assert(ir.getWeightCol === "weight") assert(ir.getPredictionCol === "prediction") - assert(ir.getIsotonicParam === true) + assert(!ir.isDefined(ir.weightCol)) + assert(ir.getIsotonic) + assert(ir.getFeatureIndex === 0) val model = ir.fit(dataset) model.transform(dataset) @@ -105,21 +95,22 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.getLabelCol === "label") assert(model.getFeaturesCol === "features") - assert(model.getWeightCol === "weight") assert(model.getPredictionCol === "prediction") - assert(model.getIsotonicParam === true) + assert(!model.isDefined(model.weightCol)) + assert(model.getIsotonic) + assert(model.getFeatureIndex === 0) assert(model.hasParent) } test("set parameters") { val isotonicRegression = new IsotonicRegression() - .setIsotonicParam(false) - .setWeightParam("w") + .setIsotonic(false) + .setWeightCol("w") .setFeaturesCol("f") .setLabelCol("l") .setPredictionCol("p") - assert(isotonicRegression.getIsotonicParam === false) + assert(!isotonicRegression.getIsotonic) assert(isotonicRegression.getWeightCol === "w") assert(isotonicRegression.getFeaturesCol === "f") assert(isotonicRegression.getLabelCol === "l") @@ -130,7 +121,7 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val dataset = generateIsotonicInput(Seq(1, 2, 3)) intercept[IllegalArgumentException] { - new IsotonicRegression().setWeightParam("w").fit(dataset) + new IsotonicRegression().setWeightCol("w").fit(dataset) } intercept[IllegalArgumentException] { @@ -145,4 +136,27 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { new IsotonicRegression().fit(dataset).setFeaturesCol("f").transform(dataset) } } + + test("vector features column with feature index") { + val dataset = sqlContext.createDataFrame(Seq( + (4.0, Vectors.dense(0.0, 1.0)), + (3.0, Vectors.dense(0.0, 2.0)), + (5.0, Vectors.sparse(2, Array(1), Array(3.0)))) + ).toDF("label", "features") + + val ir = new IsotonicRegression() + .setFeatureIndex(1) + + val model = ir.fit(dataset) + + val features = generatePredictionInput(Seq(2.0, 3.0, 4.0, 5.0)) + + val predictions = model + .transform(features) + .select("prediction").map { + case Row(pred) => pred + }.collect() + + assert(predictions === Array(3.5, 5.0, 5.0, 5.0)) + } } From abfedb9cd70af60c8290bd2f5a5cec1047845ba0 Mon Sep 17 00:00:00 2001 From: Christian Kadner Date: Thu, 6 Aug 2015 14:15:42 -0700 Subject: [PATCH 194/340] [SPARK-9211] [SQL] [TEST] normalize line separators before generating MD5 hash The golden answer file names for the existing Hive comparison tests were generated using a MD5 hash of the query text which uses Unix-style line separator characters `\n` (LF). This PR ensures that all occurrences of the Windows-style line separator `\r\n` (CR) are replaced with `\n` (LF) before generating the MD5 hash to produce an identical MD5 hash for golden answer file names generated on Windows. Author: Christian Kadner Closes #7563 from ckadner/SPARK-9211_working and squashes the following commits: d541db0 [Christian Kadner] [SPARK-9211][SQL] normalize line separators before MD5 hash --- .../spark/sql/hive/execution/HiveComparisonTest.scala | 2 +- .../apache/spark/sql/hive/execution/HiveQuerySuite.scala | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 638b9c810372a..2bdb0e11878e5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -124,7 +124,7 @@ abstract class HiveComparisonTest protected val cacheDigest = java.security.MessageDigest.getInstance("MD5") protected def getMd5(str: String): String = { val digest = java.security.MessageDigest.getInstance("MD5") - digest.update(str.getBytes("utf-8")) + digest.update(str.replaceAll(System.lineSeparator(), "\n").getBytes("utf-8")) new java.math.BigInteger(1, digest.digest).toString(16) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index edb27553671d1..83f9f3eaa3a5e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -427,7 +427,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' |USING 'cat' AS (tKey, tValue) ROW FORMAT SERDE |'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' FROM src; - """.stripMargin.replaceAll("\n", " ")) + """.stripMargin.replaceAll(System.lineSeparator(), " ")) test("transform with SerDe2") { @@ -446,7 +446,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |('avro.schema.literal'='{"namespace": "testing.hive.avro.serde","name": |"src","type": "record","fields": [{"name":"key","type":"int"}]}') |FROM small_src - """.stripMargin.replaceAll("\n", " ")).collect().head + """.stripMargin.replaceAll(System.lineSeparator(), " ")).collect().head assert(expected(0) === res(0)) } @@ -458,7 +458,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |('serialization.last.column.takes.rest'='true') USING 'cat' AS (tKey, tValue) |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' |WITH SERDEPROPERTIES ('serialization.last.column.takes.rest'='true') FROM src; - """.stripMargin.replaceAll("\n", " ")) + """.stripMargin.replaceAll(System.lineSeparator(), " ")) createQueryTest("transform with SerDe4", """ @@ -467,7 +467,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |('serialization.last.column.takes.rest'='true') USING 'cat' ROW FORMAT SERDE |'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES |('serialization.last.column.takes.rest'='true') FROM src; - """.stripMargin.replaceAll("\n", " ")) + """.stripMargin.replaceAll(System.lineSeparator(), " ")) createQueryTest("LIKE", "SELECT * FROM src WHERE value LIKE '%1%'") From 21fdfd7d6f89adbd37066c169e6ba9ccd337683e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 6 Aug 2015 14:33:29 -0700 Subject: [PATCH 195/340] [SPARK-9548][SQL] Add a destructive iterator for BytesToBytesMap This pull request adds a destructive iterator to BytesToBytesMap. When used, the iterator frees pages as it traverses them. This is part of the effort to avoid starving when we have more than one operators that can exhaust memory. This is based on #7924, but fixes a bug there (Don't use destructive iterator in UnsafeKVExternalSorter). Closes #7924. Author: Liang-Chi Hsieh Author: Reynold Xin Closes #8003 from rxin/map-destructive-iterator and squashes the following commits: 6b618c3 [Reynold Xin] Don't use destructive iterator in UnsafeKVExternalSorter. a7bd8ec [Reynold Xin] Merge remote-tracking branch 'viirya/destructive_iter' into map-destructive-iterator 7652083 [Liang-Chi Hsieh] For comments: add destructiveIterator(), modify unit test, remove code block. 4a3e9de [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into destructive_iter 581e9e3 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into destructive_iter f0ff783 [Liang-Chi Hsieh] No need to free last page. 9e9d2a3 [Liang-Chi Hsieh] Add a destructive iterator for BytesToBytesMap. --- .../spark/unsafe/map/BytesToBytesMap.java | 33 +++++++++++++++-- .../map/AbstractBytesToBytesMapSuite.java | 37 ++++++++++++++++--- .../UnsafeFixedWidthAggregationMap.java | 7 +++- .../sql/execution/UnsafeKVExternalSorter.java | 5 ++- 4 files changed, 71 insertions(+), 11 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 20347433e16b2..5ac3736ac62aa 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -227,22 +227,35 @@ public static final class BytesToBytesMapIterator implements Iterator private final Iterator dataPagesIterator; private final Location loc; - private MemoryBlock currentPage; + private MemoryBlock currentPage = null; private int currentRecordNumber = 0; private Object pageBaseObject; private long offsetInPage; + // If this iterator destructive or not. When it is true, it frees each page as it moves onto + // next one. + private boolean destructive = false; + private BytesToBytesMap bmap; + private BytesToBytesMapIterator( - int numRecords, Iterator dataPagesIterator, Location loc) { + int numRecords, Iterator dataPagesIterator, Location loc, + boolean destructive, BytesToBytesMap bmap) { this.numRecords = numRecords; this.dataPagesIterator = dataPagesIterator; this.loc = loc; + this.destructive = destructive; + this.bmap = bmap; if (dataPagesIterator.hasNext()) { advanceToNextPage(); } } private void advanceToNextPage() { + if (destructive && currentPage != null) { + dataPagesIterator.remove(); + this.bmap.taskMemoryManager.freePage(currentPage); + this.bmap.shuffleMemoryManager.release(currentPage.size()); + } currentPage = dataPagesIterator.next(); pageBaseObject = currentPage.getBaseObject(); offsetInPage = currentPage.getBaseOffset(); @@ -281,7 +294,21 @@ public void remove() { * `lookup()`, the behavior of the returned iterator is undefined. */ public BytesToBytesMapIterator iterator() { - return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc); + return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc, false, this); + } + + /** + * Returns a destructive iterator for iterating over the entries of this map. It frees each page + * as it moves onto next one. Notice: it is illegal to call any method on the map after + * `destructiveIterator()` has been called. + * + * For efficiency, all calls to `next()` will return the same {@link Location} object. + * + * If any other lookups or operations are performed on this map while iterating over it, including + * `lookup()`, the behavior of the returned iterator is undefined. + */ + public BytesToBytesMapIterator destructiveIterator() { + return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc, true, this); } /** diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 0e23a64fb74bb..3c5003380162f 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -183,8 +183,7 @@ public void setAndRetrieveAKey() { } } - @Test - public void iteratorTest() throws Exception { + private void iteratorTestBase(boolean destructive) throws Exception { final int size = 4096; BytesToBytesMap map = new BytesToBytesMap( taskMemoryManager, shuffleMemoryManager, size / 2, PAGE_SIZE_BYTES); @@ -216,7 +215,14 @@ public void iteratorTest() throws Exception { } } final java.util.BitSet valuesSeen = new java.util.BitSet(size); - final Iterator iter = map.iterator(); + final Iterator iter; + if (destructive) { + iter = map.destructiveIterator(); + } else { + iter = map.iterator(); + } + int numPages = map.getNumDataPages(); + int countFreedPages = 0; while (iter.hasNext()) { final BytesToBytesMap.Location loc = iter.next(); Assert.assertTrue(loc.isDefined()); @@ -228,11 +234,22 @@ public void iteratorTest() throws Exception { if (keyLength == 0) { Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0); } else { - final long key = PlatformDependent.UNSAFE.getLong( - keyAddress.getBaseObject(), keyAddress.getBaseOffset()); + final long key = PlatformDependent.UNSAFE.getLong( + keyAddress.getBaseObject(), keyAddress.getBaseOffset()); Assert.assertEquals(value, key); } valuesSeen.set((int) value); + if (destructive) { + // The iterator moves onto next page and frees previous page + if (map.getNumDataPages() < numPages) { + numPages = map.getNumDataPages(); + countFreedPages++; + } + } + } + if (destructive) { + // Latest page is not freed by iterator but by map itself + Assert.assertEquals(countFreedPages, numPages - 1); } Assert.assertEquals(size, valuesSeen.cardinality()); } finally { @@ -240,6 +257,16 @@ public void iteratorTest() throws Exception { } } + @Test + public void iteratorTest() throws Exception { + iteratorTestBase(false); + } + + @Test + public void destructiveIteratorTest() throws Exception { + iteratorTestBase(true); + } + @Test public void iteratingOverDataPagesWithWastedSpace() throws Exception { final int NUM_ENTRIES = 1000 * 1000; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 02458030b00e9..efb33530dac86 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -154,14 +154,17 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { } /** - * Returns an iterator over the keys and values in this map. + * Returns an iterator over the keys and values in this map. This uses destructive iterator of + * BytesToBytesMap. So it is illegal to call any other method on this map after `iterator()` has + * been called. * * For efficiency, each call returns the same object. */ public KVIterator iterator() { return new KVIterator() { - private final BytesToBytesMap.BytesToBytesMapIterator mapLocationIterator = map.iterator(); + private final BytesToBytesMap.BytesToBytesMapIterator mapLocationIterator = + map.destructiveIterator(); private final UnsafeRow key = new UnsafeRow(); private final UnsafeRow value = new UnsafeRow(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 6c1cf136d9b81..9a65c9d3a404a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -88,8 +88,11 @@ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( taskMemoryManager, recordComparator, prefixComparator, Math.max(1, map.numElements())); - final int numKeyFields = keySchema.size(); + // We cannot use the destructive iterator here because we are reusing the existing memory + // pages in BytesToBytesMap to hold records during sorting. + // The only new memory we are allocating is the pointer/prefix array. BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator(); + final int numKeyFields = keySchema.size(); UnsafeRow row = new UnsafeRow(); while (iter.hasNext()) { final BytesToBytesMap.Location loc = iter.next(); From 0a078303d08ad2bb92b9a8a6969563d75b512290 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 6 Aug 2015 14:35:30 -0700 Subject: [PATCH 196/340] [SPARK-9556] [SPARK-9619] [SPARK-9624] [STREAMING] Make BlockGenerator more robust and make all BlockGenerators subscribe to rate limit updates In some receivers, instead of using the default `BlockGenerator` in `ReceiverSupervisorImpl`, custom generator with their custom listeners are used for reliability (see [`ReliableKafkaReceiver`](https://github.com/apache/spark/blob/master/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala#L99) and [updated `KinesisReceiver`](https://github.com/apache/spark/pull/7825/files)). These custom generators do not receive rate updates. This PR modifies the code to allow custom `BlockGenerator`s to be created through the `ReceiverSupervisorImpl` so that they can be kept track and rate updates can be applied. In the process, I did some simplification, and de-flaki-fication of some rate controller related tests. In particular. - Renamed `Receiver.executor` to `Receiver.supervisor` (to match `ReceiverSupervisor`) - Made `RateControllerSuite` faster (by increasing batch interval) and less flaky - Changed a few internal API to return the current rate of block generators as Long instead of Option\[Long\] (was inconsistent at places). - Updated existing `ReceiverTrackerSuite` to test that custom block generators get rate updates as well. Author: Tathagata Das Closes #7913 from tdas/SPARK-9556 and squashes the following commits: 41d4461 [Tathagata Das] fix scala style eb9fd59 [Tathagata Das] Updated kinesis receiver d24994d [Tathagata Das] Updated BlockGeneratorSuite to use manual clock in BlockGenerator d70608b [Tathagata Das] Updated BlockGenerator with states and proper synchronization f6bd47e [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into SPARK-9556 31da173 [Tathagata Das] Fix bug 12116df [Tathagata Das] Add BlockGeneratorSuite 74bd069 [Tathagata Das] Fix style 989bb5c [Tathagata Das] Made BlockGenerator fail is used after stop, and added better unit tests for it 3ff618c [Tathagata Das] Fix test b40eff8 [Tathagata Das] slight refactoring f0df0f1 [Tathagata Das] Scala style fixes 51759cb [Tathagata Das] Refactored rate controller tests and added the ability to update rate of any custom block generator --- .../org/apache/spark/util/ManualClock.scala | 2 +- .../kafka/ReliableKafkaReceiver.scala | 2 +- .../streaming/kinesis/KinesisReceiver.scala | 2 +- .../streaming/receiver/ActorReceiver.scala | 8 +- .../streaming/receiver/BlockGenerator.scala | 131 ++++++--- .../streaming/receiver/RateLimiter.scala | 3 +- .../spark/streaming/receiver/Receiver.scala | 52 ++-- .../receiver/ReceiverSupervisor.scala | 27 +- .../receiver/ReceiverSupervisorImpl.scala | 33 ++- .../spark/streaming/CheckpointSuite.scala | 16 +- .../spark/streaming/ReceiverSuite.scala | 31 +-- .../receiver/BlockGeneratorSuite.scala | 253 ++++++++++++++++++ .../scheduler/RateControllerSuite.scala | 64 ++--- .../scheduler/ReceiverTrackerSuite.scala | 129 +++++---- 14 files changed, 534 insertions(+), 219 deletions(-) create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala diff --git a/core/src/main/scala/org/apache/spark/util/ManualClock.scala b/core/src/main/scala/org/apache/spark/util/ManualClock.scala index 1718554061985..e7a65d74a440e 100644 --- a/core/src/main/scala/org/apache/spark/util/ManualClock.scala +++ b/core/src/main/scala/org/apache/spark/util/ManualClock.scala @@ -58,7 +58,7 @@ private[spark] class ManualClock(private var time: Long) extends Clock { */ def waitTillTime(targetTime: Long): Long = synchronized { while (time < targetTime) { - wait(100) + wait(10) } getTimeMillis() } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala index 75f0dfc22b9dc..764d170934aa6 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala @@ -96,7 +96,7 @@ class ReliableKafkaReceiver[ blockOffsetMap = new ConcurrentHashMap[StreamBlockId, Map[TopicAndPartition, Long]]() // Initialize the block generator for storing Kafka message. - blockGenerator = new BlockGenerator(new GeneratedBlockHandler, streamId, conf) + blockGenerator = supervisor.createBlockGenerator(new GeneratedBlockHandler) if (kafkaParams.contains(AUTO_OFFSET_COMMIT) && kafkaParams(AUTO_OFFSET_COMMIT) == "true") { logWarning(s"$AUTO_OFFSET_COMMIT should be set to false in ReliableKafkaReceiver, " + diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index a4baeec0846b4..22324e821ce94 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -136,7 +136,7 @@ private[kinesis] class KinesisReceiver( * The KCL creates and manages the receiving/processing thread pool through Worker.run(). */ override def onStart() { - blockGenerator = new BlockGenerator(new GeneratedBlockHandler, streamId, SparkEnv.get.conf) + blockGenerator = supervisor.createBlockGenerator(new GeneratedBlockHandler) workerId = Utils.localHostName() + ":" + UUID.randomUUID() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala index cd309788a7717..7ec74016a1c2c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala @@ -144,7 +144,7 @@ private[streaming] class ActorReceiver[T: ClassTag]( receiverSupervisorStrategy: SupervisorStrategy ) extends Receiver[T](storageLevel) with Logging { - protected lazy val supervisor = SparkEnv.get.actorSystem.actorOf(Props(new Supervisor), + protected lazy val actorSupervisor = SparkEnv.get.actorSystem.actorOf(Props(new Supervisor), "Supervisor" + streamId) class Supervisor extends Actor { @@ -191,11 +191,11 @@ private[streaming] class ActorReceiver[T: ClassTag]( } def onStart(): Unit = { - supervisor - logInfo("Supervision tree for receivers initialized at:" + supervisor.path) + actorSupervisor + logInfo("Supervision tree for receivers initialized at:" + actorSupervisor.path) } def onStop(): Unit = { - supervisor ! PoisonPill + actorSupervisor ! PoisonPill } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 92b51ce39234c..794dece370b2c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -21,10 +21,10 @@ import java.util.concurrent.{ArrayBlockingQueue, TimeUnit} import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{SparkException, Logging, SparkConf} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.util.RecurringTimer -import org.apache.spark.util.SystemClock +import org.apache.spark.util.{Clock, SystemClock} /** Listener object for BlockGenerator events */ private[streaming] trait BlockGeneratorListener { @@ -69,16 +69,35 @@ private[streaming] trait BlockGeneratorListener { * named blocks at regular intervals. This class starts two threads, * one to periodically start a new batch and prepare the previous batch of as a block, * the other to push the blocks into the block manager. + * + * Note: Do not create BlockGenerator instances directly inside receivers. Use + * `ReceiverSupervisor.createBlockGenerator` to create a BlockGenerator and use it. */ private[streaming] class BlockGenerator( listener: BlockGeneratorListener, receiverId: Int, - conf: SparkConf + conf: SparkConf, + clock: Clock = new SystemClock() ) extends RateLimiter(conf) with Logging { private case class Block(id: StreamBlockId, buffer: ArrayBuffer[Any]) - private val clock = new SystemClock() + /** + * The BlockGenerator can be in 5 possible states, in the order as follows. + * - Initialized: Nothing has been started + * - Active: start() has been called, and it is generating blocks on added data. + * - StoppedAddingData: stop() has been called, the adding of data has been stopped, + * but blocks are still being generated and pushed. + * - StoppedGeneratingBlocks: Generating of blocks has been stopped, but + * they are still being pushed. + * - StoppedAll: Everything has stopped, and the BlockGenerator object can be GCed. + */ + private object GeneratorState extends Enumeration { + type GeneratorState = Value + val Initialized, Active, StoppedAddingData, StoppedGeneratingBlocks, StoppedAll = Value + } + import GeneratorState._ + private val blockIntervalMs = conf.getTimeAsMs("spark.streaming.blockInterval", "200ms") require(blockIntervalMs > 0, s"'spark.streaming.blockInterval' should be a positive value") @@ -89,59 +108,100 @@ private[streaming] class BlockGenerator( private val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } @volatile private var currentBuffer = new ArrayBuffer[Any] - @volatile private var stopped = false + @volatile private var state = Initialized /** Start block generating and pushing threads. */ - def start() { - blockIntervalTimer.start() - blockPushingThread.start() - logInfo("Started BlockGenerator") + def start(): Unit = synchronized { + if (state == Initialized) { + state = Active + blockIntervalTimer.start() + blockPushingThread.start() + logInfo("Started BlockGenerator") + } else { + throw new SparkException( + s"Cannot start BlockGenerator as its not in the Initialized state [state = $state]") + } } - /** Stop all threads. */ - def stop() { + /** + * Stop everything in the right order such that all the data added is pushed out correctly. + * - First, stop adding data to the current buffer. + * - Second, stop generating blocks. + * - Finally, wait for queue of to-be-pushed blocks to be drained. + */ + def stop(): Unit = { + // Set the state to stop adding data + synchronized { + if (state == Active) { + state = StoppedAddingData + } else { + logWarning(s"Cannot stop BlockGenerator as its not in the Active state [state = $state]") + return + } + } + + // Stop generating blocks and set the state for block pushing thread to start draining the queue logInfo("Stopping BlockGenerator") blockIntervalTimer.stop(interruptTimer = false) - stopped = true - logInfo("Waiting for block pushing thread") + synchronized { state = StoppedGeneratingBlocks } + + // Wait for the queue to drain and mark generated as stopped + logInfo("Waiting for block pushing thread to terminate") blockPushingThread.join() + synchronized { state = StoppedAll } logInfo("Stopped BlockGenerator") } /** - * Push a single data item into the buffer. All received data items - * will be periodically pushed into BlockManager. + * Push a single data item into the buffer. */ - def addData (data: Any): Unit = synchronized { - waitToPush() - currentBuffer += data + def addData(data: Any): Unit = synchronized { + if (state == Active) { + waitToPush() + currentBuffer += data + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } } /** * Push a single data item into the buffer. After buffering the data, the - * `BlockGeneratorListener.onAddData` callback will be called. All received data items - * will be periodically pushed into BlockManager. + * `BlockGeneratorListener.onAddData` callback will be called. */ def addDataWithCallback(data: Any, metadata: Any): Unit = synchronized { - waitToPush() - currentBuffer += data - listener.onAddData(data, metadata) + if (state == Active) { + waitToPush() + currentBuffer += data + listener.onAddData(data, metadata) + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } } /** * Push multiple data items into the buffer. After buffering the data, the - * `BlockGeneratorListener.onAddData` callback will be called. All received data items - * will be periodically pushed into BlockManager. Note that all the data items is guaranteed - * to be present in a single block. + * `BlockGeneratorListener.onAddData` callback will be called. Note that all the data items + * are atomically added to the buffer, and are hence guaranteed to be present in a single block. */ def addMultipleDataWithCallback(dataIterator: Iterator[Any], metadata: Any): Unit = synchronized { - dataIterator.foreach { data => - waitToPush() - currentBuffer += data + if (state == Active) { + dataIterator.foreach { data => + waitToPush() + currentBuffer += data + } + listener.onAddData(dataIterator, metadata) + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") } - listener.onAddData(dataIterator, metadata) } + def isActive(): Boolean = state == Active + + def isStopped(): Boolean = state == StoppedAll + /** Change the buffer to which single records are added to. */ private def updateCurrentBuffer(time: Long): Unit = synchronized { try { @@ -165,18 +225,21 @@ private[streaming] class BlockGenerator( /** Keep pushing blocks to the BlockManager. */ private def keepPushingBlocks() { logInfo("Started block pushing thread") + + def isGeneratingBlocks = synchronized { state == Active || state == StoppedAddingData } try { - while (!stopped) { - Option(blocksForPushing.poll(100, TimeUnit.MILLISECONDS)) match { + while (isGeneratingBlocks) { + Option(blocksForPushing.poll(10, TimeUnit.MILLISECONDS)) match { case Some(block) => pushBlock(block) case None => } } - // Push out the blocks that are still left + + // At this point, state is StoppedGeneratingBlock. So drain the queue of to-be-pushed blocks. logInfo("Pushing out the last " + blocksForPushing.size() + " blocks") while (!blocksForPushing.isEmpty) { - logDebug("Getting block ") val block = blocksForPushing.take() + logDebug(s"Pushing block $block") pushBlock(block) logInfo("Blocks left to push " + blocksForPushing.size()) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala index f663def4c0511..bca1fbc8fda2f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala @@ -45,8 +45,7 @@ private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging { /** * Return the current rate limit. If no limit has been set so far, it returns {{{Long.MaxValue}}}. */ - def getCurrentLimit: Long = - rateLimiter.getRate.toLong + def getCurrentLimit: Long = rateLimiter.getRate.toLong /** * Set the rate limit to `newRate`. The new rate will not exceed the maximum rate configured by diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala index 7504fa44d9fae..554aae0117b24 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala @@ -116,12 +116,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * being pushed into Spark's memory. */ def store(dataItem: T) { - executor.pushSingle(dataItem) + supervisor.pushSingle(dataItem) } /** Store an ArrayBuffer of received data as a data block into Spark's memory. */ def store(dataBuffer: ArrayBuffer[T]) { - executor.pushArrayBuffer(dataBuffer, None, None) + supervisor.pushArrayBuffer(dataBuffer, None, None) } /** @@ -130,12 +130,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * for being used in the corresponding InputDStream. */ def store(dataBuffer: ArrayBuffer[T], metadata: Any) { - executor.pushArrayBuffer(dataBuffer, Some(metadata), None) + supervisor.pushArrayBuffer(dataBuffer, Some(metadata), None) } /** Store an iterator of received data as a data block into Spark's memory. */ def store(dataIterator: Iterator[T]) { - executor.pushIterator(dataIterator, None, None) + supervisor.pushIterator(dataIterator, None, None) } /** @@ -144,12 +144,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * for being used in the corresponding InputDStream. */ def store(dataIterator: java.util.Iterator[T], metadata: Any) { - executor.pushIterator(dataIterator, Some(metadata), None) + supervisor.pushIterator(dataIterator, Some(metadata), None) } /** Store an iterator of received data as a data block into Spark's memory. */ def store(dataIterator: java.util.Iterator[T]) { - executor.pushIterator(dataIterator, None, None) + supervisor.pushIterator(dataIterator, None, None) } /** @@ -158,7 +158,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * for being used in the corresponding InputDStream. */ def store(dataIterator: Iterator[T], metadata: Any) { - executor.pushIterator(dataIterator, Some(metadata), None) + supervisor.pushIterator(dataIterator, Some(metadata), None) } /** @@ -167,7 +167,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * that Spark is configured to use. */ def store(bytes: ByteBuffer) { - executor.pushBytes(bytes, None, None) + supervisor.pushBytes(bytes, None, None) } /** @@ -176,12 +176,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * for being used in the corresponding InputDStream. */ def store(bytes: ByteBuffer, metadata: Any) { - executor.pushBytes(bytes, Some(metadata), None) + supervisor.pushBytes(bytes, Some(metadata), None) } /** Report exceptions in receiving data. */ def reportError(message: String, throwable: Throwable) { - executor.reportError(message, throwable) + supervisor.reportError(message, throwable) } /** @@ -193,7 +193,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * The `message` will be reported to the driver. */ def restart(message: String) { - executor.restartReceiver(message) + supervisor.restartReceiver(message) } /** @@ -205,7 +205,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * The `message` and `exception` will be reported to the driver. */ def restart(message: String, error: Throwable) { - executor.restartReceiver(message, Some(error)) + supervisor.restartReceiver(message, Some(error)) } /** @@ -215,22 +215,22 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * in a background thread. */ def restart(message: String, error: Throwable, millisecond: Int) { - executor.restartReceiver(message, Some(error), millisecond) + supervisor.restartReceiver(message, Some(error), millisecond) } /** Stop the receiver completely. */ def stop(message: String) { - executor.stop(message, None) + supervisor.stop(message, None) } /** Stop the receiver completely due to an exception */ def stop(message: String, error: Throwable) { - executor.stop(message, Some(error)) + supervisor.stop(message, Some(error)) } /** Check if the receiver has started or not. */ def isStarted(): Boolean = { - executor.isReceiverStarted() + supervisor.isReceiverStarted() } /** @@ -238,7 +238,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * the receiving of data should be stopped. */ def isStopped(): Boolean = { - executor.isReceiverStopped() + supervisor.isReceiverStopped() } /** @@ -257,7 +257,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable private var id: Int = -1 /** Handler object that runs the receiver. This is instantiated lazily in the worker. */ - private[streaming] var executor_ : ReceiverSupervisor = null + @transient private var _supervisor : ReceiverSupervisor = null /** Set the ID of the DStream that this receiver is associated with. */ private[streaming] def setReceiverId(id_ : Int) { @@ -265,15 +265,17 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable } /** Attach Network Receiver executor to this receiver. */ - private[streaming] def attachExecutor(exec: ReceiverSupervisor) { - assert(executor_ == null) - executor_ = exec + private[streaming] def attachSupervisor(exec: ReceiverSupervisor) { + assert(_supervisor == null) + _supervisor = exec } - /** Get the attached executor. */ - private def executor: ReceiverSupervisor = { - assert(executor_ != null, "Executor has not been attached to this receiver") - executor_ + /** Get the attached supervisor. */ + private[streaming] def supervisor: ReceiverSupervisor = { + assert(_supervisor != null, + "A ReceiverSupervisor have not been attached to the receiver yet. Maybe you are starting " + + "some computation in the receiver before the Receiver.onStart() has been called.") + _supervisor } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index e98017a63756e..158d1ba2f183a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -44,8 +44,8 @@ private[streaming] abstract class ReceiverSupervisor( } import ReceiverState._ - // Attach the executor to the receiver - receiver.attachExecutor(this) + // Attach the supervisor to the receiver + receiver.attachSupervisor(this) private val futureExecutionContext = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonCachedThreadPool("receiver-supervisor-future", 128)) @@ -60,7 +60,7 @@ private[streaming] abstract class ReceiverSupervisor( private val defaultRestartDelay = conf.getInt("spark.streaming.receiverRestartDelay", 2000) /** The current maximum rate limit for this receiver. */ - private[streaming] def getCurrentRateLimit: Option[Long] = None + private[streaming] def getCurrentRateLimit: Long = Long.MaxValue /** Exception associated with the stopping of the receiver */ @volatile protected var stoppingError: Throwable = null @@ -92,13 +92,30 @@ private[streaming] abstract class ReceiverSupervisor( optionalBlockId: Option[StreamBlockId] ) + /** + * Create a custom [[BlockGenerator]] that the receiver implementation can directly control + * using their provided [[BlockGeneratorListener]]. + * + * Note: Do not explicitly start or stop the `BlockGenerator`, the `ReceiverSupervisorImpl` + * will take care of it. + */ + def createBlockGenerator(blockGeneratorListener: BlockGeneratorListener): BlockGenerator + /** Report errors. */ def reportError(message: String, throwable: Throwable) - /** Called when supervisor is started */ + /** + * Called when supervisor is started. + * Note that this must be called before the receiver.onStart() is called to ensure + * things like [[BlockGenerator]]s are started before the receiver starts sending data. + */ protected def onStart() { } - /** Called when supervisor is stopped */ + /** + * Called when supervisor is stopped. + * Note that this must be called after the receiver.onStop() is called to ensure + * things like [[BlockGenerator]]s are cleaned up after the receiver stops sending data. + */ protected def onStop(message: String, error: Option[Throwable]) { } /** Called when receiver is started. Return true if the driver accepts us */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 0d802f83549af..59ef58d232ee7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -20,6 +20,7 @@ package org.apache.spark.streaming.receiver import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicLong +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import com.google.common.base.Throwables @@ -81,15 +82,20 @@ private[streaming] class ReceiverSupervisorImpl( cleanupOldBlocks(threshTime) case UpdateRateLimit(eps) => logInfo(s"Received a new rate limit: $eps.") - blockGenerator.updateRate(eps) + registeredBlockGenerators.foreach { bg => + bg.updateRate(eps) + } } }) /** Unique block ids if one wants to add blocks directly */ private val newBlockId = new AtomicLong(System.currentTimeMillis()) + private val registeredBlockGenerators = new mutable.ArrayBuffer[BlockGenerator] + with mutable.SynchronizedBuffer[BlockGenerator] + /** Divides received data records into data blocks for pushing in BlockManager. */ - private val blockGenerator = new BlockGenerator(new BlockGeneratorListener { + private val defaultBlockGeneratorListener = new BlockGeneratorListener { def onAddData(data: Any, metadata: Any): Unit = { } def onGenerateBlock(blockId: StreamBlockId): Unit = { } @@ -101,14 +107,15 @@ private[streaming] class ReceiverSupervisorImpl( def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]) { pushArrayBuffer(arrayBuffer, None, Some(blockId)) } - }, streamId, env.conf) + } + private val defaultBlockGenerator = createBlockGenerator(defaultBlockGeneratorListener) - override private[streaming] def getCurrentRateLimit: Option[Long] = - Some(blockGenerator.getCurrentLimit) + /** Get the current rate limit of the default block generator */ + override private[streaming] def getCurrentRateLimit: Long = defaultBlockGenerator.getCurrentLimit /** Push a single record of received data into block generator. */ def pushSingle(data: Any) { - blockGenerator.addData(data) + defaultBlockGenerator.addData(data) } /** Store an ArrayBuffer of received data as a data block into Spark's memory. */ @@ -162,11 +169,11 @@ private[streaming] class ReceiverSupervisorImpl( } override protected def onStart() { - blockGenerator.start() + registeredBlockGenerators.foreach { _.start() } } override protected def onStop(message: String, error: Option[Throwable]) { - blockGenerator.stop() + registeredBlockGenerators.foreach { _.stop() } env.rpcEnv.stop(endpoint) } @@ -183,6 +190,16 @@ private[streaming] class ReceiverSupervisorImpl( logInfo("Stopped receiver " + streamId) } + override def createBlockGenerator( + blockGeneratorListener: BlockGeneratorListener): BlockGenerator = { + // Cleanup BlockGenerators that have already been stopped + registeredBlockGenerators --= registeredBlockGenerators.filter{ _.isStopped() } + + val newBlockGenerator = new BlockGenerator(blockGeneratorListener, streamId, env.conf) + registeredBlockGenerators += newBlockGenerator + newBlockGenerator + } + /** Generate new block ID */ private def nextBlockId = StreamBlockId(streamId, newBlockId.getAndIncrement) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 67c2d900940ab..1bba7a143edf2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming import java.io.File -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.reflect.ClassTag import com.google.common.base.Charsets @@ -33,7 +33,7 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} -import org.apache.spark.streaming.scheduler.{RateLimitInputDStream, ConstantEstimator, SingletonTestRateReceiver} +import org.apache.spark.streaming.scheduler.{ConstantEstimator, RateTestInputDStream, RateTestReceiver} import org.apache.spark.util.{Clock, ManualClock, Utils} /** @@ -397,26 +397,24 @@ class CheckpointSuite extends TestSuiteBase { ssc = new StreamingContext(conf, batchDuration) ssc.checkpoint(checkpointDir) - val dstream = new RateLimitInputDStream(ssc) { + val dstream = new RateTestInputDStream(ssc) { override val rateController = - Some(new ReceiverRateController(id, new ConstantEstimator(200.0))) + Some(new ReceiverRateController(id, new ConstantEstimator(200))) } - SingletonTestRateReceiver.reset() val output = new TestOutputStreamWithPartitions(dstream.checkpoint(batchDuration * 2)) output.register() runStreams(ssc, 5, 5) - SingletonTestRateReceiver.reset() ssc = new StreamingContext(checkpointDir) ssc.start() val outputNew = advanceTimeWithRealDelay(ssc, 2) - eventually(timeout(5.seconds)) { - assert(dstream.getCurrentRateLimit === Some(200)) + eventually(timeout(10.seconds)) { + assert(RateTestReceiver.getActive().nonEmpty) + assert(RateTestReceiver.getActive().get.getDefaultBlockGeneratorRateLimit() === 200) } ssc.stop() - ssc = null } // This tests whether file input stream remembers what files were seen before diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index 13b4d17c86183..01279b34f73dc 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -129,32 +129,6 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { } } - test("block generator") { - val blockGeneratorListener = new FakeBlockGeneratorListener - val blockIntervalMs = 200 - val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms") - val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf) - val expectedBlocks = 5 - val waitTime = expectedBlocks * blockIntervalMs + (blockIntervalMs / 2) - val generatedData = new ArrayBuffer[Int] - - // Generate blocks - val startTime = System.currentTimeMillis() - blockGenerator.start() - var count = 0 - while(System.currentTimeMillis - startTime < waitTime) { - blockGenerator.addData(count) - generatedData += count - count += 1 - Thread.sleep(10) - } - blockGenerator.stop() - - val recordedData = blockGeneratorListener.arrayBuffers.flatten - assert(blockGeneratorListener.arrayBuffers.size > 0) - assert(recordedData.toSet === generatedData.toSet) - } - ignore("block generator throttling") { val blockGeneratorListener = new FakeBlockGeneratorListener val blockIntervalMs = 100 @@ -348,6 +322,11 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { } override protected def onReceiverStart(): Boolean = true + + override def createBlockGenerator( + blockGeneratorListener: BlockGeneratorListener): BlockGenerator = { + null + } } /** diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala new file mode 100644 index 0000000000000..a38cc603f2190 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.receiver + +import scala.collection.mutable + +import org.scalatest.BeforeAndAfter +import org.scalatest.Matchers._ +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.storage.StreamBlockId +import org.apache.spark.util.ManualClock +import org.apache.spark.{SparkException, SparkConf, SparkFunSuite} + +class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { + + private val blockIntervalMs = 10 + private val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms") + @volatile private var blockGenerator: BlockGenerator = null + + after { + if (blockGenerator != null) { + blockGenerator.stop() + } + } + + test("block generation and data callbacks") { + val listener = new TestBlockGeneratorListener + val clock = new ManualClock() + + require(blockIntervalMs > 5) + require(listener.onAddDataCalled === false) + require(listener.onGenerateBlockCalled === false) + require(listener.onPushBlockCalled === false) + + // Verify that creating the generator does not start it + blockGenerator = new BlockGenerator(listener, 0, conf, clock) + assert(blockGenerator.isActive() === false, "block generator active before start()") + assert(blockGenerator.isStopped() === false, "block generator stopped before start()") + assert(listener.onAddDataCalled === false) + assert(listener.onGenerateBlockCalled === false) + assert(listener.onPushBlockCalled === false) + + // Verify start marks the generator active, but does not call the callbacks + blockGenerator.start() + assert(blockGenerator.isActive() === true, "block generator active after start()") + assert(blockGenerator.isStopped() === false, "block generator stopped after start()") + withClue("callbacks called before adding data") { + assert(listener.onAddDataCalled === false) + assert(listener.onGenerateBlockCalled === false) + assert(listener.onPushBlockCalled === false) + } + + // Verify whether addData() adds data that is present in generated blocks + val data1 = 1 to 10 + data1.foreach { blockGenerator.addData _ } + withClue("callbacks called on adding data without metadata and without block generation") { + assert(listener.onAddDataCalled === false) // should be called only with addDataWithCallback() + assert(listener.onGenerateBlockCalled === false) + assert(listener.onPushBlockCalled === false) + } + clock.advance(blockIntervalMs) // advance clock to generate blocks + withClue("blocks not generated or pushed") { + eventually(timeout(1 second)) { + assert(listener.onGenerateBlockCalled === true) + assert(listener.onPushBlockCalled === true) + } + } + listener.pushedData should contain theSameElementsInOrderAs (data1) + assert(listener.onAddDataCalled === false) // should be called only with addDataWithCallback() + + // Verify addDataWithCallback() add data+metadata and and callbacks are called correctly + val data2 = 11 to 20 + val metadata2 = data2.map { _.toString } + data2.zip(metadata2).foreach { case (d, m) => blockGenerator.addDataWithCallback(d, m) } + assert(listener.onAddDataCalled === true) + listener.addedData should contain theSameElementsInOrderAs (data2) + listener.addedMetadata should contain theSameElementsInOrderAs (metadata2) + clock.advance(blockIntervalMs) // advance clock to generate blocks + eventually(timeout(1 second)) { + listener.pushedData should contain theSameElementsInOrderAs (data1 ++ data2) + } + + // Verify addMultipleDataWithCallback() add data+metadata and and callbacks are called correctly + val data3 = 21 to 30 + val metadata3 = "metadata" + blockGenerator.addMultipleDataWithCallback(data3.iterator, metadata3) + listener.addedMetadata should contain theSameElementsInOrderAs (metadata2 :+ metadata3) + clock.advance(blockIntervalMs) // advance clock to generate blocks + eventually(timeout(1 second)) { + listener.pushedData should contain theSameElementsInOrderAs (data1 ++ data2 ++ data3) + } + + // Stop the block generator by starting the stop on a different thread and + // then advancing the manual clock for the stopping to proceed. + val thread = stopBlockGenerator(blockGenerator) + eventually(timeout(1 second), interval(10 milliseconds)) { + clock.advance(blockIntervalMs) + assert(blockGenerator.isStopped() === true) + } + thread.join() + + // Verify that the generator cannot be used any more + intercept[SparkException] { + blockGenerator.addData(1) + } + intercept[SparkException] { + blockGenerator.addDataWithCallback(1, 1) + } + intercept[SparkException] { + blockGenerator.addMultipleDataWithCallback(Iterator(1), 1) + } + intercept[SparkException] { + blockGenerator.start() + } + blockGenerator.stop() // Calling stop again should be fine + } + + test("stop ensures correct shutdown") { + val listener = new TestBlockGeneratorListener + val clock = new ManualClock() + blockGenerator = new BlockGenerator(listener, 0, conf, clock) + require(listener.onGenerateBlockCalled === false) + blockGenerator.start() + assert(blockGenerator.isActive() === true, "block generator") + assert(blockGenerator.isStopped() === false) + + val data = 1 to 1000 + data.foreach { blockGenerator.addData _ } + + // Verify that stop() shutdowns everything in the right order + // - First, stop receiving new data + // - Second, wait for final block with all buffered data to be generated + // - Finally, wait for all blocks to be pushed + clock.advance(1) // to make sure that the timer for another interval to complete + val thread = stopBlockGenerator(blockGenerator) + eventually(timeout(1 second), interval(10 milliseconds)) { + assert(blockGenerator.isActive() === false) + } + assert(blockGenerator.isStopped() === false) + + // Verify that data cannot be added + intercept[SparkException] { + blockGenerator.addData(1) + } + intercept[SparkException] { + blockGenerator.addDataWithCallback(1, null) + } + intercept[SparkException] { + blockGenerator.addMultipleDataWithCallback(Iterator(1), null) + } + + // Verify that stop() stays blocked until another block containing all the data is generated + // This intercept always succeeds, as the body either will either throw a timeout exception + // (expected as stop() should never complete) or a SparkException (unexpected as stop() + // completed and thread terminated). + val exception = intercept[Exception] { + failAfter(200 milliseconds) { + thread.join() + throw new SparkException( + "BlockGenerator.stop() completed before generating timer was stopped") + } + } + exception should not be a [SparkException] + + + // Verify that the final data is present in the final generated block and + // pushed before complete stop + assert(blockGenerator.isStopped() === false) // generator has not stopped yet + clock.advance(blockIntervalMs) // force block generation + failAfter(1 second) { + thread.join() + } + assert(blockGenerator.isStopped() === true) // generator has finally been completely stopped + assert(listener.pushedData === data, "All data not pushed by stop()") + } + + test("block push errors are reported") { + val listener = new TestBlockGeneratorListener { + @volatile var errorReported = false + override def onPushBlock( + blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { + throw new SparkException("test") + } + override def onError(message: String, throwable: Throwable): Unit = { + errorReported = true + } + } + blockGenerator = new BlockGenerator(listener, 0, conf) + blockGenerator.start() + assert(listener.errorReported === false) + blockGenerator.addData(1) + eventually(timeout(1 second), interval(10 milliseconds)) { + assert(listener.errorReported === true) + } + blockGenerator.stop() + } + + /** + * Helper method to stop the block generator with manual clock in a different thread, + * so that the main thread can advance the clock that allows the stopping to proceed. + */ + private def stopBlockGenerator(blockGenerator: BlockGenerator): Thread = { + val thread = new Thread() { + override def run(): Unit = { + blockGenerator.stop() + } + } + thread.start() + thread + } + + /** A listener for BlockGenerator that records the data in the callbacks */ + private class TestBlockGeneratorListener extends BlockGeneratorListener { + val pushedData = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any] + val addedData = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any] + val addedMetadata = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any] + @volatile var onGenerateBlockCalled = false + @volatile var onAddDataCalled = false + @volatile var onPushBlockCalled = false + + override def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { + pushedData ++= arrayBuffer + onPushBlockCalled = true + } + override def onError(message: String, throwable: Throwable): Unit = {} + override def onGenerateBlock(blockId: StreamBlockId): Unit = { + onGenerateBlockCalled = true + } + override def onAddData(data: Any, metadata: Any): Unit = { + addedData += data + addedMetadata += metadata + onAddDataCalled = true + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala index 921da773f6c11..1eb52b7029a21 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala @@ -18,10 +18,7 @@ package org.apache.spark.streaming.scheduler import scala.collection.mutable -import scala.reflect.ClassTag -import scala.util.control.NonFatal -import org.scalatest.Matchers._ import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ @@ -32,72 +29,63 @@ class RateControllerSuite extends TestSuiteBase { override def useManualClock: Boolean = false - test("rate controller publishes updates") { + override def batchDuration: Duration = Milliseconds(50) + + test("RateController - rate controller publishes updates after batches complete") { val ssc = new StreamingContext(conf, batchDuration) withStreamingContext(ssc) { ssc => - val dstream = new RateLimitInputDStream(ssc) + val dstream = new RateTestInputDStream(ssc) dstream.register() ssc.start() eventually(timeout(10.seconds)) { - assert(dstream.publishCalls > 0) + assert(dstream.publishedRates > 0) } } } - test("publish rates reach receivers") { + test("ReceiverRateController - published rates reach receivers") { val ssc = new StreamingContext(conf, batchDuration) withStreamingContext(ssc) { ssc => - val dstream = new RateLimitInputDStream(ssc) { + val estimator = new ConstantEstimator(100) + val dstream = new RateTestInputDStream(ssc) { override val rateController = - Some(new ReceiverRateController(id, new ConstantEstimator(200.0))) + Some(new ReceiverRateController(id, estimator)) } dstream.register() - SingletonTestRateReceiver.reset() ssc.start() - eventually(timeout(10.seconds)) { - assert(dstream.getCurrentRateLimit === Some(200)) + // Wait for receiver to start + eventually(timeout(5.seconds)) { + RateTestReceiver.getActive().nonEmpty } - } - } - test("multiple publish rates reach receivers") { - val ssc = new StreamingContext(conf, batchDuration) - withStreamingContext(ssc) { ssc => - val rates = Seq(100L, 200L, 300L) - - val dstream = new RateLimitInputDStream(ssc) { - override val rateController = - Some(new ReceiverRateController(id, new ConstantEstimator(rates.map(_.toDouble): _*))) + // Update rate in the estimator and verify whether the rate was published to the receiver + def updateRateAndVerify(rate: Long): Unit = { + estimator.updateRate(rate) + eventually(timeout(5.seconds)) { + assert(RateTestReceiver.getActive().get.getDefaultBlockGeneratorRateLimit() === rate) + } } - SingletonTestRateReceiver.reset() - dstream.register() - - val observedRates = mutable.HashSet.empty[Long] - ssc.start() - eventually(timeout(20.seconds)) { - dstream.getCurrentRateLimit.foreach(observedRates += _) - // Long.MaxValue (essentially, no rate limit) is the initial rate limit for any Receiver - observedRates should contain theSameElementsAs (rates :+ Long.MaxValue) + // Verify multiple rate update + Seq(100, 200, 300).foreach { rate => + updateRateAndVerify(rate) } } } } -private[streaming] class ConstantEstimator(rates: Double*) extends RateEstimator { - private var idx: Int = 0 +private[streaming] class ConstantEstimator(@volatile private var rate: Long) + extends RateEstimator { - private def nextRate(): Double = { - val rate = rates(idx) - idx = (idx + 1) % rates.size - rate + def updateRate(newRate: Long): Unit = { + rate = newRate } def compute( time: Long, elements: Long, processingDelay: Long, - schedulingDelay: Long): Option[Double] = Some(nextRate()) + schedulingDelay: Long): Option[Double] = Some(rate) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index afad5f16dbc71..dd292ba4dd949 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -17,48 +17,43 @@ package org.apache.spark.streaming.scheduler +import scala.collection.mutable.ArrayBuffer + import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.SparkConf +import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming._ -import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.receiver._ /** Testsuite for receiver scheduling */ class ReceiverTrackerSuite extends TestSuiteBase { - val sparkConf = new SparkConf().setMaster("local[8]").setAppName("test") - - test("Receiver tracker - propagates rate limit") { - withStreamingContext(new StreamingContext(sparkConf, Milliseconds(100))) { ssc => - object ReceiverStartedWaiter extends StreamingListener { - @volatile - var started = false - - override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { - started = true - } - } - ssc.addStreamingListener(ReceiverStartedWaiter) + test("send rate update to receivers") { + withStreamingContext(new StreamingContext(conf, Milliseconds(100))) { ssc => ssc.scheduler.listenerBus.start(ssc.sc) - SingletonTestRateReceiver.reset() val newRateLimit = 100L - val inputDStream = new RateLimitInputDStream(ssc) + val inputDStream = new RateTestInputDStream(ssc) val tracker = new ReceiverTracker(ssc) tracker.start() try { // we wait until the Receiver has registered with the tracker, // otherwise our rate update is lost eventually(timeout(5 seconds)) { - assert(ReceiverStartedWaiter.started) + assert(RateTestReceiver.getActive().nonEmpty) } + + + // Verify that the rate of the block generator in the receiver get updated + val activeReceiver = RateTestReceiver.getActive().get tracker.sendRateUpdate(inputDStream.id, newRateLimit) - // this is an async message, we need to wait a bit for it to be processed - eventually(timeout(3 seconds)) { - assert(inputDStream.getCurrentRateLimit.get === newRateLimit) + eventually(timeout(5 seconds)) { + assert(activeReceiver.getDefaultBlockGeneratorRateLimit() === newRateLimit, + "default block generator did not receive rate update") + assert(activeReceiver.getCustomBlockGeneratorRateLimit() === newRateLimit, + "other block generator did not receive rate update") } } finally { tracker.stop(false) @@ -67,69 +62,73 @@ class ReceiverTrackerSuite extends TestSuiteBase { } } -/** - * An input DStream with a hard-coded receiver that gives access to internals for testing. - * - * @note Make sure to call {{{SingletonDummyReceiver.reset()}}} before using this in a test, - * or otherwise you may get {{{NotSerializableException}}} when trying to serialize - * the receiver. - * @see [[[SingletonDummyReceiver]]]. - */ -private[streaming] class RateLimitInputDStream(@transient ssc_ : StreamingContext) +/** An input DStream with for testing rate controlling */ +private[streaming] class RateTestInputDStream(@transient ssc_ : StreamingContext) extends ReceiverInputDStream[Int](ssc_) { - override def getReceiver(): RateTestReceiver = SingletonTestRateReceiver - - def getCurrentRateLimit: Option[Long] = { - invokeExecutorMethod.getCurrentRateLimit - } + override def getReceiver(): Receiver[Int] = new RateTestReceiver(id) @volatile - var publishCalls = 0 + var publishedRates = 0 override val rateController: Option[RateController] = { - Some(new RateController(id, new ConstantEstimator(100.0)) { + Some(new RateController(id, new ConstantEstimator(100)) { override def publish(rate: Long): Unit = { - publishCalls += 1 + publishedRates += 1 } }) } +} - private def invokeExecutorMethod: ReceiverSupervisor = { - val c = classOf[Receiver[_]] - val ex = c.getDeclaredMethod("executor") - ex.setAccessible(true) - ex.invoke(SingletonTestRateReceiver).asInstanceOf[ReceiverSupervisor] +/** A receiver implementation for testing rate controlling */ +private[streaming] class RateTestReceiver(receiverId: Int, host: Option[String] = None) + extends Receiver[Int](StorageLevel.MEMORY_ONLY) { + + private lazy val customBlockGenerator = supervisor.createBlockGenerator( + new BlockGeneratorListener { + override def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]): Unit = {} + override def onError(message: String, throwable: Throwable): Unit = {} + override def onGenerateBlock(blockId: StreamBlockId): Unit = {} + override def onAddData(data: Any, metadata: Any): Unit = {} + } + ) + + setReceiverId(receiverId) + + override def onStart(): Unit = { + customBlockGenerator + RateTestReceiver.registerReceiver(this) } -} -/** - * A Receiver as an object so we can read its rate limit. Make sure to call `reset()` when - * reusing this receiver, otherwise a non-null `executor_` field will prevent it from being - * serialized when receivers are installed on executors. - * - * @note It's necessary to be a top-level object, or else serialization would create another - * one on the executor side and we won't be able to read its rate limit. - */ -private[streaming] object SingletonTestRateReceiver extends RateTestReceiver(0) { + override def onStop(): Unit = { + RateTestReceiver.deregisterReceiver() + } + + override def preferredLocation: Option[String] = host - /** Reset the object to be usable in another test. */ - def reset(): Unit = { - executor_ = null + def getDefaultBlockGeneratorRateLimit(): Long = { + supervisor.getCurrentRateLimit + } + + def getCustomBlockGeneratorRateLimit(): Long = { + customBlockGenerator.getCurrentLimit } } /** - * Dummy receiver implementation + * A helper object to RateTestReceiver that give access to the currently active RateTestReceiver + * instance. */ -private[streaming] class RateTestReceiver(receiverId: Int, host: Option[String] = None) - extends Receiver[Int](StorageLevel.MEMORY_ONLY) { +private[streaming] object RateTestReceiver { + @volatile private var activeReceiver: RateTestReceiver = null - setReceiverId(receiverId) - - override def onStart(): Unit = {} + def registerReceiver(receiver: RateTestReceiver): Unit = { + activeReceiver = receiver + } - override def onStop(): Unit = {} + def deregisterReceiver(): Unit = { + activeReceiver = null + } - override def preferredLocation: Option[String] = host + def getActive(): Option[RateTestReceiver] = Option(activeReceiver) } From 1723e34893f9b087727ea0e5c8b335645f42c295 Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Thu, 6 Aug 2015 14:37:25 -0700 Subject: [PATCH 197/340] =?UTF-8?q?[DOCS]=20[STREAMING]=20make=20the=20exi?= =?UTF-8?q?sting=20parameter=20docs=20for=20OffsetRange=20ac=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …tually visible Author: cody koeninger Closes #7995 from koeninger/doc-fixes and squashes the following commits: 87af9ea [cody koeninger] [Docs][Streaming] make the existing parameter docs for OffsetRange actually visible --- .../org/apache/spark/streaming/kafka/OffsetRange.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala index f326e7f1f6f8d..2f8981d4898bd 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala @@ -42,16 +42,16 @@ trait HasOffsetRanges { * :: Experimental :: * Represents a range of offsets from a single Kafka TopicAndPartition. Instances of this class * can be created with `OffsetRange.create()`. + * @param topic Kafka topic name + * @param partition Kafka partition id + * @param fromOffset Inclusive starting offset + * @param untilOffset Exclusive ending offset */ @Experimental final class OffsetRange private( - /** Kafka topic name */ val topic: String, - /** Kafka partition id */ val partition: Int, - /** inclusive starting offset */ val fromOffset: Long, - /** exclusive ending offset */ val untilOffset: Long) extends Serializable { import OffsetRange.OffsetRangeTuple From 346209097e88fe79015359e40b49c32cc0bdc439 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 6 Aug 2015 14:39:36 -0700 Subject: [PATCH 198/340] [SPARK-9639] [STREAMING] Fix a potential NPE in Streaming JobScheduler Because `JobScheduler.stop(false)` may set `eventLoop` to null when `JobHandler` is running, then it's possible that when `post` is called, `eventLoop` happens to null. This PR fixed this bug and also set threads in `jobExecutor` to `daemon`. Author: zsxwing Closes #7960 from zsxwing/fix-npe and squashes the following commits: b0864c4 [zsxwing] Fix a potential NPE in Streaming JobScheduler --- .../streaming/scheduler/JobScheduler.scala | 32 +++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 7e735562dca33..6d4cdc4aa6b10 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.scheduler -import java.util.concurrent.{TimeUnit, ConcurrentHashMap, Executors} +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import scala.collection.JavaConversions._ import scala.util.{Failure, Success} @@ -25,7 +25,7 @@ import scala.util.{Failure, Success} import org.apache.spark.Logging import org.apache.spark.rdd.PairRDDFunctions import org.apache.spark.streaming._ -import org.apache.spark.util.EventLoop +import org.apache.spark.util.{EventLoop, ThreadUtils} private[scheduler] sealed trait JobSchedulerEvent @@ -44,7 +44,8 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { // https://gist.github.com/AlainODea/1375759b8720a3f9f094 private val jobSets: java.util.Map[Time, JobSet] = new ConcurrentHashMap[Time, JobSet] private val numConcurrentJobs = ssc.conf.getInt("spark.streaming.concurrentJobs", 1) - private val jobExecutor = Executors.newFixedThreadPool(numConcurrentJobs) + private val jobExecutor = + ThreadUtils.newDaemonFixedThreadPool(numConcurrentJobs, "streaming-job-executor") private val jobGenerator = new JobGenerator(this) val clock = jobGenerator.clock val listenerBus = new StreamingListenerBus() @@ -193,14 +194,25 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { ssc.sc.setLocalProperty(JobScheduler.BATCH_TIME_PROPERTY_KEY, job.time.milliseconds.toString) ssc.sc.setLocalProperty(JobScheduler.OUTPUT_OP_ID_PROPERTY_KEY, job.outputOpId.toString) try { - eventLoop.post(JobStarted(job)) - // Disable checks for existing output directories in jobs launched by the streaming - // scheduler, since we may need to write output to an existing directory during checkpoint - // recovery; see SPARK-4835 for more details. - PairRDDFunctions.disableOutputSpecValidation.withValue(true) { - job.run() + // We need to assign `eventLoop` to a temp variable. Otherwise, because + // `JobScheduler.stop(false)` may set `eventLoop` to null when this method is running, then + // it's possible that when `post` is called, `eventLoop` happens to null. + var _eventLoop = eventLoop + if (_eventLoop != null) { + _eventLoop.post(JobStarted(job)) + // Disable checks for existing output directories in jobs launched by the streaming + // scheduler, since we may need to write output to an existing directory during checkpoint + // recovery; see SPARK-4835 for more details. + PairRDDFunctions.disableOutputSpecValidation.withValue(true) { + job.run() + } + _eventLoop = eventLoop + if (_eventLoop != null) { + _eventLoop.post(JobCompleted(job)) + } + } else { + // JobScheduler has been stopped. } - eventLoop.post(JobCompleted(job)) } finally { ssc.sc.setLocalProperty(JobScheduler.BATCH_TIME_PROPERTY_KEY, null) ssc.sc.setLocalProperty(JobScheduler.OUTPUT_OP_ID_PROPERTY_KEY, null) From 3504bf3aa9f7b75c0985f04ce2944833d8c5b5bd Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 6 Aug 2015 15:04:44 -0700 Subject: [PATCH 199/340] [SPARK-9630] [SQL] Clean up new aggregate operators (SPARK-9240 follow up) This is the followup of https://github.com/apache/spark/pull/7813. It renames `HybridUnsafeAggregationIterator` to `TungstenAggregationIterator` and makes it only work with `UnsafeRow`. Also, I add a `TungstenAggregate` that uses `TungstenAggregationIterator` and make `SortBasedAggregate` (renamed from `SortBasedAggregate`) only works with `SafeRow`. Author: Yin Huai Closes #7954 from yhuai/agg-followUp and squashes the following commits: 4d2f4fc [Yin Huai] Add comments and free map. 0d7ddb9 [Yin Huai] Add TungstenAggregationQueryWithControlledFallbackSuite to test fall back process. 91d69c2 [Yin Huai] Rename UnsafeHybridAggregationIterator to TungstenAggregateIteraotr and make it only work with UnsafeRow. --- .../expressions/aggregate/functions.scala | 14 +- .../spark/sql/execution/SparkStrategies.scala | 3 +- .../sql/execution/UnsafeRowSerializer.scala | 20 +- .../sql/execution/aggregate/Aggregate.scala | 182 ----- .../aggregate/SortBasedAggregate.scala | 103 +++ .../SortBasedAggregationIterator.scala | 26 - .../aggregate/TungstenAggregate.scala | 102 +++ .../TungstenAggregationIterator.scala | 667 ++++++++++++++++++ .../UnsafeHybridAggregationIterator.scala | 372 ---------- .../spark/sql/execution/aggregate/utils.scala | 260 +++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../execution/AggregationQuerySuite.scala | 104 ++- 12 files changed, 1192 insertions(+), 663 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 88fb516e64aaf..a73024d6adba1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -31,8 +31,11 @@ case class Average(child: Expression) extends AlgebraicAggregate { override def dataType: DataType = resultType // Expected input data type. - // TODO: Once we remove the old code path, we can use our analyzer to cast NullType - // to the default data type of the NumericType. + // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the + // new version at planning time (after analysis phase). For now, NullType is added at here + // to make it resolved when we have cases like `select avg(null)`. + // We can use our analyzer to cast NullType to the default data type of the NumericType once + // we remove the old aggregate functions. Then, we will not need NullType at here. override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) private val resultType = child.dataType match { @@ -256,12 +259,19 @@ case class Sum(child: Expression) extends AlgebraicAggregate { override def dataType: DataType = resultType // Expected input data type. + // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the + // new version at planning time (after analysis phase). For now, NullType is added at here + // to make it resolved when we have cases like `select sum(null)`. + // We can use our analyzer to cast NullType to the default data type of the NumericType once + // we remove the old aggregate functions. Then, we will not need NullType at here. override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType)) private val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType.bounded(precision + 10, scale) + // TODO: Remove this line once we remove the NullType from inputTypes. + case NullType => IntegerType case _ => child.dataType } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index a730ffbb217c0..c5aaebe673225 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -191,8 +191,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // aggregate function to the corresponding attribute of the function. val aggregateFunctionMap = aggregateExpressions.map { agg => val aggregateFunction = agg.aggregateFunction + val attribtue = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute (aggregateFunction, agg.isDistinct) -> - Alias(aggregateFunction, aggregateFunction.toString)().toAttribute + (aggregateFunction -> attribtue) }.toMap val (functionsWithDistinct, functionsWithoutDistinct) = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 16498da080c88..39f8f992a9f00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import java.io.{DataInputStream, DataOutputStream, OutputStream, InputStream} +import java.io._ import java.nio.ByteBuffer import scala.reflect.ClassTag @@ -58,11 +58,26 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst */ override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream { private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096) + // When `out` is backed by ChainedBufferOutputStream, we will get an + // UnsupportedOperationException when we call dOut.writeInt because it internally calls + // ChainedBufferOutputStream's write(b: Int), which is not supported. + // To workaround this issue, we create an array for sorting the int value. + // To reproduce the problem, use dOut.writeInt(row.getSizeInBytes) and + // run SparkSqlSerializer2SortMergeShuffleSuite. + private[this] var intBuffer: Array[Byte] = new Array[Byte](4) private[this] val dOut: DataOutputStream = new DataOutputStream(out) override def writeValue[T: ClassTag](value: T): SerializationStream = { val row = value.asInstanceOf[UnsafeRow] - dOut.writeInt(row.getSizeInBytes) + val size = row.getSizeInBytes + // This part is based on DataOutputStream's writeInt. + // It is for dOut.writeInt(row.getSizeInBytes). + intBuffer(0) = ((size >>> 24) & 0xFF).toByte + intBuffer(1) = ((size >>> 16) & 0xFF).toByte + intBuffer(2) = ((size >>> 8) & 0xFF).toByte + intBuffer(3) = ((size >>> 0) & 0xFF).toByte + dOut.write(intBuffer, 0, 4) + row.writeToStream(out, writeBuffer) this } @@ -90,6 +105,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def close(): Unit = { writeBuffer = null + intBuffer = null dOut.writeInt(EOF) dOut.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala deleted file mode 100644 index cf568dc048674..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.aggregate - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} -import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode} -import org.apache.spark.sql.types.StructType - -/** - * An Aggregate Operator used to evaluate [[AggregateFunction2]]. Based on the data types - * of the grouping expressions and aggregate functions, it determines if it uses - * sort-based aggregation and hybrid (hash-based with sort-based as the fallback) to - * process input rows. - */ -case class Aggregate( - requiredChildDistributionExpressions: Option[Seq[Expression]], - groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - child: SparkPlan) - extends UnaryNode { - - private[this] val allAggregateExpressions = - nonCompleteAggregateExpressions ++ completeAggregateExpressions - - private[this] val hasNonAlgebricAggregateFunctions = - !allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) - - // Use the hybrid iterator if (1) unsafe is enabled, (2) the schemata of - // grouping key and aggregation buffer is supported; and (3) all - // aggregate functions are algebraic. - private[this] val supportsHybridIterator: Boolean = { - val aggregationBufferSchema: StructType = - StructType.fromAttributes( - allAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)) - val groupKeySchema: StructType = - StructType.fromAttributes(groupingExpressions.map(_.toAttribute)) - - val schemaSupportsUnsafe: Boolean = - UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && - UnsafeProjection.canSupport(groupKeySchema) - - // TODO: Use the hybrid iterator for non-algebric aggregate functions. - sqlContext.conf.unsafeEnabled && schemaSupportsUnsafe && !hasNonAlgebricAggregateFunctions - } - - // We need to use sorted input if we have grouping expressions, and - // we cannot use the hybrid iterator or the hybrid is disabled. - private[this] val requiresSortedInput: Boolean = { - groupingExpressions.nonEmpty && !supportsHybridIterator - } - - override def canProcessUnsafeRows: Boolean = !hasNonAlgebricAggregateFunctions - - // If result expressions' data types are all fixed length, we generate unsafe rows - // (We have this requirement instead of check the result of UnsafeProjection.canSupport - // is because we use a mutable projection to generate the result). - override def outputsUnsafeRows: Boolean = { - // resultExpressions.map(_.dataType).forall(UnsafeRow.isFixedLength) - // TODO: Supports generating UnsafeRows. We can just re-enable the line above and fix - // any issue we get. - false - } - - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - - override def requiredChildDistribution: List[Distribution] = { - requiredChildDistributionExpressions match { - case Some(exprs) if exprs.length == 0 => AllTuples :: Nil - case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil - case None => UnspecifiedDistribution :: Nil - } - } - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = { - if (requiresSortedInput) { - // TODO: We should not sort the input rows if they are just in reversed order. - groupingExpressions.map(SortOrder(_, Ascending)) :: Nil - } else { - Seq.fill(children.size)(Nil) - } - } - - override def outputOrdering: Seq[SortOrder] = { - if (requiresSortedInput) { - // It is possible that the child.outputOrdering starts with the required - // ordering expressions (e.g. we require [a] as the sort expression and the - // child's outputOrdering is [a, b]). We can only guarantee the output rows - // are sorted by values of groupingExpressions. - groupingExpressions.map(SortOrder(_, Ascending)) - } else { - Nil - } - } - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - child.execute().mapPartitions { iter => - // Because the constructor of an aggregation iterator will read at least the first row, - // we need to get the value of iter.hasNext first. - val hasInput = iter.hasNext - val useHybridIterator = - hasInput && - supportsHybridIterator && - groupingExpressions.nonEmpty - if (useHybridIterator) { - UnsafeHybridAggregationIterator.createFromInputIterator( - groupingExpressions, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection _, - child.output, - iter, - outputsUnsafeRows) - } else { - if (!hasInput && groupingExpressions.nonEmpty) { - // This is a grouped aggregate and the input iterator is empty, - // so return an empty iterator. - Iterator[InternalRow]() - } else { - val outputIter = SortBasedAggregationIterator.createFromInputIterator( - groupingExpressions, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection _ , - newProjection _, - child.output, - iter, - outputsUnsafeRows) - if (!hasInput && groupingExpressions.isEmpty) { - // There is no input and there is no grouping expressions. - // We need to output a single row as the output. - Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) - } else { - outputIter - } - } - } - } - } - - override def simpleString: String = { - val iterator = if (supportsHybridIterator && groupingExpressions.nonEmpty) { - classOf[UnsafeHybridAggregationIterator].getSimpleName - } else { - classOf[SortBasedAggregationIterator].getSimpleName - } - - s"""NewAggregate with $iterator ${groupingExpressions} ${allAggregateExpressions}""" - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala new file mode 100644 index 0000000000000..ad428ad663f30 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} +import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode} +import org.apache.spark.sql.types.StructType + +case class SortBasedAggregate( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + override def outputsUnsafeRows: Boolean = false + + override def canProcessUnsafeRows: Boolean = false + + override def canProcessSafeRows: Boolean = true + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.length == 0 => AllTuples :: Nil + case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + groupingExpressions.map(SortOrder(_, Ascending)) :: Nil + } + + override def outputOrdering: Seq[SortOrder] = { + groupingExpressions.map(SortOrder(_, Ascending)) + } + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + // Because the constructor of an aggregation iterator will read at least the first row, + // we need to get the value of iter.hasNext first. + val hasInput = iter.hasNext + if (!hasInput && groupingExpressions.nonEmpty) { + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. + Iterator[InternalRow]() + } else { + val outputIter = SortBasedAggregationIterator.createFromInputIterator( + groupingExpressions, + nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection _, + newProjection _, + child.output, + iter, + outputsUnsafeRows) + if (!hasInput && groupingExpressions.isEmpty) { + // There is no input and there is no grouping expressions. + // We need to output a single row as the output. + Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) + } else { + outputIter + } + } + } + } + + override def simpleString: String = { + val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions + s"""SortBasedAggregate ${groupingExpressions} ${allAggregateExpressions}""" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 40f6bff53d2b7..67ebafde25ad3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -204,31 +204,5 @@ object SortBasedAggregationIterator { newMutableProjection, outputsUnsafeRows) } - - def createFromKVIterator( - groupingKeyAttributes: Seq[Attribute], - valueAttributes: Seq[Attribute], - inputKVIterator: KVIterator[InternalRow, InternalRow], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean): SortBasedAggregationIterator = { - new SortBasedAggregationIterator( - groupingKeyAttributes, - valueAttributes, - inputKVIterator, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - outputsUnsafeRows) - } // scalastyle:on } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala new file mode 100644 index 0000000000000..5a0b4d47d62f8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} +import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} + +case class TungstenAggregate( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + completeAggregateExpressions: Seq[AggregateExpression2], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + override def outputsUnsafeRows: Boolean = true + + override def canProcessUnsafeRows: Boolean = true + + override def canProcessSafeRows: Boolean = false + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.length == 0 => AllTuples :: Nil + case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + + // This is for testing. We force TungstenAggregationIterator to fall back to sort-based + // aggregation once it has processed a given number of input rows. + private val testFallbackStartsAt: Option[Int] = { + sqlContext.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null) match { + case null | "" => None + case fallbackStartsAt => Some(fallbackStartsAt.toInt) + } + } + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + val hasInput = iter.hasNext + if (!hasInput && groupingExpressions.nonEmpty) { + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. + Iterator.empty.asInstanceOf[Iterator[UnsafeRow]] + } else { + val aggregationIterator = + new TungstenAggregationIterator( + groupingExpressions, + nonCompleteAggregateExpressions, + completeAggregateExpressions, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + child.output, + iter.asInstanceOf[Iterator[UnsafeRow]], + testFallbackStartsAt) + + if (!hasInput && groupingExpressions.isEmpty) { + Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) + } else { + aggregationIterator + } + } + } + } + + override def simpleString: String = { + val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions + + testFallbackStartsAt match { + case None => s"TungstenAggregate ${groupingExpressions} ${allAggregateExpressions}" + case Some(fallbackStartsAt) => + s"TungstenAggregateWithControlledFallback ${groupingExpressions} " + + s"${allAggregateExpressions} fallbackStartsAt=$fallbackStartsAt" + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala new file mode 100644 index 0000000000000..b9d44aace1009 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -0,0 +1,667 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.unsafe.KVIterator +import org.apache.spark.{Logging, SparkEnv, TaskContext} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner +import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap} +import org.apache.spark.sql.types.StructType + +/** + * An iterator used to evaluate aggregate functions. It operates on [[UnsafeRow]]s. + * + * This iterator first uses hash-based aggregation to process input rows. It uses + * a hash map to store groups and their corresponding aggregation buffers. If we + * this map cannot allocate memory from [[org.apache.spark.shuffle.ShuffleMemoryManager]], + * it switches to sort-based aggregation. The process of the switch has the following step: + * - Step 1: Sort all entries of the hash map based on values of grouping expressions and + * spill them to disk. + * - Step 2: Create a external sorter based on the spilled sorted map entries. + * - Step 3: Redirect all input rows to the external sorter. + * - Step 4: Get a sorted [[KVIterator]] from the external sorter. + * - Step 5: Initialize sort-based aggregation. + * Then, this iterator works in the way of sort-based aggregation. + * + * The code of this class is organized as follows: + * - Part 1: Initializing aggregate functions. + * - Part 2: Methods and fields used by setting aggregation buffer values, + * processing input rows from inputIter, and generating output + * rows. + * - Part 3: Methods and fields used by hash-based aggregation. + * - Part 4: The function used to switch this iterator from hash-based + * aggregation to sort-based aggregation. + * - Part 5: Methods and fields used by sort-based aggregation. + * - Part 6: Loads input and process input rows. + * - Part 7: Public methods of this iterator. + * - Part 8: A utility function used to generate a result when there is no + * input and there is no grouping expression. + * + * @param groupingExpressions + * expressions for grouping keys + * @param nonCompleteAggregateExpressions + * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Partial]], + * [[PartialMerge]], or [[Final]]. + * @param completeAggregateExpressions + * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Complete]]. + * @param initialInputBufferOffset + * If this iterator is used to handle functions with mode [[PartialMerge]] or [[Final]]. + * The input rows have the format of `grouping keys + aggregation buffer`. + * This offset indicates the starting position of aggregation buffer in a input row. + * @param resultExpressions + * expressions for generating output rows. + * @param newMutableProjection + * the function used to create mutable projections. + * @param originalInputAttributes + * attributes of representing input rows from `inputIter`. + * @param inputIter + * the iterator containing input [[UnsafeRow]]s. + */ +class TungstenAggregationIterator( + groupingExpressions: Seq[NamedExpression], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + completeAggregateExpressions: Seq[AggregateExpression2], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + originalInputAttributes: Seq[Attribute], + inputIter: Iterator[UnsafeRow], + testFallbackStartsAt: Option[Int]) + extends Iterator[UnsafeRow] with Logging { + + /////////////////////////////////////////////////////////////////////////// + // Part 1: Initializing aggregate functions. + /////////////////////////////////////////////////////////////////////////// + + // A Seq containing all AggregateExpressions. + // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final + // are at the beginning of the allAggregateExpressions. + private[this] val allAggregateExpressions: Seq[AggregateExpression2] = + nonCompleteAggregateExpressions ++ completeAggregateExpressions + + // Check to make sure we do not have more than three modes in our AggregateExpressions. + // If we have, users are hitting a bug and we throw an IllegalStateException. + if (allAggregateExpressions.map(_.mode).distinct.length > 2) { + throw new IllegalStateException( + s"$allAggregateExpressions should have no more than 2 kinds of modes.") + } + + // + // The modes of AggregateExpressions. Right now, we can handle the following mode: + // - Partial-only: + // All AggregateExpressions have the mode of Partial. + // For this case, aggregationMode is (Some(Partial), None). + // - PartialMerge-only: + // All AggregateExpressions have the mode of PartialMerge). + // For this case, aggregationMode is (Some(PartialMerge), None). + // - Final-only: + // All AggregateExpressions have the mode of Final. + // For this case, aggregationMode is (Some(Final), None). + // - Final-Complete: + // Some AggregateExpressions have the mode of Final and + // others have the mode of Complete. For this case, + // aggregationMode is (Some(Final), Some(Complete)). + // - Complete-only: + // nonCompleteAggregateExpressions is empty and we have AggregateExpressions + // with mode Complete in completeAggregateExpressions. For this case, + // aggregationMode is (None, Some(Complete)). + // - Grouping-only: + // There is no AggregateExpression. For this case, AggregationMode is (None,None). + // + private[this] var aggregationMode: (Option[AggregateMode], Option[AggregateMode]) = { + nonCompleteAggregateExpressions.map(_.mode).distinct.headOption -> + completeAggregateExpressions.map(_.mode).distinct.headOption + } + + // All aggregate functions. TungstenAggregationIterator only handles AlgebraicAggregates. + // If there is any functions that is not an AlgebraicAggregate, we throw an + // IllegalStateException. + private[this] val allAggregateFunctions: Array[AlgebraicAggregate] = { + if (!allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate])) { + throw new IllegalStateException( + "Only AlgebraicAggregates should be passed in TungstenAggregationIterator.") + } + + allAggregateExpressions + .map(_.aggregateFunction.asInstanceOf[AlgebraicAggregate]) + .toArray + } + + /////////////////////////////////////////////////////////////////////////// + // Part 2: Methods and fields used by setting aggregation buffer values, + // processing input rows from inputIter, and generating output + // rows. + /////////////////////////////////////////////////////////////////////////// + + // The projection used to initialize buffer values. + private[this] val algebraicInitialProjection: MutableProjection = { + val initExpressions = allAggregateFunctions.flatMap(_.initialValues) + newMutableProjection(initExpressions, Nil)() + } + + // Creates a new aggregation buffer and initializes buffer values. + // This functions should be only called at most three times (when we create the hash map, + // when we switch to sort-based aggregation, and when we create the re-used buffer for + // sort-based aggregation). + private def createNewAggregationBuffer(): UnsafeRow = { + val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) + val bufferRowSize: Int = bufferSchema.length + + val genericMutableBuffer = new GenericMutableRow(bufferRowSize) + val unsafeProjection = + UnsafeProjection.create(bufferSchema.map(_.dataType)) + val buffer = unsafeProjection.apply(genericMutableBuffer) + algebraicInitialProjection.target(buffer)(EmptyRow) + buffer + } + + // Creates a function used to process a row based on the given inputAttributes. + private def generateProcessRow( + inputAttributes: Seq[Attribute]): (UnsafeRow, UnsafeRow) => Unit = { + + val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes) + val aggregationBufferSchema = StructType.fromAttributes(aggregationBufferAttributes) + val inputSchema = StructType.fromAttributes(inputAttributes) + val unsafeRowJoiner = + GenerateUnsafeRowJoiner.create(aggregationBufferSchema, inputSchema) + + aggregationMode match { + // Partial-only + case (Some(Partial), None) => + val updateExpressions = allAggregateFunctions.flatMap(_.updateExpressions) + val algebraicUpdateProjection = + newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() + + (currentBuffer: UnsafeRow, row: UnsafeRow) => { + algebraicUpdateProjection.target(currentBuffer) + algebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row)) + } + + // PartialMerge-only or Final-only + case (Some(PartialMerge), None) | (Some(Final), None) => + val mergeExpressions = allAggregateFunctions.flatMap(_.mergeExpressions) + // This projection is used to merge buffer values for all AlgebraicAggregates. + val algebraicMergeProjection = + newMutableProjection( + mergeExpressions, + aggregationBufferAttributes ++ inputAttributes)() + + (currentBuffer: UnsafeRow, row: UnsafeRow) => { + // Process all algebraic aggregate functions. + algebraicMergeProjection.target(currentBuffer) + algebraicMergeProjection(unsafeRowJoiner.join(currentBuffer, row)) + } + + // Final-Complete + case (Some(Final), Some(Complete)) => + val nonCompleteAggregateFunctions: Array[AlgebraicAggregate] = + allAggregateFunctions.take(nonCompleteAggregateExpressions.length) + val completeAggregateFunctions: Array[AlgebraicAggregate] = + allAggregateFunctions.takeRight(completeAggregateExpressions.length) + + val completeOffsetExpressions = + Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) + val mergeExpressions = + nonCompleteAggregateFunctions.flatMap(_.mergeExpressions) ++ completeOffsetExpressions + val finalAlgebraicMergeProjection = + newMutableProjection( + mergeExpressions, + aggregationBufferAttributes ++ inputAttributes)() + + // We do not touch buffer values of aggregate functions with the Final mode. + val finalOffsetExpressions = + Seq.fill(nonCompleteAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) + val updateExpressions = + finalOffsetExpressions ++ completeAggregateFunctions.flatMap(_.updateExpressions) + val completeAlgebraicUpdateProjection = + newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() + + (currentBuffer: UnsafeRow, row: UnsafeRow) => { + val input = unsafeRowJoiner.join(currentBuffer, row) + // For all aggregate functions with mode Complete, update the given currentBuffer. + completeAlgebraicUpdateProjection.target(currentBuffer)(input) + + // For all aggregate functions with mode Final, merge buffer values in row to + // currentBuffer. + finalAlgebraicMergeProjection.target(currentBuffer)(input) + } + + // Complete-only + case (None, Some(Complete)) => + val completeAggregateFunctions: Array[AlgebraicAggregate] = + allAggregateFunctions.takeRight(completeAggregateExpressions.length) + + val updateExpressions = + completeAggregateFunctions.flatMap(_.updateExpressions) + val completeAlgebraicUpdateProjection = + newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() + + (currentBuffer: UnsafeRow, row: UnsafeRow) => { + completeAlgebraicUpdateProjection.target(currentBuffer) + // For all aggregate functions with mode Complete, update the given currentBuffer. + completeAlgebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row)) + } + + // Grouping only. + case (None, None) => (currentBuffer: UnsafeRow, row: UnsafeRow) => {} + + case other => + throw new IllegalStateException( + s"${aggregationMode} should not be passed into TungstenAggregationIterator.") + } + } + + // Creates a function used to generate output rows. + private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow = { + + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + val bufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes) + val bufferSchema = StructType.fromAttributes(bufferAttributes) + val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + + aggregationMode match { + // Partial-only or PartialMerge-only: every output row is basically the values of + // the grouping expressions and the corresponding aggregation buffer. + case (Some(Partial), None) | (Some(PartialMerge), None) => + (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { + unsafeRowJoiner.join(currentGroupingKey, currentBuffer) + } + + // Final-only, Complete-only and Final-Complete: a output row is generated based on + // resultExpressions. + case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => + val resultProjection = + UnsafeProjection.create(resultExpressions, groupingAttributes ++ bufferAttributes) + + (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { + resultProjection(unsafeRowJoiner.join(currentGroupingKey, currentBuffer)) + } + + // Grouping-only: a output row is generated from values of grouping expressions. + case (None, None) => + val resultProjection = + UnsafeProjection.create(resultExpressions, groupingAttributes) + + (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { + resultProjection(currentGroupingKey) + } + + case other => + throw new IllegalStateException( + s"${aggregationMode} should not be passed into TungstenAggregationIterator.") + } + } + + // An UnsafeProjection used to extract grouping keys from the input rows. + private[this] val groupProjection = + UnsafeProjection.create(groupingExpressions, originalInputAttributes) + + // A function used to process a input row. Its first argument is the aggregation buffer + // and the second argument is the input row. + private[this] var processRow: (UnsafeRow, UnsafeRow) => Unit = + generateProcessRow(originalInputAttributes) + + // A function used to generate output rows based on the grouping keys (first argument) + // and the corresponding aggregation buffer (second argument). + private[this] var generateOutput: (UnsafeRow, UnsafeRow) => UnsafeRow = + generateResultProjection() + + // An aggregation buffer containing initial buffer values. It is used to + // initialize other aggregation buffers. + private[this] val initialAggregationBuffer: UnsafeRow = createNewAggregationBuffer() + + /////////////////////////////////////////////////////////////////////////// + // Part 3: Methods and fields used by hash-based aggregation. + /////////////////////////////////////////////////////////////////////////// + + // This is the hash map used for hash-based aggregation. It is backed by an + // UnsafeFixedWidthAggregationMap and it is used to store + // all groups and their corresponding aggregation buffers for hash-based aggregation. + private[this] val hashMap = new UnsafeFixedWidthAggregationMap( + initialAggregationBuffer, + StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)), + StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), + TaskContext.get.taskMemoryManager(), + SparkEnv.get.shuffleMemoryManager, + 1024 * 16, // initial capacity + SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m"), + false // disable tracking of performance metrics + ) + + // The function used to read and process input rows. When processing input rows, + // it first uses hash-based aggregation by putting groups and their buffers in + // hashMap. If we could not allocate more memory for the map, we switch to + // sort-based aggregation (by calling switchToSortBasedAggregation). + private def processInputs(): Unit = { + while (!sortBased && inputIter.hasNext) { + val newInput = inputIter.next() + val groupingKey = groupProjection.apply(newInput) + val buffer: UnsafeRow = hashMap.getAggregationBuffer(groupingKey) + if (buffer == null) { + // buffer == null means that we could not allocate more memory. + // Now, we need to spill the map and switch to sort-based aggregation. + switchToSortBasedAggregation(groupingKey, newInput) + } else { + processRow(buffer, newInput) + } + } + } + + // This function is only used for testing. It basically the same as processInputs except + // that it switch to sort-based aggregation after `fallbackStartsAt` input rows have + // been processed. + private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = { + var i = 0 + while (!sortBased && inputIter.hasNext) { + val newInput = inputIter.next() + val groupingKey = groupProjection.apply(newInput) + val buffer: UnsafeRow = if (i < fallbackStartsAt) { + hashMap.getAggregationBuffer(groupingKey) + } else { + null + } + if (buffer == null) { + // buffer == null means that we could not allocate more memory. + // Now, we need to spill the map and switch to sort-based aggregation. + switchToSortBasedAggregation(groupingKey, newInput) + } else { + processRow(buffer, newInput) + } + i += 1 + } + } + + // The iterator created from hashMap. It is used to generate output rows when we + // are using hash-based aggregation. + private[this] var aggregationBufferMapIterator: KVIterator[UnsafeRow, UnsafeRow] = null + + // Indicates if aggregationBufferMapIterator still has key-value pairs. + private[this] var mapIteratorHasNext: Boolean = false + + /////////////////////////////////////////////////////////////////////////// + // Part 4: The function used to switch this iterator from hash-based + // aggregation to sort-based aggregation. + /////////////////////////////////////////////////////////////////////////// + + private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: UnsafeRow): Unit = { + logInfo("falling back to sort based aggregation.") + // Step 1: Get the ExternalSorter containing sorted entries of the map. + val externalSorter: UnsafeKVExternalSorter = hashMap.destructAndCreateExternalSorter() + + // Step 2: Free the memory used by the map. + hashMap.free() + + // Step 3: If we have aggregate function with mode Partial or Complete, + // we need to process input rows to get aggregation buffer. + // So, later in the sort-based aggregation iterator, we can do merge. + // If aggregate functions are with mode Final and PartialMerge, + // we just need to project the aggregation buffer from an input row. + val needsProcess = aggregationMode match { + case (Some(Partial), None) => true + case (None, Some(Complete)) => true + case (Some(Final), Some(Complete)) => true + case _ => false + } + + if (needsProcess) { + // First, we create a buffer. + val buffer = createNewAggregationBuffer() + + // Process firstKey and firstInput. + // Initialize buffer. + buffer.copyFrom(initialAggregationBuffer) + processRow(buffer, firstInput) + externalSorter.insertKV(firstKey, buffer) + + // Process the rest of input rows. + while (inputIter.hasNext) { + val newInput = inputIter.next() + val groupingKey = groupProjection.apply(newInput) + buffer.copyFrom(initialAggregationBuffer) + processRow(buffer, newInput) + externalSorter.insertKV(groupingKey, buffer) + } + } else { + // When needsProcess is false, the format of input rows is groupingKey + aggregation buffer. + // We need to project the aggregation buffer part from an input row. + val buffer = createNewAggregationBuffer() + // The originalInputAttributes are using cloneBufferAttributes. So, we need to use + // allAggregateFunctions.flatMap(_.cloneBufferAttributes). + val bufferExtractor = newMutableProjection( + allAggregateFunctions.flatMap(_.cloneBufferAttributes), + originalInputAttributes)() + bufferExtractor.target(buffer) + + // Insert firstKey and its buffer. + bufferExtractor(firstInput) + externalSorter.insertKV(firstKey, buffer) + + // Insert the rest of input rows. + while (inputIter.hasNext) { + val newInput = inputIter.next() + val groupingKey = groupProjection.apply(newInput) + bufferExtractor(newInput) + externalSorter.insertKV(groupingKey, buffer) + } + } + + // Set aggregationMode, processRow, and generateOutput for sort-based aggregation. + val newAggregationMode = aggregationMode match { + case (Some(Partial), None) => (Some(PartialMerge), None) + case (None, Some(Complete)) => (Some(Final), None) + case (Some(Final), Some(Complete)) => (Some(Final), None) + case other => other + } + aggregationMode = newAggregationMode + + // Basically the value of the KVIterator returned by externalSorter + // will just aggregation buffer. At here, we use cloneBufferAttributes. + val newInputAttributes: Seq[Attribute] = + allAggregateFunctions.flatMap(_.cloneBufferAttributes) + + // Set up new processRow and generateOutput. + processRow = generateProcessRow(newInputAttributes) + generateOutput = generateResultProjection() + + // Step 5: Get the sorted iterator from the externalSorter. + sortedKVIterator = externalSorter.sortedIterator() + + // Step 6: Pre-load the first key-value pair from the sorted iterator to make + // hasNext idempotent. + sortedInputHasNewGroup = sortedKVIterator.next() + + // Copy the first key and value (aggregation buffer). + if (sortedInputHasNewGroup) { + val key = sortedKVIterator.getKey + val value = sortedKVIterator.getValue + nextGroupingKey = key.copy() + currentGroupingKey = key.copy() + firstRowInNextGroup = value.copy() + } + + // Step 7: set sortBased to true. + sortBased = true + } + + /////////////////////////////////////////////////////////////////////////// + // Part 5: Methods and fields used by sort-based aggregation. + /////////////////////////////////////////////////////////////////////////// + + // Indicates if we are using sort-based aggregation. Because we first try to use + // hash-based aggregation, its initial value is false. + private[this] var sortBased: Boolean = false + + // The KVIterator containing input rows for the sort-based aggregation. It will be + // set in switchToSortBasedAggregation when we switch to sort-based aggregation. + private[this] var sortedKVIterator: UnsafeKVExternalSorter#KVSorterIterator = null + + // The grouping key of the current group. + private[this] var currentGroupingKey: UnsafeRow = null + + // The grouping key of next group. + private[this] var nextGroupingKey: UnsafeRow = null + + // The first row of next group. + private[this] var firstRowInNextGroup: UnsafeRow = null + + // Indicates if we has new group of rows from the sorted input iterator. + private[this] var sortedInputHasNewGroup: Boolean = false + + // The aggregation buffer used by the sort-based aggregation. + private[this] val sortBasedAggregationBuffer: UnsafeRow = createNewAggregationBuffer() + + // Processes rows in the current group. It will stop when it find a new group. + private def processCurrentSortedGroup(): Unit = { + // First, we need to copy nextGroupingKey to currentGroupingKey. + currentGroupingKey.copyFrom(nextGroupingKey) + // Now, we will start to find all rows belonging to this group. + // We create a variable to track if we see the next group. + var findNextPartition = false + // firstRowInNextGroup is the first row of this group. We first process it. + processRow(sortBasedAggregationBuffer, firstRowInNextGroup) + + // The search will stop when we see the next group or there is no + // input row left in the iter. + // Pre-load the first key-value pair to make the condition of the while loop + // has no action (we do not trigger loading a new key-value pair + // when we evaluate the condition). + var hasNext = sortedKVIterator.next() + while (!findNextPartition && hasNext) { + // Get the grouping key and value (aggregation buffer). + val groupingKey = sortedKVIterator.getKey + val inputAggregationBuffer = sortedKVIterator.getValue + + // Check if the current row belongs the current input row. + if (currentGroupingKey.equals(groupingKey)) { + processRow(sortBasedAggregationBuffer, inputAggregationBuffer) + + hasNext = sortedKVIterator.next() + } else { + // We find a new group. + findNextPartition = true + // copyFrom will fail when + nextGroupingKey.copyFrom(groupingKey) // = groupingKey.copy() + firstRowInNextGroup.copyFrom(inputAggregationBuffer) // = inputAggregationBuffer.copy() + + } + } + // We have not seen a new group. It means that there is no new row in the input + // iter. The current group is the last group of the sortedKVIterator. + if (!findNextPartition) { + sortedInputHasNewGroup = false + sortedKVIterator.close() + } + } + + /////////////////////////////////////////////////////////////////////////// + // Part 6: Loads input rows and setup aggregationBufferMapIterator if we + // have not switched to sort-based aggregation. + /////////////////////////////////////////////////////////////////////////// + + // Starts to process input rows. + testFallbackStartsAt match { + case None => + processInputs() + case Some(fallbackStartsAt) => + // This is the testing path. processInputsWithControlledFallback is same as processInputs + // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows + // have been processed. + processInputsWithControlledFallback(fallbackStartsAt) + } + + // If we did not switch to sort-based aggregation in processInputs, + // we pre-load the first key-value pair from the map (to make hasNext idempotent). + if (!sortBased) { + // First, set aggregationBufferMapIterator. + aggregationBufferMapIterator = hashMap.iterator() + // Pre-load the first key-value pair from the aggregationBufferMapIterator. + mapIteratorHasNext = aggregationBufferMapIterator.next() + // If the map is empty, we just free it. + if (!mapIteratorHasNext) { + hashMap.free() + } + } + + /////////////////////////////////////////////////////////////////////////// + // Par 7: Iterator's public methods. + /////////////////////////////////////////////////////////////////////////// + + override final def hasNext: Boolean = { + (sortBased && sortedInputHasNewGroup) || (!sortBased && mapIteratorHasNext) + } + + override final def next(): UnsafeRow = { + if (hasNext) { + if (sortBased) { + // Process the current group. + processCurrentSortedGroup() + // Generate output row for the current group. + val outputRow = generateOutput(currentGroupingKey, sortBasedAggregationBuffer) + // Initialize buffer values for the next group. + sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer) + + outputRow + } else { + // We did not fall back to sort-based aggregation. + val result = + generateOutput( + aggregationBufferMapIterator.getKey, + aggregationBufferMapIterator.getValue) + + // Pre-load next key-value pair form aggregationBufferMapIterator to make hasNext + // idempotent. + mapIteratorHasNext = aggregationBufferMapIterator.next() + + if (!mapIteratorHasNext) { + // If there is no input from aggregationBufferMapIterator, we copy current result. + val resultCopy = result.copy() + // Then, we free the map. + hashMap.free() + + resultCopy + } else { + result + } + } + } else { + // no more result + throw new NoSuchElementException + } + } + + /////////////////////////////////////////////////////////////////////////// + // Part 8: A utility function used to generate a output row when there is no + // input and there is no grouping expression. + /////////////////////////////////////////////////////////////////////////// + def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { + if (groupingExpressions.isEmpty) { + sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer) + // We create a output row and copy it. So, we can free the map. + val resultCopy = + generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy() + hashMap.free() + resultCopy + } else { + throw new IllegalStateException( + "This method should not be called when groupingExpressions is not empty.") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala deleted file mode 100644 index b465787fe8cbd..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala +++ /dev/null @@ -1,372 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.aggregate - -import org.apache.spark.unsafe.KVIterator -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap} -import org.apache.spark.sql.types.StructType - -/** - * An iterator used to evaluate [[AggregateFunction2]]. - * It first tries to use in-memory hash-based aggregation. If we cannot allocate more - * space for the hash map, we spill the sorted map entries, free the map, and then - * switch to sort-based aggregation. - */ -class UnsafeHybridAggregationIterator( - groupingKeyAttributes: Seq[Attribute], - valueAttributes: Seq[Attribute], - inputKVIterator: KVIterator[UnsafeRow, InternalRow], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean) - extends AggregationIterator( - groupingKeyAttributes, - valueAttributes, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - outputsUnsafeRows) { - - require(groupingKeyAttributes.nonEmpty) - - /////////////////////////////////////////////////////////////////////////// - // Unsafe Aggregation buffers - /////////////////////////////////////////////////////////////////////////// - - // This is the Unsafe Aggregation Map used to store all buffers. - private[this] val buffers = new UnsafeFixedWidthAggregationMap( - newBuffer, - StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)), - StructType.fromAttributes(groupingKeyAttributes), - TaskContext.get.taskMemoryManager(), - SparkEnv.get.shuffleMemoryManager, - 1024 * 16, // initial capacity - SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m"), - false // disable tracking of performance metrics - ) - - override protected def newBuffer: UnsafeRow = { - val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) - val bufferRowSize: Int = bufferSchema.length - - val genericMutableBuffer = new GenericMutableRow(bufferRowSize) - val unsafeProjection = - UnsafeProjection.create(bufferSchema.map(_.dataType)) - val buffer = unsafeProjection.apply(genericMutableBuffer) - initializeBuffer(buffer) - buffer - } - - /////////////////////////////////////////////////////////////////////////// - // Methods and variables related to switching to sort-based aggregation - /////////////////////////////////////////////////////////////////////////// - private[this] var sortBased = false - - private[this] var sortBasedAggregationIterator: SortBasedAggregationIterator = _ - - // The value part of the input KV iterator is used to store original input values of - // aggregate functions, we need to convert them to aggregation buffers. - private def processOriginalInput( - firstKey: UnsafeRow, - firstValue: InternalRow): KVIterator[UnsafeRow, UnsafeRow] = { - new KVIterator[UnsafeRow, UnsafeRow] { - private[this] var isFirstRow = true - - private[this] var groupingKey: UnsafeRow = _ - - private[this] val buffer: UnsafeRow = newBuffer - - override def next(): Boolean = { - initializeBuffer(buffer) - if (isFirstRow) { - isFirstRow = false - groupingKey = firstKey - processRow(buffer, firstValue) - - true - } else if (inputKVIterator.next()) { - groupingKey = inputKVIterator.getKey() - val value = inputKVIterator.getValue() - processRow(buffer, value) - - true - } else { - false - } - } - - override def getKey(): UnsafeRow = { - groupingKey - } - - override def getValue(): UnsafeRow = { - buffer - } - - override def close(): Unit = { - // Do nothing. - } - } - } - - // The value of the input KV Iterator has the format of groupingExprs + aggregation buffer. - // We need to project the aggregation buffer out. - private def projectInputBufferToUnsafe( - firstKey: UnsafeRow, - firstValue: InternalRow): KVIterator[UnsafeRow, UnsafeRow] = { - new KVIterator[UnsafeRow, UnsafeRow] { - private[this] var isFirstRow = true - - private[this] var groupingKey: UnsafeRow = _ - - private[this] val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) - - private[this] val value: UnsafeRow = { - val genericMutableRow = new GenericMutableRow(bufferSchema.length) - UnsafeProjection.create(bufferSchema.map(_.dataType)).apply(genericMutableRow) - } - - private[this] val projectInputBuffer = { - newMutableProjection(bufferSchema, valueAttributes)().target(value) - } - - override def next(): Boolean = { - if (isFirstRow) { - isFirstRow = false - groupingKey = firstKey - projectInputBuffer(firstValue) - - true - } else if (inputKVIterator.next()) { - groupingKey = inputKVIterator.getKey() - projectInputBuffer(inputKVIterator.getValue()) - - true - } else { - false - } - } - - override def getKey(): UnsafeRow = { - groupingKey - } - - override def getValue(): UnsafeRow = { - value - } - - override def close(): Unit = { - // Do nothing. - } - } - } - - /** - * We need to fall back to sort based aggregation because we do not have enough memory - * for our in-memory hash map (i.e. `buffers`). - */ - private def switchToSortBasedAggregation( - currentGroupingKey: UnsafeRow, - currentRow: InternalRow): Unit = { - logInfo("falling back to sort based aggregation.") - - // Step 1: Get the ExternalSorter containing entries of the map. - val externalSorter = buffers.destructAndCreateExternalSorter() - - // Step 2: Free the memory used by the map. - buffers.free() - - // Step 3: If we have aggregate function with mode Partial or Complete, - // we need to process them to get aggregation buffer. - // So, later in the sort-based aggregation iterator, we can do merge. - // If aggregate functions are with mode Final and PartialMerge, - // we just need to project the aggregation buffer from the input. - val needsProcess = aggregationMode match { - case (Some(Partial), None) => true - case (None, Some(Complete)) => true - case (Some(Final), Some(Complete)) => true - case _ => false - } - - val processedIterator = if (needsProcess) { - processOriginalInput(currentGroupingKey, currentRow) - } else { - // The input value's format is groupingExprs + buffer. - // We need to project the buffer part out. - projectInputBufferToUnsafe(currentGroupingKey, currentRow) - } - - // Step 4: Redirect processedIterator to externalSorter. - while (processedIterator.next()) { - externalSorter.insertKV(processedIterator.getKey(), processedIterator.getValue()) - } - - // Step 5: Get the sorted iterator from the externalSorter. - val sortedKVIterator: UnsafeKVExternalSorter#KVSorterIterator = externalSorter.sortedIterator() - - // Step 6: We now create a SortBasedAggregationIterator based on sortedKVIterator. - // For a aggregate function with mode Partial, its mode in the SortBasedAggregationIterator - // will be PartialMerge. For a aggregate function with mode Complete, - // its mode in the SortBasedAggregationIterator will be Final. - val newNonCompleteAggregateExpressions = allAggregateExpressions.map { - case AggregateExpression2(func, Partial, isDistinct) => - AggregateExpression2(func, PartialMerge, isDistinct) - case AggregateExpression2(func, Complete, isDistinct) => - AggregateExpression2(func, Final, isDistinct) - case other => other - } - val newNonCompleteAggregateAttributes = - nonCompleteAggregateAttributes ++ completeAggregateAttributes - - val newValueAttributes = - allAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes) - - sortBasedAggregationIterator = SortBasedAggregationIterator.createFromKVIterator( - groupingKeyAttributes = groupingKeyAttributes, - valueAttributes = newValueAttributes, - inputKVIterator = sortedKVIterator.asInstanceOf[KVIterator[InternalRow, InternalRow]], - nonCompleteAggregateExpressions = newNonCompleteAggregateExpressions, - nonCompleteAggregateAttributes = newNonCompleteAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = 0, - resultExpressions = resultExpressions, - newMutableProjection = newMutableProjection, - outputsUnsafeRows = outputsUnsafeRows) - } - - /////////////////////////////////////////////////////////////////////////// - // Methods used to initialize this iterator. - /////////////////////////////////////////////////////////////////////////// - - /** Starts to read input rows and falls back to sort-based aggregation if necessary. */ - protected def initialize(): Unit = { - var hasNext = inputKVIterator.next() - while (!sortBased && hasNext) { - val groupingKey = inputKVIterator.getKey() - val currentRow = inputKVIterator.getValue() - val buffer = buffers.getAggregationBuffer(groupingKey) - if (buffer == null) { - // buffer == null means that we could not allocate more memory. - // Now, we need to spill the map and switch to sort-based aggregation. - switchToSortBasedAggregation(groupingKey, currentRow) - sortBased = true - } else { - processRow(buffer, currentRow) - hasNext = inputKVIterator.next() - } - } - } - - // This is the starting point of this iterator. - initialize() - - // Creates the iterator for the Hash Aggregation Map after we have populated - // contents of that map. - private[this] val aggregationBufferMapIterator = buffers.iterator() - - private[this] var _mapIteratorHasNext = false - - // Pre-load the first key-value pair from the map to make hasNext idempotent. - if (!sortBased) { - _mapIteratorHasNext = aggregationBufferMapIterator.next() - // If the map is empty, we just free it. - if (!_mapIteratorHasNext) { - buffers.free() - } - } - - /////////////////////////////////////////////////////////////////////////// - // Iterator's public methods - /////////////////////////////////////////////////////////////////////////// - - override final def hasNext: Boolean = { - (sortBased && sortBasedAggregationIterator.hasNext) || (!sortBased && _mapIteratorHasNext) - } - - - override final def next(): InternalRow = { - if (hasNext) { - if (sortBased) { - sortBasedAggregationIterator.next() - } else { - // We did not fall back to the sort-based aggregation. - val result = - generateOutput( - aggregationBufferMapIterator.getKey, - aggregationBufferMapIterator.getValue) - // Pre-load next key-value pair form aggregationBufferMapIterator. - _mapIteratorHasNext = aggregationBufferMapIterator.next() - - if (!_mapIteratorHasNext) { - val resultCopy = result.copy() - buffers.free() - resultCopy - } else { - result - } - } - } else { - // no more result - throw new NoSuchElementException - } - } -} - -object UnsafeHybridAggregationIterator { - // scalastyle:off - def createFromInputIterator( - groupingExprs: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow], - outputsUnsafeRows: Boolean): UnsafeHybridAggregationIterator = { - new UnsafeHybridAggregationIterator( - groupingExprs.map(_.toAttribute), - inputAttributes, - AggregationIterator.unsafeKVIterator(groupingExprs, inputAttributes, inputIter), - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - outputsUnsafeRows) - } - // scalastyle:on -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 960be08f84d94..80816a095ea8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -17,20 +17,41 @@ package org.apache.spark.sql.execution.aggregate +import scala.collection.mutable + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan} +import org.apache.spark.sql.types.StructType /** * Utility functions used by the query planner to convert our plan to new aggregation code path. */ object Utils { + def supportsTungstenAggregate( + groupingExpressions: Seq[Expression], + aggregateBufferAttributes: Seq[Attribute]): Boolean = { + val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes) + + UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && + UnsafeProjection.canSupport(groupingExpressions) + } + def planAggregateWithoutDistinct( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[AggregateExpression2], - aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute], + aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { + // Check if we can use TungstenAggregate. + val usesTungstenAggregate = + child.sqlContext.conf.unsafeEnabled && + aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) && + supportsTungstenAggregate( + groupingExpressions, + aggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)) + + // 1. Create an Aggregate Operator for partial aggregations. val namedGroupingExpressions = groupingExpressions.map { case ne: NamedExpression => ne -> ne @@ -44,11 +65,23 @@ object Utils { val groupExpressionMap = namedGroupingExpressions.toMap val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial)) - val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => - agg.aggregateFunction.bufferAttributes - } - val partialAggregate = - Aggregate( + val partialAggregateAttributes = + partialAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes) + val partialResultExpressions = + namedGroupingAttributes ++ + partialAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes) + + val partialAggregate = if (usesTungstenAggregate) { + TungstenAggregate( + requiredChildDistributionExpressions = None: Option[Seq[Expression]], + groupingExpressions = namedGroupingExpressions.map(_._2), + nonCompleteAggregateExpressions = partialAggregateExpressions, + completeAggregateExpressions = Nil, + initialInputBufferOffset = 0, + resultExpressions = partialResultExpressions, + child = child) + } else { + SortBasedAggregate( requiredChildDistributionExpressions = None: Option[Seq[Expression]], groupingExpressions = namedGroupingExpressions.map(_._2), nonCompleteAggregateExpressions = partialAggregateExpressions, @@ -56,29 +89,57 @@ object Utils { completeAggregateExpressions = Nil, completeAggregateAttributes = Nil, initialInputBufferOffset = 0, - resultExpressions = namedGroupingAttributes ++ partialAggregateAttributes, + resultExpressions = partialResultExpressions, child = child) + } // 2. Create an Aggregate Operator for final aggregations. val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) val finalAggregateAttributes = finalAggregateExpressions.map { - expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) + expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2 } - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transformDown { - case agg: AggregateExpression2 => - aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] - } - val finalAggregate = - Aggregate( + + val finalAggregate = if (usesTungstenAggregate) { + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case agg: AggregateExpression2 => + // aggregateFunctionMap contains unique aggregate functions. + val aggregateFunction = + aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._1 + aggregateFunction.asInstanceOf[AlgebraicAggregate].evaluateExpression + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + + TungstenAggregate( + requiredChildDistributionExpressions = Some(namedGroupingAttributes), + groupingExpressions = namedGroupingAttributes, + nonCompleteAggregateExpressions = finalAggregateExpressions, + completeAggregateExpressions = Nil, + initialInputBufferOffset = namedGroupingAttributes.length, + resultExpressions = rewrittenResultExpressions, + child = partialAggregate) + } else { + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case agg: AggregateExpression2 => + aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2 + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + + SortBasedAggregate( requiredChildDistributionExpressions = Some(namedGroupingAttributes), groupingExpressions = namedGroupingAttributes, nonCompleteAggregateExpressions = finalAggregateExpressions, @@ -88,6 +149,7 @@ object Utils { initialInputBufferOffset = namedGroupingAttributes.length, resultExpressions = rewrittenResultExpressions, child = partialAggregate) + } finalAggregate :: Nil } @@ -96,10 +158,18 @@ object Utils { groupingExpressions: Seq[Expression], functionsWithDistinct: Seq[AggregateExpression2], functionsWithoutDistinct: Seq[AggregateExpression2], - aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute], + aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { + val aggregateExpressions = functionsWithDistinct ++ functionsWithoutDistinct + val usesTungstenAggregate = + child.sqlContext.conf.unsafeEnabled && + aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) && + supportsTungstenAggregate( + groupingExpressions, + aggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)) + // 1. Create an Aggregate Operator for partial aggregations. // The grouping expressions are original groupingExpressions and // distinct columns. For example, for avg(distinct value) ... group by key @@ -129,19 +199,26 @@ object Utils { val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute) - val partialAggregateExpressions = functionsWithoutDistinct.map { - case AggregateExpression2(aggregateFunction, mode, _) => - AggregateExpression2(aggregateFunction, Partial, false) - } - val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => - agg.aggregateFunction.bufferAttributes - } + val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val partialAggregateAttributes = + partialAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes) val partialAggregateGroupingExpressions = (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2) val partialAggregateResult = - namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes - val partialAggregate = - Aggregate( + namedGroupingAttributes ++ + distinctColumnAttributes ++ + partialAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes) + val partialAggregate = if (usesTungstenAggregate) { + TungstenAggregate( + requiredChildDistributionExpressions = None: Option[Seq[Expression]], + groupingExpressions = partialAggregateGroupingExpressions, + nonCompleteAggregateExpressions = partialAggregateExpressions, + completeAggregateExpressions = Nil, + initialInputBufferOffset = 0, + resultExpressions = partialAggregateResult, + child = child) + } else { + SortBasedAggregate( requiredChildDistributionExpressions = None: Option[Seq[Expression]], groupingExpressions = partialAggregateGroupingExpressions, nonCompleteAggregateExpressions = partialAggregateExpressions, @@ -151,20 +228,27 @@ object Utils { initialInputBufferOffset = 0, resultExpressions = partialAggregateResult, child = child) + } // 2. Create an Aggregate Operator for partial merge aggregations. - val partialMergeAggregateExpressions = functionsWithoutDistinct.map { - case AggregateExpression2(aggregateFunction, mode, _) => - AggregateExpression2(aggregateFunction, PartialMerge, false) - } + val partialMergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) val partialMergeAggregateAttributes = - partialMergeAggregateExpressions.flatMap { agg => - agg.aggregateFunction.bufferAttributes - } + partialMergeAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes) val partialMergeAggregateResult = - namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes - val partialMergeAggregate = - Aggregate( + namedGroupingAttributes ++ + distinctColumnAttributes ++ + partialMergeAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes) + val partialMergeAggregate = if (usesTungstenAggregate) { + TungstenAggregate( + requiredChildDistributionExpressions = Some(namedGroupingAttributes), + groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes, + nonCompleteAggregateExpressions = partialMergeAggregateExpressions, + completeAggregateExpressions = Nil, + initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, + resultExpressions = partialMergeAggregateResult, + child = partialAggregate) + } else { + SortBasedAggregate( requiredChildDistributionExpressions = Some(namedGroupingAttributes), groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes, nonCompleteAggregateExpressions = partialMergeAggregateExpressions, @@ -174,48 +258,91 @@ object Utils { initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, resultExpressions = partialMergeAggregateResult, child = partialAggregate) + } // 3. Create an Aggregate Operator for partial merge aggregations. - val finalAggregateExpressions = functionsWithoutDistinct.map { - case AggregateExpression2(aggregateFunction, mode, _) => - AggregateExpression2(aggregateFunction, Final, false) - } + val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) val finalAggregateAttributes = finalAggregateExpressions.map { - expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) + expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2 } + // Create a map to store those rewritten aggregate functions. We always need to use + // both function and its corresponding isDistinct flag as the key because function itself + // does not knows if it is has distinct keyword or now. + val rewrittenAggregateFunctions = + mutable.Map.empty[(AggregateFunction2, Boolean), AggregateFunction2] val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map { // Children of an AggregateFunction with DISTINCT keyword has already // been evaluated. At here, we need to replace original children // to AttributeReferences. - case agg @ AggregateExpression2(aggregateFunction, mode, isDistinct) => + case agg @ AggregateExpression2(aggregateFunction, mode, true) => val rewrittenAggregateFunction = aggregateFunction.transformDown { case expr if distinctColumnExpressionMap.contains(expr) => distinctColumnExpressionMap(expr).toAttribute }.asInstanceOf[AggregateFunction2] + // Because we have rewritten the aggregate function, we use rewrittenAggregateFunctions + // to track the old version and the new version of this function. + rewrittenAggregateFunctions += (aggregateFunction, true) -> rewrittenAggregateFunction // We rewrite the aggregate function to a non-distinct aggregation because // its input will have distinct arguments. + // We just keep the isDistinct setting to true, so when users look at the query plan, + // they still can see distinct aggregations. val rewrittenAggregateExpression = - AggregateExpression2(rewrittenAggregateFunction, Complete, false) + AggregateExpression2(rewrittenAggregateFunction, Complete, true) - val aggregateFunctionAttribute = aggregateFunctionMap(agg.aggregateFunction, isDistinct) + val aggregateFunctionAttribute = + aggregateFunctionMap(agg.aggregateFunction, true)._2 (rewrittenAggregateExpression -> aggregateFunctionAttribute) }.unzip - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transform { - case agg: AggregateExpression2 => - aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] - } - val finalAndCompleteAggregate = - Aggregate( + val finalAndCompleteAggregate = if (usesTungstenAggregate) { + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transform { + case agg: AggregateExpression2 => + val function = agg.aggregateFunction + val isDistinct = agg.isDistinct + val aggregateFunction = + if (rewrittenAggregateFunctions.contains(function, isDistinct)) { + // If this function has been rewritten, we get the rewritten version from + // rewrittenAggregateFunctions. + rewrittenAggregateFunctions(function, isDistinct) + } else { + // Oterwise, we get it from aggregateFunctionMap, which contains unique + // aggregate functions that have not been rewritten. + aggregateFunctionMap(function, isDistinct)._1 + } + aggregateFunction.asInstanceOf[AlgebraicAggregate].evaluateExpression + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + + TungstenAggregate( + requiredChildDistributionExpressions = Some(namedGroupingAttributes), + groupingExpressions = namedGroupingAttributes, + nonCompleteAggregateExpressions = finalAggregateExpressions, + completeAggregateExpressions = completeAggregateExpressions, + initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, + resultExpressions = rewrittenResultExpressions, + child = partialMergeAggregate) + } else { + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transform { + case agg: AggregateExpression2 => + aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2 + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + SortBasedAggregate( requiredChildDistributionExpressions = Some(namedGroupingAttributes), groupingExpressions = namedGroupingAttributes, nonCompleteAggregateExpressions = finalAggregateExpressions, @@ -225,6 +352,7 @@ object Utils { initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, resultExpressions = rewrittenResultExpressions, child = partialMergeAggregate) + } finalAndCompleteAggregate :: Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index cef40dd324d9e..c64aa7a07dc2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -262,7 +262,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { val df = sql(sqlText) // First, check if we have GeneratedAggregate. val hasGeneratedAgg = df.queryExecution.executedPlan - .collect { case _: aggregate.Aggregate => true } + .collect { case _: aggregate.TungstenAggregate => true } .nonEmpty if (!hasGeneratedAgg) { fail( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 4b35c8fd83533..7b5aa4763fd9e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -21,9 +21,9 @@ import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -import org.apache.spark.sql.{SQLConf, AnalysisException, QueryTest, Row} +import org.apache.spark.sql._ import org.scalatest.BeforeAndAfterAll -import test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} +import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { @@ -141,6 +141,22 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be Nil) } + test("null literal") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | AVG(null), + | COUNT(null), + | FIRST(null), + | LAST(null), + | MAX(null), + | MIN(null), + | SUM(null) + """.stripMargin), + Row(null, 0, null, null, null, null, null) :: Nil) + } + test("only do grouping") { checkAnswer( sqlContext.sql( @@ -266,13 +282,6 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be |SELECT avg(value) FROM agg1 """.stripMargin), Row(11.125) :: Nil) - - checkAnswer( - sqlContext.sql( - """ - |SELECT avg(null) - """.stripMargin), - Row(null) :: Nil) } test("udaf") { @@ -364,7 +373,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be | max(distinct value1) |FROM agg2 """.stripMargin), - Row(-60, 70.0, 101.0/9.0, 5.6, 100.0)) + Row(-60, 70.0, 101.0/9.0, 5.6, 100)) checkAnswer( sqlContext.sql( @@ -402,6 +411,23 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: Row(3, null, 3.0, null, null, null) :: Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | count(value1), + | count(*), + | count(1), + | count(DISTINCT value1), + | key + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(3, 3, 3, 2, 1) :: + Row(3, 4, 4, 2, 2) :: + Row(0, 2, 2, 0, 3) :: + Row(3, 4, 4, 3, null) :: Nil) } test("test count") { @@ -496,7 +522,8 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be |FROM agg1 |GROUP BY key """.stripMargin).queryExecution.executedPlan.collect { - case agg: aggregate.Aggregate => agg + case agg: aggregate.SortBasedAggregate => agg + case agg: aggregate.TungstenAggregate => agg } val message = "We should fallback to the old aggregation code path if " + @@ -537,3 +564,58 @@ class TungstenAggregationQuerySuite extends AggregationQuerySuite { sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) } } + +class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQuerySuite { + + var originalUnsafeEnabled: Boolean = _ + + override def beforeAll(): Unit = { + originalUnsafeEnabled = sqlContext.conf.unsafeEnabled + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true") + super.beforeAll() + } + + override def afterAll(): Unit = { + super.afterAll() + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) + sqlContext.conf.unsetConf("spark.sql.TungstenAggregate.testFallbackStartsAt") + } + + override protected def checkAnswer(actual: DataFrame, expectedAnswer: Seq[Row]): Unit = { + (0 to 2).foreach { fallbackStartsAt => + sqlContext.setConf( + "spark.sql.TungstenAggregate.testFallbackStartsAt", + fallbackStartsAt.toString) + + // Create a new df to make sure its physical operator picks up + // spark.sql.TungstenAggregate.testFallbackStartsAt. + val newActual = DataFrame(sqlContext, actual.logicalPlan) + + QueryTest.checkAnswer(newActual, expectedAnswer) match { + case Some(errorMessage) => + val newErrorMessage = + s""" + |The following aggregation query failed when using TungstenAggregate with + |controlled fallback (it falls back to sort-based aggregation once it has processed + |$fallbackStartsAt input rows). The query is + |${actual.queryExecution} + | + |$errorMessage + """.stripMargin + + fail(newErrorMessage) + case None => + } + } + } + + // Override it to make sure we call the actually overridden checkAnswer. + override protected def checkAnswer(df: DataFrame, expectedAnswer: Row): Unit = { + checkAnswer(df, Seq(expectedAnswer)) + } + + // Override it to make sure we call the actually overridden checkAnswer. + override protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = { + checkAnswer(df, expectedAnswer.collect()) + } +} From e234ea1b49d30bb6c8b8c001bd98c43de290dcff Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 6 Aug 2015 15:30:27 -0700 Subject: [PATCH 200/340] [SPARK-9645] [YARN] [CORE] Allow shuffle service to read shuffle files. Spark should not mess with the permissions of directories created by the cluster manager. Here, by setting the block manager dir permissions to 700, the shuffle service (running as the YARN user) wouldn't be able to serve shuffle files created by applications. Also, the code to protect the local app dir was missing in standalone's Worker; that has been now added. Since all processes run as the same user in standalone, `chmod 700` should not cause problems. Author: Marcelo Vanzin Closes #7966 from vanzin/SPARK-9645 and squashes the following commits: 6e07b31 [Marcelo Vanzin] Protect the app dir in standalone mode. 384ba6a [Marcelo Vanzin] [SPARK-9645] [yarn] [core] Allow shuffle service to read shuffle files. --- .../main/scala/org/apache/spark/deploy/worker/Worker.scala | 4 +++- .../scala/org/apache/spark/storage/DiskBlockManager.scala | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 6792d3310b06c..79b1536d94016 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -428,7 +428,9 @@ private[deploy] class Worker( // application finishes. val appLocalDirs = appDirectories.get(appId).getOrElse { Utils.getOrCreateLocalRootDirs(conf).map { dir => - Utils.createDirectory(dir, namePrefix = "executor").getAbsolutePath() + val appDir = Utils.createDirectory(dir, namePrefix = "executor") + Utils.chmod700(appDir) + appDir.getAbsolutePath() }.toSeq } appDirectories(appId) = appLocalDirs diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 5f537692a16c5..56a33d5ca7d60 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -133,7 +133,6 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon Utils.getConfiguredLocalDirs(conf).flatMap { rootDir => try { val localDir = Utils.createDirectory(rootDir, "blockmgr") - Utils.chmod700(localDir) logInfo(s"Created local directory at $localDir") Some(localDir) } catch { From 681e3024b6c2fcb54b42180d94d3ba3eed52a2d4 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 6 Aug 2015 23:43:52 +0100 Subject: [PATCH 201/340] [SPARK-9633] [BUILD] SBT download locations outdated; need an update Remove 2 defunct SBT download URLs and replace with the 1 known download URL. Also, use https. Follow up on https://github.com/apache/spark/pull/7792 Author: Sean Owen Closes #7956 from srowen/SPARK-9633 and squashes the following commits: caa40bd [Sean Owen] Remove 2 defunct SBT download URLs and replace with the 1 known download URL. Also, use https. --- build/sbt-launch-lib.bash | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/build/sbt-launch-lib.bash b/build/sbt-launch-lib.bash index 7930a38b9674a..615f848394650 100755 --- a/build/sbt-launch-lib.bash +++ b/build/sbt-launch-lib.bash @@ -38,8 +38,7 @@ dlog () { acquire_sbt_jar () { SBT_VERSION=`awk -F "=" '/sbt\.version/ {print $2}' ./project/build.properties` - URL1=http://typesafe.artifactoryonline.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar - URL2=http://repo.typesafe.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar + URL1=https://dl.bintray.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar JAR=build/sbt-launch-${SBT_VERSION}.jar sbt_jar=$JAR @@ -51,12 +50,10 @@ acquire_sbt_jar () { printf "Attempting to fetch sbt\n" JAR_DL="${JAR}.part" if [ $(command -v curl) ]; then - (curl --fail --location --silent ${URL1} > "${JAR_DL}" ||\ - (rm -f "${JAR_DL}" && curl --fail --location --silent ${URL2} > "${JAR_DL}")) &&\ + curl --fail --location --silent ${URL1} > "${JAR_DL}" &&\ mv "${JAR_DL}" "${JAR}" elif [ $(command -v wget) ]; then - (wget --quiet ${URL1} -O "${JAR_DL}" ||\ - (rm -f "${JAR_DL}" && wget --quiet ${URL2} -O "${JAR_DL}")) &&\ + wget --quiet ${URL1} -O "${JAR_DL}" &&\ mv "${JAR_DL}" "${JAR}" else printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n" From baf4587a569b49e39020c04c2785041bdd00789b Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 6 Aug 2015 17:03:14 -0700 Subject: [PATCH 202/340] [SPARK-9691] [SQL] PySpark SQL rand function treats seed 0 as no seed https://issues.apache.org/jira/browse/SPARK-9691 jkbradley rxin Author: Yin Huai Closes #7999 from yhuai/pythonRand and squashes the following commits: 4187e0c [Yin Huai] Regression test. a985ef9 [Yin Huai] Use "if seed is not None" instead "if seed" because "if seed" returns false when seed is 0. --- python/pyspark/sql/functions.py | 4 ++-- python/pyspark/sql/tests.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b5c6a01f18858..95f46044d324a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -268,7 +268,7 @@ def rand(seed=None): """Generates a random column with i.i.d. samples from U[0.0, 1.0]. """ sc = SparkContext._active_spark_context - if seed: + if seed is not None: jc = sc._jvm.functions.rand(seed) else: jc = sc._jvm.functions.rand() @@ -280,7 +280,7 @@ def randn(seed=None): """Generates a column with i.i.d. samples from the standard normal distribution. """ sc = SparkContext._active_spark_context - if seed: + if seed is not None: jc = sc._jvm.functions.randn(seed) else: jc = sc._jvm.functions.randn() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ebd3ea8db6a43..1e3444dd9e3b4 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -629,6 +629,16 @@ def test_rand_functions(self): for row in rndn: assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1] + # If the specified seed is 0, we should use it. + # https://issues.apache.org/jira/browse/SPARK-9691 + rnd1 = df.select('key', functions.rand(0)).collect() + rnd2 = df.select('key', functions.rand(0)).collect() + self.assertEqual(sorted(rnd1), sorted(rnd2)) + + rndn1 = df.select('key', functions.randn(0)).collect() + rndn2 = df.select('key', functions.randn(0)).collect() + self.assertEqual(sorted(rndn1), sorted(rndn2)) + def test_between_function(self): df = self.sc.parallelize([ Row(a=1, b=2, c=3), From 4e70e8256ce2f45b438642372329eac7b1e9e8cf Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 6 Aug 2015 17:30:31 -0700 Subject: [PATCH 203/340] [SPARK-9228] [SQL] use tungsten.enabled in public for both of codegen/unsafe spark.sql.tungsten.enabled will be the default value for both codegen and unsafe, they are kept internally for debug/testing. cc marmbrus rxin Author: Davies Liu Closes #7998 from davies/tungsten and squashes the following commits: c1c16da [Davies Liu] update doc 1a47be1 [Davies Liu] use tungsten.enabled for both of codegen/unsafe --- docs/sql-programming-guide.md | 6 +++--- .../scala/org/apache/spark/sql/SQLConf.scala | 20 ++++++++++++------- .../spark/sql/execution/SparkPlan.scala | 8 +++++++- .../spark/sql/execution/joins/HashJoin.scala | 3 ++- .../sql/execution/joins/HashOuterJoin.scala | 2 +- .../sql/execution/joins/HashSemiJoin.scala | 3 ++- 6 files changed, 28 insertions(+), 14 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3ea77e82422fb..6c317175d3278 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1884,11 +1884,11 @@ that these options will be deprecated in future release as more optimizations ar - spark.sql.codegen + spark.sql.tungsten.enabled true - When true, code will be dynamically generated at runtime for expression evaluation in a specific - query. For some queries with complicated expression this option can lead to significant speed-ups. + When true, use the optimized Tungsten physical execution backend which explicitly manages memory + and dynamically generates bytecode for expression evaluation. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index f836122b3e0e4..ef35c133d9cc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -223,14 +223,21 @@ private[spark] object SQLConf { defaultValue = Some(200), doc = "The default number of partitions to use when shuffling data for joins or aggregations.") - val CODEGEN_ENABLED = booleanConf("spark.sql.codegen", + val TUNGSTEN_ENABLED = booleanConf("spark.sql.tungsten.enabled", defaultValue = Some(true), + doc = "When true, use the optimized Tungsten physical execution backend which explicitly " + + "manages memory and dynamically generates bytecode for expression evaluation.") + + val CODEGEN_ENABLED = booleanConf("spark.sql.codegen", + defaultValue = Some(true), // use TUNGSTEN_ENABLED as default doc = "When true, code will be dynamically generated at runtime for expression evaluation in" + - " a specific query.") + " a specific query.", + isPublic = false) val UNSAFE_ENABLED = booleanConf("spark.sql.unsafe.enabled", - defaultValue = Some(true), - doc = "When true, use the new optimized Tungsten physical execution backend.") + defaultValue = Some(true), // use TUNGSTEN_ENABLED as default + doc = "When true, use the new optimized Tungsten physical execution backend.", + isPublic = false) val DIALECT = stringConf( "spark.sql.dialect", @@ -427,7 +434,6 @@ private[spark] object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ - private[sql] class SQLConf extends Serializable with CatalystConf { import SQLConf._ @@ -474,11 +480,11 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) - private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED) + private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, getConf(TUNGSTEN_ENABLED)) def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) - private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED) + private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, getConf(TUNGSTEN_ENABLED)) private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 2f29067f5646a..3fff79cd1b281 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -55,12 +55,18 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ protected def sparkContext = sqlContext.sparkContext // sqlContext will be null when we are being deserialized on the slaves. In this instance - // the value of codegenEnabled will be set by the desserializer after the constructor has run. + // the value of codegenEnabled/unsafeEnabled will be set by the desserializer after the + // constructor has run. val codegenEnabled: Boolean = if (sqlContext != null) { sqlContext.conf.codegenEnabled } else { false } + val unsafeEnabled: Boolean = if (sqlContext != null) { + sqlContext.conf.unsafeEnabled + } else { + false + } /** * Whether the "prepare" method is called. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 5e9cd9fd2345a..22d46d1c3e3b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -44,7 +44,8 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output protected[this] def isUnsafeMode: Boolean = { - (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys) + (self.codegenEnabled && self.unsafeEnabled + && UnsafeProjection.canSupport(buildKeys) && UnsafeProjection.canSupport(self.schema)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 346337e64245c..701bd3cd86372 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -67,7 +67,7 @@ trait HashOuterJoin { } protected[this] def isUnsafeMode: Boolean = { - (self.codegenEnabled && joinType != FullOuter + (self.codegenEnabled && self.unsafeEnabled && joinType != FullOuter && UnsafeProjection.canSupport(buildKeys) && UnsafeProjection.canSupport(self.schema)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 47a7d370f5415..82dd6eb7e7ed0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -33,7 +33,8 @@ trait HashSemiJoin { override def output: Seq[Attribute] = left.output protected[this] def supportUnsafe: Boolean = { - (self.codegenEnabled && UnsafeProjection.canSupport(leftKeys) + (self.codegenEnabled && self.unsafeEnabled + && UnsafeProjection.canSupport(leftKeys) && UnsafeProjection.canSupport(rightKeys) && UnsafeProjection.canSupport(left.schema) && UnsafeProjection.canSupport(right.schema)) From 0867b23c74a3e6347d718b67ddabff17b468eded Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 6 Aug 2015 17:31:16 -0700 Subject: [PATCH 204/340] [SPARK-9650][SQL] Fix quoting behavior on interpolated column names Make sure that `$"column"` is consistent with other methods with respect to backticks. Adds a bunch of tests for various ways of constructing columns. Author: Michael Armbrust Closes #7969 from marmbrus/namesWithDots and squashes the following commits: 53ef3d7 [Michael Armbrust] [SPARK-9650][SQL] Fix quoting behavior on interpolated column names 2bf7a92 [Michael Armbrust] WIP --- .../sql/catalyst/analysis/unresolved.scala | 57 ++++++++++++++++ .../catalyst/plans/logical/LogicalPlan.scala | 42 +----------- .../scala/org/apache/spark/sql/Column.scala | 2 +- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../spark/sql/ColumnExpressionSuite.scala | 68 +++++++++++++++++++ 5 files changed, 128 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 03da45b09f928..43ee3191935eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.errors import org.apache.spark.sql.catalyst.expressions._ @@ -69,8 +70,64 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un } object UnresolvedAttribute { + /** + * Creates an [[UnresolvedAttribute]], parsing segments separated by dots ('.'). + */ def apply(name: String): UnresolvedAttribute = new UnresolvedAttribute(name.split("\\.")) + + /** + * Creates an [[UnresolvedAttribute]], from a single quoted string (for example using backticks in + * HiveQL. Since the string is consider quoted, no processing is done on the name. + */ def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name)) + + /** + * Creates an [[UnresolvedAttribute]] from a string in an embedded language. In this case + * we treat it as a quoted identifier, except for '.', which must be further quoted using + * backticks if it is part of a column name. + */ + def quotedString(name: String): UnresolvedAttribute = + new UnresolvedAttribute(parseAttributeName(name)) + + /** + * Used to split attribute name by dot with backticks rule. + * Backticks must appear in pairs, and the quoted string must be a complete name part, + * which means `ab..c`e.f is not allowed. + * Escape character is not supported now, so we can't use backtick inside name part. + */ + def parseAttributeName(name: String): Seq[String] = { + def e = new AnalysisException(s"syntax error in attribute name: $name") + val nameParts = scala.collection.mutable.ArrayBuffer.empty[String] + val tmp = scala.collection.mutable.ArrayBuffer.empty[Char] + var inBacktick = false + var i = 0 + while (i < name.length) { + val char = name(i) + if (inBacktick) { + if (char == '`') { + inBacktick = false + if (i + 1 < name.length && name(i + 1) != '.') throw e + } else { + tmp += char + } + } else { + if (char == '`') { + if (tmp.nonEmpty) throw e + inBacktick = true + } else if (char == '.') { + if (name(i - 1) == '.' || i == name.length - 1) throw e + nameParts += tmp.mkString + tmp.clear() + } else { + tmp += char + } + } + i += 1 + } + if (inBacktick) throw e + nameParts += tmp.mkString + nameParts.toSeq + } } case class UnresolvedFunction( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 9b52f020093f0..c290e6acb361c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -179,47 +179,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { def resolveQuoted( name: String, resolver: Resolver): Option[NamedExpression] = { - resolve(parseAttributeName(name), output, resolver) - } - - /** - * Internal method, used to split attribute name by dot with backticks rule. - * Backticks must appear in pairs, and the quoted string must be a complete name part, - * which means `ab..c`e.f is not allowed. - * Escape character is not supported now, so we can't use backtick inside name part. - */ - private def parseAttributeName(name: String): Seq[String] = { - val e = new AnalysisException(s"syntax error in attribute name: $name") - val nameParts = scala.collection.mutable.ArrayBuffer.empty[String] - val tmp = scala.collection.mutable.ArrayBuffer.empty[Char] - var inBacktick = false - var i = 0 - while (i < name.length) { - val char = name(i) - if (inBacktick) { - if (char == '`') { - inBacktick = false - if (i + 1 < name.length && name(i + 1) != '.') throw e - } else { - tmp += char - } - } else { - if (char == '`') { - if (tmp.nonEmpty) throw e - inBacktick = true - } else if (char == '.') { - if (name(i - 1) == '.' || i == name.length - 1) throw e - nameParts += tmp.mkString - tmp.clear() - } else { - tmp += char - } - } - i += 1 - } - if (inBacktick) throw e - nameParts += tmp.mkString - nameParts.toSeq + resolve(UnresolvedAttribute.parseAttributeName(name), output, resolver) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 75365fbcec757..27bd084847346 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -54,7 +54,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { def this(name: String) = this(name match { case "*" => UnresolvedStar(None) case _ if name.endsWith(".*") => UnresolvedStar(Some(name.substring(0, name.length - 2))) - case _ => UnresolvedAttribute(name) + case _ => UnresolvedAttribute.quotedString(name) }) /** Creates a column based on the given expression. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 6f8ffb54402a7..075c0ea2544b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -343,7 +343,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ implicit class StringToColumn(val sc: StringContext) { def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args : _*)) + new ColumnName(sc.s(args: _*)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index e1b3443d74993..6a09a3b72c081 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -32,6 +32,74 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { override def sqlContext(): SQLContext = ctx + test("column names with space") { + val df = Seq((1, "a")).toDF("name with space", "name.with.dot") + + checkAnswer( + df.select(df("name with space")), + Row(1) :: Nil) + + checkAnswer( + df.select($"name with space"), + Row(1) :: Nil) + + checkAnswer( + df.select(col("name with space")), + Row(1) :: Nil) + + checkAnswer( + df.select("name with space"), + Row(1) :: Nil) + + checkAnswer( + df.select(expr("`name with space`")), + Row(1) :: Nil) + } + + test("column names with dot") { + val df = Seq((1, "a")).toDF("name with space", "name.with.dot").as("a") + + checkAnswer( + df.select(df("`name.with.dot`")), + Row("a") :: Nil) + + checkAnswer( + df.select($"`name.with.dot`"), + Row("a") :: Nil) + + checkAnswer( + df.select(col("`name.with.dot`")), + Row("a") :: Nil) + + checkAnswer( + df.select("`name.with.dot`"), + Row("a") :: Nil) + + checkAnswer( + df.select(expr("`name.with.dot`")), + Row("a") :: Nil) + + checkAnswer( + df.select(df("a.`name.with.dot`")), + Row("a") :: Nil) + + checkAnswer( + df.select($"a.`name.with.dot`"), + Row("a") :: Nil) + + checkAnswer( + df.select(col("a.`name.with.dot`")), + Row("a") :: Nil) + + checkAnswer( + df.select("a.`name.with.dot`"), + Row("a") :: Nil) + + checkAnswer( + df.select(expr("a.`name.with.dot`")), + Row("a") :: Nil) + } + test("alias") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") assert(df.select(df("a").as("b")).columns.head === "b") From 49b1504fe3733eb36a7fc6317ec19aeba5d46f97 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 6 Aug 2015 17:36:12 -0700 Subject: [PATCH 205/340] Revert "[SPARK-9228] [SQL] use tungsten.enabled in public for both of codegen/unsafe" This reverts commit 4e70e8256ce2f45b438642372329eac7b1e9e8cf. --- docs/sql-programming-guide.md | 6 +++--- .../scala/org/apache/spark/sql/SQLConf.scala | 20 +++++++------------ .../spark/sql/execution/SparkPlan.scala | 8 +------- .../spark/sql/execution/joins/HashJoin.scala | 3 +-- .../sql/execution/joins/HashOuterJoin.scala | 2 +- .../sql/execution/joins/HashSemiJoin.scala | 3 +-- 6 files changed, 14 insertions(+), 28 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 6c317175d3278..3ea77e82422fb 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1884,11 +1884,11 @@ that these options will be deprecated in future release as more optimizations ar - spark.sql.tungsten.enabled + spark.sql.codegen true - When true, use the optimized Tungsten physical execution backend which explicitly manages memory - and dynamically generates bytecode for expression evaluation. + When true, code will be dynamically generated at runtime for expression evaluation in a specific + query. For some queries with complicated expression this option can lead to significant speed-ups. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index ef35c133d9cc3..f836122b3e0e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -223,21 +223,14 @@ private[spark] object SQLConf { defaultValue = Some(200), doc = "The default number of partitions to use when shuffling data for joins or aggregations.") - val TUNGSTEN_ENABLED = booleanConf("spark.sql.tungsten.enabled", - defaultValue = Some(true), - doc = "When true, use the optimized Tungsten physical execution backend which explicitly " + - "manages memory and dynamically generates bytecode for expression evaluation.") - val CODEGEN_ENABLED = booleanConf("spark.sql.codegen", - defaultValue = Some(true), // use TUNGSTEN_ENABLED as default + defaultValue = Some(true), doc = "When true, code will be dynamically generated at runtime for expression evaluation in" + - " a specific query.", - isPublic = false) + " a specific query.") val UNSAFE_ENABLED = booleanConf("spark.sql.unsafe.enabled", - defaultValue = Some(true), // use TUNGSTEN_ENABLED as default - doc = "When true, use the new optimized Tungsten physical execution backend.", - isPublic = false) + defaultValue = Some(true), + doc = "When true, use the new optimized Tungsten physical execution backend.") val DIALECT = stringConf( "spark.sql.dialect", @@ -434,6 +427,7 @@ private[spark] object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ + private[sql] class SQLConf extends Serializable with CatalystConf { import SQLConf._ @@ -480,11 +474,11 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) - private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, getConf(TUNGSTEN_ENABLED)) + private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED) def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) - private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, getConf(TUNGSTEN_ENABLED)) + private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED) private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 3fff79cd1b281..2f29067f5646a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -55,18 +55,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ protected def sparkContext = sqlContext.sparkContext // sqlContext will be null when we are being deserialized on the slaves. In this instance - // the value of codegenEnabled/unsafeEnabled will be set by the desserializer after the - // constructor has run. + // the value of codegenEnabled will be set by the desserializer after the constructor has run. val codegenEnabled: Boolean = if (sqlContext != null) { sqlContext.conf.codegenEnabled } else { false } - val unsafeEnabled: Boolean = if (sqlContext != null) { - sqlContext.conf.unsafeEnabled - } else { - false - } /** * Whether the "prepare" method is called. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 22d46d1c3e3b7..5e9cd9fd2345a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -44,8 +44,7 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output protected[this] def isUnsafeMode: Boolean = { - (self.codegenEnabled && self.unsafeEnabled - && UnsafeProjection.canSupport(buildKeys) + (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys) && UnsafeProjection.canSupport(self.schema)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 701bd3cd86372..346337e64245c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -67,7 +67,7 @@ trait HashOuterJoin { } protected[this] def isUnsafeMode: Boolean = { - (self.codegenEnabled && self.unsafeEnabled && joinType != FullOuter + (self.codegenEnabled && joinType != FullOuter && UnsafeProjection.canSupport(buildKeys) && UnsafeProjection.canSupport(self.schema)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 82dd6eb7e7ed0..47a7d370f5415 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -33,8 +33,7 @@ trait HashSemiJoin { override def output: Seq[Attribute] = left.output protected[this] def supportUnsafe: Boolean = { - (self.codegenEnabled && self.unsafeEnabled - && UnsafeProjection.canSupport(leftKeys) + (self.codegenEnabled && UnsafeProjection.canSupport(leftKeys) && UnsafeProjection.canSupport(rightKeys) && UnsafeProjection.canSupport(left.schema) && UnsafeProjection.canSupport(right.schema)) From b87825310ac87485672868bf6a9ed01d154a3626 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 6 Aug 2015 18:25:38 -0700 Subject: [PATCH 206/340] [SPARK-9692] Remove SqlNewHadoopRDD's generated Tuple2 and InterruptibleIterator. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A small performance optimization – we don't need to generate a Tuple2 and then immediately discard the key. We also don't need an extra wrapper from InterruptibleIterator. Author: Reynold Xin Closes #8000 from rxin/SPARK-9692 and squashes the following commits: 1d4d0b3 [Reynold Xin] [SPARK-9692] Remove SqlNewHadoopRDD's generated Tuple2 and InterruptibleIterator. --- .../apache/spark/rdd/SqlNewHadoopRDD.scala | 44 +++++++------------ .../spark/sql/parquet/ParquetRelation.scala | 3 +- 2 files changed, 18 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index 35e44cb59c1be..6a95e44c57fec 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -26,14 +26,12 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.{Partition => SparkPartition, _} -import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -60,18 +58,16 @@ private[spark] class SqlNewHadoopPartition( * and the executor side to the shared Hadoop Configuration. * * Note: This is RDD is basically a cloned version of [[org.apache.spark.rdd.NewHadoopRDD]] with - * changes based on [[org.apache.spark.rdd.HadoopRDD]]. In future, this functionality will be - * folded into core. + * changes based on [[org.apache.spark.rdd.HadoopRDD]]. */ -private[spark] class SqlNewHadoopRDD[K, V]( +private[spark] class SqlNewHadoopRDD[V: ClassTag]( @transient sc : SparkContext, broadcastedConf: Broadcast[SerializableConfiguration], @transient initDriverSideJobFuncOpt: Option[Job => Unit], initLocalJobFuncOpt: Option[Job => Unit], - inputFormatClass: Class[_ <: InputFormat[K, V]], - keyClass: Class[K], + inputFormatClass: Class[_ <: InputFormat[Void, V]], valueClass: Class[V]) - extends RDD[(K, V)](sc, Nil) + extends RDD[V](sc, Nil) with SparkHadoopMapReduceUtil with Logging { @@ -120,8 +116,8 @@ private[spark] class SqlNewHadoopRDD[K, V]( override def compute( theSplit: SparkPartition, - context: TaskContext): InterruptibleIterator[(K, V)] = { - val iter = new Iterator[(K, V)] { + context: TaskContext): Iterator[V] = { + val iter = new Iterator[V] { val split = theSplit.asInstanceOf[SqlNewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) val conf = getConf(isDriverSide = false) @@ -154,17 +150,20 @@ private[spark] class SqlNewHadoopRDD[K, V]( configurable.setConf(conf) case _ => } - private var reader = format.createRecordReader( + private[this] var reader = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener(context => close()) - var havePair = false - var finished = false - var recordsSinceMetricsUpdate = 0 + + private[this] var havePair = false + private[this] var finished = false override def hasNext: Boolean = { + if (context.isInterrupted) { + throw new TaskKilledException + } if (!finished && !havePair) { finished = !reader.nextKeyValue if (finished) { @@ -178,7 +177,7 @@ private[spark] class SqlNewHadoopRDD[K, V]( !finished } - override def next(): (K, V) = { + override def next(): V = { if (!hasNext) { throw new java.util.NoSuchElementException("End of stream") } @@ -186,7 +185,7 @@ private[spark] class SqlNewHadoopRDD[K, V]( if (!finished) { inputMetrics.incRecordsRead(1) } - (reader.getCurrentKey, reader.getCurrentValue) + reader.getCurrentValue } private def close() { @@ -212,23 +211,14 @@ private[spark] class SqlNewHadoopRDD[K, V]( } } } catch { - case e: Exception => { + case e: Exception => if (!Utils.inShutdown()) { logWarning("Exception in RecordReader.close()", e) } - } } } } - new InterruptibleIterator(context, iter) - } - - /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */ - @DeveloperApi - def mapPartitionsWithInputSplit[U: ClassTag]( - f: (InputSplit, Iterator[(K, V)]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = { - new NewHadoopMapPartitionsWithSplitRDD(this, f, preservesPartitioning) + iter } override def getPreferredLocations(hsplit: SparkPartition): Seq[String] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index b4337a48dbd80..29c388c22ef93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -291,7 +291,6 @@ private[sql] class ParquetRelation( initDriverSideJobFuncOpt = Some(setInputPaths), initLocalJobFuncOpt = Some(initLocalJobFuncOpt), inputFormatClass = classOf[ParquetInputFormat[InternalRow]], - keyClass = classOf[Void], valueClass = classOf[InternalRow]) { val cacheMetadata = useMetadataCache @@ -328,7 +327,7 @@ private[sql] class ParquetRelation( new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) } } - }.values.asInstanceOf[RDD[Row]] // type erasure hack to pass RDD[InternalRow] as RDD[Row] + }.asInstanceOf[RDD[Row]] // type erasure hack to pass RDD[InternalRow] as RDD[Row] } } From 014a9f9d8c9521180f7a448cc7cc96cc00537d5c Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 6 Aug 2015 19:04:57 -0700 Subject: [PATCH 207/340] [SPARK-9709] [SQL] Avoid starving unsafe operators that use sort The issue is that a task may run multiple sorts, and the sorts run by the child operator (i.e. parent RDD) may acquire all available memory such that other sorts in the same task do not have enough to proceed. This manifests itself in an `IOException("Unable to acquire X bytes of memory")` thrown by `UnsafeExternalSorter`. The solution is to reserve a page in each sorter in the chain before computing the child operator's (parent RDD's) partitions. This requires us to use a new special RDD that does some preparation before computing the parent's partitions. Author: Andrew Or Closes #8011 from andrewor14/unsafe-starve-memory and squashes the following commits: 35b69a4 [Andrew Or] Simplify test 0b07782 [Andrew Or] Minor: update comments 5d5afdf [Andrew Or] Merge branch 'master' of github.com:apache/spark into unsafe-starve-memory 254032e [Andrew Or] Add tests 234acbd [Andrew Or] Reserve a page in sorter when preparing each partition b889e08 [Andrew Or] MapPartitionsWithPreparationRDD --- .../unsafe/sort/UnsafeExternalSorter.java | 43 ++++++++----- .../apache/spark/rdd/MapPartitionsRDD.scala | 3 + .../rdd/MapPartitionsWithPreparationRDD.scala | 49 +++++++++++++++ .../spark/shuffle/ShuffleMemoryManager.scala | 2 +- .../sort/UnsafeExternalSorterSuite.java | 19 +++++- ...MapPartitionsWithPreparationRDDSuite.scala | 60 +++++++++++++++++++ .../spark/sql/execution/SparkPlan.scala | 2 +- .../org/apache/spark/sql/execution/sort.scala | 28 +++++++-- 8 files changed, 184 insertions(+), 22 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala create mode 100644 core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 8f78fc5a41629..4c54ba4bce408 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -138,6 +138,11 @@ private UnsafeExternalSorter( this.inMemSorter = existingInMemorySorter; } + // Acquire a new page as soon as we construct the sorter to ensure that we have at + // least one page to work with. Otherwise, other operators in the same task may starve + // this sorter (SPARK-9709). + acquireNewPage(); + // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at // the end of the task. This is necessary to avoid memory leaks in when the downstream operator // does not fully consume the sorter's output (e.g. sort followed by limit). @@ -343,22 +348,32 @@ private void acquireNewPageIfNecessary(int requiredSpace) throws IOException { throw new IOException("Required space " + requiredSpace + " is greater than page size (" + pageSizeBytes + ")"); } else { - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes); - if (memoryAcquired < pageSizeBytes) { - shuffleMemoryManager.release(memoryAcquired); - spill(); - final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes); - if (memoryAcquiredAfterSpilling != pageSizeBytes) { - shuffleMemoryManager.release(memoryAcquiredAfterSpilling); - throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); - } - } - currentPage = taskMemoryManager.allocatePage(pageSizeBytes); - currentPagePosition = currentPage.getBaseOffset(); - freeSpaceInCurrentPage = pageSizeBytes; - allocatedPages.add(currentPage); + acquireNewPage(); + } + } + } + + /** + * Acquire a new page from the {@link ShuffleMemoryManager}. + * + * If there is not enough space to allocate the new page, spill all existing ones + * and try again. If there is still not enough space, report error to the caller. + */ + private void acquireNewPage() throws IOException { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquired < pageSizeBytes) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquiredAfterSpilling != pageSizeBytes) { + shuffleMemoryManager.release(memoryAcquiredAfterSpilling); + throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); } } + currentPage = taskMemoryManager.allocatePage(pageSizeBytes); + currentPagePosition = currentPage.getBaseOffset(); + freeSpaceInCurrentPage = pageSizeBytes; + allocatedPages.add(currentPage); } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index a838aac6e8d1a..4312d3a417759 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -21,6 +21,9 @@ import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} +/** + * An RDD that applies the provided function to every partition of the parent RDD. + */ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( prev: RDD[T], f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator) diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala new file mode 100644 index 0000000000000..b475bd8d79f85 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag + +import org.apache.spark.{Partition, Partitioner, TaskContext} + +/** + * An RDD that applies a user provided function to every partition of the parent RDD, and + * additionally allows the user to prepare each partition before computing the parent partition. + */ +private[spark] class MapPartitionsWithPreparationRDD[U: ClassTag, T: ClassTag, M: ClassTag]( + prev: RDD[T], + preparePartition: () => M, + executePartition: (TaskContext, Int, M, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false) + extends RDD[U](prev) { + + override val partitioner: Option[Partitioner] = { + if (preservesPartitioning) firstParent[T].partitioner else None + } + + override def getPartitions: Array[Partition] = firstParent[T].partitions + + /** + * Prepare a partition before computing it from its parent. + */ + override def compute(partition: Partition, context: TaskContext): Iterator[U] = { + val preparedArgument = preparePartition() + val parentIterator = firstParent[T].iterator(partition, context) + executePartition(context, partition.index, preparedArgument, parentIterator) + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala index 00c1e078a441c..e3d229cc99821 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -124,7 +124,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { } } -private object ShuffleMemoryManager { +private[spark] object ShuffleMemoryManager { /** * Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction * of the memory pool and a safety factor since collections can sometimes grow bigger than diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 117745f9a9c00..f5300373d87ea 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -340,7 +340,8 @@ public void testPeakMemoryUsed() throws Exception { for (int i = 0; i < numRecordsPerPage * 10; i++) { insertNumber(sorter, i); newPeakMemory = sorter.getPeakMemoryUsedBytes(); - if (i % numRecordsPerPage == 0) { + // The first page is pre-allocated on instantiation + if (i % numRecordsPerPage == 0 && i > 0) { // We allocated a new page for this record, so peak memory should change assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); } else { @@ -364,5 +365,21 @@ public void testPeakMemoryUsed() throws Exception { } } + @Test + public void testReservePageOnInstantiation() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); + try { + assertEquals(1, sorter.getNumberOfAllocatedPages()); + // Inserting a new record doesn't allocate more memory since we already have a page + long peakMemory = sorter.getPeakMemoryUsedBytes(); + insertNumber(sorter, 100); + assertEquals(peakMemory, sorter.getPeakMemoryUsedBytes()); + assertEquals(1, sorter.getNumberOfAllocatedPages()); + } finally { + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); + } + } + } diff --git a/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala new file mode 100644 index 0000000000000..c16930e7d6491 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.collection.mutable + +import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite, TaskContext} + +class MapPartitionsWithPreparationRDDSuite extends SparkFunSuite with LocalSparkContext { + + test("prepare called before parent partition is computed") { + sc = new SparkContext("local", "test") + + // Have the parent partition push a number to the list + val parent = sc.parallelize(1 to 100, 1).mapPartitions { iter => + TestObject.things.append(20) + iter + } + + // Push a different number during the prepare phase + val preparePartition = () => { TestObject.things.append(10) } + + // Push yet another number during the execution phase + val executePartition = ( + taskContext: TaskContext, + partitionIndex: Int, + notUsed: Unit, + parentIterator: Iterator[Int]) => { + TestObject.things.append(30) + TestObject.things.iterator + } + + // Verify that the numbers are pushed in the order expected + val result = { + new MapPartitionsWithPreparationRDD[Int, Int, Unit]( + parent, preparePartition, executePartition).collect() + } + assert(result === Array(10, 20, 30)) + } + +} + +private object TestObject { + val things = new mutable.ListBuffer[Int] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 2f29067f5646a..490428965a61d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -158,7 +158,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ */ final def prepare(): Unit = { if (prepareCalled.compareAndSet(false, true)) { - doPrepare + doPrepare() children.foreach(_.prepare()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index 3192b6ebe9075..7f69cdb08aa78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.{InternalAccumulator, TaskContext} -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ @@ -123,7 +123,12 @@ case class TungstenSort( val schema = child.schema val childOutput = child.output val pageSize = sparkContext.conf.getSizeAsBytes("spark.buffer.pageSize", "64m") - child.execute().mapPartitions({ iter => + + /** + * Set up the sorter in each partition before computing the parent partition. + * This makes sure our sorter is not starved by other sorters used in the same task. + */ + def preparePartition(): UnsafeExternalRowSorter = { val ordering = newOrdering(sortOrder, childOutput) // The comparator for comparing prefix @@ -143,12 +148,25 @@ case class TungstenSort( if (testSpillFrequency > 0) { sorter.setTestSpillFrequency(testSpillFrequency) } - val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) - val taskContext = TaskContext.get() + sorter + } + + /** Compute a partition using the sorter already set up previously. */ + def executePartition( + taskContext: TaskContext, + partitionIndex: Int, + sorter: UnsafeExternalRowSorter, + parentIterator: Iterator[InternalRow]): Iterator[InternalRow] = { + val sortedIterator = sorter.sort(parentIterator.asInstanceOf[Iterator[UnsafeRow]]) taskContext.internalMetricsToAccumulators( InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage) sortedIterator - }, preservesPartitioning = true) + } + + // Note: we need to set up the external sorter in each partition before computing + // the parent partition, so we cannot simply use `mapPartitions` here (SPARK-9709). + new MapPartitionsWithPreparationRDD[InternalRow, InternalRow, UnsafeExternalRowSorter]( + child.execute(), preparePartition, executePartition, preservesPartitioning = true) } } From 17284db314f52bdb2065482b8a49656f7683d30a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 6 Aug 2015 17:30:31 -0700 Subject: [PATCH 208/340] [SPARK-9228] [SQL] use tungsten.enabled in public for both of codegen/unsafe spark.sql.tungsten.enabled will be the default value for both codegen and unsafe, they are kept internally for debug/testing. cc marmbrus rxin Author: Davies Liu Closes #7998 from davies/tungsten and squashes the following commits: c1c16da [Davies Liu] update doc 1a47be1 [Davies Liu] use tungsten.enabled for both of codegen/unsafe (cherry picked from commit 4e70e8256ce2f45b438642372329eac7b1e9e8cf) Signed-off-by: Reynold Xin --- docs/sql-programming-guide.md | 6 +++--- .../scala/org/apache/spark/sql/SQLConf.scala | 20 ++++++++++++------- .../spark/sql/execution/SparkPlan.scala | 8 +++++++- .../spark/sql/execution/joins/HashJoin.scala | 3 ++- .../sql/execution/joins/HashOuterJoin.scala | 2 +- .../sql/execution/joins/HashSemiJoin.scala | 3 ++- 6 files changed, 28 insertions(+), 14 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3ea77e82422fb..6c317175d3278 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1884,11 +1884,11 @@ that these options will be deprecated in future release as more optimizations ar - spark.sql.codegen + spark.sql.tungsten.enabled true - When true, code will be dynamically generated at runtime for expression evaluation in a specific - query. For some queries with complicated expression this option can lead to significant speed-ups. + When true, use the optimized Tungsten physical execution backend which explicitly manages memory + and dynamically generates bytecode for expression evaluation. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index f836122b3e0e4..ef35c133d9cc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -223,14 +223,21 @@ private[spark] object SQLConf { defaultValue = Some(200), doc = "The default number of partitions to use when shuffling data for joins or aggregations.") - val CODEGEN_ENABLED = booleanConf("spark.sql.codegen", + val TUNGSTEN_ENABLED = booleanConf("spark.sql.tungsten.enabled", defaultValue = Some(true), + doc = "When true, use the optimized Tungsten physical execution backend which explicitly " + + "manages memory and dynamically generates bytecode for expression evaluation.") + + val CODEGEN_ENABLED = booleanConf("spark.sql.codegen", + defaultValue = Some(true), // use TUNGSTEN_ENABLED as default doc = "When true, code will be dynamically generated at runtime for expression evaluation in" + - " a specific query.") + " a specific query.", + isPublic = false) val UNSAFE_ENABLED = booleanConf("spark.sql.unsafe.enabled", - defaultValue = Some(true), - doc = "When true, use the new optimized Tungsten physical execution backend.") + defaultValue = Some(true), // use TUNGSTEN_ENABLED as default + doc = "When true, use the new optimized Tungsten physical execution backend.", + isPublic = false) val DIALECT = stringConf( "spark.sql.dialect", @@ -427,7 +434,6 @@ private[spark] object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ - private[sql] class SQLConf extends Serializable with CatalystConf { import SQLConf._ @@ -474,11 +480,11 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) - private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED) + private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, getConf(TUNGSTEN_ENABLED)) def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) - private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED) + private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, getConf(TUNGSTEN_ENABLED)) private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 490428965a61d..719ad432e2fe0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -55,12 +55,18 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ protected def sparkContext = sqlContext.sparkContext // sqlContext will be null when we are being deserialized on the slaves. In this instance - // the value of codegenEnabled will be set by the desserializer after the constructor has run. + // the value of codegenEnabled/unsafeEnabled will be set by the desserializer after the + // constructor has run. val codegenEnabled: Boolean = if (sqlContext != null) { sqlContext.conf.codegenEnabled } else { false } + val unsafeEnabled: Boolean = if (sqlContext != null) { + sqlContext.conf.unsafeEnabled + } else { + false + } /** * Whether the "prepare" method is called. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 5e9cd9fd2345a..22d46d1c3e3b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -44,7 +44,8 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output protected[this] def isUnsafeMode: Boolean = { - (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys) + (self.codegenEnabled && self.unsafeEnabled + && UnsafeProjection.canSupport(buildKeys) && UnsafeProjection.canSupport(self.schema)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 346337e64245c..701bd3cd86372 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -67,7 +67,7 @@ trait HashOuterJoin { } protected[this] def isUnsafeMode: Boolean = { - (self.codegenEnabled && joinType != FullOuter + (self.codegenEnabled && self.unsafeEnabled && joinType != FullOuter && UnsafeProjection.canSupport(buildKeys) && UnsafeProjection.canSupport(self.schema)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 47a7d370f5415..82dd6eb7e7ed0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -33,7 +33,8 @@ trait HashSemiJoin { override def output: Seq[Attribute] = left.output protected[this] def supportUnsafe: Boolean = { - (self.codegenEnabled && UnsafeProjection.canSupport(leftKeys) + (self.codegenEnabled && self.unsafeEnabled + && UnsafeProjection.canSupport(leftKeys) && UnsafeProjection.canSupport(rightKeys) && UnsafeProjection.canSupport(left.schema) && UnsafeProjection.canSupport(right.schema)) From fe12277b40082585e40e1bdf6aa2ebcfe80ed83f Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Thu, 6 Aug 2015 21:03:47 -0700 Subject: [PATCH 209/340] Fix doc typo Straightforward fix on doc typo Author: Jeff Zhang Closes #8019 from zjffdu/master and squashes the following commits: aed6e64 [Jeff Zhang] Fix doc typo --- docs/tuning.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tuning.md b/docs/tuning.md index 572c7270e4999..6936912a6be54 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -240,7 +240,7 @@ worth optimizing. ## Data Locality Data locality can have a major impact on the performance of Spark jobs. If data and the code that -operates on it are together than computation tends to be fast. But if code and data are separated, +operates on it are together then computation tends to be fast. But if code and data are separated, one must move to the other. Typically it is faster to ship serialized code from place to place than a chunk of data because code size is much smaller than data. Spark builds its scheduling around this general principle of data locality. From 672f467668da1cf20895ee57652489c306120288 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 6 Aug 2015 21:42:42 -0700 Subject: [PATCH 210/340] [SPARK-8057][Core]Call TaskAttemptContext.getTaskAttemptID using Reflection Someone may use the Spark core jar in the maven repo with hadoop 1. SPARK-2075 has already resolved the compatibility issue to support it. But `SparkHadoopMapRedUtil.commitTask` broke it recently. This PR uses Reflection to call `TaskAttemptContext.getTaskAttemptID` to fix the compatibility issue. Author: zsxwing Closes #6599 from zsxwing/SPARK-8057 and squashes the following commits: f7a343c [zsxwing] Remove the redundant import 6b7f1af [zsxwing] Call TaskAttemptContext.getTaskAttemptID using Reflection --- .../org/apache/spark/deploy/SparkHadoopUtil.scala | 14 ++++++++++++++ .../spark/mapred/SparkHadoopMapRedUtil.scala | 3 ++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index e06b06e06fb4a..7e9dba42bebd8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -34,6 +34,8 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce.JobContext +import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext} +import org.apache.hadoop.mapreduce.{TaskAttemptID => MapReduceTaskAttemptID} import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.spark.annotation.DeveloperApi @@ -194,6 +196,18 @@ class SparkHadoopUtil extends Logging { method.invoke(context).asInstanceOf[Configuration] } + /** + * Using reflection to call `getTaskAttemptID` from TaskAttemptContext. If we directly + * call `TaskAttemptContext.getTaskAttemptID`, it will generate different byte codes + * for Hadoop 1.+ and Hadoop 2.+ because TaskAttemptContext is class in Hadoop 1.+ + * while it's interface in Hadoop 2.+. + */ + def getTaskAttemptIDFromTaskAttemptContext( + context: MapReduceTaskAttemptContext): MapReduceTaskAttemptID = { + val method = context.getClass.getMethod("getTaskAttemptID") + method.invoke(context).asInstanceOf[MapReduceTaskAttemptID] + } + /** * Get [[FileStatus]] objects for all leaf children (files) under the given base path. If the * given path points to a file, return a single-element collection containing [[FileStatus]] of diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 87df42748be44..f405b732e4725 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.mapred._ import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext} import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.CommitDeniedException import org.apache.spark.{Logging, SparkEnv, TaskContext} import org.apache.spark.util.{Utils => SparkUtils} @@ -93,7 +94,7 @@ object SparkHadoopMapRedUtil extends Logging { splitId: Int, attemptId: Int): Unit = { - val mrTaskAttemptID = mrTaskContext.getTaskAttemptID + val mrTaskAttemptID = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(mrTaskContext) // Called after we have decided to commit def performCommit(): Unit = { From f0cda587fb80bf2f1ba53d35dc9dc87bf72ee338 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 6 Aug 2015 22:49:01 -0700 Subject: [PATCH 211/340] [SPARK-7550] [SQL] [MINOR] Fixes logs when persisting DataFrames Author: Cheng Lian Closes #8021 from liancheng/spark-7550/fix-logs and squashes the following commits: b7bd0ed [Cheng Lian] Fixes logs --- .../org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 1523ebe9d5493..7198a32df4a02 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -317,19 +317,17 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive case (Some(serde), relation: HadoopFsRelation) if relation.partitionColumns.nonEmpty => logWarning { - val paths = relation.paths.mkString(", ") "Persisting partitioned data source relation into Hive metastore in " + s"Spark SQL specific format, which is NOT compatible with Hive. Input path(s): " + - paths.mkString("\n", "\n", "") + relation.paths.mkString("\n", "\n", "") } newSparkSQLSpecificMetastoreTable() case (Some(serde), relation: HadoopFsRelation) => logWarning { - val paths = relation.paths.mkString(", ") "Persisting data source relation with multiple input paths into Hive metastore in " + s"Spark SQL specific format, which is NOT compatible with Hive. Input paths: " + - paths.mkString("\n", "\n", "") + relation.paths.mkString("\n", "\n", "") } newSparkSQLSpecificMetastoreTable() From 7aaed1b114751a24835204b8c588533d5c5ffaf0 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 6 Aug 2015 22:52:23 -0700 Subject: [PATCH 212/340] [SPARK-8862][SQL]Support multiple SQLContexts in Web UI This is a follow-up PR to solve the UI issue when there are multiple SQLContexts. Each SQLContext has a separate tab and contains queries which are executed by this SQLContext. multiple sqlcontexts Author: zsxwing Closes #7962 from zsxwing/multi-sqlcontext-ui and squashes the following commits: cf661e1 [zsxwing] sql -> SQL 39b0c97 [zsxwing] Support multiple SQLContexts in Web UI --- .../org/apache/spark/sql/ui/AllExecutionsPage.scala | 2 +- .../main/scala/org/apache/spark/sql/ui/SQLTab.scala | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/ui/AllExecutionsPage.scala index 727fc4b37fa48..cb7ca60b2fe48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ui/AllExecutionsPage.scala @@ -178,7 +178,7 @@ private[ui] abstract class ExecutionTable( "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), jobId) private def executionURL(executionID: Long): String = - "%s/sql/execution?id=%s".format(UIUtils.prependBaseUri(parent.basePath), executionID) + s"${UIUtils.prependBaseUri(parent.basePath)}/${parent.prefix}/execution?id=$executionID" } private[ui] class RunningExecutionTable( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLTab.scala index a9e5226303978..3bba0afaf14eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLTab.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLTab.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.ui +import java.util.concurrent.atomic.AtomicInteger + import org.apache.spark.Logging import org.apache.spark.sql.SQLContext import org.apache.spark.ui.{SparkUI, SparkUITab} private[sql] class SQLTab(sqlContext: SQLContext, sparkUI: SparkUI) - extends SparkUITab(sparkUI, "sql") with Logging { - + extends SparkUITab(sparkUI, SQLTab.nextTabName) with Logging { val parent = sparkUI val listener = sqlContext.listener @@ -38,4 +39,11 @@ private[sql] class SQLTab(sqlContext: SQLContext, sparkUI: SparkUI) private[sql] object SQLTab { private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/ui/static" + + private val nextTabId = new AtomicInteger(0) + + private def nextTabName: String = { + val nextId = nextTabId.getAndIncrement() + if (nextId == 0) "SQL" else s"SQL${nextId}" + } } From 4309262ec9146d7158ee9957a128bb152289d557 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 6 Aug 2015 23:18:29 -0700 Subject: [PATCH 213/340] [SPARK-9700] Pick default page size more intelligently. Previously, we use 64MB as the default page size, which was way too big for a lot of Spark applications (especially for single node). This patch changes it so that the default page size, if unset by the user, is determined by the number of cores available and the total execution memory available. Author: Reynold Xin Closes #8012 from rxin/pagesize and squashes the following commits: 16f4756 [Reynold Xin] Fixed failing test. 5afd570 [Reynold Xin] private... 0d5fb98 [Reynold Xin] Update default value. 674a6cd [Reynold Xin] Address review feedback. dc00e05 [Reynold Xin] Merge with master. 73ebdb6 [Reynold Xin] [SPARK-9700] Pick default page size more intelligently. --- R/run-tests.sh | 2 +- .../unsafe/UnsafeShuffleExternalSorter.java | 3 +- .../spark/unsafe/map/BytesToBytesMap.java | 8 +-- .../unsafe/sort/UnsafeExternalSorter.java | 1 - .../scala/org/apache/spark/SparkConf.scala | 7 +++ .../scala/org/apache/spark/SparkContext.scala | 2 +- .../scala/org/apache/spark/SparkEnv.scala | 2 +- .../spark/shuffle/ShuffleMemoryManager.scala | 53 +++++++++++++++++-- .../unsafe/UnsafeShuffleWriterSuite.java | 5 +- .../map/AbstractBytesToBytesMapSuite.java | 6 +-- .../sort/UnsafeExternalSorterSuite.java | 4 +- .../shuffle/ShuffleMemoryManagerSuite.scala | 14 ++--- python/pyspark/java_gateway.py | 1 - .../TungstenAggregationIterator.scala | 2 +- .../sql/execution/joins/HashedRelation.scala | 16 +++--- .../org/apache/spark/sql/execution/sort.scala | 4 +- ...ypes.scala => ParquetTypesConverter.scala} | 0 .../execution/TestShuffleMemoryManager.scala | 2 +- .../apache/spark/sql/hive/test/TestHive.scala | 1 - .../spark/unsafe/array/ByteArrayMethods.java | 6 +++ 20 files changed, 93 insertions(+), 46 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/parquet/{ParquetTypes.scala => ParquetTypesConverter.scala} (100%) diff --git a/R/run-tests.sh b/R/run-tests.sh index 18a1e13bdc655..e82ad0ba2cd06 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -23,7 +23,7 @@ FAILED=0 LOGFILE=$FWDIR/unit-tests.out rm -f $LOGFILE -SPARK_TESTING=1 $FWDIR/../bin/sparkR --conf spark.buffer.pageSize=4m --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) if [[ $FAILED != 0 ]]; then diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index bf4eaa59ff589..f6e0913a7a0b3 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -115,8 +115,7 @@ public UnsafeShuffleExternalSorter( // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.pageSizeBytes = (int) Math.min( - PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, - conf.getSizeAsBytes("spark.buffer.pageSize", "64m")); + PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, shuffleMemoryManager.pageSizeBytes()); this.maxRecordSizeBytes = pageSizeBytes - 4; this.writeMetrics = writeMetrics; initializeForWriting(); diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 5ac3736ac62aa..0636ae7c8df1a 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -642,7 +642,7 @@ public boolean putNewKey( private void allocate(int capacity) { assert (capacity >= 0); // The capacity needs to be divisible by 64 so that our bit set can be sized properly - capacity = Math.max((int) Math.min(MAX_CAPACITY, nextPowerOf2(capacity)), 64); + capacity = Math.max((int) Math.min(MAX_CAPACITY, ByteArrayMethods.nextPowerOf2(capacity)), 64); assert (capacity <= MAX_CAPACITY); longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2])); bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64])); @@ -770,10 +770,4 @@ void growAndRehash() { timeSpentResizingNs += System.nanoTime() - resizeStartTime; } } - - /** Returns the next number greater or equal num that is power of 2. */ - private static long nextPowerOf2(long num) { - final long highBit = Long.highestOneBit(num); - return (highBit == num) ? num : highBit << 1; - } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 4c54ba4bce408..5ebbf9b068fd6 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -127,7 +127,6 @@ private UnsafeExternalSorter( // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.fileBufferSizeBytes = 32 * 1024; - // this.pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "64m"); this.pageSizeBytes = pageSizeBytes; this.writeMetrics = new ShuffleWriteMetrics(); diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 08bab4bf2739f..8ff154fb5e334 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -249,6 +249,13 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { Utils.byteStringAsBytes(get(key, defaultValue)) } + /** + * Get a size parameter as bytes, falling back to a default if not set. + */ + def getSizeAsBytes(key: String, defaultValue: Long): Long = { + Utils.byteStringAsBytes(get(key, defaultValue + "B")) + } + /** * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Kibibytes are assumed. diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 0c0705325b169..5662686436900 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -629,7 +629,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * [[org.apache.spark.SparkContext.setLocalProperty]]. */ def getLocalProperty(key: String): String = - Option(localProperties.get).map(_.getProperty(key)).getOrElse(null) + Option(localProperties.get).map(_.getProperty(key)).orNull /** Set a human readable description of the current job. */ def setJobDescription(value: String) { diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index adfece4d6e7c0..a796e72850191 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -324,7 +324,7 @@ object SparkEnv extends Logging { val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) - val shuffleMemoryManager = new ShuffleMemoryManager(conf) + val shuffleMemoryManager = ShuffleMemoryManager.create(conf, numUsableCores) val blockTransferService = conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match { diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala index e3d229cc99821..8c3a72644c38a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -19,6 +19,9 @@ package org.apache.spark.shuffle import scala.collection.mutable +import com.google.common.annotations.VisibleForTesting + +import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext} /** @@ -34,11 +37,19 @@ import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext} * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever * this set changes. This is all done by synchronizing access on "this" to mutate state and using * wait() and notifyAll() to signal changes. + * + * Use `ShuffleMemoryManager.create()` factory method to create a new instance. + * + * @param maxMemory total amount of memory available for execution, in bytes. + * @param pageSizeBytes number of bytes for each page, by default. */ -private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { - private val taskMemory = new mutable.HashMap[Long, Long]() // taskAttemptId -> memory bytes +private[spark] +class ShuffleMemoryManager protected ( + val maxMemory: Long, + val pageSizeBytes: Long) + extends Logging { - def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf)) + private val taskMemory = new mutable.HashMap[Long, Long]() // taskAttemptId -> memory bytes private def currentTaskAttemptId(): Long = { // In case this is called on the driver, return an invalid task attempt id. @@ -124,15 +135,49 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { } } + private[spark] object ShuffleMemoryManager { + + def create(conf: SparkConf, numCores: Int): ShuffleMemoryManager = { + val maxMemory = ShuffleMemoryManager.getMaxMemory(conf) + val pageSize = ShuffleMemoryManager.getPageSize(conf, maxMemory, numCores) + new ShuffleMemoryManager(maxMemory, pageSize) + } + + def create(maxMemory: Long, pageSizeBytes: Long): ShuffleMemoryManager = { + new ShuffleMemoryManager(maxMemory, pageSizeBytes) + } + + @VisibleForTesting + def createForTesting(maxMemory: Long): ShuffleMemoryManager = { + new ShuffleMemoryManager(maxMemory, 4 * 1024 * 1024) + } + /** * Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction * of the memory pool and a safety factor since collections can sometimes grow bigger than * the size we target before we estimate their sizes again. */ - def getMaxMemory(conf: SparkConf): Long = { + private def getMaxMemory(conf: SparkConf): Long = { val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2) val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8) (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong } + + /** + * Sets the page size, in bytes. + * + * If user didn't explicitly set "spark.buffer.pageSize", we figure out the default value + * by looking at the number of cores available to the process, and the total amount of memory, + * and then divide it by a factor of safety. + */ + private def getPageSize(conf: SparkConf, maxMemory: Long, numCores: Int): Long = { + val minPageSize = 1L * 1024 * 1024 // 1MB + val maxPageSize = 64L * minPageSize // 64MB + val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors() + val safetyFactor = 8 + val size = ByteArrayMethods.nextPowerOf2(maxMemory / cores / safetyFactor) + val default = math.min(maxPageSize, math.max(minPageSize, size)) + conf.getSizeAsBytes("spark.buffer.pageSize", default) + } } diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 98c32bbc298d7..c68354ba49a46 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -115,6 +115,7 @@ public void setUp() throws IOException { taskMetrics = new TaskMetrics(); when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); + when(shuffleMemoryManager.pageSizeBytes()).thenReturn(128L * 1024 * 1024); when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); when(blockManager.getDiskWriter( @@ -549,14 +550,14 @@ public void testPeakMemoryUsed() throws Exception { final long recordLengthBytes = 8; final long pageSizeBytes = 256; final long numRecordsPerPage = pageSizeBytes / recordLengthBytes; - final SparkConf conf = new SparkConf().set("spark.buffer.pageSize", pageSizeBytes + "b"); + when(shuffleMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes); final UnsafeShuffleWriter writer = new UnsafeShuffleWriter( blockManager, shuffleBlockResolver, taskMemoryManager, shuffleMemoryManager, - new UnsafeShuffleHandle(0, 1, shuffleDep), + new UnsafeShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, conf); diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 3c5003380162f..0b11562980b8e 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -48,7 +48,7 @@ public abstract class AbstractBytesToBytesMapSuite { @Before public void setup() { - shuffleMemoryManager = new ShuffleMemoryManager(Long.MAX_VALUE); + shuffleMemoryManager = ShuffleMemoryManager.create(Long.MAX_VALUE, PAGE_SIZE_BYTES); taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator())); // Mocked memory manager for tests that check the maximum array size, since actually allocating // such large arrays will cause us to run out of memory in our tests. @@ -441,7 +441,7 @@ public void randomizedTestWithRecordsLargerThanPageSize() { @Test public void failureToAllocateFirstPage() { - shuffleMemoryManager = new ShuffleMemoryManager(1024); + shuffleMemoryManager = ShuffleMemoryManager.createForTesting(1024); BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES); try { @@ -461,7 +461,7 @@ public void failureToAllocateFirstPage() { @Test public void failureToGrow() { - shuffleMemoryManager = new ShuffleMemoryManager(1024 * 10); + shuffleMemoryManager = ShuffleMemoryManager.createForTesting(1024 * 10); BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, 1024); try { boolean success = true; diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index f5300373d87ea..83049b8a21fcf 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -102,7 +102,7 @@ public void setUp() { MockitoAnnotations.initMocks(this); sparkConf = new SparkConf(); tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test"); - shuffleMemoryManager = new ShuffleMemoryManager(Long.MAX_VALUE); + shuffleMemoryManager = ShuffleMemoryManager.create(Long.MAX_VALUE, pageSizeBytes); spillFilesCreated.clear(); taskContext = mock(TaskContext.class); when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); @@ -237,7 +237,7 @@ public void testSortingEmptyArrays() throws Exception { @Test public void spillingOccursInResponseToMemoryPressure() throws Exception { - shuffleMemoryManager = new ShuffleMemoryManager(pageSizeBytes * 2); + shuffleMemoryManager = ShuffleMemoryManager.create(pageSizeBytes * 2, pageSizeBytes); final UnsafeExternalSorter sorter = newSorter(); final int numRecords = (int) pageSizeBytes / 4; for (int i = 0; i <= numRecords; i++) { diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala index f495b6a037958..6d45b1a101be6 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala @@ -24,7 +24,7 @@ import org.mockito.Mockito._ import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkFunSuite, TaskContext} +import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext} class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { @@ -50,7 +50,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } test("single task requesting memory") { - val manager = new ShuffleMemoryManager(1000L) + val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) assert(manager.tryToAcquire(100L) === 100L) assert(manager.tryToAcquire(400L) === 400L) @@ -72,7 +72,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { // Two threads request 500 bytes first, wait for each other to get it, and then request // 500 more; we should immediately return 0 as both are now at 1 / N - val manager = new ShuffleMemoryManager(1000L) + val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) class State { var t1Result1 = -1L @@ -124,7 +124,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { // Two tasks request 250 bytes first, wait for each other to get it, and then request // 500 more; we should only grant 250 bytes to each of them on this second request - val manager = new ShuffleMemoryManager(1000L) + val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) class State { var t1Result1 = -1L @@ -176,7 +176,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { // for a bit and releases 250 bytes, which should then be granted to t2. Further requests // by t2 will return false right away because it now has 1 / 2N of the memory. - val manager = new ShuffleMemoryManager(1000L) + val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) class State { var t1Requested = false @@ -241,7 +241,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps // for a bit and releases all its memory. t2 should now be able to grab all the memory. - val manager = new ShuffleMemoryManager(1000L) + val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) class State { var t1Requested = false @@ -307,7 +307,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } test("tasks should not be granted a negative size") { - val manager = new ShuffleMemoryManager(1000L) + val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) manager.tryToAcquire(700L) val latch = new CountDownLatch(1) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 60be85e53e2aa..cd4c55f79f18c 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -54,7 +54,6 @@ def launch_gateway(): if os.environ.get("SPARK_TESTING"): submit_args = ' '.join([ "--conf spark.ui.enabled=false", - "--conf spark.buffer.pageSize=4mb", submit_args ]) command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index b9d44aace1009..4d5e98a3e90c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -342,7 +342,7 @@ class TungstenAggregationIterator( TaskContext.get.taskMemoryManager(), SparkEnv.get.shuffleMemoryManager, 1024 * 16, // initial capacity - SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m"), + SparkEnv.get.shuffleMemoryManager.pageSizeBytes, false // disable tracking of performance metrics ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 3f257ecdd156c..953abf409f220 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -282,17 +282,15 @@ private[joins] final class UnsafeHashedRelation( // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory val taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + val pageSizeBytes = Option(SparkEnv.get).map(_.shuffleMemoryManager.pageSizeBytes) + .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m")) + // Dummy shuffle memory manager which always grants all memory allocation requests. // We use this because it doesn't make sense count shared broadcast variables' memory usage // towards individual tasks' quotas. In the future, we should devise a better way of handling // this. - val shuffleMemoryManager = new ShuffleMemoryManager(new SparkConf()) { - override def tryToAcquire(numBytes: Long): Long = numBytes - override def release(numBytes: Long): Unit = {} - } - - val pageSizeBytes = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) - .getSizeAsBytes("spark.buffer.pageSize", "64m") + val shuffleMemoryManager = + ShuffleMemoryManager.create(maxMemory = Long.MaxValue, pageSizeBytes = pageSizeBytes) binaryMap = new BytesToBytesMap( taskMemoryManager, @@ -306,11 +304,11 @@ private[joins] final class UnsafeHashedRelation( while (i < nKeys) { val keySize = in.readInt() val valuesSize = in.readInt() - if (keySize > keyBuffer.size) { + if (keySize > keyBuffer.length) { keyBuffer = new Array[Byte](keySize) } in.readFully(keyBuffer, 0, keySize) - if (valuesSize > valuesBuffer.size) { + if (valuesSize > valuesBuffer.length) { valuesBuffer = new Array[Byte](valuesSize) } in.readFully(valuesBuffer, 0, valuesSize) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index 7f69cdb08aa78..e316930470127 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.{InternalAccumulator, TaskContext} +import org.apache.spark.{SparkEnv, InternalAccumulator, TaskContext} import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ @@ -122,7 +122,7 @@ case class TungstenSort( protected override def doExecute(): RDD[InternalRow] = { val schema = child.schema val childOutput = child.output - val pageSize = sparkContext.conf.getSizeAsBytes("spark.buffer.pageSize", "64m") + val pageSize = SparkEnv.get.shuffleMemoryManager.pageSizeBytes /** * Set up the sorter in each partition before computing the parent partition. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypesConverter.scala similarity index 100% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala rename to sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypesConverter.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala index 53de2d0f0771f..48c3938ff87ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala @@ -22,7 +22,7 @@ import org.apache.spark.shuffle.ShuffleMemoryManager /** * A [[ShuffleMemoryManager]] that can be controlled to run out of memory. */ -class TestShuffleMemoryManager extends ShuffleMemoryManager(Long.MaxValue) { +class TestShuffleMemoryManager extends ShuffleMemoryManager(Long.MaxValue, 4 * 1024 * 1024) { private var oom = false override def tryToAcquire(numBytes: Long): Long = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 167086db5bfe2..296cc5c5e0b04 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -52,7 +52,6 @@ object TestHive .set("spark.sql.test", "") .set("spark.sql.hive.metastore.barrierPrefixes", "org.apache.spark.sql.hive.execution.PairSerDe") - .set("spark.buffer.pageSize", "4m") // SPARK-8910 .set("spark.ui.enabled", "false"))) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index cf693d01a4f5b..70b81ce015ddc 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -25,6 +25,12 @@ private ByteArrayMethods() { // Private constructor, since this class only contains static methods. } + /** Returns the next number greater or equal num that is power of 2. */ + public static long nextPowerOf2(long num) { + final long highBit = Long.highestOneBit(num); + return (highBit == num) ? num : highBit << 1; + } + public static int roundNumberOfBytesToNearestWord(int numBytes) { int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` if (remainder == 0) { From 15bd6f338dff4bcab4a1a3a2c568655022e49c32 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 6 Aug 2015 23:40:38 -0700 Subject: [PATCH 214/340] [SPARK-9453] [SQL] support records larger than page size in UnsafeShuffleExternalSorter This patch follows exactly #7891 (except testing) Author: Davies Liu Closes #8005 from davies/larger_record and squashes the following commits: f9c4aff [Davies Liu] address comments 9de5c72 [Davies Liu] support records larger than page size in UnsafeShuffleExternalSorter --- .../unsafe/UnsafeShuffleExternalSorter.java | 143 +++++++++++------- .../shuffle/unsafe/UnsafeShuffleWriter.java | 10 +- .../unsafe/UnsafeShuffleWriterSuite.java | 60 ++------ 3 files changed, 103 insertions(+), 110 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index f6e0913a7a0b3..925b60a145886 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -17,10 +17,10 @@ package org.apache.spark.shuffle.unsafe; +import javax.annotation.Nullable; import java.io.File; import java.io.IOException; import java.util.LinkedList; -import javax.annotation.Nullable; import scala.Tuple2; @@ -34,8 +34,11 @@ import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleMemoryManager; -import org.apache.spark.storage.*; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.DiskBlockObjectWriter; +import org.apache.spark.storage.TempShuffleBlockId; import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; import org.apache.spark.util.Utils; @@ -68,7 +71,7 @@ final class UnsafeShuffleExternalSorter { private final int pageSizeBytes; @VisibleForTesting final int maxRecordSizeBytes; - private final TaskMemoryManager memoryManager; + private final TaskMemoryManager taskMemoryManager; private final ShuffleMemoryManager shuffleMemoryManager; private final BlockManager blockManager; private final TaskContext taskContext; @@ -91,7 +94,7 @@ final class UnsafeShuffleExternalSorter { private long peakMemoryUsedBytes; // These variables are reset after spilling: - @Nullable private UnsafeShuffleInMemorySorter sorter; + @Nullable private UnsafeShuffleInMemorySorter inMemSorter; @Nullable private MemoryBlock currentPage = null; private long currentPagePosition = -1; private long freeSpaceInCurrentPage = 0; @@ -105,7 +108,7 @@ public UnsafeShuffleExternalSorter( int numPartitions, SparkConf conf, ShuffleWriteMetrics writeMetrics) throws IOException { - this.memoryManager = memoryManager; + this.taskMemoryManager = memoryManager; this.shuffleMemoryManager = shuffleMemoryManager; this.blockManager = blockManager; this.taskContext = taskContext; @@ -133,7 +136,7 @@ private void initializeForWriting() throws IOException { throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); } - this.sorter = new UnsafeShuffleInMemorySorter(initialSize); + this.inMemSorter = new UnsafeShuffleInMemorySorter(initialSize); } /** @@ -160,7 +163,7 @@ private void writeSortedFile(boolean isLastFile) throws IOException { // This call performs the actual sort. final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords = - sorter.getSortedIterator(); + inMemSorter.getSortedIterator(); // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this // after SPARK-5581 is fixed. @@ -206,8 +209,8 @@ private void writeSortedFile(boolean isLastFile) throws IOException { } final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); - final Object recordPage = memoryManager.getPage(recordPointer); - final long recordOffsetInPage = memoryManager.getOffsetInPage(recordPointer); + final Object recordPage = taskMemoryManager.getPage(recordPointer); + final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer); int dataRemaining = PlatformDependent.UNSAFE.getInt(recordPage, recordOffsetInPage); long recordReadPosition = recordOffsetInPage + 4; // skip over record length while (dataRemaining > 0) { @@ -269,9 +272,9 @@ void spill() throws IOException { spills.size() > 1 ? " times" : " time"); writeSortedFile(false); - final long sorterMemoryUsage = sorter.getMemoryUsage(); - sorter = null; - shuffleMemoryManager.release(sorterMemoryUsage); + final long inMemSorterMemoryUsage = inMemSorter.getMemoryUsage(); + inMemSorter = null; + shuffleMemoryManager.release(inMemSorterMemoryUsage); final long spillSize = freeMemory(); taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); @@ -283,7 +286,7 @@ private long getMemoryUsage() { for (MemoryBlock page : allocatedPages) { totalPageSize += page.size(); } - return ((sorter == null) ? 0 : sorter.getMemoryUsage()) + totalPageSize; + return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize; } private void updatePeakMemoryUsed() { @@ -305,7 +308,7 @@ private long freeMemory() { updatePeakMemoryUsed(); long memoryFreed = 0; for (MemoryBlock block : allocatedPages) { - memoryManager.freePage(block); + taskMemoryManager.freePage(block); shuffleMemoryManager.release(block.size()); memoryFreed += block.size(); } @@ -319,54 +322,53 @@ private long freeMemory() { /** * Force all memory and spill files to be deleted; called by shuffle error-handling code. */ - public void cleanupAfterError() { + public void cleanupResources() { freeMemory(); for (SpillInfo spill : spills) { if (spill.file.exists() && !spill.file.delete()) { logger.error("Unable to delete spill file {}", spill.file.getPath()); } } - if (sorter != null) { - shuffleMemoryManager.release(sorter.getMemoryUsage()); - sorter = null; + if (inMemSorter != null) { + shuffleMemoryManager.release(inMemSorter.getMemoryUsage()); + inMemSorter = null; } } /** - * Checks whether there is enough space to insert a new record into the sorter. - * - * @param requiredSpace the required space in the data page, in bytes, including space for storing - * the record size. - - * @return true if the record can be inserted without requiring more allocations, false otherwise. - */ - private boolean haveSpaceForRecord(int requiredSpace) { - assert (requiredSpace > 0); - return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage)); - } - - /** - * Allocates more memory in order to insert an additional record. This will request additional - * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be - * obtained. - * - * @param requiredSpace the required space in the data page, in bytes, including space for storing - * the record size. + * Checks whether there is enough space to insert an additional record in to the sort pointer + * array and grows the array if additional space is required. If the required space cannot be + * obtained, then the in-memory data will be spilled to disk. */ - private void allocateSpaceForRecord(int requiredSpace) throws IOException { - if (!sorter.hasSpaceForAnotherRecord()) { + private void growPointerArrayIfNecessary() throws IOException { + assert(inMemSorter != null); + if (!inMemSorter.hasSpaceForAnotherRecord()) { logger.debug("Attempting to expand sort pointer array"); - final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage(); + final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage(); final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2; final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray); if (memoryAcquired < memoryToGrowPointerArray) { shuffleMemoryManager.release(memoryAcquired); spill(); } else { - sorter.expandPointerArray(); + inMemSorter.expandPointerArray(); shuffleMemoryManager.release(oldPointerArrayMemoryUsage); } } + } + + /** + * Allocates more memory in order to insert an additional record. This will request additional + * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be + * obtained. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. This must be less than or equal to the page size (records + * that exceed the page size are handled via a different code path which uses + * special overflow pages). + */ + private void acquireNewPageIfNecessary(int requiredSpace) throws IOException { + growPointerArrayIfNecessary(); if (requiredSpace > freeSpaceInCurrentPage) { logger.trace("Required space {} is less than free space in current page ({})", requiredSpace, freeSpaceInCurrentPage); @@ -387,7 +389,7 @@ private void allocateSpaceForRecord(int requiredSpace) throws IOException { throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); } } - currentPage = memoryManager.allocatePage(pageSizeBytes); + currentPage = taskMemoryManager.allocatePage(pageSizeBytes); currentPagePosition = currentPage.getBaseOffset(); freeSpaceInCurrentPage = pageSizeBytes; allocatedPages.add(currentPage); @@ -403,27 +405,58 @@ public void insertRecord( long recordBaseOffset, int lengthInBytes, int partitionId) throws IOException { + + growPointerArrayIfNecessary(); // Need 4 bytes to store the record length. final int totalSpaceRequired = lengthInBytes + 4; - if (!haveSpaceForRecord(totalSpaceRequired)) { - allocateSpaceForRecord(totalSpaceRequired); + + // --- Figure out where to insert the new record ---------------------------------------------- + + final MemoryBlock dataPage; + long dataPagePosition; + boolean useOverflowPage = totalSpaceRequired > pageSizeBytes; + if (useOverflowPage) { + long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired); + // The record is larger than the page size, so allocate a special overflow page just to hold + // that record. + final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize); + if (memoryGranted != overflowPageSize) { + shuffleMemoryManager.release(memoryGranted); + spill(); + final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize); + if (memoryGrantedAfterSpill != overflowPageSize) { + shuffleMemoryManager.release(memoryGrantedAfterSpill); + throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory"); + } + } + MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); + allocatedPages.add(overflowPage); + dataPage = overflowPage; + dataPagePosition = overflowPage.getBaseOffset(); + } else { + // The record is small enough to fit in a regular data page, but the current page might not + // have enough space to hold it (or no pages have been allocated yet). + acquireNewPageIfNecessary(totalSpaceRequired); + dataPage = currentPage; + dataPagePosition = currentPagePosition; + // Update bookkeeping information + freeSpaceInCurrentPage -= totalSpaceRequired; + currentPagePosition += totalSpaceRequired; } + final Object dataPageBaseObject = dataPage.getBaseObject(); final long recordAddress = - memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); - final Object dataPageBaseObject = currentPage.getBaseObject(); - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes); - currentPagePosition += 4; - freeSpaceInCurrentPage -= 4; + taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); + dataPagePosition += 4; PlatformDependent.copyMemory( recordBaseObject, recordBaseOffset, dataPageBaseObject, - currentPagePosition, + dataPagePosition, lengthInBytes); - currentPagePosition += lengthInBytes; - freeSpaceInCurrentPage -= lengthInBytes; - sorter.insertRecord(recordAddress, partitionId); + assert(inMemSorter != null); + inMemSorter.insertRecord(recordAddress, partitionId); } /** @@ -435,14 +468,14 @@ public void insertRecord( */ public SpillInfo[] closeAndGetSpills() throws IOException { try { - if (sorter != null) { + if (inMemSorter != null) { // Do not count the final file towards the spill count. writeSortedFile(true); freeMemory(); } return spills.toArray(new SpillInfo[spills.size()]); } catch (IOException e) { - cleanupAfterError(); + cleanupResources(); throw e; } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 6e2eeb37c86f1..02084f9122e00 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -17,17 +17,17 @@ package org.apache.spark.shuffle.unsafe; +import javax.annotation.Nullable; import java.io.*; import java.nio.channels.FileChannel; import java.util.Iterator; -import javax.annotation.Nullable; import scala.Option; import scala.Product2; import scala.collection.JavaConversions; +import scala.collection.immutable.Map; import scala.reflect.ClassTag; import scala.reflect.ClassTag$; -import scala.collection.immutable.Map; import com.google.common.annotations.VisibleForTesting; import com.google.common.io.ByteStreams; @@ -38,10 +38,10 @@ import org.apache.spark.*; import org.apache.spark.annotation.Private; +import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; import org.apache.spark.io.LZFCompressionCodec; -import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; @@ -178,7 +178,7 @@ public void write(scala.collection.Iterator> records) throws IOEx } finally { if (sorter != null) { try { - sorter.cleanupAfterError(); + sorter.cleanupResources(); } catch (Exception e) { // Only throw this error if we won't be masking another // error. @@ -482,7 +482,7 @@ public Option stop(boolean success) { if (sorter != null) { // If sorter is non-null, then this implies that we called stop() in response to an error, // so we need to clean up memory and spill files created by the sorter - sorter.cleanupAfterError(); + sorter.cleanupResources(); } } } diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index c68354ba49a46..94650be536b5f 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -475,62 +475,22 @@ public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception @Test public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception { - // Use a custom serializer so that we have exact control over the size of serialized data. - final Serializer byteArraySerializer = new Serializer() { - @Override - public SerializerInstance newInstance() { - return new SerializerInstance() { - @Override - public SerializationStream serializeStream(final OutputStream s) { - return new SerializationStream() { - @Override - public void flush() { } - - @Override - public SerializationStream writeObject(T t, ClassTag ev1) { - byte[] bytes = (byte[]) t; - try { - s.write(bytes); - } catch (IOException e) { - throw new RuntimeException(e); - } - return this; - } - - @Override - public void close() { } - }; - } - public ByteBuffer serialize(T t, ClassTag ev1) { return null; } - public DeserializationStream deserializeStream(InputStream s) { return null; } - public T deserialize(ByteBuffer b, ClassLoader l, ClassTag ev1) { return null; } - public T deserialize(ByteBuffer bytes, ClassTag ev1) { return null; } - }; - } - }; - when(shuffleDep.serializer()).thenReturn(Option.apply(byteArraySerializer)); final UnsafeShuffleWriter writer = createWriter(false); - // Insert a record and force a spill so that there's something to clean up: - writer.insertRecordIntoSorter(new Tuple2(new byte[1], new byte[1])); - writer.forceSorterToSpill(); + final ArrayList> dataToWrite = new ArrayList>(); + dataToWrite.add(new Tuple2(1, ByteBuffer.wrap(new byte[1]))); // We should be able to write a record that's right _at_ the max record size final byte[] atMaxRecordSize = new byte[writer.maxRecordSizeBytes()]; new Random(42).nextBytes(atMaxRecordSize); - writer.insertRecordIntoSorter(new Tuple2(new byte[0], atMaxRecordSize)); - writer.forceSorterToSpill(); - // Inserting a record that's larger than the max record size should fail: + dataToWrite.add(new Tuple2(2, ByteBuffer.wrap(atMaxRecordSize))); + // Inserting a record that's larger than the max record size final byte[] exceedsMaxRecordSize = new byte[writer.maxRecordSizeBytes() + 1]; new Random(42).nextBytes(exceedsMaxRecordSize); - Product2 hugeRecord = - new Tuple2(new byte[0], exceedsMaxRecordSize); - try { - // Here, we write through the public `write()` interface instead of the test-only - // `insertRecordIntoSorter` interface: - writer.write(Collections.singletonList(hugeRecord).iterator()); - fail("Expected exception to be thrown"); - } catch (IOException e) { - // Pass - } + dataToWrite.add(new Tuple2(3, ByteBuffer.wrap(exceedsMaxRecordSize))); + writer.write(dataToWrite.iterator()); + writer.stop(true); + assertEquals( + HashMultiset.create(dataToWrite), + HashMultiset.create(readRecordsFromFile())); assertSpillFilesWereCleanedUp(); } From e57d6b56137bf3557efe5acea3ad390c1987b257 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 7 Aug 2015 00:00:43 -0700 Subject: [PATCH 215/340] [SPARK-9683] [SQL] copy UTF8String when convert unsafe array/map to safe When we convert unsafe row to safe row, we will do copy if the column is struct or string type. However, the string inside unsafe array/map are not copied, which may cause problems. Author: Wenchen Fan Closes #7990 from cloud-fan/copy and squashes the following commits: c13d1e3 [Wenchen Fan] change test name fe36294 [Wenchen Fan] we should deep copy UTF8String when convert unsafe row to safe row --- .../sql/catalyst/expressions/FromUnsafe.scala | 3 ++ .../execution/RowFormatConvertersSuite.scala | 38 ++++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala index 3caf0fb3410c4..9b960b136f984 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String case class FromUnsafe(child: Expression) extends UnaryExpression with ExpectsInputTypes with CodegenFallback { @@ -52,6 +53,8 @@ case class FromUnsafe(child: Expression) extends UnaryExpression } new GenericArrayData(result) + case StringType => value.asInstanceOf[UTF8String].clone() + case MapType(kt, vt, _) => val map = value.asInstanceOf[UnsafeMapData] val safeKeyArray = convert(map.keys, ArrayType(kt)).asInstanceOf[GenericArrayData] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 707cd9c6d939b..8208b25b5708c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -17,9 +17,13 @@ package org.apache.spark.sql.execution +import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.{Literal, IsNull} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull} import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StructType, StringType} +import org.apache.spark.unsafe.types.UTF8String class RowFormatConvertersSuite extends SparkPlanTest { @@ -87,4 +91,36 @@ class RowFormatConvertersSuite extends SparkPlanTest { input.map(Row.fromTuple) ) } + + test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") { + SparkPlan.currentContext.set(TestSQLContext) + val schema = ArrayType(StringType) + val rows = (1 to 100).map { i => + InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString)))) + } + val relation = LocalTableScan(Seq(AttributeReference("t", schema)()), rows) + + val plan = + DummyPlan( + ConvertToSafe( + ConvertToUnsafe(relation))) + assert(plan.execute().collect().map(_.getUTF8String(0).toString) === (1 to 100).map(_.toString)) + } +} + +case class DummyPlan(child: SparkPlan) extends UnaryNode { + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + // cache all strings to make sure we have deep copied UTF8String inside incoming + // safe InternalRow. + val strings = new scala.collection.mutable.ArrayBuffer[UTF8String] + iter.foreach { row => + strings += row.getArray(0).getUTF8String(0) + } + strings.map(InternalRow(_)).iterator + } + } + + override def output: Seq[Attribute] = Seq(AttributeReference("a", StringType)()) } From ebfd91c542aaead343cb154277fcf9114382fee7 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 7 Aug 2015 00:09:58 -0700 Subject: [PATCH 216/340] [SPARK-9467][SQL]Add SQLMetric to specialize accumulators to avoid boxing This PR adds SQLMetric/SQLMetricParam/SQLMetricValue to specialize accumulators to avoid boxing. All SQL metrics should use these classes rather than `Accumulator`. Author: zsxwing Closes #7996 from zsxwing/sql-accu and squashes the following commits: 14a5f0a [zsxwing] Address comments 367ca23 [zsxwing] Use localValue directly to avoid changing Accumulable 42f50c3 [zsxwing] Add SQLMetric to specialize accumulators to avoid boxing --- .../scala/org/apache/spark/Accumulators.scala | 2 +- .../scala/org/apache/spark/SparkContext.scala | 15 -- .../spark/sql/execution/SparkPlan.scala | 33 ++-- .../spark/sql/execution/basicOperators.scala | 11 +- .../apache/spark/sql/metric/SQLMetrics.scala | 149 ++++++++++++++++++ .../org/apache/spark/sql/ui/SQLListener.scala | 17 +- .../apache/spark/sql/ui/SparkPlanGraph.scala | 8 +- .../spark/sql/metric/SQLMetricsSuite.scala | 145 +++++++++++++++++ .../spark/sql/ui/SQLListenerSuite.scala | 5 +- 9 files changed, 338 insertions(+), 47 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 462d5c96d480b..064246dfa7fc3 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -257,7 +257,7 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa */ class Accumulator[T] private[spark] ( @transient private[spark] val initialValue: T, - private[spark] val param: AccumulatorParam[T], + param: AccumulatorParam[T], name: Option[String], internal: Boolean) extends Accumulable[T, T](initialValue, param, name, internal) { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 5662686436900..9ced44131b0d9 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1238,21 +1238,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli acc } - /** - * Create an [[org.apache.spark.Accumulator]] variable of a given type, with a name for display - * in the Spark UI. Tasks can "add" values to the accumulator using the `+=` method. Only the - * driver can access the accumulator's `value`. The latest local value of such accumulator will be - * sent back to the driver via heartbeats. - * - * @tparam T type that can be added to the accumulator, must be thread safe - */ - private[spark] def internalAccumulator[T](initialValue: T, name: String)( - implicit param: AccumulatorParam[T]): Accumulator[T] = { - val acc = new Accumulator(initialValue, param, Some(name), internal = true) - cleaner.foreach(_.registerAccumulatorForCleanup(acc)) - acc - } - /** * Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values * with `+=`. Only the driver can access the accumuable's `value`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 719ad432e2fe0..1915496d16205 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{Accumulator, Logging} +import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.sql.SQLContext @@ -32,6 +32,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.metric.{IntSQLMetric, LongSQLMetric, SQLMetric, SQLMetrics} import org.apache.spark.sql.types.DataType object SparkPlan { @@ -84,22 +85,30 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ */ protected[sql] def trackNumOfRowsEnabled: Boolean = false - private lazy val numOfRowsAccumulator = sparkContext.internalAccumulator(0L, "number of rows") + private lazy val defaultMetrics: Map[String, SQLMetric[_, _]] = + if (trackNumOfRowsEnabled) { + Map("numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) + } + else { + Map.empty + } /** - * Return all accumulators containing metrics of this SparkPlan. + * Return all metrics containing metrics of this SparkPlan. */ - private[sql] def accumulators: Map[String, Accumulator[_]] = if (trackNumOfRowsEnabled) { - Map("numRows" -> numOfRowsAccumulator) - } else { - Map.empty - } + private[sql] def metrics: Map[String, SQLMetric[_, _]] = defaultMetrics + + /** + * Return a IntSQLMetric according to the name. + */ + private[sql] def intMetric(name: String): IntSQLMetric = + metrics(name).asInstanceOf[IntSQLMetric] /** - * Return the accumulator according to the name. + * Return a LongSQLMetric according to the name. */ - private[sql] def accumulator[T](name: String): Accumulator[T] = - accumulators(name).asInstanceOf[Accumulator[T]] + private[sql] def longMetric(name: String): LongSQLMetric = + metrics(name).asInstanceOf[LongSQLMetric] // TODO: Move to `DistributedPlan` /** Specifies how data is partitioned across different nodes in the cluster. */ @@ -148,7 +157,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ RDDOperationScope.withScope(sparkContext, nodeName, false, true) { prepare() if (trackNumOfRowsEnabled) { - val numRows = accumulator[Long]("numRows") + val numRows = longMetric("numRows") doExecute().map { row => numRows += 1 row diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index f4677b4ee86bb..0680f31d40f6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.metric.SQLMetrics import org.apache.spark.sql.types.StructType import org.apache.spark.util.collection.ExternalSorter import org.apache.spark.util.collection.unsafe.sort.PrefixComparator @@ -81,13 +82,13 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - private[sql] override lazy val accumulators = Map( - "numInputRows" -> sparkContext.internalAccumulator(0L, "number of input rows"), - "numOutputRows" -> sparkContext.internalAccumulator(0L, "number of output rows")) + private[sql] override lazy val metrics = Map( + "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) protected override def doExecute(): RDD[InternalRow] = { - val numInputRows = accumulator[Long]("numInputRows") - val numOutputRows = accumulator[Long]("numOutputRows") + val numInputRows = longMetric("numInputRows") + val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitions { iter => val predicate = newPredicate(condition, child.output) iter.filter { row => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala new file mode 100644 index 0000000000000..3b907e5da7897 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala @@ -0,0 +1,149 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.metric + +import org.apache.spark.{Accumulable, AccumulableParam, SparkContext} + +/** + * Create a layer for specialized metric. We cannot add `@specialized` to + * `Accumulable/AccumulableParam` because it will break Java source compatibility. + * + * An implementation of SQLMetric should override `+=` and `add` to avoid boxing. + */ +private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T]( + name: String, val param: SQLMetricParam[R, T]) + extends Accumulable[R, T](param.zero, param, Some(name), true) + +/** + * Create a layer for specialized metric. We cannot add `@specialized` to + * `Accumulable/AccumulableParam` because it will break Java source compatibility. + */ +private[sql] trait SQLMetricParam[R <: SQLMetricValue[T], T] extends AccumulableParam[R, T] { + + def zero: R +} + +/** + * Create a layer for specialized metric. We cannot add `@specialized` to + * `Accumulable/AccumulableParam` because it will break Java source compatibility. + */ +private[sql] trait SQLMetricValue[T] extends Serializable { + + def value: T + + override def toString: String = value.toString +} + +/** + * A wrapper of Long to avoid boxing and unboxing when using Accumulator + */ +private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetricValue[Long] { + + def add(incr: Long): LongSQLMetricValue = { + _value += incr + this + } + + // Although there is a boxing here, it's fine because it's only called in SQLListener + override def value: Long = _value +} + +/** + * A wrapper of Int to avoid boxing and unboxing when using Accumulator + */ +private[sql] class IntSQLMetricValue(private var _value: Int) extends SQLMetricValue[Int] { + + def add(term: Int): IntSQLMetricValue = { + _value += term + this + } + + // Although there is a boxing here, it's fine because it's only called in SQLListener + override def value: Int = _value +} + +/** + * A specialized long Accumulable to avoid boxing and unboxing when using Accumulator's + * `+=` and `add`. + */ +private[sql] class LongSQLMetric private[metric](name: String) + extends SQLMetric[LongSQLMetricValue, Long](name, LongSQLMetricParam) { + + override def +=(term: Long): Unit = { + localValue.add(term) + } + + override def add(term: Long): Unit = { + localValue.add(term) + } +} + +/** + * A specialized int Accumulable to avoid boxing and unboxing when using Accumulator's + * `+=` and `add`. + */ +private[sql] class IntSQLMetric private[metric](name: String) + extends SQLMetric[IntSQLMetricValue, Int](name, IntSQLMetricParam) { + + override def +=(term: Int): Unit = { + localValue.add(term) + } + + override def add(term: Int): Unit = { + localValue.add(term) + } +} + +private object LongSQLMetricParam extends SQLMetricParam[LongSQLMetricValue, Long] { + + override def addAccumulator(r: LongSQLMetricValue, t: Long): LongSQLMetricValue = r.add(t) + + override def addInPlace(r1: LongSQLMetricValue, r2: LongSQLMetricValue): LongSQLMetricValue = + r1.add(r2.value) + + override def zero(initialValue: LongSQLMetricValue): LongSQLMetricValue = zero + + override def zero: LongSQLMetricValue = new LongSQLMetricValue(0L) +} + +private object IntSQLMetricParam extends SQLMetricParam[IntSQLMetricValue, Int] { + + override def addAccumulator(r: IntSQLMetricValue, t: Int): IntSQLMetricValue = r.add(t) + + override def addInPlace(r1: IntSQLMetricValue, r2: IntSQLMetricValue): IntSQLMetricValue = + r1.add(r2.value) + + override def zero(initialValue: IntSQLMetricValue): IntSQLMetricValue = zero + + override def zero: IntSQLMetricValue = new IntSQLMetricValue(0) +} + +private[sql] object SQLMetrics { + + def createIntMetric(sc: SparkContext, name: String): IntSQLMetric = { + val acc = new IntSQLMetric(name) + sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc + } + + def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = { + val acc = new LongSQLMetric(name) + sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala index e7b1dd1ffac68..2fd4fc658d068 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala @@ -21,11 +21,12 @@ import scala.collection.mutable import com.google.common.annotations.VisibleForTesting -import org.apache.spark.{AccumulatorParam, JobExecutionStatus, Logging} +import org.apache.spark.{JobExecutionStatus, Logging} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.metric.{SQLMetricParam, SQLMetricValue} private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener with Logging { @@ -36,8 +37,6 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit // Old data in the following fields must be removed in "trimExecutionsIfNecessary". // If adding new fields, make sure "trimExecutionsIfNecessary" can clean up old data - - // VisibleForTesting private val _executionIdToData = mutable.HashMap[Long, SQLExecutionUIData]() /** @@ -270,9 +269,10 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit accumulatorUpdate <- taskMetrics.accumulatorUpdates.toSeq) yield { accumulatorUpdate } - }.filter { case (id, _) => executionUIData.accumulatorMetrics.keySet(id) } + }.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) } mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId => - executionUIData.accumulatorMetrics(accumulatorId).accumulatorParam) + executionUIData.accumulatorMetrics(accumulatorId).metricParam). + mapValues(_.asInstanceOf[SQLMetricValue[_]].value) case None => // This execution has been dropped Map.empty @@ -281,10 +281,11 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit private def mergeAccumulatorUpdates( accumulatorUpdates: Seq[(Long, Any)], - paramFunc: Long => AccumulatorParam[Any]): Map[Long, Any] = { + paramFunc: Long => SQLMetricParam[SQLMetricValue[Any], Any]): Map[Long, Any] = { accumulatorUpdates.groupBy(_._1).map { case (accumulatorId, values) => val param = paramFunc(accumulatorId) - (accumulatorId, values.map(_._2).reduceLeft(param.addInPlace)) + (accumulatorId, + values.map(_._2.asInstanceOf[SQLMetricValue[Any]]).foldLeft(param.zero)(param.addInPlace)) } } @@ -336,7 +337,7 @@ private[ui] class SQLExecutionUIData( private[ui] case class SQLPlanMetric( name: String, accumulatorId: Long, - accumulatorParam: AccumulatorParam[Any]) + metricParam: SQLMetricParam[SQLMetricValue[Any], Any]) /** * Store all accumulatorUpdates for all tasks in a Spark stage. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala index 7910c163ba453..1ba50b95becc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala @@ -21,8 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import org.apache.spark.AccumulatorParam import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.metric.{SQLMetricParam, SQLMetricValue} /** * A graph used for storing information of an executionPlan of DataFrame. @@ -61,9 +61,9 @@ private[sql] object SparkPlanGraph { nodeIdGenerator: AtomicLong, nodes: mutable.ArrayBuffer[SparkPlanGraphNode], edges: mutable.ArrayBuffer[SparkPlanGraphEdge]): SparkPlanGraphNode = { - val metrics = plan.accumulators.toSeq.map { case (key, accumulator) => - SQLPlanMetric(accumulator.name.getOrElse(key), accumulator.id, - accumulator.param.asInstanceOf[AccumulatorParam[Any]]) + val metrics = plan.metrics.toSeq.map { case (key, metric) => + SQLPlanMetric(metric.name.getOrElse(key), metric.id, + metric.param.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]]) } val node = SparkPlanGraphNode( nodeIdGenerator.getAndIncrement(), plan.nodeName, plan.simpleString, metrics) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala new file mode 100644 index 0000000000000..d22160f5384f4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala @@ -0,0 +1,145 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.metric + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + +import scala.collection.mutable + +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm._ +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.util.Utils + + +class SQLMetricsSuite extends SparkFunSuite { + + test("LongSQLMetric should not box Long") { + val l = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "long") + val f = () => { l += 1L } + BoxingFinder.getClassReader(f.getClass).foreach { cl => + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}") + } + } + + test("IntSQLMetric should not box Int") { + val l = SQLMetrics.createIntMetric(TestSQLContext.sparkContext, "Int") + val f = () => { l += 1 } + BoxingFinder.getClassReader(f.getClass).foreach { cl => + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}") + } + } + + test("Normal accumulator should do boxing") { + // We need this test to make sure BoxingFinder works. + val l = TestSQLContext.sparkContext.accumulator(0L) + val f = () => { l += 1L } + BoxingFinder.getClassReader(f.getClass).foreach { cl => + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.nonEmpty, "Found find boxing in this test") + } + } +} + +private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String) + +/** + * If `method` is null, search all methods of this class recursively to find if they do some boxing. + * If `method` is specified, only search this method of the class to speed up the searching. + * + * This method will skip the methods in `visitedMethods` to avoid potential infinite cycles. + */ +private class BoxingFinder( + method: MethodIdentifier[_] = null, + val boxingInvokes: mutable.Set[String] = mutable.Set.empty, + visitedMethods: mutable.Set[MethodIdentifier[_]] = mutable.Set.empty) + extends ClassVisitor(ASM4) { + + private val primitiveBoxingClassName = + Set("java/lang/Long", + "java/lang/Double", + "java/lang/Integer", + "java/lang/Float", + "java/lang/Short", + "java/lang/Character", + "java/lang/Byte", + "java/lang/Boolean") + + override def visitMethod( + access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): + MethodVisitor = { + if (method != null && (method.name != name || method.desc != desc)) { + // If method is specified, skip other methods. + return new MethodVisitor(ASM4) {} + } + + new MethodVisitor(ASM4) { + override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { + if (op == INVOKESPECIAL && name == "" || op == INVOKESTATIC && name == "valueOf") { + if (primitiveBoxingClassName.contains(owner)) { + // Find boxing methods, e.g, new java.lang.Long(l) or java.lang.Long.valueOf(l) + boxingInvokes.add(s"$owner.$name") + } + } else { + // scalastyle:off classforname + val classOfMethodOwner = Class.forName(owner.replace('/', '.'), false, + Thread.currentThread.getContextClassLoader) + // scalastyle:on classforname + val m = MethodIdentifier(classOfMethodOwner, name, desc) + if (!visitedMethods.contains(m)) { + // Keep track of visited methods to avoid potential infinite cycles + visitedMethods += m + BoxingFinder.getClassReader(classOfMethodOwner).foreach { cl => + visitedMethods += m + cl.accept(new BoxingFinder(m, boxingInvokes, visitedMethods), 0) + } + } + } + } + } + } +} + +private object BoxingFinder { + + def getClassReader(cls: Class[_]): Option[ClassReader] = { + val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" + val resourceStream = cls.getResourceAsStream(className) + val baos = new ByteArrayOutputStream(128) + // Copy data over, before delegating to ClassReader - + // else we can run out of open file handles. + Utils.copyStream(resourceStream, baos, true) + // ASM4 doesn't support Java 8 classes, which requires ASM5. + // So if the class is ASM5 (E.g., java.lang.Long when using JDK8 runtime to run these codes), + // then ClassReader will throw IllegalArgumentException, + // However, since this is only for testing, it's safe to skip these classes. + try { + Some(new ClassReader(new ByteArrayInputStream(baos.toByteArray))) + } catch { + case _: IllegalArgumentException => None + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala index f1fcaf59532b8..69a561e16aa17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala @@ -21,6 +21,7 @@ import java.util.Properties import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite} import org.apache.spark.executor.TaskMetrics +import org.apache.spark.sql.metric.LongSQLMetricValue import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.execution.SQLExecution @@ -65,9 +66,9 @@ class SQLListenerSuite extends SparkFunSuite { speculative = false ) - private def createTaskMetrics(accumulatorUpdates: Map[Long, Any]): TaskMetrics = { + private def createTaskMetrics(accumulatorUpdates: Map[Long, Long]): TaskMetrics = { val metrics = new TaskMetrics - metrics.setAccumulatorsUpdater(() => accumulatorUpdates) + metrics.setAccumulatorsUpdater(() => accumulatorUpdates.mapValues(new LongSQLMetricValue(_))) metrics.updateAccumulators() metrics } From 76eaa701833a2ff23b50147d70ced41e85719572 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 7 Aug 2015 11:02:53 -0700 Subject: [PATCH 217/340] [SPARK-9674][SPARK-9667] Remove SparkSqlSerializer2 It is now subsumed by various Tungsten operators. Author: Reynold Xin Closes #7981 from rxin/SPARK-9674 and squashes the following commits: 144f96e [Reynold Xin] Re-enable test 58b7332 [Reynold Xin] Disable failing list. fb797e3 [Reynold Xin] Match all UDTs. be9f243 [Reynold Xin] Updated if. 71fc99c [Reynold Xin] [SPARK-9674][SPARK-9667] Remove GeneratedAggregate & SparkSqlSerializer2. --- .../scala/org/apache/spark/sql/SQLConf.scala | 6 - .../apache/spark/sql/execution/Exchange.scala | 48 +- .../sql/execution/SparkSqlSerializer2.scala | 426 ------------------ .../execution/SparkSqlSerializer2Suite.scala | 221 --------- 4 files changed, 24 insertions(+), 677 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index ef35c133d9cc3..45d3d8c863512 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -416,10 +416,6 @@ private[spark] object SQLConf { val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2", defaultValue = Some(true), doc = "") - val USE_SQL_SERIALIZER2 = booleanConf( - "spark.sql.useSerializer2", - defaultValue = Some(true), isPublic = false) - object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -488,8 +484,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2) - private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2) - private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) private[spark] def defaultSizeInBytes: Long = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 6ea5eeedf1bbe..60087f2ca4a3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.UserDefinedType import org.apache.spark.util.MutablePair import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} @@ -39,21 +40,34 @@ import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEn @DeveloperApi case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { - override def outputPartitioning: Partitioning = newPartitioning - - override def output: Seq[Attribute] = child.output - - override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + override def nodeName: String = if (tungstenMode) "TungstenExchange" else "Exchange" - override def canProcessSafeRows: Boolean = true - - override def canProcessUnsafeRows: Boolean = { + /** + * Returns true iff the children outputs aggregate UDTs that are not part of the SQL type. + * This only happens with the old aggregate implementation and should be removed in 1.6. + */ + private lazy val tungstenMode: Boolean = { + val unserializableUDT = child.schema.exists(_.dataType match { + case _: UserDefinedType[_] => true + case _ => false + }) // Do not use the Unsafe path if we are using a RangePartitioning, since this may lead to // an interpreted RowOrdering being applied to an UnsafeRow, which will lead to // ClassCastExceptions at runtime. This check can be removed after SPARK-9054 is fixed. - !newPartitioning.isInstanceOf[RangePartitioning] + !unserializableUDT && !newPartitioning.isInstanceOf[RangePartitioning] } + override def outputPartitioning: Partitioning = newPartitioning + + override def output: Seq[Attribute] = child.output + + // This setting is somewhat counterintuitive: + // If the schema works with UnsafeRow, then we tell the planner that we don't support safe row, + // so the planner inserts a converter to convert data into UnsafeRow if needed. + override def outputsUnsafeRows: Boolean = tungstenMode + override def canProcessSafeRows: Boolean = !tungstenMode + override def canProcessUnsafeRows: Boolean = tungstenMode + /** * Determines whether records must be defensively copied before being sent to the shuffle. * Several of Spark's shuffle components will buffer deserialized Java objects in memory. The @@ -124,23 +138,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una private val serializer: Serializer = { val rowDataTypes = child.output.map(_.dataType).toArray - // It is true when there is no field that needs to be write out. - // For now, we will not use SparkSqlSerializer2 when noField is true. - val noField = rowDataTypes == null || rowDataTypes.length == 0 - - val useSqlSerializer2 = - child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled. - SparkSqlSerializer2.support(rowDataTypes) && // The schema of row is supported. - !noField - - if (child.outputsUnsafeRows) { - logInfo("Using UnsafeRowSerializer.") + if (tungstenMode) { new UnsafeRowSerializer(child.output.size) - } else if (useSqlSerializer2) { - logInfo("Using SparkSqlSerializer2.") - new SparkSqlSerializer2(rowDataTypes) } else { - logInfo("Using SparkSqlSerializer.") new SparkSqlSerializer(sparkConf) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala deleted file mode 100644 index e811f1de3e6dd..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ /dev/null @@ -1,426 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.io._ -import java.math.{BigDecimal, BigInteger} -import java.nio.ByteBuffer - -import scala.reflect.ClassTag - -import org.apache.spark.Logging -import org.apache.spark.serializer._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -/** - * The serialization stream for [[SparkSqlSerializer2]]. It assumes that the object passed in - * its `writeObject` are [[Product2]]. The serialization functions for the key and value of the - * [[Product2]] are constructed based on their schemata. - * The benefit of this serialization stream is that compared with general-purpose serializers like - * Kryo and Java serializer, it can significantly reduce the size of serialized and has a lower - * allocation cost, which can benefit the shuffle operation. Right now, its main limitations are: - * 1. It does not support complex types, i.e. Map, Array, and Struct. - * 2. It assumes that the objects passed in are [[Product2]]. So, it cannot be used when - * [[org.apache.spark.util.collection.ExternalSorter]]'s merge sort operation is used because - * the objects passed in the serializer are not in the type of [[Product2]]. Also also see - * the comment of the `serializer` method in [[Exchange]] for more information on it. - */ -private[sql] class Serializer2SerializationStream( - rowSchema: Array[DataType], - out: OutputStream) - extends SerializationStream with Logging { - - private val rowOut = new DataOutputStream(new BufferedOutputStream(out)) - private val writeRowFunc = SparkSqlSerializer2.createSerializationFunction(rowSchema, rowOut) - - override def writeObject[T: ClassTag](t: T): SerializationStream = { - val kv = t.asInstanceOf[Product2[InternalRow, InternalRow]] - writeKey(kv._1) - writeValue(kv._2) - - this - } - - override def writeKey[T: ClassTag](t: T): SerializationStream = { - // No-op. - this - } - - override def writeValue[T: ClassTag](t: T): SerializationStream = { - writeRowFunc(t.asInstanceOf[InternalRow]) - this - } - - def flush(): Unit = { - rowOut.flush() - } - - def close(): Unit = { - rowOut.close() - } -} - -/** - * The corresponding deserialization stream for [[Serializer2SerializationStream]]. - */ -private[sql] class Serializer2DeserializationStream( - rowSchema: Array[DataType], - in: InputStream) - extends DeserializationStream with Logging { - - private val rowIn = new DataInputStream(new BufferedInputStream(in)) - - private def rowGenerator(schema: Array[DataType]): () => (MutableRow) = { - if (schema == null) { - () => null - } else { - // It is safe to reuse the mutable row. - val mutableRow = new SpecificMutableRow(schema) - () => mutableRow - } - } - - // Functions used to return rows for key and value. - private val getRow = rowGenerator(rowSchema) - // Functions used to read a serialized row from the InputStream and deserialize it. - private val readRowFunc = SparkSqlSerializer2.createDeserializationFunction(rowSchema, rowIn) - - override def readObject[T: ClassTag](): T = { - readValue() - } - - override def readKey[T: ClassTag](): T = { - null.asInstanceOf[T] // intentionally left blank. - } - - override def readValue[T: ClassTag](): T = { - readRowFunc(getRow()).asInstanceOf[T] - } - - override def close(): Unit = { - rowIn.close() - } -} - -private[sql] class SparkSqlSerializer2Instance( - rowSchema: Array[DataType]) - extends SerializerInstance { - - def serialize[T: ClassTag](t: T): ByteBuffer = - throw new UnsupportedOperationException("Not supported.") - - def deserialize[T: ClassTag](bytes: ByteBuffer): T = - throw new UnsupportedOperationException("Not supported.") - - def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = - throw new UnsupportedOperationException("Not supported.") - - def serializeStream(s: OutputStream): SerializationStream = { - new Serializer2SerializationStream(rowSchema, s) - } - - def deserializeStream(s: InputStream): DeserializationStream = { - new Serializer2DeserializationStream(rowSchema, s) - } -} - -/** - * SparkSqlSerializer2 is a special serializer that creates serialization function and - * deserialization function based on the schema of data. It assumes that values passed in - * are Rows. - */ -private[sql] class SparkSqlSerializer2(rowSchema: Array[DataType]) - extends Serializer - with Logging - with Serializable{ - - def newInstance(): SerializerInstance = new SparkSqlSerializer2Instance(rowSchema) - - override def supportsRelocationOfSerializedObjects: Boolean = { - // SparkSqlSerializer2 is stateless and writes no stream headers - true - } -} - -private[sql] object SparkSqlSerializer2 { - - final val NULL = 0 - final val NOT_NULL = 1 - - /** - * Check if rows with the given schema can be serialized with ShuffleSerializer. - * Right now, we do not support a schema having complex types or UDTs, or all data types - * of fields are NullTypes. - */ - def support(schema: Array[DataType]): Boolean = { - if (schema == null) return true - - var allNullTypes = true - var i = 0 - while (i < schema.length) { - schema(i) match { - case NullType => // Do nothing - case udt: UserDefinedType[_] => - allNullTypes = false - return false - case array: ArrayType => - allNullTypes = false - return false - case map: MapType => - allNullTypes = false - return false - case struct: StructType => - allNullTypes = false - return false - case _ => - allNullTypes = false - } - i += 1 - } - - // If types of fields are all NullTypes, we return false. - // Otherwise, we return true. - return !allNullTypes - } - - /** - * The util function to create the serialization function based on the given schema. - */ - def createSerializationFunction(schema: Array[DataType], out: DataOutputStream) - : InternalRow => Unit = { - (row: InternalRow) => - // If the schema is null, the returned function does nothing when it get called. - if (schema != null) { - var i = 0 - while (i < schema.length) { - schema(i) match { - // When we write values to the underlying stream, we also first write the null byte - // first. Then, if the value is not null, we write the contents out. - - case NullType => // Write nothing. - - case BooleanType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeBoolean(row.getBoolean(i)) - } - - case ByteType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeByte(row.getByte(i)) - } - - case ShortType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeShort(row.getShort(i)) - } - - case IntegerType | DateType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeInt(row.getInt(i)) - } - - case LongType | TimestampType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeLong(row.getLong(i)) - } - - case FloatType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeFloat(row.getFloat(i)) - } - - case DoubleType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeDouble(row.getDouble(i)) - } - - case StringType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - val bytes = row.getUTF8String(i).getBytes - out.writeInt(bytes.length) - out.write(bytes) - } - - case BinaryType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - val bytes = row.getBinary(i) - out.writeInt(bytes.length) - out.write(bytes) - } - - case decimal: DecimalType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - val value = row.getDecimal(i, decimal.precision, decimal.scale) - val javaBigDecimal = value.toJavaBigDecimal - // First, write out the unscaled value. - val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray - out.writeInt(bytes.length) - out.write(bytes) - // Then, write out the scale. - out.writeInt(javaBigDecimal.scale()) - } - } - i += 1 - } - } - } - - /** - * The util function to create the deserialization function based on the given schema. - */ - def createDeserializationFunction( - schema: Array[DataType], - in: DataInputStream): (MutableRow) => InternalRow = { - if (schema == null) { - (mutableRow: MutableRow) => null - } else { - (mutableRow: MutableRow) => { - var i = 0 - while (i < schema.length) { - schema(i) match { - // When we read values from the underlying stream, we also first read the null byte - // first. Then, if the value is not null, we update the field of the mutable row. - - case NullType => mutableRow.setNullAt(i) // Read nothing. - - case BooleanType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setBoolean(i, in.readBoolean()) - } - - case ByteType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setByte(i, in.readByte()) - } - - case ShortType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setShort(i, in.readShort()) - } - - case IntegerType | DateType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setInt(i, in.readInt()) - } - - case LongType | TimestampType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setLong(i, in.readLong()) - } - - case FloatType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setFloat(i, in.readFloat()) - } - - case DoubleType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setDouble(i, in.readDouble()) - } - - case StringType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - val length = in.readInt() - val bytes = new Array[Byte](length) - in.readFully(bytes) - mutableRow.update(i, UTF8String.fromBytes(bytes)) - } - - case BinaryType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - val length = in.readInt() - val bytes = new Array[Byte](length) - in.readFully(bytes) - mutableRow.update(i, bytes) - } - - case decimal: DecimalType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - // First, read in the unscaled value. - val length = in.readInt() - val bytes = new Array[Byte](length) - in.readFully(bytes) - val unscaledVal = new BigInteger(bytes) - // Then, read the scale. - val scale = in.readInt() - // Finally, create the Decimal object and set it in the row. - mutableRow.update(i, - Decimal(new BigDecimal(unscaledVal, scale), decimal.precision, decimal.scale)) - } - } - i += 1 - } - - mutableRow - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala deleted file mode 100644 index 7978ed57a937e..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ /dev/null @@ -1,221 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.sql.{Timestamp, Date} - -import org.apache.spark.sql.test.TestSQLContext -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.rdd.ShuffledRDD -import org.apache.spark.serializer.Serializer -import org.apache.spark.{ShuffleDependency, SparkFunSuite} -import org.apache.spark.sql.types._ -import org.apache.spark.sql.Row -import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest} - -class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite { - // Make sure that we will not use serializer2 for unsupported data types. - def checkSupported(dataType: DataType, isSupported: Boolean): Unit = { - val testName = - s"${if (dataType == null) null else dataType.toString} is " + - s"${if (isSupported) "supported" else "unsupported"}" - - test(testName) { - assert(SparkSqlSerializer2.support(Array(dataType)) === isSupported) - } - } - - checkSupported(null, isSupported = true) - checkSupported(BooleanType, isSupported = true) - checkSupported(ByteType, isSupported = true) - checkSupported(ShortType, isSupported = true) - checkSupported(IntegerType, isSupported = true) - checkSupported(LongType, isSupported = true) - checkSupported(FloatType, isSupported = true) - checkSupported(DoubleType, isSupported = true) - checkSupported(DateType, isSupported = true) - checkSupported(TimestampType, isSupported = true) - checkSupported(StringType, isSupported = true) - checkSupported(BinaryType, isSupported = true) - checkSupported(DecimalType(10, 5), isSupported = true) - checkSupported(DecimalType.SYSTEM_DEFAULT, isSupported = true) - - // If NullType is the only data type in the schema, we do not support it. - checkSupported(NullType, isSupported = false) - // For now, ArrayType, MapType, and StructType are not supported. - checkSupported(ArrayType(DoubleType, true), isSupported = false) - checkSupported(ArrayType(StringType, false), isSupported = false) - checkSupported(MapType(IntegerType, StringType, true), isSupported = false) - checkSupported(MapType(IntegerType, ArrayType(DoubleType), false), isSupported = false) - checkSupported(StructType(StructField("a", IntegerType, true) :: Nil), isSupported = false) - // UDTs are not supported right now. - checkSupported(new MyDenseVectorUDT, isSupported = false) -} - -abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll { - var allColumns: String = _ - val serializerClass: Class[Serializer] = - classOf[SparkSqlSerializer2].asInstanceOf[Class[Serializer]] - var numShufflePartitions: Int = _ - var useSerializer2: Boolean = _ - - protected lazy val ctx = TestSQLContext - - override def beforeAll(): Unit = { - numShufflePartitions = ctx.conf.numShufflePartitions - useSerializer2 = ctx.conf.useSqlSerializer2 - - ctx.sql("set spark.sql.useSerializer2=true") - - val supportedTypes = - Seq(StringType, BinaryType, NullType, BooleanType, - ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType(6, 5), - DateType, TimestampType) - - val fields = supportedTypes.zipWithIndex.map { case (dataType, index) => - StructField(s"col$index", dataType, true) - } - allColumns = fields.map(_.name).mkString(",") - val schema = StructType(fields) - - // Create a RDD with all data types supported by SparkSqlSerializer2. - val rdd = - ctx.sparkContext.parallelize((1 to 1000), 10).map { i => - Row( - s"str${i}: test serializer2.", - s"binary${i}: test serializer2.".getBytes("UTF-8"), - null, - i % 2 == 0, - i.toByte, - i.toShort, - i, - Long.MaxValue - i.toLong, - (i + 0.25).toFloat, - (i + 0.75), - BigDecimal(Long.MaxValue.toString + ".12345"), - new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"), - new Date(i), - new Timestamp(i)) - } - - ctx.createDataFrame(rdd, schema).registerTempTable("shuffle") - - super.beforeAll() - } - - override def afterAll(): Unit = { - ctx.dropTempTable("shuffle") - ctx.sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions") - ctx.sql(s"set spark.sql.useSerializer2=$useSerializer2") - super.afterAll() - } - - def checkSerializer[T <: Serializer]( - executedPlan: SparkPlan, - expectedSerializerClass: Class[T]): Unit = { - executedPlan.foreach { - case exchange: Exchange => - val shuffledRDD = exchange.execute() - val dependency = shuffledRDD.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] - val serializerNotSetMessage = - s"Expected $expectedSerializerClass as the serializer of Exchange. " + - s"However, the serializer was not set." - val serializer = dependency.serializer.getOrElse(fail(serializerNotSetMessage)) - val isExpectedSerializer = - serializer.getClass == expectedSerializerClass || - serializer.getClass == classOf[UnsafeRowSerializer] - val wrongSerializerErrorMessage = - s"Expected ${expectedSerializerClass.getCanonicalName} or " + - s"${classOf[UnsafeRowSerializer].getCanonicalName}. But " + - s"${serializer.getClass.getCanonicalName} is used." - assert(isExpectedSerializer, wrongSerializerErrorMessage) - case _ => // Ignore other nodes. - } - } - - test("key schema and value schema are not nulls") { - val df = ctx.sql(s"SELECT DISTINCT ${allColumns} FROM shuffle") - checkSerializer(df.queryExecution.executedPlan, serializerClass) - checkAnswer( - df, - ctx.table("shuffle").collect()) - } - - test("key schema is null") { - val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",") - val df = ctx.sql(s"SELECT $aggregations FROM shuffle") - checkSerializer(df.queryExecution.executedPlan, serializerClass) - checkAnswer( - df, - Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000)) - } - - test("value schema is null") { - val df = ctx.sql(s"SELECT col0 FROM shuffle ORDER BY col0") - checkSerializer(df.queryExecution.executedPlan, serializerClass) - assert(df.map(r => r.getString(0)).collect().toSeq === - ctx.table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq) - } - - test("no map output field") { - val df = ctx.sql(s"SELECT 1 + 1 FROM shuffle") - checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer]) - } - - test("types of fields are all NullTypes") { - // Test range partitioning code path. - val nulls = ctx.sql(s"SELECT null as a, null as b, null as c") - val df = nulls.unionAll(nulls).sort("a") - checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer]) - checkAnswer( - df, - Row(null, null, null) :: Row(null, null, null) :: Nil) - - // Test hash partitioning code path. - val oneRow = ctx.sql(s"SELECT DISTINCT null, null, null FROM shuffle") - checkSerializer(oneRow.queryExecution.executedPlan, classOf[SparkSqlSerializer]) - checkAnswer( - oneRow, - Row(null, null, null)) - } -} - -/** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */ -class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite { - override def beforeAll(): Unit = { - super.beforeAll() - // Sort merge will not be triggered. - val bypassMergeThreshold = - ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}") - } -} - -/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */ -class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite { - - override def beforeAll(): Unit = { - super.beforeAll() - // To trigger the sort merge. - val bypassMergeThreshold = - ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}") - } -} From 2432c2e239f66049a7a7d7e0591204abcc993f1a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 7 Aug 2015 11:28:43 -0700 Subject: [PATCH 218/340] [SPARK-8382] [SQL] Improve Analysis Unit test framework Author: Wenchen Fan Closes #8025 from cloud-fan/analysis and squashes the following commits: 51461b1 [Wenchen Fan] move test file to test folder ec88ace [Wenchen Fan] Improve Analysis Unit test framework --- .../analysis/AnalysisErrorSuite.scala | 48 +++++----------- .../sql/catalyst/analysis/AnalysisSuite.scala | 55 +------------------ .../sql/catalyst/analysis/AnalysisTest.scala | 33 +---------- .../sql/catalyst/analysis/TestRelations.scala | 51 +++++++++++++++++ .../BooleanSimplificationSuite.scala | 19 ++++--- 5 files changed, 79 insertions(+), 127 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 26935c6e3b24f..63b475b6366c2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -42,8 +42,8 @@ case class UnresolvedTestPlan() extends LeafNode { override def output: Seq[Attribute] = Nil } -class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { - import AnalysisSuite._ +class AnalysisErrorSuite extends AnalysisTest with BeforeAndAfter { + import TestRelations._ def errorTest( name: String, @@ -51,15 +51,7 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { errorMessages: Seq[String], caseSensitive: Boolean = true): Unit = { test(name) { - val error = intercept[AnalysisException] { - if (caseSensitive) { - caseSensitiveAnalyze(plan) - } else { - caseInsensitiveAnalyze(plan) - } - } - - errorMessages.foreach(m => assert(error.getMessage.toLowerCase.contains(m.toLowerCase))) + assertAnalysisError(plan, errorMessages, caseSensitive) } } @@ -69,21 +61,21 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { "single invalid type, single arg", testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)), "cannot resolve" :: "testfunction" :: "argument 1" :: "requires int type" :: - "'null' is of date type" ::Nil) + "'null' is of date type" :: Nil) errorTest( "single invalid type, second arg", testRelation.select( TestFunction(dateLit :: dateLit :: Nil, DateType :: IntegerType :: Nil).as('a)), "cannot resolve" :: "testfunction" :: "argument 2" :: "requires int type" :: - "'null' is of date type" ::Nil) + "'null' is of date type" :: Nil) errorTest( "multiple invalid type", testRelation.select( TestFunction(dateLit :: dateLit :: Nil, IntegerType :: IntegerType :: Nil).as('a)), "cannot resolve" :: "testfunction" :: "argument 1" :: "argument 2" :: - "requires int type" :: "'null' is of date type" ::Nil) + "requires int type" :: "'null' is of date type" :: Nil) errorTest( "unresolved window function", @@ -169,11 +161,7 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { assert(plan.resolved) - val message = intercept[AnalysisException] { - caseSensitiveAnalyze(plan) - }.getMessage - - assert(message.contains("resolved attribute(s) a#1 missing from a#2")) + assertAnalysisError(plan, "resolved attribute(s) a#1 missing from a#2" :: Nil) } test("error test for self-join") { @@ -194,10 +182,8 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { AttributeReference("a", BinaryType)(exprId = ExprId(2)), AttributeReference("b", IntegerType)(exprId = ExprId(1)))) - val error = intercept[AnalysisException] { - caseSensitiveAnalyze(plan) - } - assert(error.message.contains("binary type expression a cannot be used in grouping expression")) + assertAnalysisError(plan, + "binary type expression a cannot be used in grouping expression" :: Nil) val plan2 = Aggregate( @@ -207,10 +193,8 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)), AttributeReference("b", IntegerType)(exprId = ExprId(1)))) - val error2 = intercept[AnalysisException] { - caseSensitiveAnalyze(plan2) - } - assert(error2.message.contains("map type expression a cannot be used in grouping expression")) + assertAnalysisError(plan2, + "map type expression a cannot be used in grouping expression" :: Nil) } test("Join can't work on binary and map types") { @@ -226,10 +210,7 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { Some(EqualTo(AttributeReference("a", BinaryType)(exprId = ExprId(2)), AttributeReference("c", BinaryType)(exprId = ExprId(4))))) - val error = intercept[AnalysisException] { - caseSensitiveAnalyze(plan) - } - assert(error.message.contains("binary type expression a cannot be used in join conditions")) + assertAnalysisError(plan, "binary type expression a cannot be used in join conditions" :: Nil) val plan2 = Join( @@ -243,9 +224,6 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { Some(EqualTo(AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)), AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4))))) - val error2 = intercept[AnalysisException] { - caseSensitiveAnalyze(plan2) - } - assert(error2.message.contains("map type expression a cannot be used in join conditions")) + assertAnalysisError(plan2, "map type expression a cannot be used in join conditions" :: Nil) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 221b4e92f086c..c944bc69e25b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -24,61 +24,8 @@ import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -// todo: remove this and use AnalysisTest instead. -object AnalysisSuite { - val caseSensitiveConf = new SimpleCatalystConf(true) - val caseInsensitiveConf = new SimpleCatalystConf(false) - - val caseSensitiveCatalog = new SimpleCatalog(caseSensitiveConf) - val caseInsensitiveCatalog = new SimpleCatalog(caseInsensitiveConf) - - val caseSensitiveAnalyzer = - new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitiveConf) { - override val extendedResolutionRules = EliminateSubQueries :: Nil - } - val caseInsensitiveAnalyzer = - new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseInsensitiveConf) { - override val extendedResolutionRules = EliminateSubQueries :: Nil - } - - def caseSensitiveAnalyze(plan: LogicalPlan): Unit = - caseSensitiveAnalyzer.checkAnalysis(caseSensitiveAnalyzer.execute(plan)) - - def caseInsensitiveAnalyze(plan: LogicalPlan): Unit = - caseInsensitiveAnalyzer.checkAnalysis(caseInsensitiveAnalyzer.execute(plan)) - - val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) - val testRelation2 = LocalRelation( - AttributeReference("a", StringType)(), - AttributeReference("b", StringType)(), - AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType(10, 2))(), - AttributeReference("e", ShortType)()) - - val nestedRelation = LocalRelation( - AttributeReference("top", StructType( - StructField("duplicateField", StringType) :: - StructField("duplicateField", StringType) :: - StructField("differentCase", StringType) :: - StructField("differentcase", StringType) :: Nil - ))()) - - val nestedRelation2 = LocalRelation( - AttributeReference("top", StructType( - StructField("aField", StringType) :: - StructField("bField", StringType) :: - StructField("cField", StringType) :: Nil - ))()) - - val listRelation = LocalRelation( - AttributeReference("list", ArrayType(IntegerType))()) - - caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) - caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) -} - - class AnalysisSuite extends AnalysisTest { + import TestRelations._ test("union project *") { val plan = (1 to 100) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index fdb4f28950daf..ee1f8f54251e0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -17,40 +17,11 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.SimpleCatalystConf -import org.apache.spark.sql.types._ trait AnalysisTest extends PlanTest { - val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) - - val testRelation2 = LocalRelation( - AttributeReference("a", StringType)(), - AttributeReference("b", StringType)(), - AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType(10, 2))(), - AttributeReference("e", ShortType)()) - - val nestedRelation = LocalRelation( - AttributeReference("top", StructType( - StructField("duplicateField", StringType) :: - StructField("duplicateField", StringType) :: - StructField("differentCase", StringType) :: - StructField("differentcase", StringType) :: Nil - ))()) - - val nestedRelation2 = LocalRelation( - AttributeReference("top", StructType( - StructField("aField", StringType) :: - StructField("bField", StringType) :: - StructField("cField", StringType) :: Nil - ))()) - - val listRelation = LocalRelation( - AttributeReference("list", ArrayType(IntegerType))()) val (caseSensitiveAnalyzer, caseInsensitiveAnalyzer) = { val caseSensitiveConf = new SimpleCatalystConf(true) @@ -59,8 +30,8 @@ trait AnalysisTest extends PlanTest { val caseSensitiveCatalog = new SimpleCatalog(caseSensitiveConf) val caseInsensitiveCatalog = new SimpleCatalog(caseInsensitiveConf) - caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) - caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) + caseSensitiveCatalog.registerTable(Seq("TaBlE"), TestRelations.testRelation) + caseInsensitiveCatalog.registerTable(Seq("TaBlE"), TestRelations.testRelation) new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitiveConf) { override val extendedResolutionRules = EliminateSubQueries :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala new file mode 100644 index 0000000000000..05b870705e7ea --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.types._ + +object TestRelations { + val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) + + val testRelation2 = LocalRelation( + AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", DoubleType)(), + AttributeReference("d", DecimalType(10, 2))(), + AttributeReference("e", ShortType)()) + + val nestedRelation = LocalRelation( + AttributeReference("top", StructType( + StructField("duplicateField", StringType) :: + StructField("duplicateField", StringType) :: + StructField("differentCase", StringType) :: + StructField("differentcase", StringType) :: Nil + ))()) + + val nestedRelation2 = LocalRelation( + AttributeReference("top", StructType( + StructField("aField", StringType) :: + StructField("bField", StringType) :: + StructField("cField", StringType) :: Nil + ))()) + + val listRelation = LocalRelation( + AttributeReference("list", ArrayType(IntegerType))()) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index d4916ea8d273a..1877cff1334bd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.{AnalysisSuite, EliminateSubQueries} +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest @@ -88,20 +89,24 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { ('a === 'b || 'b > 3 && 'a > 3 && 'a < 5)) } - private def caseInsensitiveAnalyse(plan: LogicalPlan) = - AnalysisSuite.caseInsensitiveAnalyzer.execute(plan) + private val caseInsensitiveAnalyzer = + new Analyzer(EmptyCatalog, EmptyFunctionRegistry, new SimpleCatalystConf(false)) test("(a && b) || (a && c) => a && (b || c) when case insensitive") { - val plan = caseInsensitiveAnalyse(testRelation.where(('a > 2 && 'b > 3) || ('A > 2 && 'b < 5))) + val plan = caseInsensitiveAnalyzer.execute( + testRelation.where(('a > 2 && 'b > 3) || ('A > 2 && 'b < 5))) val actual = Optimize.execute(plan) - val expected = caseInsensitiveAnalyse(testRelation.where('a > 2 && ('b > 3 || 'b < 5))) + val expected = caseInsensitiveAnalyzer.execute( + testRelation.where('a > 2 && ('b > 3 || 'b < 5))) comparePlans(actual, expected) } test("(a || b) && (a || c) => a || (b && c) when case insensitive") { - val plan = caseInsensitiveAnalyse(testRelation.where(('a > 2 || 'b > 3) && ('A > 2 || 'b < 5))) + val plan = caseInsensitiveAnalyzer.execute( + testRelation.where(('a > 2 || 'b > 3) && ('A > 2 || 'b < 5))) val actual = Optimize.execute(plan) - val expected = caseInsensitiveAnalyse(testRelation.where('a > 2 || ('b > 3 && 'b < 5))) + val expected = caseInsensitiveAnalyzer.execute( + testRelation.where('a > 2 || ('b > 3 && 'b < 5))) comparePlans(actual, expected) } } From 9897cc5e3d6c70f7e45e887e2c6fc24dfa1adada Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 7 Aug 2015 11:29:13 -0700 Subject: [PATCH 219/340] [SPARK-9736] [SQL] JoinedRow.anyNull should delegate to the underlying rows. JoinedRow.anyNull currently loops through every field to check for null, which is inefficient if the underlying rows are UnsafeRows. It should just delegate to the underlying implementation. Author: Reynold Xin Closes #8027 from rxin/SPARK-9736 and squashes the following commits: 03a2e92 [Reynold Xin] Include all files. 90f1add [Reynold Xin] [SPARK-9736][SQL] JoinedRow.anyNull should delegate to the underlying rows. --- .../spark/sql/catalyst/InternalRow.scala | 10 +- .../sql/catalyst/expressions/JoinedRow.scala | 144 ++++++++++++++++++ .../sql/catalyst/expressions/Projection.scala | 119 --------------- .../spark/sql/catalyst/expressions/rows.scala | 12 +- 4 files changed, 156 insertions(+), 129 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 85b4bf3b6aef5..eba95c5c8b908 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -37,15 +37,7 @@ abstract class InternalRow extends SpecializedGetters with Serializable { def copy(): InternalRow /** Returns true if there are any NULL values in this row. */ - def anyNull: Boolean = { - val len = numFields - var i = 0 - while (i < len) { - if (isNullAt(i)) { return true } - i += 1 - } - false - } + def anyNull: Boolean /* ---------------------- utility methods for Scala ---------------------- */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala new file mode 100644 index 0000000000000..b76757c93523d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + + +/** + * A mutable wrapper that makes two rows appear as a single concatenated row. Designed to + * be instantiated once per thread and reused. + */ +class JoinedRow extends InternalRow { + private[this] var row1: InternalRow = _ + private[this] var row2: InternalRow = _ + + def this(left: InternalRow, right: InternalRow) = { + this() + row1 = left + row2 = right + } + + /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ + def apply(r1: InternalRow, r2: InternalRow): InternalRow = { + row1 = r1 + row2 = r2 + this + } + + /** Updates this JoinedRow by updating its left base row. Returns itself. */ + def withLeft(newLeft: InternalRow): InternalRow = { + row1 = newLeft + this + } + + /** Updates this JoinedRow by updating its right base row. Returns itself. */ + def withRight(newRight: InternalRow): InternalRow = { + row2 = newRight + this + } + + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = { + assert(fieldTypes.length == row1.numFields + row2.numFields) + val (left, right) = fieldTypes.splitAt(row1.numFields) + row1.toSeq(left) ++ row2.toSeq(right) + } + + override def numFields: Int = row1.numFields + row2.numFields + + override def get(i: Int, dt: DataType): AnyRef = + if (i < row1.numFields) row1.get(i, dt) else row2.get(i - row1.numFields, dt) + + override def isNullAt(i: Int): Boolean = + if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) + + override def getBoolean(i: Int): Boolean = + if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) + + override def getByte(i: Int): Byte = + if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) + + override def getShort(i: Int): Short = + if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) + + override def getInt(i: Int): Int = + if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) + + override def getLong(i: Int): Long = + if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) + + override def getFloat(i: Int): Float = + if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) + + override def getDouble(i: Int): Double = + if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) + + override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = { + if (i < row1.numFields) { + row1.getDecimal(i, precision, scale) + } else { + row2.getDecimal(i - row1.numFields, precision, scale) + } + } + + override def getUTF8String(i: Int): UTF8String = + if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) + + override def getBinary(i: Int): Array[Byte] = + if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) + + override def getArray(i: Int): ArrayData = + if (i < row1.numFields) row1.getArray(i) else row2.getArray(i - row1.numFields) + + override def getInterval(i: Int): CalendarInterval = + if (i < row1.numFields) row1.getInterval(i) else row2.getInterval(i - row1.numFields) + + override def getMap(i: Int): MapData = + if (i < row1.numFields) row1.getMap(i) else row2.getMap(i - row1.numFields) + + override def getStruct(i: Int, numFields: Int): InternalRow = { + if (i < row1.numFields) { + row1.getStruct(i, numFields) + } else { + row2.getStruct(i - row1.numFields, numFields) + } + } + + override def anyNull: Boolean = row1.anyNull || row2.anyNull + + override def copy(): InternalRow = { + val copy1 = row1.copy() + val copy2 = row2.copy() + new JoinedRow(copy1, copy2) + } + + override def toString: String = { + // Make sure toString never throws NullPointerException. + if ((row1 eq null) && (row2 eq null)) { + "[ empty row ]" + } else if (row1 eq null) { + row2.toString + } else if (row2 eq null) { + row1.toString + } else { + s"{${row1.toString} + ${row2.toString}}" + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 59ce7fc4f2c63..796bc327a3db1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -169,122 +169,3 @@ object FromUnsafeProjection { GenerateSafeProjection.generate(exprs) } } - -/** - * A mutable wrapper that makes two rows appear as a single concatenated row. Designed to - * be instantiated once per thread and reused. - */ -class JoinedRow extends InternalRow { - private[this] var row1: InternalRow = _ - private[this] var row2: InternalRow = _ - - def this(left: InternalRow, right: InternalRow) = { - this() - row1 = left - row2 = right - } - - /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: InternalRow, r2: InternalRow): InternalRow = { - row1 = r1 - row2 = r2 - this - } - - /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: InternalRow): InternalRow = { - row1 = newLeft - this - } - - /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: InternalRow): InternalRow = { - row2 = newRight - this - } - - override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = { - assert(fieldTypes.length == row1.numFields + row2.numFields) - val (left, right) = fieldTypes.splitAt(row1.numFields) - row1.toSeq(left) ++ row2.toSeq(right) - } - - override def numFields: Int = row1.numFields + row2.numFields - - override def get(i: Int, dt: DataType): AnyRef = - if (i < row1.numFields) row1.get(i, dt) else row2.get(i - row1.numFields, dt) - - override def isNullAt(i: Int): Boolean = - if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) - - override def getBoolean(i: Int): Boolean = - if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) - - override def getByte(i: Int): Byte = - if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) - - override def getShort(i: Int): Short = - if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) - - override def getInt(i: Int): Int = - if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) - - override def getLong(i: Int): Long = - if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) - - override def getFloat(i: Int): Float = - if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) - - override def getDouble(i: Int): Double = - if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) - - override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = { - if (i < row1.numFields) { - row1.getDecimal(i, precision, scale) - } else { - row2.getDecimal(i - row1.numFields, precision, scale) - } - } - - override def getUTF8String(i: Int): UTF8String = - if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) - - override def getBinary(i: Int): Array[Byte] = - if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) - - override def getArray(i: Int): ArrayData = - if (i < row1.numFields) row1.getArray(i) else row2.getArray(i - row1.numFields) - - override def getInterval(i: Int): CalendarInterval = - if (i < row1.numFields) row1.getInterval(i) else row2.getInterval(i - row1.numFields) - - override def getMap(i: Int): MapData = - if (i < row1.numFields) row1.getMap(i) else row2.getMap(i - row1.numFields) - - override def getStruct(i: Int, numFields: Int): InternalRow = { - if (i < row1.numFields) { - row1.getStruct(i, numFields) - } else { - row2.getStruct(i - row1.numFields, numFields) - } - } - - override def copy(): InternalRow = { - val copy1 = row1.copy() - val copy2 = row2.copy() - new JoinedRow(copy1, copy2) - } - - override def toString: String = { - // Make sure toString never throws NullPointerException. - if ((row1 eq null) && (row2 eq null)) { - "[ empty row ]" - } else if (row1 eq null) { - row2.toString - } else if (row2 eq null) { - row1.toString - } else { - s"{${row1.toString} + ${row2.toString}}" - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 11d10b2d8a48b..017efd2a166a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -49,7 +49,17 @@ trait BaseGenericInternalRow extends InternalRow { override def getMap(ordinal: Int): MapData = getAs(ordinal) override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) - override def toString(): String = { + override def anyNull: Boolean = { + val len = numFields + var i = 0 + while (i < len) { + if (isNullAt(i)) { return true } + i += 1 + } + false + } + + override def toString: String = { if (numFields == 0) { "[empty row]" } else { From aeddeafc03d77a5149d2c8f9489b0ca83e6b3e03 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 7 Aug 2015 13:26:03 -0700 Subject: [PATCH 220/340] [SPARK-9667][SQL] followup: Use GenerateUnsafeProjection.canSupport to test Exchange supported data types. This way we recursively test the data types. cc chenghao-intel Author: Reynold Xin Closes #8036 from rxin/cansupport and squashes the following commits: f7302ff [Reynold Xin] Can GenerateUnsafeProjection.canSupport to test Exchange supported data types. --- .../org/apache/spark/sql/execution/Exchange.scala | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 60087f2ca4a3e..49bb729800863 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -27,9 +27,9 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.types.UserDefinedType import org.apache.spark.util.MutablePair import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} @@ -43,18 +43,11 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una override def nodeName: String = if (tungstenMode) "TungstenExchange" else "Exchange" /** - * Returns true iff the children outputs aggregate UDTs that are not part of the SQL type. - * This only happens with the old aggregate implementation and should be removed in 1.6. + * Returns true iff we can support the data type, and we are not doing range partitioning. */ private lazy val tungstenMode: Boolean = { - val unserializableUDT = child.schema.exists(_.dataType match { - case _: UserDefinedType[_] => true - case _ => false - }) - // Do not use the Unsafe path if we are using a RangePartitioning, since this may lead to - // an interpreted RowOrdering being applied to an UnsafeRow, which will lead to - // ClassCastExceptions at runtime. This check can be removed after SPARK-9054 is fixed. - !unserializableUDT && !newPartitioning.isInstanceOf[RangePartitioning] + GenerateUnsafeProjection.canSupport(child.schema) && + !newPartitioning.isInstanceOf[RangePartitioning] } override def outputPartitioning: Partitioning = newPartitioning From 05d04e10a8ea030bea840c3c5ba93ecac479a039 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 7 Aug 2015 13:41:45 -0700 Subject: [PATCH 221/340] [SPARK-9733][SQL] Improve physical plan explain for data sources All data sources show up as "PhysicalRDD" in physical plan explain. It'd be better if we can show the name of the data source. Without this patch: ``` == Physical Plan == NewAggregate with UnsafeHybridAggregationIterator ArrayBuffer(date#0, cat#1) ArrayBuffer((sum(CAST((CAST(count#2, IntegerType) + 1), LongType))2,mode=Final,isDistinct=false)) Exchange hashpartitioning(date#0,cat#1) NewAggregate with UnsafeHybridAggregationIterator ArrayBuffer(date#0, cat#1) ArrayBuffer((sum(CAST((CAST(count#2, IntegerType) + 1), LongType))2,mode=Partial,isDistinct=false)) PhysicalRDD [date#0,cat#1,count#2], MapPartitionsRDD[3] at ``` With this patch: ``` == Physical Plan == TungstenAggregate(key=[date#0,cat#1], value=[(sum(CAST((CAST(count#2, IntegerType) + 1), LongType)),mode=Final,isDistinct=false)] Exchange hashpartitioning(date#0,cat#1) TungstenAggregate(key=[date#0,cat#1], value=[(sum(CAST((CAST(count#2, IntegerType) + 1), LongType)),mode=Partial,isDistinct=false)] ConvertToUnsafe Scan ParquetRelation[file:/scratch/rxin/spark/sales4][date#0,cat#1,count#2] ``` Author: Reynold Xin Closes #8024 from rxin/SPARK-9733 and squashes the following commits: 811b90e [Reynold Xin] Fixed Python test case. 52cab77 [Reynold Xin] Cast. eea9ccc [Reynold Xin] Fix test case. fcecb22 [Reynold Xin] [SPARK-9733][SQL] Improve explain message for data source scan node. --- python/pyspark/sql/dataframe.py | 4 +--- .../spark/sql/catalyst/expressions/Cast.scala | 4 ++-- .../expressions/aggregate/interfaces.scala | 2 +- .../org/apache/spark/sql/SQLContext.scala | 4 ---- .../spark/sql/execution/ExistingRDD.scala | 15 ++++++++++++- .../spark/sql/execution/SparkStrategies.scala | 4 ++-- .../aggregate/TungstenAggregate.scala | 9 +++++--- .../datasources/DataSourceStrategy.scala | 22 +++++++++++++------ .../apache/spark/sql/sources/interfaces.scala | 2 +- .../execution/RowFormatConvertersSuite.scala | 4 ++-- .../sql/hive/execution/HiveExplainSuite.scala | 2 +- 11 files changed, 45 insertions(+), 27 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 0f3480c239187..47d5a6a43a84d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -212,8 +212,7 @@ def explain(self, extended=False): :param extended: boolean, default ``False``. If ``False``, prints only the physical plan. >>> df.explain() - PhysicalRDD [age#0,name#1], MapPartitionsRDD[...] at applySchemaToPythonRDD at\ - NativeMethodAccessorImpl.java:... + Scan PhysicalRDD[age#0,name#1] >>> df.explain(True) == Parsed Logical Plan == @@ -224,7 +223,6 @@ def explain(self, extended=False): ... == Physical Plan == ... - == RDD == """ if extended: print(self._jdf.queryExecution().toString()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 39f99700c8a26..946c5a9c04f14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -107,6 +107,8 @@ object Cast { case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with CodegenFallback { + override def toString: String = s"cast($child as ${dataType.simpleString})" + override def checkInputDataTypes(): TypeCheckResult = { if (Cast.canCast(child.dataType, dataType)) { TypeCheckResult.TypeCheckSuccess @@ -118,8 +120,6 @@ case class Cast(child: Expression, dataType: DataType) override def nullable: Boolean = Cast.forceNullable(child.dataType, dataType) || child.nullable - override def toString: String = s"CAST($child, $dataType)" - // [[func]] assumes the input is no longer null because eval already does the null check. @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 4abfdfe87d5e9..576d8c7a3a68a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -93,7 +93,7 @@ private[sql] case class AggregateExpression2( AttributeSet(childReferences) } - override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)" + override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)" } abstract class AggregateFunction2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 075c0ea2544b2..832572571cabd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1011,9 +1011,6 @@ class SQLContext(@transient val sparkContext: SparkContext) def output = analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}").mkString(", ") - // TODO previously will output RDD details by run (${stringOrError(toRdd.toDebugString)}) - // however, the `toRdd` will cause the real execution, which is not what we want. - // We need to think about how to avoid the side effect. s"""== Parsed Logical Plan == |${stringOrError(logical)} |== Analyzed Logical Plan == @@ -1024,7 +1021,6 @@ class SQLContext(@transient val sparkContext: SparkContext) |== Physical Plan == |${stringOrError(executedPlan)} |Code Generation: ${stringOrError(executedPlan.codegenEnabled)} - |== RDD == """.stripMargin.trim } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index fbaa8e276ddb7..cae7ca5cbdc88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} +import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.DataType import org.apache.spark.sql.{Row, SQLContext} @@ -95,11 +96,23 @@ private[sql] case class LogicalRDD( /** Physical plan node for scanning data from an RDD. */ private[sql] case class PhysicalRDD( output: Seq[Attribute], - rdd: RDD[InternalRow]) extends LeafNode { + rdd: RDD[InternalRow], + extraInformation: String) extends LeafNode { override protected[sql] val trackNumOfRowsEnabled = true protected override def doExecute(): RDD[InternalRow] = rdd + + override def simpleString: String = "Scan " + extraInformation + output.mkString("[", ",", "]") +} + +private[sql] object PhysicalRDD { + def createFromDataSource( + output: Seq[Attribute], + rdd: RDD[InternalRow], + relation: BaseRelation): PhysicalRDD = { + PhysicalRDD(output, rdd, relation.toString) + } } /** Logical plan node for scanning data from a local collection. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c5aaebe673225..c4b9b5acea4de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -363,12 +363,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Generate( generator, join = join, outer = outer, g.output, planLater(child)) :: Nil case logical.OneRowRelation => - execution.PhysicalRDD(Nil, singleRowRdd) :: Nil + execution.PhysicalRDD(Nil, singleRowRdd, "OneRowRelation") :: Nil case logical.RepartitionByExpression(expressions, child) => execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil - case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil + case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "PhysicalRDD") :: Nil case BroadcastHint(child) => apply(child) case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 5a0b4d47d62f8..c3dcbd2b71ee8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -93,10 +93,13 @@ case class TungstenAggregate( val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions testFallbackStartsAt match { - case None => s"TungstenAggregate ${groupingExpressions} ${allAggregateExpressions}" + case None => + val keyString = groupingExpressions.mkString("[", ",", "]") + val valueString = allAggregateExpressions.mkString("[", ",", "]") + s"TungstenAggregate(key=$keyString, value=$valueString" case Some(fallbackStartsAt) => - s"TungstenAggregateWithControlledFallback ${groupingExpressions} " + - s"${allAggregateExpressions} fallbackStartsAt=$fallbackStartsAt" + s"TungstenAggregateWithControlledFallback $groupingExpressions " + + s"$allAggregateExpressions fallbackStartsAt=$fallbackStartsAt" } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index e5dc676b87841..5b5fa8c93ec52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -101,8 +101,9 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { (a, f) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f, t.paths, confBroadcast))) :: Nil - case l @ LogicalRelation(t: TableScan) => - execution.PhysicalRDD(l.output, toCatalystRDD(l, t.buildScan())) :: Nil + case l @ LogicalRelation(baseRelation: TableScan) => + execution.PhysicalRDD.createFromDataSource( + l.output, toCatalystRDD(l, baseRelation.buildScan()), baseRelation) :: Nil case i @ logical.InsertIntoTable( l @ LogicalRelation(t: InsertableRelation), part, query, overwrite, false) if part.isEmpty => @@ -169,7 +170,10 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { new UnionRDD(relation.sqlContext.sparkContext, perPartitionRows) } - execution.PhysicalRDD(projections.map(_.toAttribute), unionedRows) + execution.PhysicalRDD.createFromDataSource( + projections.map(_.toAttribute), + unionedRows, + logicalRelation.relation) } // TODO: refactor this thing. It is very complicated because it does projection internally. @@ -299,14 +303,18 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { projects.asInstanceOf[Seq[Attribute]] // Safe due to if above. .map(relation.attributeMap) // Match original case of attributes. - val scan = execution.PhysicalRDD(projects.map(_.toAttribute), - scanBuilder(requestedColumns, pushedFilters)) + val scan = execution.PhysicalRDD.createFromDataSource( + projects.map(_.toAttribute), + scanBuilder(requestedColumns, pushedFilters), + relation.relation) filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) } else { val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq - val scan = execution.PhysicalRDD(requestedColumns, - scanBuilder(requestedColumns, pushedFilters)) + val scan = execution.PhysicalRDD.createFromDataSource( + requestedColumns, + scanBuilder(requestedColumns, pushedFilters), + relation.relation) execution.Project(projects, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index c04557e5a0818..0b2929661b657 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -383,7 +383,7 @@ private[sql] abstract class OutputWriterInternal extends OutputWriter { abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[PartitionSpec]) extends BaseRelation with Logging { - logInfo("Constructing HadoopFsRelation") + override def toString: String = getClass.getSimpleName + paths.mkString("[", ",", "]") def this() = this(None) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 8208b25b5708c..322966f423784 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -32,9 +32,9 @@ class RowFormatConvertersSuite extends SparkPlanTest { case c: ConvertToSafe => c } - private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) + private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) assert(!outputsSafe.outputsUnsafeRows) - private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null)) + private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) assert(outputsUnsafe.outputsUnsafeRows) test("planner should insert unsafe->safe conversions when required") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 697211222b90c..8215dd6c2e711 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -36,7 +36,7 @@ class HiveExplainSuite extends QueryTest { "== Analyzed Logical Plan ==", "== Optimized Logical Plan ==", "== Physical Plan ==", - "Code Generation", "== RDD ==") + "Code Generation") } test("explain create table command") { From 881548ab20fa4c4b635c51d956b14bd13981e2f4 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 7 Aug 2015 14:20:13 -0700 Subject: [PATCH 222/340] [SPARK-9674] Re-enable ignored test in SQLQuerySuite The original code that this test tests is removed in https://github.com/apache/spark/commit/9270bd06fd0b16892e3f37213b5bc7813ea11fdd. It was ignored shortly before that so we never caught it. This patch re-enables the test and adds the code necessary to make it pass. JoshRosen yhuai Author: Andrew Or Closes #8015 from andrewor14/SPARK-9674 and squashes the following commits: 225eac2 [Andrew Or] Merge branch 'master' of github.com:apache/spark into SPARK-9674 8c24209 [Andrew Or] Fix NPE e541d64 [Andrew Or] Track aggregation memory for both sort and hash 0be3a42 [Andrew Or] Fix test --- .../spark/unsafe/map/BytesToBytesMap.java | 37 +++++++++++++++++-- .../map/AbstractBytesToBytesMapSuite.java | 20 ++++++---- .../UnsafeFixedWidthAggregationMap.java | 7 ++-- .../sql/execution/UnsafeKVExternalSorter.java | 7 ++++ .../TungstenAggregationIterator.scala | 32 +++++++++++++--- .../org/apache/spark/sql/SQLQuerySuite.scala | 8 ++-- 6 files changed, 85 insertions(+), 26 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 0636ae7c8df1a..7f79cd13aab43 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -109,7 +109,7 @@ public final class BytesToBytesMap { * Position {@code 2 * i} in the array is used to track a pointer to the key at index {@code i}, * while position {@code 2 * i + 1} in the array holds key's full 32-bit hashcode. */ - private LongArray longArray; + @Nullable private LongArray longArray; // TODO: we're wasting 32 bits of space here; we can probably store fewer bits of the hashcode // and exploit word-alignment to use fewer bits to hold the address. This might let us store // only one long per map entry, increasing the chance that this array will fit in cache at the @@ -124,7 +124,7 @@ public final class BytesToBytesMap { * A {@link BitSet} used to track location of the map where the key is set. * Size of the bitset should be half of the size of the long array. */ - private BitSet bitset; + @Nullable private BitSet bitset; private final double loadFactor; @@ -166,6 +166,8 @@ public final class BytesToBytesMap { private long numHashCollisions = 0; + private long peakMemoryUsedBytes = 0L; + public BytesToBytesMap( TaskMemoryManager taskMemoryManager, ShuffleMemoryManager shuffleMemoryManager, @@ -321,6 +323,9 @@ public Location lookup( Object keyBaseObject, long keyBaseOffset, int keyRowLengthBytes) { + assert(bitset != null); + assert(longArray != null); + if (enablePerfMetrics) { numKeyLookups++; } @@ -410,6 +415,7 @@ private void updateAddressesAndSizes(final Object page, final long offsetInPage) } private Location with(int pos, int keyHashcode, boolean isDefined) { + assert(longArray != null); this.pos = pos; this.isDefined = isDefined; this.keyHashcode = keyHashcode; @@ -525,6 +531,9 @@ public boolean putNewKey( assert (!isDefined) : "Can only set value once for a key"; assert (keyLengthBytes % 8 == 0); assert (valueLengthBytes % 8 == 0); + assert(bitset != null); + assert(longArray != null); + if (numElements == MAX_CAPACITY) { throw new IllegalStateException("BytesToBytesMap has reached maximum capacity"); } @@ -658,6 +667,7 @@ private void allocate(int capacity) { * This method is idempotent and can be called multiple times. */ public void free() { + updatePeakMemoryUsed(); longArray = null; bitset = null; Iterator dataPagesIterator = dataPages.iterator(); @@ -684,14 +694,30 @@ public long getPageSizeBytes() { /** * Returns the total amount of memory, in bytes, consumed by this map's managed structures. - * Note that this is also the peak memory used by this map, since the map is append-only. */ public long getTotalMemoryConsumption() { long totalDataPagesSize = 0L; for (MemoryBlock dataPage : dataPages) { totalDataPagesSize += dataPage.size(); } - return totalDataPagesSize + bitset.memoryBlock().size() + longArray.memoryBlock().size(); + return totalDataPagesSize + + ((bitset != null) ? bitset.memoryBlock().size() : 0L) + + ((longArray != null) ? longArray.memoryBlock().size() : 0L); + } + + private void updatePeakMemoryUsed() { + long mem = getTotalMemoryConsumption(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; } /** @@ -731,6 +757,9 @@ int getNumDataPages() { */ @VisibleForTesting void growAndRehash() { + assert(bitset != null); + assert(longArray != null); + long resizeStartTime = -1; if (enablePerfMetrics) { resizeStartTime = System.nanoTime(); diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 0b11562980b8e..e56a3f0b6d12c 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -525,7 +525,7 @@ public void resizingLargeMap() { } @Test - public void testTotalMemoryConsumption() { + public void testPeakMemoryUsed() { final long recordLengthBytes = 24; final long pageSizeBytes = 256 + 8; // 8 bytes for end-of-page marker final long numRecordsPerPage = (pageSizeBytes - 8) / recordLengthBytes; @@ -536,8 +536,8 @@ public void testTotalMemoryConsumption() { // monotonically increasing. More specifically, every time we allocate a new page it // should increase by exactly the size of the page. In this regard, the memory usage // at any given time is also the peak memory used. - long previousMemory = map.getTotalMemoryConsumption(); - long newMemory; + long previousPeakMemory = map.getPeakMemoryUsedBytes(); + long newPeakMemory; try { for (long i = 0; i < numRecordsPerPage * 10; i++) { final long[] value = new long[]{i}; @@ -548,15 +548,21 @@ public void testTotalMemoryConsumption() { value, PlatformDependent.LONG_ARRAY_OFFSET, 8); - newMemory = map.getTotalMemoryConsumption(); + newPeakMemory = map.getPeakMemoryUsedBytes(); if (i % numRecordsPerPage == 0) { // We allocated a new page for this record, so peak memory should change - assertEquals(previousMemory + pageSizeBytes, newMemory); + assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); } else { - assertEquals(previousMemory, newMemory); + assertEquals(previousPeakMemory, newPeakMemory); } - previousMemory = newMemory; + previousPeakMemory = newPeakMemory; } + + // Freeing the map should not change the peak memory + map.free(); + newPeakMemory = map.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + } finally { map.free(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index efb33530dac86..b08a4a13a28be 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -210,11 +210,10 @@ public void close() { } /** - * The memory used by this map's managed structures, in bytes. - * Note that this is also the peak memory used by this map, since the map is append-only. + * Return the peak memory used so far, in bytes. */ - public long getMemoryUsage() { - return map.getTotalMemoryConsumption(); + public long getPeakMemoryUsedBytes() { + return map.getPeakMemoryUsedBytes(); } /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 9a65c9d3a404a..69d6784713a24 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -159,6 +159,13 @@ public KVSorterIterator sortedIterator() throws IOException { } } + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsedBytes() { + return sorter.getPeakMemoryUsedBytes(); + } + /** * Marks the current page as no-more-space-available, and as a result, either allocate a * new page or spill when we see the next record. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 4d5e98a3e90c8..440bef32f4e9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.unsafe.KVIterator -import org.apache.spark.{Logging, SparkEnv, TaskContext} +import org.apache.spark.{InternalAccumulator, Logging, SparkEnv, TaskContext} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner @@ -397,14 +397,20 @@ class TungstenAggregationIterator( private[this] var mapIteratorHasNext: Boolean = false /////////////////////////////////////////////////////////////////////////// - // Part 4: The function used to switch this iterator from hash-based - // aggregation to sort-based aggregation. + // Part 3: Methods and fields used by sort-based aggregation. /////////////////////////////////////////////////////////////////////////// + // This sorter is used for sort-based aggregation. It is initialized as soon as + // we switch from hash-based to sort-based aggregation. Otherwise, it is not used. + private[this] var externalSorter: UnsafeKVExternalSorter = null + + /** + * Switch to sort-based aggregation when the hash-based approach is unable to acquire memory. + */ private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: UnsafeRow): Unit = { logInfo("falling back to sort based aggregation.") // Step 1: Get the ExternalSorter containing sorted entries of the map. - val externalSorter: UnsafeKVExternalSorter = hashMap.destructAndCreateExternalSorter() + externalSorter = hashMap.destructAndCreateExternalSorter() // Step 2: Free the memory used by the map. hashMap.free() @@ -601,7 +607,7 @@ class TungstenAggregationIterator( } /////////////////////////////////////////////////////////////////////////// - // Par 7: Iterator's public methods. + // Part 7: Iterator's public methods. /////////////////////////////////////////////////////////////////////////// override final def hasNext: Boolean = { @@ -610,7 +616,7 @@ class TungstenAggregationIterator( override final def next(): UnsafeRow = { if (hasNext) { - if (sortBased) { + val res = if (sortBased) { // Process the current group. processCurrentSortedGroup() // Generate output row for the current group. @@ -641,6 +647,19 @@ class TungstenAggregationIterator( result } } + + // If this is the last record, update the task's peak memory usage. Since we destroy + // the map to create the sorter, their memory usages should not overlap, so it is safe + // to just use the max of the two. + if (!hasNext) { + val mapMemory = hashMap.getPeakMemoryUsedBytes + val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L) + val peakMemory = Math.max(mapMemory, sorterMemory) + TaskContext.get().internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemory) + } + + res } else { // no more result throw new NoSuchElementException @@ -651,6 +670,7 @@ class TungstenAggregationIterator( // Part 8: A utility function used to generate a output row when there is no // input and there is no grouping expression. /////////////////////////////////////////////////////////////////////////// + def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { if (groupingExpressions.isEmpty) { sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c64aa7a07dc2b..b14ef9bab90cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -267,7 +267,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { if (!hasGeneratedAgg) { fail( s""" - |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan. + |Codegen is enabled, but query $sqlText does not have TungstenAggregate in the plan. |${df.queryExecution.simpleString} """.stripMargin) } @@ -1602,10 +1602,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123)))) } - ignore("aggregation with codegen updates peak execution memory") { - withSQLConf( - (SQLConf.CODEGEN_ENABLED.key, "true"), - (SQLConf.USE_SQL_AGGREGATE2.key, "false")) { + test("aggregation with codegen updates peak execution memory") { + withSQLConf((SQLConf.CODEGEN_ENABLED.key, "true")) { val sc = sqlContext.sparkContext AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "aggregation with codegen") { testCodeGen( From e2fbbe73111d4624390f596a19a1799c86a05f6c Mon Sep 17 00:00:00 2001 From: Dariusz Kobylarz Date: Fri, 7 Aug 2015 14:51:03 -0700 Subject: [PATCH 223/340] [SPARK-8481] [MLLIB] GaussianMixtureModel predict accepting single vector Resubmit of [https://github.com/apache/spark/pull/6906] for adding single-vec predict to GMMs CC: dkobylarz mengxr To be merged with master and branch-1.5 Primary author: dkobylarz Author: Dariusz Kobylarz Closes #8039 from jkbradley/gmm-predict-vec and squashes the following commits: bfbedc4 [Dariusz Kobylarz] [SPARK-8481] [MLlib] GaussianMixtureModel predict accepting single vector --- .../mllib/clustering/GaussianMixtureModel.scala | 13 +++++++++++++ .../mllib/clustering/GaussianMixtureSuite.scala | 10 ++++++++++ 2 files changed, 23 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index cb807c8038101..76aeebd703d4e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -66,6 +66,12 @@ class GaussianMixtureModel( responsibilityMatrix.map(r => r.indexOf(r.max)) } + /** Maps given point to its cluster index. */ + def predict(point: Vector): Int = { + val r = computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k) + r.indexOf(r.max) + } + /** Java-friendly version of [[predict()]] */ def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] @@ -83,6 +89,13 @@ class GaussianMixtureModel( } } + /** + * Given the input vector, return the membership values to all mixture components. + */ + def predictSoft(point: Vector): Array[Double] = { + computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k) + } + /** * Compute the partial assignments for each vector */ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala index b218d72f1268a..b636d02f786e6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -148,6 +148,16 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("model prediction, parallel and local") { + val data = sc.parallelize(GaussianTestData.data) + val gmm = new GaussianMixture().setK(2).setSeed(0).run(data) + + val batchPredictions = gmm.predict(data) + batchPredictions.zip(data).collect().foreach { case (batchPred, datum) => + assert(batchPred === gmm.predict(datum)) + } + } + object GaussianTestData { val data = Array( From 902334fd55bbe40a57c1de2a9bdb25eddf1c8cf6 Mon Sep 17 00:00:00 2001 From: Bertrand Dechoux Date: Fri, 7 Aug 2015 16:07:24 -0700 Subject: [PATCH 224/340] [SPARK-9748] [MLLIB] Centriod typo in KMeansModel A minor typo (centriod -> centroid). Readable variable names help every users. Author: Bertrand Dechoux Closes #8037 from BertrandDechoux/kmeans-typo and squashes the following commits: 47632fe [Bertrand Dechoux] centriod typo --- .../apache/spark/mllib/clustering/KMeansModel.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 8ecb3df11d95e..96359024fa228 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -120,11 +120,11 @@ object KMeansModel extends Loader[KMeansModel] { assert(className == thisClassName) assert(formatVersion == thisFormatVersion) val k = (metadata \ "k").extract[Int] - val centriods = sqlContext.read.parquet(Loader.dataPath(path)) - Loader.checkSchema[Cluster](centriods.schema) - val localCentriods = centriods.map(Cluster.apply).collect() - assert(k == localCentriods.size) - new KMeansModel(localCentriods.sortBy(_.id).map(_.point)) + val centroids = sqlContext.read.parquet(Loader.dataPath(path)) + Loader.checkSchema[Cluster](centroids.schema) + val localCentroids = centroids.map(Cluster.apply).collect() + assert(k == localCentroids.size) + new KMeansModel(localCentroids.sortBy(_.id).map(_.point)) } } } From 49702bd738de681255a7177339510e0e1b25a8db Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 7 Aug 2015 16:24:50 -0700 Subject: [PATCH 225/340] [SPARK-8890] [SQL] Fallback on sorting when writing many dynamic partitions Previously, we would open a new file for each new dynamic written out using `HadoopFsRelation`. For formats like parquet this is very costly due to the buffers required to get good compression. In this PR I refactor the code allowing us to fall back on an external sort when many partitions are seen. As such each task will open no more than `spark.sql.sources.maxFiles` files. I also did the following cleanup: - Instead of keying the file HashMap on an expensive to compute string representation of the partition, we now use a fairly cheap UnsafeProjection that avoids heap allocations. - The control flow for instantiating and invoking a writer container has been simplified. Now instead of switching in two places based on the use of partitioning, the specific writer container must implement a single method `writeRows` that is invoked using `runJob`. - `InternalOutputWriter` has been removed. Instead we have a `private[sql]` method `writeInternal` that converts and calls the public method. This method can be overridden by internal datasources to avoid the conversion. This change remove a lot of code duplication and per-row `asInstanceOf` checks. - `commands.scala` has been split up. Author: Michael Armbrust Closes #8010 from marmbrus/fsWriting and squashes the following commits: 00804fe [Michael Armbrust] use shuffleMemoryManager.pageSizeBytes 775cc49 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into fsWriting 17b690e [Michael Armbrust] remove comment 40f0372 [Michael Armbrust] address comments f5675bd [Michael Armbrust] char -> string 7e2d0a4 [Michael Armbrust] make sure we close current writer 8100100 [Michael Armbrust] delete empty commands.scala 71cc717 [Michael Armbrust] update comment 8ec75ac [Michael Armbrust] [SPARK-8890][SQL] Fallback on sorting when writing many dynamic partitions --- .../scala/org/apache/spark/sql/SQLConf.scala | 8 +- .../datasources/InsertIntoDataSource.scala | 64 ++ .../InsertIntoHadoopFsRelation.scala | 165 +++++ .../datasources/WriterContainer.scala | 404 ++++++++++++ .../sql/execution/datasources/commands.scala | 606 ------------------ .../apache/spark/sql/json/JSONRelation.scala | 6 +- .../spark/sql/parquet/ParquetRelation.scala | 6 +- .../apache/spark/sql/sources/interfaces.scala | 17 +- .../sql/sources/PartitionedWriteSuite.scala | 56 ++ .../spark/sql/hive/orc/OrcRelation.scala | 6 +- 10 files changed, 715 insertions(+), 623 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 45d3d8c863512..e9de14f025502 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -366,17 +366,21 @@ private[spark] object SQLConf { "storing additional schema information in Hive's metastore.", isPublic = false) - // Whether to perform partition discovery when loading external data sources. Default to true. val PARTITION_DISCOVERY_ENABLED = booleanConf("spark.sql.sources.partitionDiscovery.enabled", defaultValue = Some(true), doc = "When true, automtically discover data partitions.") - // Whether to perform partition column type inference. Default to true. val PARTITION_COLUMN_TYPE_INFERENCE = booleanConf("spark.sql.sources.partitionColumnTypeInference.enabled", defaultValue = Some(true), doc = "When true, automatically infer the data types for partitioned columns.") + val PARTITION_MAX_FILES = + intConf("spark.sql.sources.maxConcurrentWrites", + defaultValue = Some(5), + doc = "The maximum number of concurent files to open before falling back on sorting when " + + "writing out files using dynamic partitioning.") + // The output committer class used by HadoopFsRelation. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. // diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala new file mode 100644 index 0000000000000..6ccde7693bd34 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.IOException +import java.util.{Date, UUID} + +import scala.collection.JavaConversions.asScalaIterator + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter, FileOutputFormat} +import org.apache.spark._ +import org.apache.spark.mapred.SparkHadoopMapRedUtil +import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.execution.{RunnableCommand, SQLExecution} +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.StringType +import org.apache.spark.util.{Utils, SerializableConfiguration} + + +/** + * Inserts the results of `query` in to a relation that extends [[InsertableRelation]]. + */ +private[sql] case class InsertIntoDataSource( + logicalRelation: LogicalRelation, + query: LogicalPlan, + overwrite: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] + val data = DataFrame(sqlContext, query) + // Apply the schema of the existing table to the new data. + val df = sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) + relation.insert(df, overwrite) + + // Invalidate the cache. + sqlContext.cacheManager.invalidateCache(logicalRelation) + + Seq.empty[Row] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala new file mode 100644 index 0000000000000..735d52f808868 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.IOException + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat +import org.apache.spark._ +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.{RunnableCommand, SQLExecution} +import org.apache.spark.sql.sources._ +import org.apache.spark.util.Utils + + +/** + * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. + * Writing to dynamic partitions is also supported. Each [[InsertIntoHadoopFsRelation]] issues a + * single write job, and owns a UUID that identifies this job. Each concrete implementation of + * [[HadoopFsRelation]] should use this UUID together with task id to generate unique file path for + * each task output file. This UUID is passed to executor side via a property named + * `spark.sql.sources.writeJobUUID`. + * + * Different writer containers, [[DefaultWriterContainer]] and [[DynamicPartitionWriterContainer]] + * are used to write to normal tables and tables with dynamic partitions. + * + * Basic work flow of this command is: + * + * 1. Driver side setup, including output committer initialization and data source specific + * preparation work for the write job to be issued. + * 2. Issues a write job consists of one or more executor side tasks, each of which writes all + * rows within an RDD partition. + * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any + * exception is thrown during task commitment, also aborts that task. + * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is + * thrown during job commitment, also aborts the job. + */ +private[sql] case class InsertIntoHadoopFsRelation( + @transient relation: HadoopFsRelation, + @transient query: LogicalPlan, + mode: SaveMode) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + require( + relation.paths.length == 1, + s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}") + + val hadoopConf = sqlContext.sparkContext.hadoopConfiguration + val outputPath = new Path(relation.paths.head) + val fs = outputPath.getFileSystem(hadoopConf) + val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + + val pathExists = fs.exists(qualifiedOutputPath) + val doInsertion = (mode, pathExists) match { + case (SaveMode.ErrorIfExists, true) => + throw new AnalysisException(s"path $qualifiedOutputPath already exists.") + case (SaveMode.Overwrite, true) => + Utils.tryOrIOException { + if (!fs.delete(qualifiedOutputPath, true /* recursively */)) { + throw new IOException(s"Unable to clear output " + + s"directory $qualifiedOutputPath prior to writing to it") + } + } + true + case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => + true + case (SaveMode.Ignore, exists) => + !exists + case (s, exists) => + throw new IllegalStateException(s"unsupported save mode $s ($exists)") + } + // If we are appending data to an existing dir. + val isAppend = pathExists && (mode == SaveMode.Append) + + if (doInsertion) { + val job = new Job(hadoopConf) + job.setOutputKeyClass(classOf[Void]) + job.setOutputValueClass(classOf[InternalRow]) + FileOutputFormat.setOutputPath(job, qualifiedOutputPath) + + // A partitioned relation schema's can be different from the input logicalPlan, since + // partition columns are all moved after data column. We Project to adjust the ordering. + // TODO: this belongs in the analyzer. + val project = Project( + relation.schema.map(field => UnresolvedAttribute.quoted(field.name)), query) + val queryExecution = DataFrame(sqlContext, project).queryExecution + + SQLExecution.withNewExecutionId(sqlContext, queryExecution) { + val df = sqlContext.internalCreateDataFrame(queryExecution.toRdd, relation.schema) + val partitionColumns = relation.partitionColumns.fieldNames + + // Some pre-flight checks. + require( + df.schema == relation.schema, + s"""DataFrame must have the same schema as the relation to which is inserted. + |DataFrame schema: ${df.schema} + |Relation schema: ${relation.schema} + """.stripMargin) + val partitionColumnsInSpec = relation.partitionColumns.fieldNames + require( + partitionColumnsInSpec.sameElements(partitionColumns), + s"""Partition columns mismatch. + |Expected: ${partitionColumnsInSpec.mkString(", ")} + |Actual: ${partitionColumns.mkString(", ")} + """.stripMargin) + + val writerContainer = if (partitionColumns.isEmpty) { + new DefaultWriterContainer(relation, job, isAppend) + } else { + val output = df.queryExecution.executedPlan.output + val (partitionOutput, dataOutput) = + output.partition(a => partitionColumns.contains(a.name)) + + new DynamicPartitionWriterContainer( + relation, + job, + partitionOutput, + dataOutput, + output, + PartitioningUtils.DEFAULT_PARTITION_NAME, + sqlContext.conf.getConf(SQLConf.PARTITION_MAX_FILES), + isAppend) + } + + // This call shouldn't be put into the `try` block below because it only initializes and + // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. + writerContainer.driverSideSetup() + + try { + sqlContext.sparkContext.runJob(df.queryExecution.toRdd, writerContainer.writeRows _) + writerContainer.commitJob() + relation.refresh() + } catch { case cause: Throwable => + logError("Aborting job.", cause) + writerContainer.abortJob() + throw new SparkException("Job aborted.", cause) + } + } + } else { + logInfo("Skipping insertion into a relation that already exists.") + } + + Seq.empty[Row] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala new file mode 100644 index 0000000000000..2f11f40422402 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -0,0 +1,404 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.util.{Date, UUID} + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter} +import org.apache.spark._ +import org.apache.spark.mapred.SparkHadoopMapRedUtil +import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.UnsafeKVExternalSorter +import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory} +import org.apache.spark.sql.types.{StructType, StringType} +import org.apache.spark.util.SerializableConfiguration + + +private[sql] abstract class BaseWriterContainer( + @transient val relation: HadoopFsRelation, + @transient job: Job, + isAppend: Boolean) + extends SparkHadoopMapReduceUtil + with Logging + with Serializable { + + protected val dataSchema = relation.dataSchema + + protected val serializableConf = new SerializableConfiguration(job.getConfiguration) + + // This UUID is used to avoid output file name collision between different appending write jobs. + // These jobs may belong to different SparkContext instances. Concrete data source implementations + // may use this UUID to generate unique file names (e.g., `part-r--.parquet`). + // The reason why this ID is used to identify a job rather than a single task output file is + // that, speculative tasks must generate the same output file name as the original task. + private val uniqueWriteJobId = UUID.randomUUID() + + // This is only used on driver side. + @transient private val jobContext: JobContext = job + + // The following fields are initialized and used on both driver and executor side. + @transient protected var outputCommitter: OutputCommitter = _ + @transient private var jobId: JobID = _ + @transient private var taskId: TaskID = _ + @transient private var taskAttemptId: TaskAttemptID = _ + @transient protected var taskAttemptContext: TaskAttemptContext = _ + + protected val outputPath: String = { + assert( + relation.paths.length == 1, + s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}") + relation.paths.head + } + + protected var outputWriterFactory: OutputWriterFactory = _ + + private var outputFormatClass: Class[_ <: OutputFormat[_, _]] = _ + + def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit + + def driverSideSetup(): Unit = { + setupIDs(0, 0, 0) + setupConf() + + // This UUID is sent to executor side together with the serialized `Configuration` object within + // the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate + // unique task output files. + job.getConfiguration.set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) + + // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor + // clones the Configuration object passed in. If we initialize the TaskAttemptContext first, + // configurations made in prepareJobForWrite(job) are not populated into the TaskAttemptContext. + // + // Also, the `prepareJobForWrite` call must happen before initializing output format and output + // committer, since their initialization involve the job configuration, which can be potentially + // decorated in `prepareJobForWrite`. + outputWriterFactory = relation.prepareJobForWrite(job) + taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) + + outputFormatClass = job.getOutputFormatClass + outputCommitter = newOutputCommitter(taskAttemptContext) + outputCommitter.setupJob(jobContext) + } + + def executorSideSetup(taskContext: TaskContext): Unit = { + setupIDs(taskContext.stageId(), taskContext.partitionId(), taskContext.attemptNumber()) + setupConf() + taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) + outputCommitter = newOutputCommitter(taskAttemptContext) + outputCommitter.setupTask(taskAttemptContext) + } + + protected def getWorkPath: String = { + outputCommitter match { + // FileOutputCommitter writes to a temporary location returned by `getWorkPath`. + case f: MapReduceFileOutputCommitter => f.getWorkPath.toString + case _ => outputPath + } + } + + private def newOutputCommitter(context: TaskAttemptContext): OutputCommitter = { + val defaultOutputCommitter = outputFormatClass.newInstance().getOutputCommitter(context) + + if (isAppend) { + // If we are appending data to an existing dir, we will only use the output committer + // associated with the file output format since it is not safe to use a custom + // committer for appending. For example, in S3, direct parquet output committer may + // leave partial data in the destination dir when the the appending job fails. + logInfo( + s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName} " + + "for appending.") + defaultOutputCommitter + } else { + val committerClass = context.getConfiguration.getClass( + SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) + + Option(committerClass).map { clazz => + logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}") + + // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat + // has an associated output committer. To override this output committer, + // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS. + // If a data source needs to override the output committer, it needs to set the + // output committer in prepareForWrite method. + if (classOf[MapReduceFileOutputCommitter].isAssignableFrom(clazz)) { + // The specified output committer is a FileOutputCommitter. + // So, we will use the FileOutputCommitter-specified constructor. + val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) + ctor.newInstance(new Path(outputPath), context) + } else { + // The specified output committer is just a OutputCommitter. + // So, we will use the no-argument constructor. + val ctor = clazz.getDeclaredConstructor() + ctor.newInstance() + } + }.getOrElse { + // If output committer class is not set, we will use the one associated with the + // file output format. + logInfo( + s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName}") + defaultOutputCommitter + } + } + } + + private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = { + this.jobId = SparkHadoopWriter.createJobID(new Date, jobId) + this.taskId = new TaskID(this.jobId, true, splitId) + this.taskAttemptId = new TaskAttemptID(taskId, attemptId) + } + + private def setupConf(): Unit = { + serializableConf.value.set("mapred.job.id", jobId.toString) + serializableConf.value.set("mapred.tip.id", taskAttemptId.getTaskID.toString) + serializableConf.value.set("mapred.task.id", taskAttemptId.toString) + serializableConf.value.setBoolean("mapred.task.is.map", true) + serializableConf.value.setInt("mapred.task.partition", 0) + } + + def commitTask(): Unit = { + SparkHadoopMapRedUtil.commitTask( + outputCommitter, taskAttemptContext, jobId.getId, taskId.getId, taskAttemptId.getId) + } + + def abortTask(): Unit = { + if (outputCommitter != null) { + outputCommitter.abortTask(taskAttemptContext) + } + logError(s"Task attempt $taskAttemptId aborted.") + } + + def commitJob(): Unit = { + outputCommitter.commitJob(jobContext) + logInfo(s"Job $jobId committed.") + } + + def abortJob(): Unit = { + if (outputCommitter != null) { + outputCommitter.abortJob(jobContext, JobStatus.State.FAILED) + } + logError(s"Job $jobId aborted.") + } +} + +/** + * A writer that writes all of the rows in a partition to a single file. + */ +private[sql] class DefaultWriterContainer( + @transient relation: HadoopFsRelation, + @transient job: Job, + isAppend: Boolean) + extends BaseWriterContainer(relation, job, isAppend) { + + def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { + executorSideSetup(taskContext) + taskAttemptContext.getConfiguration.set("spark.sql.sources.output.path", outputPath) + val writer = outputWriterFactory.newInstance(getWorkPath, dataSchema, taskAttemptContext) + writer.initConverter(dataSchema) + + // If anything below fails, we should abort the task. + try { + while (iterator.hasNext) { + val internalRow = iterator.next() + writer.writeInternal(internalRow) + } + + commitTask() + } catch { + case cause: Throwable => + logError("Aborting task.", cause) + abortTask() + throw new SparkException("Task failed while writing rows.", cause) + } + + def commitTask(): Unit = { + try { + assert(writer != null, "OutputWriter instance should have been initialized") + writer.close() + super.commitTask() + } catch { + case cause: Throwable => + // This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and + // will cause `abortTask()` to be invoked. + throw new RuntimeException("Failed to commit task", cause) + } + } + + def abortTask(): Unit = { + try { + writer.close() + } finally { + super.abortTask() + } + } + } +} + +/** + * A writer that dynamically opens files based on the given partition columns. Internally this is + * done by maintaining a HashMap of open files until `maxFiles` is reached. If this occurs, the + * writer externally sorts the remaining rows and then writes out them out one file at a time. + */ +private[sql] class DynamicPartitionWriterContainer( + @transient relation: HadoopFsRelation, + @transient job: Job, + partitionColumns: Seq[Attribute], + dataColumns: Seq[Attribute], + inputSchema: Seq[Attribute], + defaultPartitionName: String, + maxOpenFiles: Int, + isAppend: Boolean) + extends BaseWriterContainer(relation, job, isAppend) { + + def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { + val outputWriters = new java.util.HashMap[InternalRow, OutputWriter] + executorSideSetup(taskContext) + + // Returns the partition key given an input row + val getPartitionKey = UnsafeProjection.create(partitionColumns, inputSchema) + // Returns the data columns to be written given an input row + val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) + + // Expressions that given a partition key build a string like: col1=val/col2=val/... + val partitionStringExpression = partitionColumns.zipWithIndex.flatMap { case (c, i) => + val escaped = + ScalaUDF( + PartitioningUtils.escapePathName _, StringType, Seq(Cast(c, StringType)), Seq(StringType)) + val str = If(IsNull(c), Literal(defaultPartitionName), escaped) + val partitionName = Literal(c.name + "=") :: str :: Nil + if (i == 0) partitionName else Literal(Path.SEPARATOR_CHAR.toString) :: partitionName + } + + // Returns the partition path given a partition key. + val getPartitionString = + UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns) + + // If anything below fails, we should abort the task. + try { + // This will be filled in if we have to fall back on sorting. + var sorter: UnsafeKVExternalSorter = null + while (iterator.hasNext && sorter == null) { + val inputRow = iterator.next() + val currentKey = getPartitionKey(inputRow) + var currentWriter = outputWriters.get(currentKey) + + if (currentWriter == null) { + if (outputWriters.size < maxOpenFiles) { + currentWriter = newOutputWriter(currentKey) + outputWriters.put(currentKey.copy(), currentWriter) + currentWriter.writeInternal(getOutputRow(inputRow)) + } else { + logInfo(s"Maximum partitions reached, falling back on sorting.") + sorter = new UnsafeKVExternalSorter( + StructType.fromAttributes(partitionColumns), + StructType.fromAttributes(dataColumns), + SparkEnv.get.blockManager, + SparkEnv.get.shuffleMemoryManager, + SparkEnv.get.shuffleMemoryManager.pageSizeBytes) + sorter.insertKV(currentKey, getOutputRow(inputRow)) + } + } else { + currentWriter.writeInternal(getOutputRow(inputRow)) + } + } + + // If the sorter is not null that means that we reached the maxFiles above and need to finish + // using external sort. + if (sorter != null) { + while (iterator.hasNext) { + val currentRow = iterator.next() + sorter.insertKV(getPartitionKey(currentRow), getOutputRow(currentRow)) + } + + logInfo(s"Sorting complete. Writing out partition files one at a time.") + + val sortedIterator = sorter.sortedIterator() + var currentKey: InternalRow = null + var currentWriter: OutputWriter = null + try { + while (sortedIterator.next()) { + if (currentKey != sortedIterator.getKey) { + if (currentWriter != null) { + currentWriter.close() + } + currentKey = sortedIterator.getKey.copy() + logDebug(s"Writing partition: $currentKey") + + // Either use an existing file from before, or open a new one. + currentWriter = outputWriters.remove(currentKey) + if (currentWriter == null) { + currentWriter = newOutputWriter(currentKey) + } + } + + currentWriter.writeInternal(sortedIterator.getValue) + } + } finally { + if (currentWriter != null) { currentWriter.close() } + } + } + + commitTask() + } catch { + case cause: Throwable => + logError("Aborting task.", cause) + abortTask() + throw new SparkException("Task failed while writing rows.", cause) + } + + /** Open and returns a new OutputWriter given a partition key. */ + def newOutputWriter(key: InternalRow): OutputWriter = { + val partitionPath = getPartitionString(key).getString(0) + val path = new Path(getWorkPath, partitionPath) + taskAttemptContext.getConfiguration.set( + "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) + val newWriter = outputWriterFactory.newInstance(path.toString, dataSchema, taskAttemptContext) + newWriter.initConverter(dataSchema) + newWriter + } + + def clearOutputWriters(): Unit = { + outputWriters.asScala.values.foreach(_.close()) + outputWriters.clear() + } + + def commitTask(): Unit = { + try { + clearOutputWriters() + super.commitTask() + } catch { + case cause: Throwable => + throw new RuntimeException("Failed to commit task", cause) + } + } + + def abortTask(): Unit = { + try { + clearOutputWriters() + } finally { + super.abortTask() + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala deleted file mode 100644 index 42668979c9a32..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala +++ /dev/null @@ -1,606 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources - -import java.io.IOException -import java.util.{Date, UUID} - -import scala.collection.JavaConversions.asScalaIterator - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter, FileOutputFormat} -import org.apache.spark._ -import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.execution.{RunnableCommand, SQLExecution} -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.StringType -import org.apache.spark.util.{Utils, SerializableConfiguration} - - -private[sql] case class InsertIntoDataSource( - logicalRelation: LogicalRelation, - query: LogicalPlan, - overwrite: Boolean) - extends RunnableCommand { - - override def run(sqlContext: SQLContext): Seq[Row] = { - val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] - val data = DataFrame(sqlContext, query) - // Apply the schema of the existing table to the new data. - val df = sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) - relation.insert(df, overwrite) - - // Invalidate the cache. - sqlContext.cacheManager.invalidateCache(logicalRelation) - - Seq.empty[Row] - } -} - -/** - * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. - * Writing to dynamic partitions is also supported. Each [[InsertIntoHadoopFsRelation]] issues a - * single write job, and owns a UUID that identifies this job. Each concrete implementation of - * [[HadoopFsRelation]] should use this UUID together with task id to generate unique file path for - * each task output file. This UUID is passed to executor side via a property named - * `spark.sql.sources.writeJobUUID`. - * - * Different writer containers, [[DefaultWriterContainer]] and [[DynamicPartitionWriterContainer]] - * are used to write to normal tables and tables with dynamic partitions. - * - * Basic work flow of this command is: - * - * 1. Driver side setup, including output committer initialization and data source specific - * preparation work for the write job to be issued. - * 2. Issues a write job consists of one or more executor side tasks, each of which writes all - * rows within an RDD partition. - * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any - * exception is thrown during task commitment, also aborts that task. - * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is - * thrown during job commitment, also aborts the job. - */ -private[sql] case class InsertIntoHadoopFsRelation( - @transient relation: HadoopFsRelation, - @transient query: LogicalPlan, - mode: SaveMode) - extends RunnableCommand { - - override def run(sqlContext: SQLContext): Seq[Row] = { - require( - relation.paths.length == 1, - s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}") - - val hadoopConf = sqlContext.sparkContext.hadoopConfiguration - val outputPath = new Path(relation.paths.head) - val fs = outputPath.getFileSystem(hadoopConf) - val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - - val pathExists = fs.exists(qualifiedOutputPath) - val doInsertion = (mode, pathExists) match { - case (SaveMode.ErrorIfExists, true) => - throw new AnalysisException(s"path $qualifiedOutputPath already exists.") - case (SaveMode.Overwrite, true) => - Utils.tryOrIOException { - if (!fs.delete(qualifiedOutputPath, true /* recursively */)) { - throw new IOException(s"Unable to clear output " + - s"directory $qualifiedOutputPath prior to writing to it") - } - } - true - case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => - true - case (SaveMode.Ignore, exists) => - !exists - case (s, exists) => - throw new IllegalStateException(s"unsupported save mode $s ($exists)") - } - // If we are appending data to an existing dir. - val isAppend = pathExists && (mode == SaveMode.Append) - - if (doInsertion) { - val job = new Job(hadoopConf) - job.setOutputKeyClass(classOf[Void]) - job.setOutputValueClass(classOf[InternalRow]) - FileOutputFormat.setOutputPath(job, qualifiedOutputPath) - - // We create a DataFrame by applying the schema of relation to the data to make sure. - // We are writing data based on the expected schema, - - // For partitioned relation r, r.schema's column ordering can be different from the column - // ordering of data.logicalPlan (partition columns are all moved after data column). We - // need a Project to adjust the ordering, so that inside InsertIntoHadoopFsRelation, we can - // safely apply the schema of r.schema to the data. - val project = Project( - relation.schema.map(field => new UnresolvedAttribute(Seq(field.name))), query) - - val queryExecution = DataFrame(sqlContext, project).queryExecution - SQLExecution.withNewExecutionId(sqlContext, queryExecution) { - val df = sqlContext.internalCreateDataFrame(queryExecution.toRdd, relation.schema) - - val partitionColumns = relation.partitionColumns.fieldNames - if (partitionColumns.isEmpty) { - insert(new DefaultWriterContainer(relation, job, isAppend), df) - } else { - val writerContainer = new DynamicPartitionWriterContainer( - relation, job, partitionColumns, PartitioningUtils.DEFAULT_PARTITION_NAME, isAppend) - insertWithDynamicPartitions(sqlContext, writerContainer, df, partitionColumns) - } - } - } - - Seq.empty[Row] - } - - /** - * Inserts the content of the [[DataFrame]] into a table without any partitioning columns. - */ - private def insert(writerContainer: BaseWriterContainer, df: DataFrame): Unit = { - // Uses local vals for serialization - val needsConversion = relation.needConversion - val dataSchema = relation.dataSchema - - // This call shouldn't be put into the `try` block below because it only initializes and - // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. - writerContainer.driverSideSetup() - - try { - df.sqlContext.sparkContext.runJob(df.queryExecution.toRdd, writeRows _) - writerContainer.commitJob() - relation.refresh() - } catch { case cause: Throwable => - logError("Aborting job.", cause) - writerContainer.abortJob() - throw new SparkException("Job aborted.", cause) - } - - def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { - // If anything below fails, we should abort the task. - try { - writerContainer.executorSideSetup(taskContext) - - if (needsConversion) { - val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) - .asInstanceOf[InternalRow => Row] - while (iterator.hasNext) { - val internalRow = iterator.next() - writerContainer.outputWriterForRow(internalRow).write(converter(internalRow)) - } - } else { - while (iterator.hasNext) { - val internalRow = iterator.next() - writerContainer.outputWriterForRow(internalRow) - .asInstanceOf[OutputWriterInternal].writeInternal(internalRow) - } - } - - writerContainer.commitTask() - } catch { case cause: Throwable => - logError("Aborting task.", cause) - writerContainer.abortTask() - throw new SparkException("Task failed while writing rows.", cause) - } - } - } - - /** - * Inserts the content of the [[DataFrame]] into a table with partitioning columns. - */ - private def insertWithDynamicPartitions( - sqlContext: SQLContext, - writerContainer: BaseWriterContainer, - df: DataFrame, - partitionColumns: Array[String]): Unit = { - // Uses a local val for serialization - val needsConversion = relation.needConversion - val dataSchema = relation.dataSchema - - require( - df.schema == relation.schema, - s"""DataFrame must have the same schema as the relation to which is inserted. - |DataFrame schema: ${df.schema} - |Relation schema: ${relation.schema} - """.stripMargin) - - val partitionColumnsInSpec = relation.partitionColumns.fieldNames - require( - partitionColumnsInSpec.sameElements(partitionColumns), - s"""Partition columns mismatch. - |Expected: ${partitionColumnsInSpec.mkString(", ")} - |Actual: ${partitionColumns.mkString(", ")} - """.stripMargin) - - val output = df.queryExecution.executedPlan.output - val (partitionOutput, dataOutput) = output.partition(a => partitionColumns.contains(a.name)) - val codegenEnabled = df.sqlContext.conf.codegenEnabled - - // This call shouldn't be put into the `try` block below because it only initializes and - // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. - writerContainer.driverSideSetup() - - try { - df.sqlContext.sparkContext.runJob(df.queryExecution.toRdd, writeRows _) - writerContainer.commitJob() - relation.refresh() - } catch { case cause: Throwable => - logError("Aborting job.", cause) - writerContainer.abortJob() - throw new SparkException("Job aborted.", cause) - } - - def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { - // If anything below fails, we should abort the task. - try { - writerContainer.executorSideSetup(taskContext) - - // Projects all partition columns and casts them to strings to build partition directories. - val partitionCasts = partitionOutput.map(Cast(_, StringType)) - val partitionProj = newProjection(codegenEnabled, partitionCasts, output) - val dataProj = newProjection(codegenEnabled, dataOutput, output) - - if (needsConversion) { - val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) - .asInstanceOf[InternalRow => Row] - while (iterator.hasNext) { - val internalRow = iterator.next() - val partitionPart = partitionProj(internalRow) - val dataPart = converter(dataProj(internalRow)) - writerContainer.outputWriterForRow(partitionPart).write(dataPart) - } - } else { - while (iterator.hasNext) { - val internalRow = iterator.next() - val partitionPart = partitionProj(internalRow) - val dataPart = dataProj(internalRow) - writerContainer.outputWriterForRow(partitionPart) - .asInstanceOf[OutputWriterInternal].writeInternal(dataPart) - } - } - - writerContainer.commitTask() - } catch { case cause: Throwable => - logError("Aborting task.", cause) - writerContainer.abortTask() - throw new SparkException("Task failed while writing rows.", cause) - } - } - } - - // This is copied from SparkPlan, probably should move this to a more general place. - private def newProjection( - codegenEnabled: Boolean, - expressions: Seq[Expression], - inputSchema: Seq[Attribute]): Projection = { - log.debug( - s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if (codegenEnabled) { - - try { - GenerateProjection.generate(expressions, inputSchema) - } catch { - case e: Exception => - if (sys.props.contains("spark.testing")) { - throw e - } else { - log.error("failed to generate projection, fallback to interpreted", e) - new InterpretedProjection(expressions, inputSchema) - } - } - } else { - new InterpretedProjection(expressions, inputSchema) - } - } -} - -private[sql] abstract class BaseWriterContainer( - @transient val relation: HadoopFsRelation, - @transient job: Job, - isAppend: Boolean) - extends SparkHadoopMapReduceUtil - with Logging - with Serializable { - - protected val serializableConf = new SerializableConfiguration(job.getConfiguration) - - // This UUID is used to avoid output file name collision between different appending write jobs. - // These jobs may belong to different SparkContext instances. Concrete data source implementations - // may use this UUID to generate unique file names (e.g., `part-r--.parquet`). - // The reason why this ID is used to identify a job rather than a single task output file is - // that, speculative tasks must generate the same output file name as the original task. - private val uniqueWriteJobId = UUID.randomUUID() - - // This is only used on driver side. - @transient private val jobContext: JobContext = job - - // The following fields are initialized and used on both driver and executor side. - @transient protected var outputCommitter: OutputCommitter = _ - @transient private var jobId: JobID = _ - @transient private var taskId: TaskID = _ - @transient private var taskAttemptId: TaskAttemptID = _ - @transient protected var taskAttemptContext: TaskAttemptContext = _ - - protected val outputPath: String = { - assert( - relation.paths.length == 1, - s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}") - relation.paths.head - } - - protected val dataSchema = relation.dataSchema - - protected var outputWriterFactory: OutputWriterFactory = _ - - private var outputFormatClass: Class[_ <: OutputFormat[_, _]] = _ - - def driverSideSetup(): Unit = { - setupIDs(0, 0, 0) - setupConf() - - // This UUID is sent to executor side together with the serialized `Configuration` object within - // the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate - // unique task output files. - job.getConfiguration.set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) - - // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor - // clones the Configuration object passed in. If we initialize the TaskAttemptContext first, - // configurations made in prepareJobForWrite(job) are not populated into the TaskAttemptContext. - // - // Also, the `prepareJobForWrite` call must happen before initializing output format and output - // committer, since their initialization involve the job configuration, which can be potentially - // decorated in `prepareJobForWrite`. - outputWriterFactory = relation.prepareJobForWrite(job) - taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) - - outputFormatClass = job.getOutputFormatClass - outputCommitter = newOutputCommitter(taskAttemptContext) - outputCommitter.setupJob(jobContext) - } - - def executorSideSetup(taskContext: TaskContext): Unit = { - setupIDs(taskContext.stageId(), taskContext.partitionId(), taskContext.attemptNumber()) - setupConf() - taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) - outputCommitter = newOutputCommitter(taskAttemptContext) - outputCommitter.setupTask(taskAttemptContext) - initWriters() - } - - protected def getWorkPath: String = { - outputCommitter match { - // FileOutputCommitter writes to a temporary location returned by `getWorkPath`. - case f: MapReduceFileOutputCommitter => f.getWorkPath.toString - case _ => outputPath - } - } - - private def newOutputCommitter(context: TaskAttemptContext): OutputCommitter = { - val defaultOutputCommitter = outputFormatClass.newInstance().getOutputCommitter(context) - - if (isAppend) { - // If we are appending data to an existing dir, we will only use the output committer - // associated with the file output format since it is not safe to use a custom - // committer for appending. For example, in S3, direct parquet output committer may - // leave partial data in the destination dir when the the appending job fails. - logInfo( - s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName} " + - "for appending.") - defaultOutputCommitter - } else { - val committerClass = context.getConfiguration.getClass( - SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) - - Option(committerClass).map { clazz => - logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}") - - // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat - // has an associated output committer. To override this output committer, - // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS. - // If a data source needs to override the output committer, it needs to set the - // output committer in prepareForWrite method. - if (classOf[MapReduceFileOutputCommitter].isAssignableFrom(clazz)) { - // The specified output committer is a FileOutputCommitter. - // So, we will use the FileOutputCommitter-specified constructor. - val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) - ctor.newInstance(new Path(outputPath), context) - } else { - // The specified output committer is just a OutputCommitter. - // So, we will use the no-argument constructor. - val ctor = clazz.getDeclaredConstructor() - ctor.newInstance() - } - }.getOrElse { - // If output committer class is not set, we will use the one associated with the - // file output format. - logInfo( - s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName}") - defaultOutputCommitter - } - } - } - - private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = { - this.jobId = SparkHadoopWriter.createJobID(new Date, jobId) - this.taskId = new TaskID(this.jobId, true, splitId) - this.taskAttemptId = new TaskAttemptID(taskId, attemptId) - } - - private def setupConf(): Unit = { - serializableConf.value.set("mapred.job.id", jobId.toString) - serializableConf.value.set("mapred.tip.id", taskAttemptId.getTaskID.toString) - serializableConf.value.set("mapred.task.id", taskAttemptId.toString) - serializableConf.value.setBoolean("mapred.task.is.map", true) - serializableConf.value.setInt("mapred.task.partition", 0) - } - - // Called on executor side when writing rows - def outputWriterForRow(row: InternalRow): OutputWriter - - protected def initWriters(): Unit - - def commitTask(): Unit = { - SparkHadoopMapRedUtil.commitTask( - outputCommitter, taskAttemptContext, jobId.getId, taskId.getId, taskAttemptId.getId) - } - - def abortTask(): Unit = { - if (outputCommitter != null) { - outputCommitter.abortTask(taskAttemptContext) - } - logError(s"Task attempt $taskAttemptId aborted.") - } - - def commitJob(): Unit = { - outputCommitter.commitJob(jobContext) - logInfo(s"Job $jobId committed.") - } - - def abortJob(): Unit = { - if (outputCommitter != null) { - outputCommitter.abortJob(jobContext, JobStatus.State.FAILED) - } - logError(s"Job $jobId aborted.") - } -} - -private[sql] class DefaultWriterContainer( - @transient relation: HadoopFsRelation, - @transient job: Job, - isAppend: Boolean) - extends BaseWriterContainer(relation, job, isAppend) { - - @transient private var writer: OutputWriter = _ - - override protected def initWriters(): Unit = { - taskAttemptContext.getConfiguration.set("spark.sql.sources.output.path", outputPath) - writer = outputWriterFactory.newInstance(getWorkPath, dataSchema, taskAttemptContext) - } - - override def outputWriterForRow(row: InternalRow): OutputWriter = writer - - override def commitTask(): Unit = { - try { - assert(writer != null, "OutputWriter instance should have been initialized") - writer.close() - super.commitTask() - } catch { case cause: Throwable => - // This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and will - // cause `abortTask()` to be invoked. - throw new RuntimeException("Failed to commit task", cause) - } - } - - override def abortTask(): Unit = { - try { - // It's possible that the task fails before `writer` gets initialized - if (writer != null) { - writer.close() - } - } finally { - super.abortTask() - } - } -} - -private[sql] class DynamicPartitionWriterContainer( - @transient relation: HadoopFsRelation, - @transient job: Job, - partitionColumns: Array[String], - defaultPartitionName: String, - isAppend: Boolean) - extends BaseWriterContainer(relation, job, isAppend) { - - // All output writers are created on executor side. - @transient protected var outputWriters: java.util.HashMap[String, OutputWriter] = _ - - override protected def initWriters(): Unit = { - outputWriters = new java.util.HashMap[String, OutputWriter] - } - - // The `row` argument is supposed to only contain partition column values which have been casted - // to strings. - override def outputWriterForRow(row: InternalRow): OutputWriter = { - val partitionPath = { - val partitionPathBuilder = new StringBuilder - var i = 0 - - while (i < partitionColumns.length) { - val col = partitionColumns(i) - val partitionValueString = { - val string = row.getUTF8String(i) - if (string.eq(null)) { - defaultPartitionName - } else { - PartitioningUtils.escapePathName(string.toString) - } - } - - if (i > 0) { - partitionPathBuilder.append(Path.SEPARATOR_CHAR) - } - - partitionPathBuilder.append(s"$col=$partitionValueString") - i += 1 - } - - partitionPathBuilder.toString() - } - - val writer = outputWriters.get(partitionPath) - if (writer.eq(null)) { - val path = new Path(getWorkPath, partitionPath) - taskAttemptContext.getConfiguration.set( - "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) - val newWriter = outputWriterFactory.newInstance(path.toString, dataSchema, taskAttemptContext) - outputWriters.put(partitionPath, newWriter) - newWriter - } else { - writer - } - } - - private def clearOutputWriters(): Unit = { - if (!outputWriters.isEmpty) { - asScalaIterator(outputWriters.values().iterator()).foreach(_.close()) - outputWriters.clear() - } - } - - override def commitTask(): Unit = { - try { - clearOutputWriters() - super.commitTask() - } catch { case cause: Throwable => - throw new RuntimeException("Failed to commit task", cause) - } - } - - override def abortTask(): Unit = { - try { - clearOutputWriters() - } finally { - super.abortTask() - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 5d371402877c6..10f1367e6984c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -152,7 +152,7 @@ private[json] class JsonOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriterInternal with SparkHadoopMapRedUtil with Logging { + extends OutputWriter with SparkHadoopMapRedUtil with Logging { val writer = new CharArrayWriter() // create the Generator without separator inserted between 2 records @@ -170,7 +170,9 @@ private[json] class JsonOutputWriter( }.getRecordWriter(context) } - override def writeInternal(row: InternalRow): Unit = { + override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") + + override protected[sql] def writeInternal(row: InternalRow): Unit = { JacksonGenerator(dataSchema, gen, row) gen.flush() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 29c388c22ef93..48009b2fd007d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -62,7 +62,7 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider { // NOTE: This class is instantiated and used on executor side only, no need to be serializable. private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext) - extends OutputWriterInternal { + extends OutputWriter { private val recordWriter: RecordWriter[Void, InternalRow] = { val outputFormat = { @@ -87,7 +87,9 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext outputFormat.getRecordWriter(context) } - override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) + override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") + + override protected[sql] def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) override def close(): Unit = recordWriter.close(context) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 0b2929661b657..c5b7ee73eb784 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -342,18 +342,17 @@ abstract class OutputWriter { * @since 1.4.0 */ def close(): Unit -} -/** - * This is an internal, private version of [[OutputWriter]] with an writeInternal method that - * accepts an [[InternalRow]] rather than an [[Row]]. Data sources that return this must have - * the conversion flag set to false. - */ -private[sql] abstract class OutputWriterInternal extends OutputWriter { + private var converter: InternalRow => Row = _ - override def write(row: Row): Unit = throw new UnsupportedOperationException + protected[sql] def initConverter(dataSchema: StructType) = { + converter = + CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] + } - def writeInternal(row: InternalRow): Unit + protected[sql] def writeInternal(row: InternalRow): Unit = { + write(converter(row)) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala new file mode 100644 index 0000000000000..c86ddd7c83e53 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.util.Utils + +class PartitionedWriteSuite extends QueryTest { + import TestSQLContext.implicits._ + + test("write many partitions") { + val path = Utils.createTempDir() + path.delete() + + val df = TestSQLContext.range(100).select($"id", lit(1).as("data")) + df.write.partitionBy("id").save(path.getCanonicalPath) + + checkAnswer( + TestSQLContext.read.load(path.getCanonicalPath), + (0 to 99).map(Row(1, _)).toSeq) + + Utils.deleteRecursively(path) + } + + test("write many partitions with repeats") { + val path = Utils.createTempDir() + path.delete() + + val base = TestSQLContext.range(100) + val df = base.unionAll(base).select($"id", lit(1).as("data")) + df.write.partitionBy("id").save(path.getCanonicalPath) + + checkAnswer( + TestSQLContext.read.load(path.getCanonicalPath), + (0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq) + + Utils.deleteRecursively(path) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 4a310ff4e9016..7c8704b47f286 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -66,7 +66,7 @@ private[orc] class OrcOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriterInternal with SparkHadoopMapRedUtil with HiveInspectors { + extends OutputWriter with SparkHadoopMapRedUtil with HiveInspectors { private val serializer = { val table = new Properties() @@ -120,7 +120,9 @@ private[orc] class OrcOutputWriter( ).asInstanceOf[RecordWriter[NullWritable, Writable]] } - override def writeInternal(row: InternalRow): Unit = { + override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") + + override protected[sql] def writeInternal(row: InternalRow): Unit = { var i = 0 while (i < row.numFields) { reusableOutputBuffer(i) = wrappers(i)(row.get(i, dataSchema(i).dataType)) From cd540c1e59561ad1fdac59af6170944c60e685d8 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Fri, 7 Aug 2015 17:19:48 -0700 Subject: [PATCH 226/340] [SPARK-9756] [ML] Make constructors in ML decision trees private These should be made private until there is a public constructor for providing `rootNode: Node` to use these constructors. jkbradley Author: Feynman Liang Closes #8046 from feynmanliang/SPARK-9756 and squashes the following commits: 2cbdf08 [Feynman Liang] Make RFRegressionModel aux constructor private a06f596 [Feynman Liang] Make constructors in ML decision trees private --- .../spark/ml/classification/DecisionTreeClassifier.scala | 2 +- .../spark/ml/classification/RandomForestClassifier.scala | 5 ++++- .../apache/spark/ml/regression/DecisionTreeRegressor.scala | 2 +- .../apache/spark/ml/regression/RandomForestRegressor.scala | 2 +- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index f2b992f8ba249..29598f3f05c2d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -117,7 +117,7 @@ final class DecisionTreeClassificationModel private[ml] ( * Construct a decision tree classification model. * @param rootNode Root node of tree, with other nodes attached. */ - def this(rootNode: Node, numClasses: Int) = + private[ml] def this(rootNode: Node, numClasses: Int) = this(Identifiable.randomUID("dtc"), rootNode, numClasses) override protected def predict(features: Vector): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index b59826a59499a..156050aaf7a45 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -136,7 +136,10 @@ final class RandomForestClassificationModel private[ml] ( * Construct a random forest classification model, with all trees weighted equally. * @param trees Component trees */ - def this(trees: Array[DecisionTreeClassificationModel], numFeatures: Int, numClasses: Int) = + private[ml] def this( + trees: Array[DecisionTreeClassificationModel], + numFeatures: Int, + numClasses: Int) = this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses) override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 4d30e4b5548aa..dc94a14014542 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -107,7 +107,7 @@ final class DecisionTreeRegressionModel private[ml] ( * Construct a decision tree regression model. * @param rootNode Root node of tree, with other nodes attached. */ - def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode) + private[ml] def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode) override protected def predict(features: Vector): Double = { rootNode.predictImpl(features).prediction diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 1ee43c8725732..db75c0d26392f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -125,7 +125,7 @@ final class RandomForestRegressionModel private[ml] ( * Construct a random forest regression model, with all trees weighted equally. * @param trees Component trees */ - def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) = + private[ml] def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) = this(Identifiable.randomUID("rfr"), trees, numFeatures) override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] From 85be65b39ce669f937a898195a844844d757666b Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Fri, 7 Aug 2015 17:21:12 -0700 Subject: [PATCH 227/340] [SPARK-9719] [ML] Clean up Naive Bayes doc Small documentation cleanups, including: * Adds documentation for `pi` and `theta` * setParam to `setModelType` Author: Feynman Liang Closes #8047 from feynmanliang/SPARK-9719 and squashes the following commits: b372438 [Feynman Liang] Clean up naive bayes doc --- .../scala/org/apache/spark/ml/classification/NaiveBayes.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index b46b676204e0e..97cbaf1fa8761 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -86,6 +86,7 @@ class NaiveBayes(override val uid: String) * Set the model type using a string (case-sensitive). * Supported options: "multinomial" and "bernoulli". * Default is "multinomial" + * @group setParam */ def setModelType(value: String): this.type = set(modelType, value) setDefault(modelType -> OldNaiveBayes.Multinomial) @@ -101,6 +102,9 @@ class NaiveBayes(override val uid: String) /** * Model produced by [[NaiveBayes]] + * @param pi log of class priors, whose dimension is C (number of classes) + * @param theta log of class conditional probabilities, whose dimension is C (number of classes) + * by D (number of features) */ class NaiveBayesModel private[ml] ( override val uid: String, From 998f4ff94df1d9db1c9e32c04091017c25cd4e81 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 7 Aug 2015 19:09:28 -0700 Subject: [PATCH 228/340] [SPARK-9754][SQL] Remove TypeCheck in debug package. TypeCheck no longer applies in the new "Tungsten" world. Author: Reynold Xin Closes #8043 from rxin/SPARK-9754 and squashes the following commits: 4ec471e [Reynold Xin] [SPARK-9754][SQL] Remove TypeCheck in debug package. --- .../spark/sql/execution/debug/package.scala | 104 +----------------- .../sql/execution/debug/DebuggingSuite.scala | 4 - 2 files changed, 4 insertions(+), 104 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index dd3858ea2b520..74892e4e13fa4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -17,21 +17,16 @@ package org.apache.spark.sql.execution -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.unsafe.types.UTF8String - import scala.collection.mutable.HashSet -import org.apache.spark.{AccumulatorParam, Accumulator, Logging} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.sql.types._ +import org.apache.spark.{Accumulator, AccumulatorParam, Logging} /** - * :: DeveloperApi :: * Contains methods for debugging query execution. * * Usage: @@ -53,10 +48,8 @@ package object debug { } /** - * :: DeveloperApi :: * Augments [[DataFrame]]s with debug methods. */ - @DeveloperApi implicit class DebugQuery(query: DataFrame) extends Logging { def debug(): Unit = { val plan = query.queryExecution.executedPlan @@ -72,23 +65,6 @@ package object debug { case _ => } } - - def typeCheck(): Unit = { - val plan = query.queryExecution.executedPlan - val visited = new collection.mutable.HashSet[TreeNodeRef]() - val debugPlan = plan transform { - case s: SparkPlan if !visited.contains(new TreeNodeRef(s)) => - visited += new TreeNodeRef(s) - TypeCheck(s) - } - try { - logDebug(s"Results returned: ${debugPlan.execute().count()}") - } catch { - case e: Exception => - def unwrap(e: Throwable): Throwable = if (e.getCause == null) e else unwrap(e.getCause) - logDebug(s"Deepest Error: ${unwrap(e)}") - } - } } private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode { @@ -148,76 +124,4 @@ package object debug { } } } - - /** - * Helper functions for checking that runtime types match a given schema. - */ - private[sql] object TypeCheck { - def typeCheck(data: Any, schema: DataType): Unit = (data, schema) match { - case (null, _) => - - case (row: InternalRow, s: StructType) => - row.toSeq(s).zip(s.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } - case (a: ArrayData, ArrayType(elemType, _)) => - a.foreach(elemType, (_, e) => { - typeCheck(e, elemType) - }) - case (m: MapData, MapType(keyType, valueType, _)) => - m.keyArray().foreach(keyType, (_, e) => { - typeCheck(e, keyType) - }) - m.valueArray().foreach(valueType, (_, e) => { - typeCheck(e, valueType) - }) - - case (_: Long, LongType) => - case (_: Int, IntegerType) => - case (_: UTF8String, StringType) => - case (_: Float, FloatType) => - case (_: Byte, ByteType) => - case (_: Short, ShortType) => - case (_: Boolean, BooleanType) => - case (_: Double, DoubleType) => - case (_: Int, DateType) => - case (_: Long, TimestampType) => - case (v, udt: UserDefinedType[_]) => typeCheck(v, udt.sqlType) - - case (d, t) => sys.error(s"Invalid data found: got $d (${d.getClass}) expected $t") - } - } - - /** - * Augments [[DataFrame]]s with debug methods. - */ - private[sql] case class TypeCheck(child: SparkPlan) extends SparkPlan { - import TypeCheck._ - - override def nodeName: String = "" - - /* Only required when defining this class in a REPL. - override def makeCopy(args: Array[Object]): this.type = - TypeCheck(args(0).asInstanceOf[SparkPlan]).asInstanceOf[this.type] - */ - - def output: Seq[Attribute] = child.output - - def children: List[SparkPlan] = child :: Nil - - protected override def doExecute(): RDD[InternalRow] = { - child.execute().map { row => - try typeCheck(row, child.schema) catch { - case e: Exception => - sys.error( - s""" - |ERROR WHEN TYPE CHECKING QUERY - |============================== - |$e - |======== BAD TREE ============ - |$child - """.stripMargin) - } - row - } - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 8ec3985e00360..239deb7973845 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -25,8 +25,4 @@ class DebuggingSuite extends SparkFunSuite { test("DataFrame.debug()") { testData.debug() } - - test("DataFrame.typeCheck()") { - testData.typeCheck() - } } From c564b27447ed99e55b359b3df1d586d5766b85ea Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 7 Aug 2015 20:04:17 -0700 Subject: [PATCH 229/340] [SPARK-9753] [SQL] TungstenAggregate should also accept InternalRow instead of just UnsafeRow https://issues.apache.org/jira/browse/SPARK-9753 This PR makes TungstenAggregate to accept `InternalRow` instead of just `UnsafeRow`. Also, it adds an `getAggregationBufferFromUnsafeRow` method to `UnsafeFixedWidthAggregationMap`. It is useful when we already have grouping keys stored in `UnsafeRow`s. Finally, it wraps `InputStream` and `OutputStream` in `UnsafeRowSerializer` with `BufferedInputStream` and `BufferedOutputStream`, respectively. Author: Yin Huai Closes #8041 from yhuai/joinedRowForProjection and squashes the following commits: 7753e34 [Yin Huai] Use BufferedInputStream and BufferedOutputStream. d68b74e [Yin Huai] Use joinedRow instead of UnsafeRowJoiner. e93c009 [Yin Huai] Add getAggregationBufferFromUnsafeRow for cases that the given groupingKeyRow is already an UnsafeRow. --- .../UnsafeFixedWidthAggregationMap.java | 4 ++ .../sql/execution/UnsafeRowSerializer.scala | 30 +++-------- .../aggregate/TungstenAggregate.scala | 4 +- .../TungstenAggregationIterator.scala | 51 +++++++++---------- 4 files changed, 39 insertions(+), 50 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index b08a4a13a28be..00218f213054b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -121,6 +121,10 @@ public UnsafeFixedWidthAggregationMap( public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { final UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey); + return getAggregationBufferFromUnsafeRow(unsafeGroupingKeyRow); + } + + public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRow) { // Probe our map using the serialized key final BytesToBytesMap.Location loc = map.lookup( unsafeGroupingKeyRow.getBaseObject(), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 39f8f992a9f00..6c7e5cacc99e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -58,27 +58,14 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst */ override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream { private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096) - // When `out` is backed by ChainedBufferOutputStream, we will get an - // UnsupportedOperationException when we call dOut.writeInt because it internally calls - // ChainedBufferOutputStream's write(b: Int), which is not supported. - // To workaround this issue, we create an array for sorting the int value. - // To reproduce the problem, use dOut.writeInt(row.getSizeInBytes) and - // run SparkSqlSerializer2SortMergeShuffleSuite. - private[this] var intBuffer: Array[Byte] = new Array[Byte](4) - private[this] val dOut: DataOutputStream = new DataOutputStream(out) + private[this] val dOut: DataOutputStream = + new DataOutputStream(new BufferedOutputStream(out)) override def writeValue[T: ClassTag](value: T): SerializationStream = { val row = value.asInstanceOf[UnsafeRow] - val size = row.getSizeInBytes - // This part is based on DataOutputStream's writeInt. - // It is for dOut.writeInt(row.getSizeInBytes). - intBuffer(0) = ((size >>> 24) & 0xFF).toByte - intBuffer(1) = ((size >>> 16) & 0xFF).toByte - intBuffer(2) = ((size >>> 8) & 0xFF).toByte - intBuffer(3) = ((size >>> 0) & 0xFF).toByte - dOut.write(intBuffer, 0, 4) - - row.writeToStream(out, writeBuffer) + + dOut.writeInt(row.getSizeInBytes) + row.writeToStream(dOut, writeBuffer) this } @@ -105,7 +92,6 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def close(): Unit = { writeBuffer = null - intBuffer = null dOut.writeInt(EOF) dOut.close() } @@ -113,7 +99,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def deserializeStream(in: InputStream): DeserializationStream = { new DeserializationStream { - private[this] val dIn: DataInputStream = new DataInputStream(in) + private[this] val dIn: DataInputStream = new DataInputStream(new BufferedInputStream(in)) // 1024 is a default buffer size; this buffer will grow to accommodate larger rows private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024) private[this] var row: UnsafeRow = new UnsafeRow() @@ -129,7 +115,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst if (rowBuffer.length < rowSize) { rowBuffer = new Array[Byte](rowSize) } - ByteStreams.readFully(in, rowBuffer, 0, rowSize) + ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize) rowSize = dIn.readInt() // read the next row's size if (rowSize == EOF) { // We are returning the last row in this stream @@ -163,7 +149,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst if (rowBuffer.length < rowSize) { rowBuffer = new Array[Byte](rowSize) } - ByteStreams.readFully(in, rowBuffer, 0, rowSize) + ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize) row.asInstanceOf[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index c3dcbd2b71ee8..1694794a53d9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -39,7 +39,7 @@ case class TungstenAggregate( override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false + override def canProcessSafeRows: Boolean = true override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -77,7 +77,7 @@ case class TungstenAggregate( resultExpressions, newMutableProjection, child.output, - iter.asInstanceOf[Iterator[UnsafeRow]], + iter, testFallbackStartsAt) if (!hasInput && groupingExpressions.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 440bef32f4e9b..32160906c3bc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -22,6 +22,7 @@ import org.apache.spark.{InternalAccumulator, Logging, SparkEnv, TaskContext} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap} import org.apache.spark.sql.types.StructType @@ -46,8 +47,7 @@ import org.apache.spark.sql.types.StructType * processing input rows from inputIter, and generating output * rows. * - Part 3: Methods and fields used by hash-based aggregation. - * - Part 4: The function used to switch this iterator from hash-based - * aggregation to sort-based aggregation. + * - Part 4: Methods and fields used when we switch to sort-based aggregation. * - Part 5: Methods and fields used by sort-based aggregation. * - Part 6: Loads input and process input rows. * - Part 7: Public methods of this iterator. @@ -82,7 +82,7 @@ class TungstenAggregationIterator( resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), originalInputAttributes: Seq[Attribute], - inputIter: Iterator[UnsafeRow], + inputIter: Iterator[InternalRow], testFallbackStartsAt: Option[Int]) extends Iterator[UnsafeRow] with Logging { @@ -174,13 +174,10 @@ class TungstenAggregationIterator( // Creates a function used to process a row based on the given inputAttributes. private def generateProcessRow( - inputAttributes: Seq[Attribute]): (UnsafeRow, UnsafeRow) => Unit = { + inputAttributes: Seq[Attribute]): (UnsafeRow, InternalRow) => Unit = { val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes) - val aggregationBufferSchema = StructType.fromAttributes(aggregationBufferAttributes) - val inputSchema = StructType.fromAttributes(inputAttributes) - val unsafeRowJoiner = - GenerateUnsafeRowJoiner.create(aggregationBufferSchema, inputSchema) + val joinedRow = new JoinedRow() aggregationMode match { // Partial-only @@ -189,9 +186,9 @@ class TungstenAggregationIterator( val algebraicUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() - (currentBuffer: UnsafeRow, row: UnsafeRow) => { + (currentBuffer: UnsafeRow, row: InternalRow) => { algebraicUpdateProjection.target(currentBuffer) - algebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row)) + algebraicUpdateProjection(joinedRow(currentBuffer, row)) } // PartialMerge-only or Final-only @@ -203,10 +200,10 @@ class TungstenAggregationIterator( mergeExpressions, aggregationBufferAttributes ++ inputAttributes)() - (currentBuffer: UnsafeRow, row: UnsafeRow) => { + (currentBuffer: UnsafeRow, row: InternalRow) => { // Process all algebraic aggregate functions. algebraicMergeProjection.target(currentBuffer) - algebraicMergeProjection(unsafeRowJoiner.join(currentBuffer, row)) + algebraicMergeProjection(joinedRow(currentBuffer, row)) } // Final-Complete @@ -233,8 +230,8 @@ class TungstenAggregationIterator( val completeAlgebraicUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() - (currentBuffer: UnsafeRow, row: UnsafeRow) => { - val input = unsafeRowJoiner.join(currentBuffer, row) + (currentBuffer: UnsafeRow, row: InternalRow) => { + val input = joinedRow(currentBuffer, row) // For all aggregate functions with mode Complete, update the given currentBuffer. completeAlgebraicUpdateProjection.target(currentBuffer)(input) @@ -253,14 +250,14 @@ class TungstenAggregationIterator( val completeAlgebraicUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() - (currentBuffer: UnsafeRow, row: UnsafeRow) => { + (currentBuffer: UnsafeRow, row: InternalRow) => { completeAlgebraicUpdateProjection.target(currentBuffer) // For all aggregate functions with mode Complete, update the given currentBuffer. - completeAlgebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row)) + completeAlgebraicUpdateProjection(joinedRow(currentBuffer, row)) } // Grouping only. - case (None, None) => (currentBuffer: UnsafeRow, row: UnsafeRow) => {} + case (None, None) => (currentBuffer: UnsafeRow, row: InternalRow) => {} case other => throw new IllegalStateException( @@ -272,15 +269,16 @@ class TungstenAggregationIterator( private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow = { val groupingAttributes = groupingExpressions.map(_.toAttribute) - val groupingKeySchema = StructType.fromAttributes(groupingAttributes) val bufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes) - val bufferSchema = StructType.fromAttributes(bufferAttributes) - val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) aggregationMode match { // Partial-only or PartialMerge-only: every output row is basically the values of // the grouping expressions and the corresponding aggregation buffer. case (Some(Partial), None) | (Some(PartialMerge), None) => + val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + val bufferSchema = StructType.fromAttributes(bufferAttributes) + val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { unsafeRowJoiner.join(currentGroupingKey, currentBuffer) } @@ -288,11 +286,12 @@ class TungstenAggregationIterator( // Final-only, Complete-only and Final-Complete: a output row is generated based on // resultExpressions. case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => + val joinedRow = new JoinedRow() val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes ++ bufferAttributes) (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { - resultProjection(unsafeRowJoiner.join(currentGroupingKey, currentBuffer)) + resultProjection(joinedRow(currentGroupingKey, currentBuffer)) } // Grouping-only: a output row is generated from values of grouping expressions. @@ -316,7 +315,7 @@ class TungstenAggregationIterator( // A function used to process a input row. Its first argument is the aggregation buffer // and the second argument is the input row. - private[this] var processRow: (UnsafeRow, UnsafeRow) => Unit = + private[this] var processRow: (UnsafeRow, InternalRow) => Unit = generateProcessRow(originalInputAttributes) // A function used to generate output rows based on the grouping keys (first argument) @@ -354,7 +353,7 @@ class TungstenAggregationIterator( while (!sortBased && inputIter.hasNext) { val newInput = inputIter.next() val groupingKey = groupProjection.apply(newInput) - val buffer: UnsafeRow = hashMap.getAggregationBuffer(groupingKey) + val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) if (buffer == null) { // buffer == null means that we could not allocate more memory. // Now, we need to spill the map and switch to sort-based aggregation. @@ -374,7 +373,7 @@ class TungstenAggregationIterator( val newInput = inputIter.next() val groupingKey = groupProjection.apply(newInput) val buffer: UnsafeRow = if (i < fallbackStartsAt) { - hashMap.getAggregationBuffer(groupingKey) + hashMap.getAggregationBufferFromUnsafeRow(groupingKey) } else { null } @@ -397,7 +396,7 @@ class TungstenAggregationIterator( private[this] var mapIteratorHasNext: Boolean = false /////////////////////////////////////////////////////////////////////////// - // Part 3: Methods and fields used by sort-based aggregation. + // Part 4: Methods and fields used when we switch to sort-based aggregation. /////////////////////////////////////////////////////////////////////////// // This sorter is used for sort-based aggregation. It is initialized as soon as @@ -407,7 +406,7 @@ class TungstenAggregationIterator( /** * Switch to sort-based aggregation when the hash-based approach is unable to acquire memory. */ - private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: UnsafeRow): Unit = { + private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = { logInfo("falling back to sort based aggregation.") // Step 1: Get the ExternalSorter containing sorted entries of the map. externalSorter = hashMap.destructAndCreateExternalSorter() From ef062c15992b0d08554495b8ea837bef3fabf6e9 Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Fri, 7 Aug 2015 23:36:26 -0700 Subject: [PATCH 230/340] [SPARK-9731] Standalone scheduling incorrect cores if spark.executor.cores is not set The issue only happens if `spark.executor.cores` is not set and executor memory is set to a high value. For example, if we have a worker with 4G and 10 cores and we set `spark.executor.memory` to 3G, then only 1 core is assigned to the executor. The correct number should be 10 cores. I've added a unit test to illustrate the issue. Author: Carson Wang Closes #8017 from carsonwang/SPARK-9731 and squashes the following commits: d09ec48 [Carson Wang] Fix code style 86b651f [Carson Wang] Simplify the code 943cc4c [Carson Wang] fix scheduling correct cores to executors --- .../apache/spark/deploy/master/Master.scala | 26 ++++++++++--------- .../spark/deploy/master/MasterSuite.scala | 15 +++++++++++ 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index e38e437fe1c5a..9217202b69a66 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -581,20 +581,22 @@ private[deploy] class Master( /** Return whether the specified worker can launch an executor for this app. */ def canLaunchExecutor(pos: Int): Boolean = { + val keepScheduling = coresToAssign >= minCoresPerExecutor + val enoughCores = usableWorkers(pos).coresFree - assignedCores(pos) >= minCoresPerExecutor + // If we allow multiple executors per worker, then we can always launch new executors. - // Otherwise, we may have already started assigning cores to the executor on this worker. + // Otherwise, if there is already an executor on this worker, just give it more cores. val launchingNewExecutor = !oneExecutorPerWorker || assignedExecutors(pos) == 0 - val underLimit = - if (launchingNewExecutor) { - assignedExecutors.sum + app.executors.size < app.executorLimit - } else { - true - } - val assignedMemory = assignedExecutors(pos) * memoryPerExecutor - usableWorkers(pos).memoryFree - assignedMemory >= memoryPerExecutor && - usableWorkers(pos).coresFree - assignedCores(pos) >= minCoresPerExecutor && - coresToAssign >= minCoresPerExecutor && - underLimit + if (launchingNewExecutor) { + val assignedMemory = assignedExecutors(pos) * memoryPerExecutor + val enoughMemory = usableWorkers(pos).memoryFree - assignedMemory >= memoryPerExecutor + val underLimit = assignedExecutors.sum + app.executors.size < app.executorLimit + keepScheduling && enoughCores && enoughMemory && underLimit + } else { + // We're adding cores to an existing executor, so no need + // to check memory and executor limits + keepScheduling && enoughCores + } } // Keep launching executors until no more workers can accommodate any diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index ae0e037d822ea..20d0201a364ab 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -151,6 +151,14 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva basicScheduling(spreadOut = false) } + test("basic scheduling with more memory - spread out") { + basicSchedulingWithMoreMemory(spreadOut = true) + } + + test("basic scheduling with more memory - no spread out") { + basicSchedulingWithMoreMemory(spreadOut = false) + } + test("scheduling with max cores - spread out") { schedulingWithMaxCores(spreadOut = true) } @@ -214,6 +222,13 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva assert(scheduledCores === Array(10, 10, 10)) } + private def basicSchedulingWithMoreMemory(spreadOut: Boolean): Unit = { + val master = makeMaster() + val appInfo = makeAppInfo(3072) + val scheduledCores = scheduleExecutorsOnWorkers(master, appInfo, workerInfos, spreadOut) + assert(scheduledCores === Array(10, 10, 10)) + } + private def schedulingWithMaxCores(spreadOut: Boolean): Unit = { val master = makeMaster() val appInfo1 = makeAppInfo(1024, maxCores = Some(8)) From 11caf1ce290b6931647c2f71268f847d1d48930e Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 8 Aug 2015 18:09:48 +0800 Subject: [PATCH 231/340] [SPARK-4176] [SQL] [MINOR] Should use unscaled Long to write decimals for precision <= 18 rather than 8 This PR fixes a minor bug introduced in #7455: when writing decimals, we should use the unscaled Long for better performance when the precision <= 18 rather than 8 (should be a typo). This bug doesn't affect correctness, but hurts Parquet decimal writing performance. This PR also replaced similar magic numbers with newly defined constants. Author: Cheng Lian Closes #8031 from liancheng/spark-4176/minor-fix-for-writing-decimals and squashes the following commits: 10d4ea3 [Cheng Lian] Should use unscaled Long to write decimals for precision <= 18 rather than 8 --- .../sql/parquet/CatalystRowConverter.scala | 2 +- .../sql/parquet/CatalystSchemaConverter.scala | 29 +++++++++++-------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala index 6938b071065cd..4fe8a39f20abd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala @@ -264,7 +264,7 @@ private[parquet] class CatalystRowConverter( val scale = decimalType.scale val bytes = value.getBytes - if (precision <= 8) { + if (precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64) { // Constructs a `Decimal` with an unscaled `Long` value if possible. var unscaled = 0L var i = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala index d43ca95b4eea0..b12149dcf1c92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -25,6 +25,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.parquet.schema.Type.Repetition._ import org.apache.parquet.schema._ +import org.apache.spark.sql.parquet.CatalystSchemaConverter.{MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64, maxPrecisionForBytes} import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLConf} @@ -155,7 +156,7 @@ private[parquet] class CatalystSchemaConverter( case INT_16 => ShortType case INT_32 | null => IntegerType case DATE => DateType - case DECIMAL => makeDecimalType(maxPrecisionForBytes(4)) + case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT32) case TIME_MILLIS => typeNotImplemented() case _ => illegalType() } @@ -163,7 +164,7 @@ private[parquet] class CatalystSchemaConverter( case INT64 => originalType match { case INT_64 | null => LongType - case DECIMAL => makeDecimalType(maxPrecisionForBytes(8)) + case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT64) case TIMESTAMP_MILLIS => typeNotImplemented() case _ => illegalType() } @@ -405,7 +406,7 @@ private[parquet] class CatalystSchemaConverter( // Uses INT32 for 1 <= precision <= 9 case DecimalType.Fixed(precision, scale) - if precision <= maxPrecisionForBytes(4) && followParquetFormatSpec => + if precision <= MAX_PRECISION_FOR_INT32 && followParquetFormatSpec => Types .primitive(INT32, repetition) .as(DECIMAL) @@ -415,7 +416,7 @@ private[parquet] class CatalystSchemaConverter( // Uses INT64 for 1 <= precision <= 18 case DecimalType.Fixed(precision, scale) - if precision <= maxPrecisionForBytes(8) && followParquetFormatSpec => + if precision <= MAX_PRECISION_FOR_INT64 && followParquetFormatSpec => Types .primitive(INT64, repetition) .as(DECIMAL) @@ -534,14 +535,6 @@ private[parquet] class CatalystSchemaConverter( throw new AnalysisException(s"Unsupported data type $field.dataType") } } - - // Max precision of a decimal value stored in `numBytes` bytes - private def maxPrecisionForBytes(numBytes: Int): Int = { - Math.round( // convert double to long - Math.floor(Math.log10( // number of base-10 digits - Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes - .asInstanceOf[Int] - } } @@ -584,4 +577,16 @@ private[parquet] object CatalystSchemaConverter { computeMinBytesForPrecision(precision) } } + + val MAX_PRECISION_FOR_INT32 = maxPrecisionForBytes(4) + + val MAX_PRECISION_FOR_INT64 = maxPrecisionForBytes(8) + + // Max precision of a decimal value stored in `numBytes` bytes + def maxPrecisionForBytes(numBytes: Int): Int = { + Math.round( // convert double to long + Math.floor(Math.log10( // number of base-10 digits + Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes + .asInstanceOf[Int] + } } From 106c0789d8c83c7081bc9a335df78ba728e95872 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 8 Aug 2015 08:33:14 -0700 Subject: [PATCH 232/340] [SPARK-9738] [SQL] remove FromUnsafe and add its codegen version to GenerateSafe In https://github.com/apache/spark/pull/7752 we added `FromUnsafe` to convert nexted unsafe data like array/map/struct to safe versions. It's a quick solution and we already have `GenerateSafe` to do the conversion which is codegened. So we should remove `FromUnsafe` and implement its codegen version in `GenerateSafe`. Author: Wenchen Fan Closes #8029 from cloud-fan/from-unsafe and squashes the following commits: ed40d8f [Wenchen Fan] add the copy back a93fd4b [Wenchen Fan] cogengen FromUnsafe --- .../sql/catalyst/expressions/FromUnsafe.scala | 70 ---------- .../sql/catalyst/expressions/Projection.scala | 8 +- .../codegen/GenerateSafeProjection.scala | 120 +++++++++++++----- .../execution/RowFormatConvertersSuite.scala | 4 +- 4 files changed, 95 insertions(+), 107 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala deleted file mode 100644 index 9b960b136f984..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -case class FromUnsafe(child: Expression) extends UnaryExpression - with ExpectsInputTypes with CodegenFallback { - - override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(ArrayType, StructType, MapType)) - - override def dataType: DataType = child.dataType - - private def convert(value: Any, dt: DataType): Any = dt match { - case StructType(fields) => - val row = value.asInstanceOf[UnsafeRow] - val result = new Array[Any](fields.length) - fields.map(_.dataType).zipWithIndex.foreach { case (dt, i) => - if (!row.isNullAt(i)) { - result(i) = convert(row.get(i, dt), dt) - } - } - new GenericInternalRow(result) - - case ArrayType(elementType, _) => - val array = value.asInstanceOf[UnsafeArrayData] - val length = array.numElements() - val result = new Array[Any](length) - var i = 0 - while (i < length) { - if (!array.isNullAt(i)) { - result(i) = convert(array.get(i, elementType), elementType) - } - i += 1 - } - new GenericArrayData(result) - - case StringType => value.asInstanceOf[UTF8String].clone() - - case MapType(kt, vt, _) => - val map = value.asInstanceOf[UnsafeMapData] - val safeKeyArray = convert(map.keys, ArrayType(kt)).asInstanceOf[GenericArrayData] - val safeValueArray = convert(map.values, ArrayType(vt)).asInstanceOf[GenericArrayData] - new ArrayBasedMapData(safeKeyArray, safeValueArray) - - case _ => value - } - - override def nullSafeEval(input: Any): Any = { - convert(input, dataType) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 796bc327a3db1..afe52e6a667eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -152,13 +152,7 @@ object FromUnsafeProjection { */ def apply(fields: Seq[DataType]): Projection = { create(fields.zipWithIndex.map(x => { - val b = new BoundReference(x._2, x._1, true) - // todo: this is quite slow, maybe remove this whole projection after remove generic getter of - // InternalRow? - b.dataType match { - case _: StructType | _: ArrayType | _: MapType => FromUnsafe(b) - case _ => b - } + new BoundReference(x._2, x._1, true) })) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index f06ffc5449e76..ef08ddf041afc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp -import org.apache.spark.sql.types.{StringType, StructType, DataType} +import org.apache.spark.sql.types._ /** @@ -36,34 +36,94 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = in.map(BindReferences.bindReference(_, inputSchema)) - private def genUpdater( + private def createCodeForStruct( ctx: CodeGenContext, - setter: String, - dataType: DataType, - ordinal: Int, - value: String): String = { - dataType match { - case struct: StructType => - val rowTerm = ctx.freshName("row") - val updates = struct.map(_.dataType).zipWithIndex.map { case (dt, i) => - val colTerm = ctx.freshName("col") - s""" - if ($value.isNullAt($i)) { - $rowTerm.setNullAt($i); - } else { - ${ctx.javaType(dt)} $colTerm = ${ctx.getValue(value, dt, s"$i")}; - ${genUpdater(ctx, rowTerm, dt, i, colTerm)}; - } - """ - }.mkString("\n") - s""" - $genericMutableRowType $rowTerm = new $genericMutableRowType(${struct.fields.length}); - $updates - $setter.update($ordinal, $rowTerm.copy()); - """ - case _ => - ctx.setColumn(setter, dataType, ordinal, value) - } + input: String, + schema: StructType): GeneratedExpressionCode = { + val tmp = ctx.freshName("tmp") + val output = ctx.freshName("safeRow") + val values = ctx.freshName("values") + val rowClass = classOf[GenericInternalRow].getName + + val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => + val converter = convertToSafe(ctx, ctx.getValue(tmp, dt, i.toString), dt) + s""" + if (!$tmp.isNullAt($i)) { + ${converter.code} + $values[$i] = ${converter.primitive}; + } + """ + }.mkString("\n") + + val code = s""" + final InternalRow $tmp = $input; + final Object[] $values = new Object[${schema.length}]; + $fieldWriters + final InternalRow $output = new $rowClass($values); + """ + + GeneratedExpressionCode(code, "false", output) + } + + private def createCodeForArray( + ctx: CodeGenContext, + input: String, + elementType: DataType): GeneratedExpressionCode = { + val tmp = ctx.freshName("tmp") + val output = ctx.freshName("safeArray") + val values = ctx.freshName("values") + val numElements = ctx.freshName("numElements") + val index = ctx.freshName("index") + val arrayClass = classOf[GenericArrayData].getName + + val elementConverter = convertToSafe(ctx, ctx.getValue(tmp, elementType, index), elementType) + val code = s""" + final ArrayData $tmp = $input; + final int $numElements = $tmp.numElements(); + final Object[] $values = new Object[$numElements]; + for (int $index = 0; $index < $numElements; $index++) { + if (!$tmp.isNullAt($index)) { + ${elementConverter.code} + $values[$index] = ${elementConverter.primitive}; + } + } + final ArrayData $output = new $arrayClass($values); + """ + + GeneratedExpressionCode(code, "false", output) + } + + private def createCodeForMap( + ctx: CodeGenContext, + input: String, + keyType: DataType, + valueType: DataType): GeneratedExpressionCode = { + val tmp = ctx.freshName("tmp") + val output = ctx.freshName("safeMap") + val mapClass = classOf[ArrayBasedMapData].getName + + val keyConverter = createCodeForArray(ctx, s"$tmp.keyArray()", keyType) + val valueConverter = createCodeForArray(ctx, s"$tmp.valueArray()", valueType) + val code = s""" + final MapData $tmp = $input; + ${keyConverter.code} + ${valueConverter.code} + final MapData $output = new $mapClass(${keyConverter.primitive}, ${valueConverter.primitive}); + """ + + GeneratedExpressionCode(code, "false", output) + } + + private def convertToSafe( + ctx: CodeGenContext, + input: String, + dataType: DataType): GeneratedExpressionCode = dataType match { + case s: StructType => createCodeForStruct(ctx, input, s) + case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) + case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) + // UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe. + case StringType => GeneratedExpressionCode("", "false", s"$input.clone()") + case _ => GeneratedExpressionCode("", "false", input) } protected def create(expressions: Seq[Expression]): Projection = { @@ -72,12 +132,14 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] case (NoOp, _) => "" case (e, i) => val evaluationCode = e.gen(ctx) + val converter = convertToSafe(ctx, evaluationCode.primitive, e.dataType) evaluationCode.code + s""" if (${evaluationCode.isNull}) { mutableRow.setNullAt($i); } else { - ${genUpdater(ctx, "mutableRow", e.dataType, i, evaluationCode.primitive)}; + ${converter.code} + ${ctx.setColumn("mutableRow", e.dataType, i, converter.primitive)}; } """ } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 322966f423784..dd08e9025a927 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -112,7 +112,9 @@ case class DummyPlan(child: SparkPlan) extends UnaryNode { override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitions { iter => - // cache all strings to make sure we have deep copied UTF8String inside incoming + // This `DummyPlan` is in safe mode, so we don't need to do copy even we hold some + // values gotten from the incoming rows. + // we cache all strings here to make sure we have deep copied UTF8String inside incoming // safe InternalRow. val strings = new scala.collection.mutable.ArrayBuffer[UTF8String] iter.foreach { row => From 74a6541aa82bcd7a052b2e57b5ca55b7c316495b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 8 Aug 2015 08:36:14 -0700 Subject: [PATCH 233/340] [SPARK-4561] [PYSPARK] [SQL] turn Row into dict recursively Add an option `recursive` to `Row.asDict()`, when True (default is False), it will convert the nested Row into dict. Author: Davies Liu Closes #8006 from davies/as_dict and squashes the following commits: 922cc5a [Davies Liu] turn Row into dict recursively --- python/pyspark/sql/types.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 6f74b7162f7cc..e2e6f03ae9fd7 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1197,13 +1197,36 @@ def __new__(self, *args, **kwargs): else: raise ValueError("No args or kwargs") - def asDict(self): + def asDict(self, recursive=False): """ Return as an dict + + :param recursive: turns the nested Row as dict (default: False). + + >>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11} + True + >>> row = Row(key=1, value=Row(name='a', age=2)) + >>> row.asDict() == {'key': 1, 'value': Row(age=2, name='a')} + True + >>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}} + True """ if not hasattr(self, "__fields__"): raise TypeError("Cannot convert a Row class into dict") - return dict(zip(self.__fields__, self)) + + if recursive: + def conv(obj): + if isinstance(obj, Row): + return obj.asDict(True) + elif isinstance(obj, list): + return [conv(o) for o in obj] + elif isinstance(obj, dict): + return dict((k, conv(v)) for k, v in obj.items()) + else: + return obj + return dict(zip(self.__fields__, (conv(o) for o in self))) + else: + return dict(zip(self.__fields__, self)) # let object acts like class def __call__(self, *args): From ac507a03c3371cd5404ca195ee0ba0306badfc23 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 8 Aug 2015 08:38:18 -0700 Subject: [PATCH 234/340] [SPARK-6902] [SQL] [PYSPARK] Row should be read-only Raise an read-only exception when user try to mutable a Row. Author: Davies Liu Closes #8009 from davies/readonly_row and squashes the following commits: 8722f3f [Davies Liu] add tests 05a3d36 [Davies Liu] Row should be read-only --- python/pyspark/sql/tests.py | 15 +++++++++++++++ python/pyspark/sql/types.py | 5 +++++ 2 files changed, 20 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1e3444dd9e3b4..38c83c427a747 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -179,6 +179,21 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name, ignore_errors=True) + def test_row_should_be_read_only(self): + row = Row(a=1, b=2) + self.assertEqual(1, row.a) + + def foo(): + row.a = 3 + self.assertRaises(Exception, foo) + + row2 = self.sqlCtx.range(10).first() + self.assertEqual(0, row2.id) + + def foo2(): + row2.id = 2 + self.assertRaises(Exception, foo2) + def test_range(self): self.assertEqual(self.sqlCtx.range(1, 1).count(), 0) self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index e2e6f03ae9fd7..c083bf89905bf 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1246,6 +1246,11 @@ def __getattr__(self, item): except ValueError: raise AttributeError(item) + def __setattr__(self, key, value): + if key != '__fields__': + raise Exception("Row is read-only") + self.__dict__[key] = value + def __reduce__(self): """Returns a tuple so Python knows how to pickle Row.""" if hasattr(self, "__fields__"): From 23695f1d2d7ef9f3ea92cebcd96b1cf0e8904eb4 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sat, 8 Aug 2015 11:01:25 -0700 Subject: [PATCH 235/340] [SPARK-9728][SQL]Support CalendarIntervalType in HiveQL This PR enables converting interval term in HiveQL to CalendarInterval Literal. JIRA: https://issues.apache.org/jira/browse/SPARK-9728 Author: Yijie Shen Closes #8034 from yjshen/interval_hiveql and squashes the following commits: 7fe9a5e [Yijie Shen] declare throw exception and add unit test fce7795 [Yijie Shen] convert hiveql interval term into CalendarInterval literal --- .../org/apache/spark/sql/hive/HiveQl.scala | 25 +++ .../apache/spark/sql/hive/HiveQlSuite.scala | 15 ++ .../sql/hive/execution/SQLQuerySuite.scala | 22 +++ .../spark/unsafe/types/CalendarInterval.java | 156 ++++++++++++++++++ .../unsafe/types/CalendarIntervalSuite.java | 91 ++++++++++ 5 files changed, 309 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 7d7b4b9167306..c3f29350101d3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -45,6 +45,7 @@ import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable, HiveScriptIOSchema} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.random.RandomSampler /* Implicit conversions */ @@ -1519,6 +1520,30 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case ast: ASTNode if ast.getType == HiveParser.TOK_CHARSETLITERAL => Literal(BaseSemanticAnalyzer.charSetString(ast.getChild(0).getText, ast.getChild(1).getText)) + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => + Literal(CalendarInterval.fromYearMonthString(ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_DAY_TIME_LITERAL => + Literal(CalendarInterval.fromDayTimeString(ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_YEAR_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("year", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_MONTH_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("month", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_DAY_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("day", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_HOUR_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("hour", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_MINUTE_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("minute", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_SECOND_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("second", ast.getText)) + case a: ASTNode => throw new NotImplementedError( s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText} : diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index f765395e148af..79cf40aba4bf2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -175,4 +175,19 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { assert(desc.serde == Option("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe")) assert(desc.properties == Map(("tbl_p1" -> "p11"), ("tbl_p2" -> "p22"))) } + + test("Invalid interval term should throw AnalysisException") { + def assertError(sql: String, errorMessage: String): Unit = { + val e = intercept[AnalysisException] { + HiveQl.parseSql(sql) + } + assert(e.getMessage.contains(errorMessage)) + } + assertError("select interval '42-32' year to month", + "month 32 outside range [0, 11]") + assertError("select interval '5 49:12:15' day to second", + "hour 49 outside range [0, 23]") + assertError("select interval '.1111111111' second", + "nanosecond 1111111111 outside range") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 1dff07a6de8ad..2fa7ae3fa2e12 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval case class Nested1(f1: Nested2) case class Nested2(f2: Nested3) @@ -1115,4 +1116,25 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) } + + test("Convert hive interval term into Literal of CalendarIntervalType") { + checkAnswer(sql("select interval '10-9' year to month"), + Row(CalendarInterval.fromString("interval 10 years 9 months"))) + checkAnswer(sql("select interval '20 15:40:32.99899999' day to second"), + Row(CalendarInterval.fromString("interval 2 weeks 6 days 15 hours 40 minutes " + + "32 seconds 99 milliseconds 899 microseconds"))) + checkAnswer(sql("select interval '30' year"), + Row(CalendarInterval.fromString("interval 30 years"))) + checkAnswer(sql("select interval '25' month"), + Row(CalendarInterval.fromString("interval 25 months"))) + checkAnswer(sql("select interval '-100' day"), + Row(CalendarInterval.fromString("interval -14 weeks -2 days"))) + checkAnswer(sql("select interval '40' hour"), + Row(CalendarInterval.fromString("interval 1 days 16 hours"))) + checkAnswer(sql("select interval '80' minute"), + Row(CalendarInterval.fromString("interval 1 hour 20 minutes"))) + checkAnswer(sql("select interval '299.889987299' second"), + Row(CalendarInterval.fromString( + "interval 4 minutes 59 seconds 889 milliseconds 987 microseconds"))) + } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java index 92a5e4f86f234..30e1758076361 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java @@ -50,6 +50,14 @@ private static String unitRegex(String unit) { unitRegex("week") + unitRegex("day") + unitRegex("hour") + unitRegex("minute") + unitRegex("second") + unitRegex("millisecond") + unitRegex("microsecond")); + private static Pattern yearMonthPattern = + Pattern.compile("^(?:['|\"])?([+|-])?(\\d+)-(\\d+)(?:['|\"])?$"); + + private static Pattern dayTimePattern = + Pattern.compile("^(?:['|\"])?([+|-])?(\\d+) (\\d+):(\\d+):(\\d+)(\\.(\\d+))?(?:['|\"])?$"); + + private static Pattern quoteTrimPattern = Pattern.compile("^(?:['|\"])?(.*?)(?:['|\"])?$"); + private static long toLong(String s) { if (s == null) { return 0; @@ -79,6 +87,154 @@ public static CalendarInterval fromString(String s) { } } + public static long toLongWithRange(String fieldName, + String s, long minValue, long maxValue) throws IllegalArgumentException { + long result = 0; + if (s != null) { + result = Long.valueOf(s); + if (result < minValue || result > maxValue) { + throw new IllegalArgumentException(String.format("%s %d outside range [%d, %d]", + fieldName, result, minValue, maxValue)); + } + } + return result; + } + + /** + * Parse YearMonth string in form: [-]YYYY-MM + * + * adapted from HiveIntervalYearMonth.valueOf + */ + public static CalendarInterval fromYearMonthString(String s) throws IllegalArgumentException { + CalendarInterval result = null; + if (s == null) { + throw new IllegalArgumentException("Interval year-month string was null"); + } + s = s.trim(); + Matcher m = yearMonthPattern.matcher(s); + if (!m.matches()) { + throw new IllegalArgumentException( + "Interval string does not match year-month format of 'y-m': " + s); + } else { + try { + int sign = m.group(1) != null && m.group(1).equals("-") ? -1 : 1; + int years = (int) toLongWithRange("year", m.group(2), 0, Integer.MAX_VALUE); + int months = (int) toLongWithRange("month", m.group(3), 0, 11); + result = new CalendarInterval(sign * (years * 12 + months), 0); + } catch (Exception e) { + throw new IllegalArgumentException( + "Error parsing interval year-month string: " + e.getMessage(), e); + } + } + return result; + } + + /** + * Parse dayTime string in form: [-]d HH:mm:ss.nnnnnnnnn + * + * adapted from HiveIntervalDayTime.valueOf + */ + public static CalendarInterval fromDayTimeString(String s) throws IllegalArgumentException { + CalendarInterval result = null; + if (s == null) { + throw new IllegalArgumentException("Interval day-time string was null"); + } + s = s.trim(); + Matcher m = dayTimePattern.matcher(s); + if (!m.matches()) { + throw new IllegalArgumentException( + "Interval string does not match day-time format of 'd h:m:s.n': " + s); + } else { + try { + int sign = m.group(1) != null && m.group(1).equals("-") ? -1 : 1; + long days = toLongWithRange("day", m.group(2), 0, Integer.MAX_VALUE); + long hours = toLongWithRange("hour", m.group(3), 0, 23); + long minutes = toLongWithRange("minute", m.group(4), 0, 59); + long seconds = toLongWithRange("second", m.group(5), 0, 59); + // Hive allow nanosecond precision interval + long nanos = toLongWithRange("nanosecond", m.group(7), 0L, 999999999L); + result = new CalendarInterval(0, sign * ( + days * MICROS_PER_DAY + hours * MICROS_PER_HOUR + minutes * MICROS_PER_MINUTE + + seconds * MICROS_PER_SECOND + nanos / 1000L)); + } catch (Exception e) { + throw new IllegalArgumentException( + "Error parsing interval day-time string: " + e.getMessage(), e); + } + } + return result; + } + + public static CalendarInterval fromSingleUnitString(String unit, String s) + throws IllegalArgumentException { + + CalendarInterval result = null; + if (s == null) { + throw new IllegalArgumentException(String.format("Interval %s string was null", unit)); + } + s = s.trim(); + Matcher m = quoteTrimPattern.matcher(s); + if (!m.matches()) { + throw new IllegalArgumentException( + "Interval string does not match day-time format of 'd h:m:s.n': " + s); + } else { + try { + if (unit.equals("year")) { + int year = (int) toLongWithRange("year", m.group(1), + Integer.MIN_VALUE / 12, Integer.MAX_VALUE / 12); + result = new CalendarInterval(year * 12, 0L); + + } else if (unit.equals("month")) { + int month = (int) toLongWithRange("month", m.group(1), + Integer.MIN_VALUE, Integer.MAX_VALUE); + result = new CalendarInterval(month, 0L); + + } else if (unit.equals("day")) { + long day = toLongWithRange("day", m.group(1), + Long.MIN_VALUE / MICROS_PER_DAY, Long.MAX_VALUE / MICROS_PER_DAY); + result = new CalendarInterval(0, day * MICROS_PER_DAY); + + } else if (unit.equals("hour")) { + long hour = toLongWithRange("hour", m.group(1), + Long.MIN_VALUE / MICROS_PER_HOUR, Long.MAX_VALUE / MICROS_PER_HOUR); + result = new CalendarInterval(0, hour * MICROS_PER_HOUR); + + } else if (unit.equals("minute")) { + long minute = toLongWithRange("minute", m.group(1), + Long.MIN_VALUE / MICROS_PER_MINUTE, Long.MAX_VALUE / MICROS_PER_MINUTE); + result = new CalendarInterval(0, minute * MICROS_PER_MINUTE); + + } else if (unit.equals("second")) { + long micros = parseSecondNano(m.group(1)); + result = new CalendarInterval(0, micros); + } + } catch (Exception e) { + throw new IllegalArgumentException("Error parsing interval string: " + e.getMessage(), e); + } + } + return result; + } + + /** + * Parse second_nano string in ss.nnnnnnnnn format to microseconds + */ + public static long parseSecondNano(String secondNano) throws IllegalArgumentException { + String[] parts = secondNano.split("\\."); + if (parts.length == 1) { + return toLongWithRange("second", parts[0], Long.MIN_VALUE / MICROS_PER_SECOND, + Long.MAX_VALUE / MICROS_PER_SECOND) * MICROS_PER_SECOND; + + } else if (parts.length == 2) { + long seconds = parts[0].equals("") ? 0L : toLongWithRange("second", parts[0], + Long.MIN_VALUE / MICROS_PER_SECOND, Long.MAX_VALUE / MICROS_PER_SECOND); + long nanos = toLongWithRange("nanosecond", parts[1], 0L, 999999999L); + return seconds * MICROS_PER_SECOND + nanos / 1000L; + + } else { + throw new IllegalArgumentException( + "Interval string does not match second-nano format of ss.nnnnnnnnn"); + } + } + public final int months; public final long microseconds; diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java index 6274b92b47dd4..80d4982c4b576 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java @@ -101,6 +101,97 @@ public void fromStringTest() { assertEquals(CalendarInterval.fromString(input), null); } + @Test + public void fromYearMonthStringTest() { + String input; + CalendarInterval i; + + input = "99-10"; + i = new CalendarInterval(99 * 12 + 10, 0L); + assertEquals(CalendarInterval.fromYearMonthString(input), i); + + input = "-8-10"; + i = new CalendarInterval(-8 * 12 - 10, 0L); + assertEquals(CalendarInterval.fromYearMonthString(input), i); + + try { + input = "99-15"; + CalendarInterval.fromYearMonthString(input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("month 15 outside range")); + } + } + + @Test + public void fromDayTimeStringTest() { + String input; + CalendarInterval i; + + input = "5 12:40:30.999999999"; + i = new CalendarInterval(0, 5 * MICROS_PER_DAY + 12 * MICROS_PER_HOUR + + 40 * MICROS_PER_MINUTE + 30 * MICROS_PER_SECOND + 999999L); + assertEquals(CalendarInterval.fromDayTimeString(input), i); + + input = "10 0:12:0.888"; + i = new CalendarInterval(0, 10 * MICROS_PER_DAY + 12 * MICROS_PER_MINUTE); + assertEquals(CalendarInterval.fromDayTimeString(input), i); + + input = "-3 0:0:0"; + i = new CalendarInterval(0, -3 * MICROS_PER_DAY); + assertEquals(CalendarInterval.fromDayTimeString(input), i); + + try { + input = "5 30:12:20"; + CalendarInterval.fromDayTimeString(input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("hour 30 outside range")); + } + + try { + input = "5 30-12"; + CalendarInterval.fromDayTimeString(input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("not match day-time format")); + } + } + + @Test + public void fromSingleUnitStringTest() { + String input; + CalendarInterval i; + + input = "12"; + i = new CalendarInterval(12 * 12, 0L); + assertEquals(CalendarInterval.fromSingleUnitString("year", input), i); + + input = "100"; + i = new CalendarInterval(0, 100 * MICROS_PER_DAY); + assertEquals(CalendarInterval.fromSingleUnitString("day", input), i); + + input = "1999.38888"; + i = new CalendarInterval(0, 1999 * MICROS_PER_SECOND + 38); + assertEquals(CalendarInterval.fromSingleUnitString("second", input), i); + + try { + input = String.valueOf(Integer.MAX_VALUE); + CalendarInterval.fromSingleUnitString("year", input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("outside range")); + } + + try { + input = String.valueOf(Long.MAX_VALUE / MICROS_PER_HOUR + 1); + CalendarInterval.fromSingleUnitString("hour", input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("outside range")); + } + } + @Test public void addTest() { String input = "interval 3 month 1 hour"; From a3aec918bed22f8e33cf91dc0d6e712e6653c7d2 Mon Sep 17 00:00:00 2001 From: Joseph Batchik Date: Sat, 8 Aug 2015 11:03:01 -0700 Subject: [PATCH 236/340] [SPARK-9486][SQL] Add data source aliasing for external packages Users currently have to provide the full class name for external data sources, like: `sqlContext.read.format("com.databricks.spark.avro").load(path)` This allows external data source packages to register themselves using a Service Loader so that they can add custom alias like: `sqlContext.read.format("avro").load(path)` This makes it so that using external data source packages uses the same format as the internal data sources like parquet, json, etc. Author: Joseph Batchik Author: Joseph Batchik Closes #7802 from JDrit/service_loader and squashes the following commits: 49a01ec [Joseph Batchik] fixed a couple of format / error bugs e5e93b2 [Joseph Batchik] modified rat file to only excluded added services 72b349a [Joseph Batchik] fixed error with orc data source actually 9f93ea7 [Joseph Batchik] fixed error with orc data source 87b7f1c [Joseph Batchik] fixed typo 101cd22 [Joseph Batchik] removing unneeded changes 8f3cf43 [Joseph Batchik] merged in changes b63d337 [Joseph Batchik] merged in master 95ae030 [Joseph Batchik] changed the new trait to be used as a mixin for data source to register themselves 74db85e [Joseph Batchik] reformatted class loader ac2270d [Joseph Batchik] removing some added test a6926db [Joseph Batchik] added test cases for data source loader 208a2a8 [Joseph Batchik] changes to do error catching if there are multiple data sources 946186e [Joseph Batchik] started working on service loader --- .rat-excludes | 1 + ...pache.spark.sql.sources.DataSourceRegister | 3 + .../spark/sql/execution/datasources/ddl.scala | 52 ++++++------ .../apache/spark/sql/jdbc/JDBCRelation.scala | 5 +- .../apache/spark/sql/json/JSONRelation.scala | 5 +- .../spark/sql/parquet/ParquetRelation.scala | 5 +- .../apache/spark/sql/sources/interfaces.scala | 21 +++++ ...pache.spark.sql.sources.DataSourceRegister | 3 + .../sql/sources/DDLSourceLoadSuite.scala | 85 +++++++++++++++++++ ...pache.spark.sql.sources.DataSourceRegister | 1 + .../spark/sql/hive/orc/OrcRelation.scala | 5 +- 11 files changed, 156 insertions(+), 30 deletions(-) create mode 100644 sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister create mode 100644 sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala create mode 100644 sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister diff --git a/.rat-excludes b/.rat-excludes index 236c2db05367c..72771465846b8 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -93,3 +93,4 @@ INDEX .lintr gen-java.* .*avpr +org.apache.spark.sql.sources.DataSourceRegister diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000000..cc32d4b72748e --- /dev/null +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1,3 @@ +org.apache.spark.sql.jdbc.DefaultSource +org.apache.spark.sql.json.DefaultSource +org.apache.spark.sql.parquet.DefaultSource diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 0cdb407ad57b9..8c2f297e42458 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -17,7 +17,12 @@ package org.apache.spark.sql.execution.datasources +import java.util.ServiceLoader + +import scala.collection.Iterator +import scala.collection.JavaConversions._ import scala.language.{existentials, implicitConversions} +import scala.util.{Failure, Success, Try} import scala.util.matching.Regex import org.apache.hadoop.fs.Path @@ -190,37 +195,32 @@ private[sql] class DDLParser( } } -private[sql] object ResolvedDataSource { - - private val builtinSources = Map( - "jdbc" -> "org.apache.spark.sql.jdbc.DefaultSource", - "json" -> "org.apache.spark.sql.json.DefaultSource", - "parquet" -> "org.apache.spark.sql.parquet.DefaultSource", - "orc" -> "org.apache.spark.sql.hive.orc.DefaultSource" - ) +private[sql] object ResolvedDataSource extends Logging { /** Given a provider name, look up the data source class definition. */ def lookupDataSource(provider: String): Class[_] = { + val provider2 = s"$provider.DefaultSource" val loader = Utils.getContextOrSparkClassLoader - - if (builtinSources.contains(provider)) { - return loader.loadClass(builtinSources(provider)) - } - - try { - loader.loadClass(provider) - } catch { - case cnf: java.lang.ClassNotFoundException => - try { - loader.loadClass(provider + ".DefaultSource") - } catch { - case cnf: java.lang.ClassNotFoundException => - if (provider.startsWith("org.apache.spark.sql.hive.orc")) { - sys.error("The ORC data source must be used with Hive support enabled.") - } else { - sys.error(s"Failed to load class for data source: $provider") - } + val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) + + serviceLoader.iterator().filter(_.format().equalsIgnoreCase(provider)).toList match { + /** the provider format did not match any given registered aliases */ + case Nil => Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { + case Success(dataSource) => dataSource + case Failure(error) => if (provider.startsWith("org.apache.spark.sql.hive.orc")) { + throw new ClassNotFoundException( + "The ORC data source must be used with Hive support enabled.", error) + } else { + throw new ClassNotFoundException( + s"Failed to load class for data source: $provider", error) } + } + /** there is exactly one registered alias */ + case head :: Nil => head.getClass + /** There are multiple registered aliases for the input */ + case sources => sys.error(s"Multiple sources found for $provider, " + + s"(${sources.map(_.getClass.getName).mkString(", ")}), " + + "please specify the fully qualified class name") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 41d0ecb4bbfbf..48d97ced9ca0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -77,7 +77,10 @@ private[sql] object JDBCRelation { } } -private[sql] class DefaultSource extends RelationProvider { +private[sql] class DefaultSource extends RelationProvider with DataSourceRegister { + + def format(): String = "jdbc" + /** Returns a new base relation with the given parameters. */ override def createRelation( sqlContext: SQLContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 10f1367e6984c..b34a272ec547f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -37,7 +37,10 @@ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{AnalysisException, Row, SQLContext} -private[sql] class DefaultSource extends HadoopFsRelationProvider { +private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + def format(): String = "json" + override def createRelation( sqlContext: SQLContext, paths: Array[String], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 48009b2fd007d..b6db71b5b8a62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -49,7 +49,10 @@ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} -private[sql] class DefaultSource extends HadoopFsRelationProvider { +private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + def format(): String = "parquet" + override def createRelation( sqlContext: SQLContext, paths: Array[String], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index c5b7ee73eb784..4aafec0e2df27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -37,6 +37,27 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.sql._ import org.apache.spark.util.SerializableConfiguration +/** + * ::DeveloperApi:: + * Data sources should implement this trait so that they can register an alias to their data source. + * This allows users to give the data source alias as the format type over the fully qualified + * class name. + * + * ex: parquet.DefaultSource.format = "parquet". + * + * A new instance of this class with be instantiated each time a DDL call is made. + */ +@DeveloperApi +trait DataSourceRegister { + + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source, + * ex: override def format(): String = "parquet" + */ + def format(): String +} + /** * ::DeveloperApi:: * Implemented by objects that produce relations for a specific kind of data source. When diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000000..cfd7889b4ac2c --- /dev/null +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1,3 @@ +org.apache.spark.sql.sources.FakeSourceOne +org.apache.spark.sql.sources.FakeSourceTwo +org.apache.spark.sql.sources.FakeSourceThree diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala new file mode 100644 index 0000000000000..1a4d41b02ca68 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala @@ -0,0 +1,85 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.types.{StringType, StructField, StructType} + +class FakeSourceOne extends RelationProvider with DataSourceRegister { + + def format(): String = "Fluet da Bomb" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } +} + +class FakeSourceTwo extends RelationProvider with DataSourceRegister { + + def format(): String = "Fluet da Bomb" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } +} + +class FakeSourceThree extends RelationProvider with DataSourceRegister { + + def format(): String = "gathering quorum" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } +} +// please note that the META-INF/services had to be modified for the test directory for this to work +class DDLSourceLoadSuite extends DataSourceTest { + + test("data sources with the same name") { + intercept[RuntimeException] { + caseInsensitiveContext.read.format("Fluet da Bomb").load() + } + } + + test("load data source from format alias") { + caseInsensitiveContext.read.format("gathering quorum").load().schema == + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } + + test("specify full classname with duplicate formats") { + caseInsensitiveContext.read.format("org.apache.spark.sql.sources.FakeSourceOne") + .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false))) + } + + test("Loading Orc") { + intercept[ClassNotFoundException] { + caseInsensitiveContext.read.format("orc").load() + } + } +} diff --git a/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000000..4a774fbf1fdf8 --- /dev/null +++ b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.sql.hive.orc.DefaultSource diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 7c8704b47f286..0c344c63fde3f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -47,7 +47,10 @@ import org.apache.spark.util.SerializableConfiguration /* Implicit conversions */ import scala.collection.JavaConversions._ -private[sql] class DefaultSource extends HadoopFsRelationProvider { +private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + def format(): String = "orc" + def createRelation( sqlContext: SQLContext, paths: Array[String], From 25c363e93bc79119c5ba5c228fcad620061cff62 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Sat, 8 Aug 2015 18:22:46 -0700 Subject: [PATCH 237/340] [MINOR] inaccurate comments for showString() Author: CodingCat Closes #8050 from CodingCat/minor and squashes the following commits: 5bc4b89 [CodingCat] inaccurate comments --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 405b5a4a9a7f9..570b8b2d5928d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -168,7 +168,7 @@ class DataFrame private[sql]( } /** - * Internal API for Python + * Compose the string representing rows for output * @param _numRows Number of rows to show * @param truncate Whether truncate long strings and align cells right */ From 3ca995b78f373251081f6877623649bfba3040b2 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sat, 8 Aug 2015 21:05:50 -0700 Subject: [PATCH 238/340] [SPARK-6212] [SQL] The EXPLAIN output of CTAS only shows the analyzed plan JIRA: https://issues.apache.org/jira/browse/SPARK-6212 Author: Yijie Shen Closes #7986 from yjshen/ctas_explain and squashes the following commits: bb6fee5 [Yijie Shen] refine test f731041 [Yijie Shen] address comment b2cf8ab [Yijie Shen] bug fix bd7eb20 [Yijie Shen] ctas explain --- .../apache/spark/sql/execution/commands.scala | 2 ++ .../hive/execution/CreateTableAsSelect.scala | 4 ++- .../sql/hive/execution/HiveExplainSuite.scala | 35 +++++++++++++++++-- 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 6b83025d5a153..95209e6634519 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -69,6 +69,8 @@ private[sql] case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan val converted = sideEffectResult.map(convert(_).asInstanceOf[InternalRow]) sqlContext.sparkContext.parallelize(converted, 1) } + + override def argString: String = cmd.toString } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 84358cb73c9e3..8422287e177e5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -40,6 +40,8 @@ case class CreateTableAsSelect( def database: String = tableDesc.database def tableName: String = tableDesc.name + override def children: Seq[LogicalPlan] = Seq(query) + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] lazy val metastoreRelation: MetastoreRelation = { @@ -91,6 +93,6 @@ case class CreateTableAsSelect( } override def argString: String = { - s"[Database:$database, TableName: $tableName, InsertIntoHiveTable]\n" + query.toString + s"[Database:$database, TableName: $tableName, InsertIntoHiveTable]" } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 8215dd6c2e711..44c5b80392fa5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -17,13 +17,18 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.{SQLContext, QueryTest} +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.test.SQLTestUtils /** * A set of tests that validates support for Hive Explain command. */ -class HiveExplainSuite extends QueryTest { +class HiveExplainSuite extends QueryTest with SQLTestUtils { + + def sqlContext: SQLContext = TestHive + test("explain extended command") { checkExistence(sql(" explain select * from src where key=123 "), true, "== Physical Plan ==") @@ -74,4 +79,30 @@ class HiveExplainSuite extends QueryTest { "Limit", "src") } + + test("SPARK-6212: The EXPLAIN output of CTAS only shows the analyzed plan") { + withTempTable("jt") { + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) + read.json(rdd).registerTempTable("jt") + val outputs = sql( + s""" + |EXPLAIN EXTENDED + |CREATE TABLE t1 + |AS + |SELECT * FROM jt + """.stripMargin).collect().map(_.mkString).mkString + + val shouldContain = + "== Parsed Logical Plan ==" :: "== Analyzed Logical Plan ==" :: "Subquery" :: + "== Optimized Logical Plan ==" :: "== Physical Plan ==" :: + "CreateTableAsSelect" :: "InsertIntoHiveTable" :: "jt" :: Nil + for (key <- shouldContain) { + assert(outputs.contains(key), s"$key doesn't exist in result") + } + + val physicalIndex = outputs.indexOf("== Physical Plan ==") + assert(!outputs.substring(physicalIndex).contains("Subquery"), + "Physical Plan should not contain Subquery since it's eliminated by optimizer") + } + } } From e9c36938ba972b6fe3c9f6228508e3c9f1c876b2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 9 Aug 2015 10:58:36 -0700 Subject: [PATCH 239/340] [SPARK-9752][SQL] Support UnsafeRow in Sample operator. In order for this to work, I had to disable gap sampling. Author: Reynold Xin Closes #8040 from rxin/SPARK-9752 and squashes the following commits: f9e248c [Reynold Xin] Fix the test case for real this time. adbccb3 [Reynold Xin] Fixed test case. 589fb23 [Reynold Xin] Merge branch 'SPARK-9752' of github.com:rxin/spark into SPARK-9752 55ccddc [Reynold Xin] Fixed core test. 78fa895 [Reynold Xin] [SPARK-9752][SQL] Support UnsafeRow in Sample operator. c9e7112 [Reynold Xin] [SPARK-9752][SQL] Support UnsafeRow in Sample operator. --- .../spark/util/random/RandomSampler.scala | 18 ++++++---- .../spark/sql/execution/basicOperators.scala | 18 +++++++--- .../apache/spark/sql/DataFrameStatSuite.scala | 35 +++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 17 --------- 4 files changed, 61 insertions(+), 27 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 786b97ad7b9ec..c156b03cdb7c4 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -176,10 +176,15 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T * A sampler for sampling with replacement, based on values drawn from Poisson distribution. * * @param fraction the sampling fraction (with replacement) + * @param useGapSamplingIfPossible if true, use gap sampling when sampling ratio is low. * @tparam T item type */ @DeveloperApi -class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] { +class PoissonSampler[T: ClassTag]( + fraction: Double, + useGapSamplingIfPossible: Boolean) extends RandomSampler[T, T] { + + def this(fraction: Double) = this(fraction, useGapSamplingIfPossible = true) /** Epsilon slop to avoid failure from floating point jitter. */ require( @@ -199,17 +204,18 @@ class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] override def sample(items: Iterator[T]): Iterator[T] = { if (fraction <= 0.0) { Iterator.empty - } else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) { - new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon) + } else if (useGapSamplingIfPossible && + fraction <= RandomSampler.defaultMaxGapSamplingFraction) { + new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon) } else { - items.flatMap { item => { + items.flatMap { item => val count = rng.sample() if (count == 0) Iterator.empty else Iterator.fill(count)(item) - }} + } } } - override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction) + override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction, useGapSamplingIfPossible) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 0680f31d40f6d..c5d1ed0937b19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.{RDD, ShuffledRDD} +import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -30,6 +30,7 @@ import org.apache.spark.sql.metric.SQLMetrics import org.apache.spark.sql.types.StructType import org.apache.spark.util.collection.ExternalSorter import org.apache.spark.util.collection.unsafe.sort.PrefixComparator +import org.apache.spark.util.random.PoissonSampler import org.apache.spark.util.{CompletionIterator, MutablePair} import org.apache.spark.{HashPartitioner, SparkEnv} @@ -130,12 +131,21 @@ case class Sample( { override def output: Seq[Attribute] = child.output - // TODO: How to pick seed? + override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + protected override def doExecute(): RDD[InternalRow] = { if (withReplacement) { - child.execute().map(_.copy()).sample(withReplacement, upperBound - lowerBound, seed) + // Disable gap sampling since the gap sampling method buffers two rows internally, + // requiring us to copy the row, which is more expensive than the random number generator. + new PartitionwiseSampledRDD[InternalRow, InternalRow]( + child.execute(), + new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false), + preservesPartitioning = true, + seed) } else { - child.execute().map(_.copy()).randomSampleWithRange(lowerBound, upperBound, seed) + child.execute().randomSampleWithRange(lowerBound, upperBound, seed) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 0e7659f443ecd..8f5984e4a8ce2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -30,6 +30,41 @@ class DataFrameStatSuite extends QueryTest { private def toLetter(i: Int): String = (i + 97).toChar.toString + test("sample with replacement") { + val n = 100 + val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id") + checkAnswer( + data.sample(withReplacement = true, 0.05, seed = 13), + Seq(5, 10, 52, 73).map(Row(_)) + ) + } + + test("sample without replacement") { + val n = 100 + val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id") + checkAnswer( + data.sample(withReplacement = false, 0.05, seed = 13), + Seq(16, 23, 88, 100).map(Row(_)) + ) + } + + test("randomSplit") { + val n = 600 + val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id") + for (seed <- 1 to 5) { + val splits = data.randomSplit(Array[Double](1, 2, 3), seed) + assert(splits.length == 3, "wrong number of splits") + + assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList == + data.collect().toList, "incomplete or wrong split") + + val s = splits.map(_.count()) + assert(math.abs(s(0) - 100) < 50) // std = 9.13 + assert(math.abs(s(1) - 200) < 50) // std = 11.55 + assert(math.abs(s(2) - 300) < 50) // std = 12.25 + } + } + test("pearson correlation") { val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c") val corr1 = df.stat.corr("a", "b", "pearson") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f9cc6d1f3c250..0212637a829e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -415,23 +415,6 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol")) } - test("randomSplit") { - val n = 600 - val data = sqlContext.sparkContext.parallelize(1 to n, 2).toDF("id") - for (seed <- 1 to 5) { - val splits = data.randomSplit(Array[Double](1, 2, 3), seed) - assert(splits.length == 3, "wrong number of splits") - - assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList == - data.collect().toList, "incomplete or wrong split") - - val s = splits.map(_.count()) - assert(math.abs(s(0) - 100) < 50) // std = 9.13 - assert(math.abs(s(1) - 200) < 50) // std = 11.55 - assert(math.abs(s(2) - 300) < 50) // std = 12.25 - } - } - test("describe") { val describeTestData = Seq( ("Bob", 16, 176), From 68ccc6e184598822b19a880fdd4597b66a1c2d92 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sun, 9 Aug 2015 11:44:51 -0700 Subject: [PATCH 240/340] [SPARK-8930] [SQL] Throw a AnalysisException with meaningful messages if DataFrame#explode takes a star in expressions Author: Yijie Shen Closes #8057 from yjshen/explode_star and squashes the following commits: eae181d [Yijie Shen] change explaination message 54c9d11 [Yijie Shen] meaning message for * in explode --- .../spark/sql/catalyst/analysis/Analyzer.scala | 4 +++- .../sql/catalyst/analysis/AnalysisTest.scala | 4 +++- .../org/apache/spark/sql/DataFrameSuite.scala | 15 +++++++++++++++ 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 82158e61e3fb5..a684dbc3afa42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -408,7 +408,7 @@ class Analyzer( /** * Returns true if `exprs` contains a [[Star]]. */ - protected def containsStar(exprs: Seq[Expression]): Boolean = + def containsStar(exprs: Seq[Expression]): Boolean = exprs.exists(_.collect { case _: Star => true }.nonEmpty) } @@ -602,6 +602,8 @@ class Analyzer( */ object ResolveGenerate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case g: Generate if ResolveReferences.containsStar(g.generator.children) => + failAnalysis("Cannot explode *, explode can only be applied on a specific column.") case p: Generate if !p.child.resolved || !p.generator.resolved => p case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index ee1f8f54251e0..53b3695a86be5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -71,6 +71,8 @@ trait AnalysisTest extends PlanTest { val e = intercept[Exception] { analyzer.checkAnalysis(analyzer.execute(inputPlan)) } - expectedErrors.forall(e.getMessage.contains) + assert(expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains), + s"Expected to throw Exception contains: ${expectedErrors.mkString(", ")}, " + + s"actually we get ${e.getMessage}") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 0212637a829e5..c49f256be5501 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -134,6 +134,21 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { ) } + test("SPARK-8930: explode should fail with a meaningful message if it takes a star") { + val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv") + val e = intercept[AnalysisException] { + df.explode($"*") { case Row(prefix: String, csv: String) => + csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq + }.queryExecution.assertAnalyzed() + } + assert(e.getMessage.contains( + "Cannot explode *, explode can only be applied on a specific column.")) + + df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) => + csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq + }.queryExecution.assertAnalyzed() + } + test("explode alias and star") { val df = Seq((Array("a"), 1)).toDF("a", "b") From 86fa4ba6d13f909cb508b7cb3b153d586fe59bc3 Mon Sep 17 00:00:00 2001 From: Yadong Qi Date: Sun, 9 Aug 2015 19:54:05 +0100 Subject: [PATCH 241/340] [SPARK-9737] [YARN] Add the suggested configuration when required executor memory is above the max threshold of this cluster on YARN mode Author: Yadong Qi Closes #8028 from watermen/SPARK-9737 and squashes the following commits: 48bdf3d [Yadong Qi] Add suggested configuration. --- .../main/scala/org/apache/spark/deploy/yarn/Client.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index fc11bbf97e2ec..b4ba3f0221600 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -203,12 +203,14 @@ private[spark] class Client( val executorMem = args.executorMemory + executorMemoryOverhead if (executorMem > maxMem) { throw new IllegalArgumentException(s"Required executor memory (${args.executorMemory}" + - s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster!") + s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " + + "Please increase the value of 'yarn.scheduler.maximum-allocation-mb'.") } val amMem = args.amMemory + amMemoryOverhead if (amMem > maxMem) { throw new IllegalArgumentException(s"Required AM memory (${args.amMemory}" + - s"+$amMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster!") + s"+$amMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " + + "Please increase the value of 'yarn.scheduler.maximum-allocation-mb'.") } logInfo("Will allocate AM container, with %d MB memory including %d MB overhead".format( amMem, From a863348fd85848e0d4325c4de359da12e5f548d2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 9 Aug 2015 13:43:31 -0700 Subject: [PATCH 242/340] Disable JobGeneratorSuite "Do not clear received block data too soon". --- .../apache/spark/streaming/scheduler/JobGeneratorSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala index a2dbae149f311..9b6cd4bc4e315 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala @@ -56,7 +56,8 @@ class JobGeneratorSuite extends TestSuiteBase { // 4. allow subsequent batches to be generated (to allow premature deletion of 3rd batch metadata) // 5. verify whether 3rd batch's block metadata still exists // - test("SPARK-6222: Do not clear received block data too soon") { + // TODO: SPARK-7420 enable this test + ignore("SPARK-6222: Do not clear received block data too soon") { import JobGeneratorSuite._ val checkpointDir = Utils.createTempDir() val testConf = conf From 23cf5af08d98da771c41571c00a2f5cafedfebdd Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 9 Aug 2015 14:26:01 -0700 Subject: [PATCH 243/340] [SPARK-9703] [SQL] Refactor EnsureRequirements to avoid certain unnecessary shuffles This pull request refactors the `EnsureRequirements` planning rule in order to avoid the addition of certain unnecessary shuffles. As an example of how unnecessary shuffles can occur, consider SortMergeJoin, which requires clustered distribution and sorted ordering of its children's input rows. Say that both of SMJ's children produce unsorted output but are both SinglePartition. In this case, we will need to inject sort operators but should not need to inject Exchanges. Unfortunately, it looks like the EnsureRequirements unnecessarily repartitions using a hash partitioning. This patch solves this problem by refactoring `EnsureRequirements` to properly implement the `compatibleWith` checks that were broken in earlier implementations. See the significant inline comments for a better description of how this works. The majority of this PR is new comments and test cases, with few actual changes to the code. Author: Josh Rosen Closes #7988 from JoshRosen/exchange-fixes and squashes the following commits: 38006e7 [Josh Rosen] Rewrite EnsureRequirements _yet again_ to make things even simpler 0983f75 [Josh Rosen] More guarantees vs. compatibleWith cleanup; delete BroadcastPartitioning. 8784bd9 [Josh Rosen] Giant comment explaining compatibleWith vs. guarantees 1307c50 [Josh Rosen] Update conditions for requiring child compatibility. 18cddeb [Josh Rosen] Rename DummyPlan to DummySparkPlan. 2c7e126 [Josh Rosen] Merge remote-tracking branch 'origin/master' into exchange-fixes fee65c4 [Josh Rosen] Further refinement to comments / reasoning 642b0bb [Josh Rosen] Further expand comment / reasoning 06aba0c [Josh Rosen] Add more comments 8dbc845 [Josh Rosen] Add even more tests. 4f08278 [Josh Rosen] Fix the test by adding the compatibility check to EnsureRequirements a1c12b9 [Josh Rosen] Add failing test to demonstrate allCompatible bug 0725a34 [Josh Rosen] Small assertion cleanup. 5172ac5 [Josh Rosen] Add test for requiresChildrenToProduceSameNumberOfPartitions. 2e0f33a [Josh Rosen] Write a more generic test for EnsureRequirements. 752b8de [Josh Rosen] style fix c628daf [Josh Rosen] Revert accidental ExchangeSuite change. c9fb231 [Josh Rosen] Rewrite exchange to fix better handle this case. adcc742 [Josh Rosen] Move test to PlannerSuite. 0675956 [Josh Rosen] Preserving ordering and partitioning in row format converters also does not help. cc5669c [Josh Rosen] Adding outputPartitioning to Repartition does not fix the test. 2dfc648 [Josh Rosen] Add failing test illustrating bad exchange planning. --- .../plans/physical/partitioning.scala | 128 +++++++++++++-- .../apache/spark/sql/execution/Exchange.scala | 104 ++++++------ .../spark/sql/execution/basicOperators.scala | 5 + .../sql/execution/rowFormatConverters.scala | 5 + .../spark/sql/execution/PlannerSuite.scala | 151 ++++++++++++++++++ 5 files changed, 328 insertions(+), 65 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index ec659ce789c27..5a89a90b735a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -75,6 +75,37 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { def clustering: Set[Expression] = ordering.map(_.child).toSet } +/** + * Describes how an operator's output is split across partitions. The `compatibleWith`, + * `guarantees`, and `satisfies` methods describe relationships between child partitionings, + * target partitionings, and [[Distribution]]s. These relations are described more precisely in + * their individual method docs, but at a high level: + * + * - `satisfies` is a relationship between partitionings and distributions. + * - `compatibleWith` is relationships between an operator's child output partitionings. + * - `guarantees` is a relationship between a child's existing output partitioning and a target + * output partitioning. + * + * Diagrammatically: + * + * +--------------+ + * | Distribution | + * +--------------+ + * ^ + * | + * satisfies + * | + * +--------------+ +--------------+ + * | Child | | Target | + * +----| Partitioning |----guarantees--->| Partitioning | + * | +--------------+ +--------------+ + * | ^ + * | | + * | compatibleWith + * | | + * +------------+ + * + */ sealed trait Partitioning { /** Returns the number of partitions that the data is split across */ val numPartitions: Int @@ -90,9 +121,66 @@ sealed trait Partitioning { /** * Returns true iff we can say that the partitioning scheme of this [[Partitioning]] * guarantees the same partitioning scheme described by `other`. + * + * Compatibility of partitionings is only checked for operators that have multiple children + * and that require a specific child output [[Distribution]], such as joins. + * + * Intuitively, partitionings are compatible if they route the same partitioning key to the same + * partition. For instance, two hash partitionings are only compatible if they produce the same + * number of output partitionings and hash records according to the same hash function and + * same partitioning key schema. + * + * Put another way, two partitionings are compatible with each other if they satisfy all of the + * same distribution guarantees. */ - // TODO: Add an example once we have the `nullSafe` concept. - def guarantees(other: Partitioning): Boolean + def compatibleWith(other: Partitioning): Boolean + + /** + * Returns true iff we can say that the partitioning scheme of this [[Partitioning]] guarantees + * the same partitioning scheme described by `other`. If a `A.guarantees(B)`, then repartitioning + * the child's output according to `B` will be unnecessary. `guarantees` is used as a performance + * optimization to allow the exchange planner to avoid redundant repartitionings. By default, + * a partitioning only guarantees partitionings that are equal to itself (i.e. the same number + * of partitions, same strategy (range or hash), etc). + * + * In order to enable more aggressive optimization, this strict equality check can be relaxed. + * For example, say that the planner needs to repartition all of an operator's children so that + * they satisfy the [[AllTuples]] distribution. One way to do this is to repartition all children + * to have the [[SinglePartition]] partitioning. If one of the operator's children already happens + * to be hash-partitioned with a single partition then we do not need to re-shuffle this child; + * this repartitioning can be avoided if a single-partition [[HashPartitioning]] `guarantees` + * [[SinglePartition]]. + * + * The SinglePartition example given above is not particularly interesting; guarantees' real + * value occurs for more advanced partitioning strategies. SPARK-7871 will introduce a notion + * of null-safe partitionings, under which partitionings can specify whether rows whose + * partitioning keys contain null values will be grouped into the same partition or whether they + * will have an unknown / random distribution. If a partitioning does not require nulls to be + * clustered then a partitioning which _does_ cluster nulls will guarantee the null clustered + * partitioning. The converse is not true, however: a partitioning which clusters nulls cannot + * be guaranteed by one which does not cluster them. Thus, in general `guarantees` is not a + * symmetric relation. + * + * Another way to think about `guarantees`: if `A.guarantees(B)`, then any partitioning of rows + * produced by `A` could have also been produced by `B`. + */ + def guarantees(other: Partitioning): Boolean = this == other +} + +object Partitioning { + def allCompatible(partitionings: Seq[Partitioning]): Boolean = { + // Note: this assumes transitivity + partitionings.sliding(2).map { + case Seq(a) => true + case Seq(a, b) => + if (a.numPartitions != b.numPartitions) { + assert(!a.compatibleWith(b) && !b.compatibleWith(a)) + false + } else { + a.compatibleWith(b) && b.compatibleWith(a) + } + }.forall(_ == true) + } } case class UnknownPartitioning(numPartitions: Int) extends Partitioning { @@ -101,6 +189,8 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning { case _ => false } + override def compatibleWith(other: Partitioning): Boolean = false + override def guarantees(other: Partitioning): Boolean = false } @@ -109,21 +199,9 @@ case object SinglePartition extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def guarantees(other: Partitioning): Boolean = other match { - case SinglePartition => true - case _ => false - } -} - -case object BroadcastPartitioning extends Partitioning { - val numPartitions = 1 + override def compatibleWith(other: Partitioning): Boolean = other.numPartitions == 1 - override def satisfies(required: Distribution): Boolean = true - - override def guarantees(other: Partitioning): Boolean = other match { - case BroadcastPartitioning => true - case _ => false - } + override def guarantees(other: Partitioning): Boolean = other.numPartitions == 1 } /** @@ -147,6 +225,12 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } + override def compatibleWith(other: Partitioning): Boolean = other match { + case o: HashPartitioning => + this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions + case _ => false + } + override def guarantees(other: Partitioning): Boolean = other match { case o: HashPartitioning => this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions @@ -185,6 +269,11 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case _ => false } + override def compatibleWith(other: Partitioning): Boolean = other match { + case o: RangePartitioning => this == o + case _ => false + } + override def guarantees(other: Partitioning): Boolean = other match { case o: RangePartitioning => this == o case _ => false @@ -228,6 +317,13 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) override def satisfies(required: Distribution): Boolean = partitionings.exists(_.satisfies(required)) + /** + * Returns true if any `partitioning` of this collection is compatible with + * the given [[Partitioning]]. + */ + override def compatibleWith(other: Partitioning): Boolean = + partitionings.exists(_.compatibleWith(other)) + /** * Returns true if any `partitioning` of this collection guarantees * the given [[Partitioning]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 49bb729800863..b89e634761eb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -190,66 +190,72 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una * of input data meets the * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for * each operator by inserting [[Exchange]] Operators where required. Also ensure that the - * required input partition ordering requirements are met. + * input partition ordering requirements are met. */ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] { // TODO: Determine the number of partitions. - def numPartitions: Int = sqlContext.conf.numShufflePartitions + private def numPartitions: Int = sqlContext.conf.numShufflePartitions - def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case operator: SparkPlan => - // Adds Exchange or Sort operators as required - def addOperatorsIfNecessary( - partitioning: Partitioning, - rowOrdering: Seq[SortOrder], - child: SparkPlan): SparkPlan = { - - def addShuffleIfNecessary(child: SparkPlan): SparkPlan = { - if (!child.outputPartitioning.guarantees(partitioning)) { - Exchange(partitioning, child) - } else { - child - } - } + /** + * Given a required distribution, returns a partitioning that satisfies that distribution. + */ + private def canonicalPartitioning(requiredDistribution: Distribution): Partitioning = { + requiredDistribution match { + case AllTuples => SinglePartition + case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions) + case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions) + case dist => sys.error(s"Do not know how to satisfy distribution $dist") + } + } - def addSortIfNecessary(child: SparkPlan): SparkPlan = { - - if (rowOrdering.nonEmpty) { - // If child.outputOrdering is [a, b] and rowOrdering is [a], we do not need to sort. - val minSize = Seq(rowOrdering.size, child.outputOrdering.size).min - if (minSize == 0 || rowOrdering.take(minSize) != child.outputOrdering.take(minSize)) { - sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) - } else { - child - } - } else { - child - } - } + private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = { + val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution + val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering + var children: Seq[SparkPlan] = operator.children - addSortIfNecessary(addShuffleIfNecessary(child)) + // Ensure that the operator's children satisfy their output distribution requirements: + children = children.zip(requiredChildDistributions).map { case (child, distribution) => + if (child.outputPartitioning.satisfies(distribution)) { + child + } else { + Exchange(canonicalPartitioning(distribution), child) } + } - val requirements = - (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children) - - val fixedChildren = requirements.zipped.map { - case (AllTuples, rowOrdering, child) => - addOperatorsIfNecessary(SinglePartition, rowOrdering, child) - case (ClusteredDistribution(clustering), rowOrdering, child) => - addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) - case (OrderedDistribution(ordering), rowOrdering, child) => - addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) - - case (UnspecifiedDistribution, Seq(), child) => + // If the operator has multiple children and specifies child output distributions (e.g. join), + // then the children's output partitionings must be compatible: + if (children.length > 1 + && requiredChildDistributions.toSet != Set(UnspecifiedDistribution) + && !Partitioning.allCompatible(children.map(_.outputPartitioning))) { + children = children.zip(requiredChildDistributions).map { case (child, distribution) => + val targetPartitioning = canonicalPartitioning(distribution) + if (child.outputPartitioning.guarantees(targetPartitioning)) { child - case (UnspecifiedDistribution, rowOrdering, child) => - sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) + } else { + Exchange(targetPartitioning, child) + } + } + } - case (dist, ordering, _) => - sys.error(s"Don't know how to ensure $dist with ordering $ordering") + // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings: + children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) => + if (requiredOrdering.nonEmpty) { + // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort. + val minSize = Seq(requiredOrdering.size, child.outputOrdering.size).min + if (minSize == 0 || requiredOrdering.take(minSize) != child.outputOrdering.take(minSize)) { + sqlContext.planner.BasicOperators.getSortOperator(requiredOrdering, global = false, child) + } else { + child + } + } else { + child } + } - operator.withNewChildren(fixedChildren) + operator.withNewChildren(children) + } + + def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case operator: SparkPlan => ensureDistributionAndOrdering(operator) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index c5d1ed0937b19..24950f26061f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -256,6 +256,11 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = { + if (numPartitions == 1) SinglePartition + else UnknownPartitioning(numPartitions) + } + protected override def doExecute(): RDD[InternalRow] = { child.execute().map(_.copy()).coalesce(numPartitions, shuffle) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala index 29f3beb3cb3c8..855555dd1d4c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala @@ -21,6 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule /** @@ -33,6 +34,8 @@ case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode { require(UnsafeProjection.canSupport(child.schema), s"Cannot convert ${child.schema} to Unsafe") override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def outputsUnsafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = false override def canProcessSafeRows: Boolean = true @@ -51,6 +54,8 @@ case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode { @DeveloperApi case class ConvertToSafe(child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def outputsUnsafeRows: Boolean = false override def canProcessUnsafeRows: Boolean = true override def canProcessSafeRows: Boolean = false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 18b0e54dc7c53..5582caa0d366e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,9 +18,13 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkFunSuite +import org.apache.spark.rdd.RDD import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} @@ -202,4 +206,151 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { } } } + + // --- Unit tests of EnsureRequirements --------------------------------------------------------- + + // When it comes to testing whether EnsureRequirements properly ensures distribution requirements, + // there two dimensions that need to be considered: are the child partitionings compatible and + // do they satisfy the distribution requirements? As a result, we need at least four test cases. + + private def assertDistributionRequirementsAreSatisfied(outputPlan: SparkPlan): Unit = { + if (outputPlan.children.length > 1 + && outputPlan.requiredChildDistribution.toSet != Set(UnspecifiedDistribution)) { + val childPartitionings = outputPlan.children.map(_.outputPartitioning) + if (!Partitioning.allCompatible(childPartitionings)) { + fail(s"Partitionings are not compatible: $childPartitionings") + } + } + outputPlan.children.zip(outputPlan.requiredChildDistribution).foreach { + case (child, requiredDist) => + assert(child.outputPartitioning.satisfies(requiredDist), + s"$child output partitioning does not satisfy $requiredDist:\n$outputPlan") + } + } + + test("EnsureRequirements with incompatible child partitionings which satisfy distribution") { + // Consider an operator that requires inputs that are clustered by two expressions (e.g. + // sort merge join where there are multiple columns in the equi-join condition) + val clusteringA = Literal(1) :: Nil + val clusteringB = Literal(2) :: Nil + val distribution = ClusteredDistribution(clusteringA ++ clusteringB) + // Say that the left and right inputs are each partitioned by _one_ of the two join columns: + val leftPartitioning = HashPartitioning(clusteringA, 1) + val rightPartitioning = HashPartitioning(clusteringB, 1) + // Individually, each input's partitioning satisfies the clustering distribution: + assert(leftPartitioning.satisfies(distribution)) + assert(rightPartitioning.satisfies(distribution)) + // However, these partitionings are not compatible with each other, so we still need to + // repartition both inputs prior to performing the join: + assert(!leftPartitioning.compatibleWith(rightPartitioning)) + assert(!rightPartitioning.compatibleWith(leftPartitioning)) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = leftPartitioning), + DummySparkPlan(outputPartitioning = rightPartitioning) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(Seq.empty, Seq.empty) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case Exchange(_, _) => true }.isEmpty) { + fail(s"Exchange should have been added:\n$outputPlan") + } + } + + test("EnsureRequirements with child partitionings with different numbers of output partitions") { + // This is similar to the previous test, except it checks that partitionings are not compatible + // unless they produce the same number of partitions. + val clustering = Literal(1) :: Nil + val distribution = ClusteredDistribution(clustering) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = HashPartitioning(clustering, 1)), + DummySparkPlan(outputPartitioning = HashPartitioning(clustering, 2)) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(Seq.empty, Seq.empty) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + } + + test("EnsureRequirements with compatible child partitionings that do not satisfy distribution") { + val distribution = ClusteredDistribution(Literal(1) :: Nil) + // The left and right inputs have compatible partitionings but they do not satisfy the + // distribution because they are clustered on different columns. Thus, we need to shuffle. + val childPartitioning = HashPartitioning(Literal(2) :: Nil, 1) + assert(!childPartitioning.satisfies(distribution)) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = childPartitioning), + DummySparkPlan(outputPartitioning = childPartitioning) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(Seq.empty, Seq.empty) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case Exchange(_, _) => true }.isEmpty) { + fail(s"Exchange should have been added:\n$outputPlan") + } + } + + test("EnsureRequirements with compatible child partitionings that satisfy distribution") { + // In this case, all requirements are satisfied and no exchange should be added. + val distribution = ClusteredDistribution(Literal(1) :: Nil) + val childPartitioning = HashPartitioning(Literal(1) :: Nil, 5) + assert(childPartitioning.satisfies(distribution)) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = childPartitioning), + DummySparkPlan(outputPartitioning = childPartitioning) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(Seq.empty, Seq.empty) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case Exchange(_, _) => true }.nonEmpty) { + fail(s"Exchange should not have been added:\n$outputPlan") + } + } + + // This is a regression test for SPARK-9703 + test("EnsureRequirements should not repartition if only ordering requirement is unsatisfied") { + // Consider an operator that imposes both output distribution and ordering requirements on its + // children, such as sort sort merge join. If the distribution requirements are satisfied but + // the output ordering requirements are unsatisfied, then the planner should only add sorts and + // should not need to add additional shuffles / exchanges. + val outputOrdering = Seq(SortOrder(Literal(1), Ascending)) + val distribution = ClusteredDistribution(Literal(1) :: Nil) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = SinglePartition), + DummySparkPlan(outputPartitioning = SinglePartition) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(outputOrdering, outputOrdering) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case Exchange(_, _) => true }.nonEmpty) { + fail(s"No Exchanges should have been added:\n$outputPlan") + } + } + + // --------------------------------------------------------------------------------------------- +} + +// Used for unit-testing EnsureRequirements +private case class DummySparkPlan( + override val children: Seq[SparkPlan] = Nil, + override val outputOrdering: Seq[SortOrder] = Nil, + override val outputPartitioning: Partitioning = UnknownPartitioning(0), + override val requiredChildDistribution: Seq[Distribution] = Nil, + override val requiredChildOrdering: Seq[Seq[SortOrder]] = Nil + ) extends SparkPlan { + override protected def doExecute(): RDD[InternalRow] = throw new NotImplementedError + override def output: Seq[Attribute] = Seq.empty } From 46025616b414eaf1da01fcc1255d8041ea1554bc Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Sun, 9 Aug 2015 14:30:30 -0700 Subject: [PATCH 244/340] [CORE] [SPARK-9760] Use Option instead of Some for Ivy repos This was introduced in #7599 cc rxin brkyvz Author: Shivaram Venkataraman Closes #8055 from shivaram/spark-packages-repo-fix and squashes the following commits: 890f306 [Shivaram Venkataraman] Remove test case 51d69ee [Shivaram Venkataraman] Add test case for --packages without --repository c02e0b4 [Shivaram Venkataraman] Use Option instead of Some for Ivy repos --- core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 1186bed485250..7ac6cbce4cd1d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -284,7 +284,7 @@ object SparkSubmit { Nil } val resolvedMavenCoordinates = SparkSubmitUtils.resolveMavenCoordinates(args.packages, - Some(args.repositories), Some(args.ivyRepoPath), exclusions = exclusions) + Option(args.repositories), Option(args.ivyRepoPath), exclusions = exclusions) if (!StringUtils.isBlank(resolvedMavenCoordinates)) { args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates) if (args.isPython) { From be80def0d07ed0f45d60453f4f82500d8c4c9106 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sun, 9 Aug 2015 22:33:53 -0700 Subject: [PATCH 245/340] [SPARK-9777] [SQL] Window operator can accept UnsafeRows https://issues.apache.org/jira/browse/SPARK-9777 Author: Yin Huai Closes #8064 from yhuai/windowUnsafe and squashes the following commits: 8fb3537 [Yin Huai] Set canProcessUnsafeRows to true. --- .../src/main/scala/org/apache/spark/sql/execution/Window.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index fe9f2c7028171..0269d6d4b7a1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -101,6 +101,8 @@ case class Window( override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override def canProcessUnsafeRows: Boolean = true + /** * Create a bound ordering object for a given frame type and offset. A bound ordering object is * used to determine which input row lies within the frame boundaries of an output row. From e3fef0f9e17b1766a3869cb80ce7e4cd521cb7b6 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 10 Aug 2015 09:07:08 -0700 Subject: [PATCH 246/340] [SPARK-9743] [SQL] Fixes JSONRelation refreshing PR #7696 added two `HadoopFsRelation.refresh()` calls ([this] [1], and [this] [2]) in `DataSourceStrategy` to make test case `InsertSuite.save directly to the path of a JSON table` pass. However, this forces every `HadoopFsRelation` table scan to do a refresh, which can be super expensive for tables with large number of partitions. The reason why the original test case fails without the `refresh()` calls is that, the old JSON relation builds the base RDD with the input paths, while `HadoopFsRelation` provides `FileStatus`es of leaf files. With the old JSON relation, we can create a temporary table based on a path, writing data to that, and then read newly written data without refreshing the table. This is no long true for `HadoopFsRelation`. This PR removes those two expensive refresh calls, and moves the refresh into `JSONRelation` to fix this issue. We might want to update `HadoopFsRelation` interface to provide better support for this use case. [1]: https://github.com/apache/spark/blob/ebfd91c542aaead343cb154277fcf9114382fee7/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala#L63 [2]: https://github.com/apache/spark/blob/ebfd91c542aaead343cb154277fcf9114382fee7/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala#L91 Author: Cheng Lian Closes #8035 from liancheng/spark-9743/fix-json-relation-refreshing and squashes the following commits: ec1957d [Cheng Lian] Fixes JSONRelation refreshing --- .../datasources/DataSourceStrategy.scala | 2 -- .../apache/spark/sql/json/JSONRelation.scala | 19 +++++++++++++++---- .../apache/spark/sql/sources/interfaces.scala | 2 +- .../spark/sql/sources/InsertSuite.scala | 10 +++++----- 4 files changed, 21 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 5b5fa8c93ec52..78a4acdf4b1bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -60,7 +60,6 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // Scanning partitioned HadoopFsRelation case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation)) if t.partitionSpec.partitionColumns.nonEmpty => - t.refresh() val selectedPartitions = prunePartitions(filters, t.partitionSpec).toArray logInfo { @@ -88,7 +87,6 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // Scanning non-partitioned HadoopFsRelation case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation)) => - t.refresh() // See buildPartitionedTableScan for the reason that we need to create a shard // broadcast HadoopConf. val sharedHadoopConf = SparkHadoopUtil.get.conf diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index b34a272ec547f..5bb9e62310a50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -22,20 +22,22 @@ import java.io.CharArrayWriter import com.fasterxml.jackson.core.JsonFactory import com.google.common.base.Objects import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.{Text, LongWritable, NullWritable} +import org.apache.hadoop.io.{LongWritable, NullWritable, Text} import org.apache.hadoop.mapred.{JobConf, TextInputFormat} -import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat -import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext, Job} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat +import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} + import org.apache.spark.Logging +import org.apache.spark.broadcast.Broadcast import org.apache.spark.mapred.SparkHadoopMapRedUtil - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.util.SerializableConfiguration private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { @@ -108,6 +110,15 @@ private[sql] class JSONRelation( jsonSchema } + override private[sql] def buildScan( + requiredColumns: Array[String], + filters: Array[Filter], + inputPaths: Array[String], + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { + refresh() + super.buildScan(requiredColumns, filters, inputPaths, broadcastedConf) + } + override def buildScan( requiredColumns: Array[String], filters: Array[Filter], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 4aafec0e2df27..6bcabbab4f77b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -555,7 +555,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio }) } - private[sql] final def buildScan( + private[sql] def buildScan( requiredColumns: Array[String], filters: Array[Filter], inputPaths: Array[String], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 39d18d712ef8c..cdbfaf6455fe4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -32,9 +32,9 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { var path: File = null - override def beforeAll: Unit = { + override def beforeAll(): Unit = { path = Utils.createTempDir() - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) caseInsensitiveContext.read.json(rdd).registerTempTable("jt") sql( s""" @@ -46,7 +46,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { """.stripMargin) } - override def afterAll: Unit = { + override def afterAll(): Unit = { caseInsensitiveContext.dropTempTable("jsonTable") caseInsensitiveContext.dropTempTable("jt") Utils.deleteRecursively(path) @@ -110,7 +110,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { ) // Writing the table to less part files. - val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 5) + val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""), 5) caseInsensitiveContext.read.json(rdd1).registerTempTable("jt1") sql( s""" @@ -122,7 +122,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { ) // Writing the table to more part files. - val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 10) + val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""), 10) caseInsensitiveContext.read.json(rdd2).registerTempTable("jt2") sql( s""" From 0f3366a4c740147a7a7519922642912e2dd238f8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 10 Aug 2015 10:10:40 -0700 Subject: [PATCH 247/340] [SPARK-9710] [TEST] Fix RPackageUtilsSuite when R is not available. RUtils.isRInstalled throws an exception if R is not installed, instead of returning false. Fix that. Author: Marcelo Vanzin Closes #8008 from vanzin/SPARK-9710 and squashes the following commits: df72d8c [Marcelo Vanzin] [SPARK-9710] [test] Fix RPackageUtilsSuite when R is not available. --- core/src/main/scala/org/apache/spark/api/r/RUtils.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala index 93b3bea578676..427b2bc7cbcbb 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -67,7 +67,11 @@ private[spark] object RUtils { /** Check if R is installed before running tests that use R commands. */ def isRInstalled: Boolean = { - val builder = new ProcessBuilder(Seq("R", "--version")) - builder.start().waitFor() == 0 + try { + val builder = new ProcessBuilder(Seq("R", "--version")) + builder.start().waitFor() == 0 + } catch { + case e: Exception => false + } } } From 00b655cced637e1c3b750c19266086b9dcd7c158 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Mon, 10 Aug 2015 11:01:45 -0700 Subject: [PATCH 248/340] [SPARK-9755] [MLLIB] Add docs to MultivariateOnlineSummarizer methods Adds method documentations back to `MultivariateOnlineSummarizer`, which were present in 1.4 but disappeared somewhere along the way to 1.5. jkbradley Author: Feynman Liang Closes #8045 from feynmanliang/SPARK-9755 and squashes the following commits: af67fde [Feynman Liang] Add MultivariateOnlineSummarizer docs --- .../stat/MultivariateOnlineSummarizer.scala | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 62da9f2ef22a3..64e4be0ebb97e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -153,6 +153,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** + * Sample mean of each dimension. + * * @since 1.1.0 */ override def mean: Vector = { @@ -168,6 +170,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** + * Sample variance of each dimension. + * * @since 1.1.0 */ override def variance: Vector = { @@ -193,11 +197,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** + * Sample size. + * * @since 1.1.0 */ override def count: Long = totalCnt /** + * Number of nonzero elements in each dimension. + * * @since 1.1.0 */ override def numNonzeros: Vector = { @@ -207,6 +215,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** + * Maximum value of each dimension. + * * @since 1.1.0 */ override def max: Vector = { @@ -221,6 +231,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** + * Minimum value of each dimension. + * * @since 1.1.0 */ override def min: Vector = { @@ -235,6 +247,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** + * L2 (Euclidian) norm of each dimension. + * * @since 1.2.0 */ override def normL2: Vector = { @@ -252,6 +266,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** + * L1 norm of each dimension. + * * @since 1.2.0 */ override def normL1: Vector = { From d285212756168200383bf4df2c951bd80a492a7c Mon Sep 17 00:00:00 2001 From: Mahmoud Lababidi Date: Mon, 10 Aug 2015 13:02:01 -0700 Subject: [PATCH 249/340] Fixed AtmoicReference<> Example Author: Mahmoud Lababidi Closes #8076 from lababidi/master and squashes the following commits: af4553b [Mahmoud Lababidi] Fixed AtmoicReference<> Example --- docs/streaming-kafka-integration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 775d508d4879b..7571e22575efd 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -152,7 +152,7 @@ Next, we discuss how to use this approach in your streaming application.
    // Hold a reference to the current offset ranges, so it can be used downstream - final AtomicReference offsetRanges = new AtomicReference(); + final AtomicReference offsetRanges = new AtomicReference<>(); directKafkaStream.transformToPair( new Function, JavaPairRDD>() { From 0fe66744f16854fc8cd8a72174de93a788e3cf6c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 10 Aug 2015 13:05:03 -0700 Subject: [PATCH 250/340] [SPARK-9784] [SQL] Exchange.isUnsafe should check whether codegen and unsafe are enabled Exchange.isUnsafe should check whether codegen and unsafe are enabled. Author: Josh Rosen Closes #8073 from JoshRosen/SPARK-9784 and squashes the following commits: 7a1019f [Josh Rosen] [SPARK-9784] Exchange.isUnsafe should check whether codegen and unsafe are enabled --- .../main/scala/org/apache/spark/sql/execution/Exchange.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index b89e634761eb1..029f2264a6a27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -46,7 +46,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una * Returns true iff we can support the data type, and we are not doing range partitioning. */ private lazy val tungstenMode: Boolean = { - GenerateUnsafeProjection.canSupport(child.schema) && + unsafeEnabled && codegenEnabled && GenerateUnsafeProjection.canSupport(child.schema) && !newPartitioning.isInstanceOf[RangePartitioning] } From 40ed2af587cedadc6e5249031857a922b3b234ca Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 10 Aug 2015 13:49:23 -0700 Subject: [PATCH 251/340] [SPARK-9763][SQL] Minimize exposure of internal SQL classes. There are a few changes in this pull request: 1. Moved all data sources to execution.datasources, except the public JDBC APIs. 2. In order to maintain backward compatibility from 1, added a backward compatibility translation map in data source resolution. 3. Moved ui and metric package into execution. 4. Added more documentation on some internal classes. 5. Renamed DataSourceRegister.format -> shortName. 6. Added "override" modifier on shortName. 7. Removed IntSQLMetric. Author: Reynold Xin Closes #8056 from rxin/SPARK-9763 and squashes the following commits: 9df4801 [Reynold Xin] Removed hardcoded name in test cases. d9babc6 [Reynold Xin] Shorten. e484419 [Reynold Xin] Removed VisibleForTesting. 171b812 [Reynold Xin] MimaExcludes. 2041389 [Reynold Xin] Compile ... 79dda42 [Reynold Xin] Compile. 0818ba3 [Reynold Xin] Removed IntSQLMetric. c46884f [Reynold Xin] Two more fixes. f9aa88d [Reynold Xin] [SPARK-9763][SQL] Minimize exposure of internal SQL classes. --- project/MimaExcludes.scala | 24 +- ...pache.spark.sql.sources.DataSourceRegister | 6 +- .../ui/static/spark-sql-viz.css | 0 .../ui/static/spark-sql-viz.js | 0 .../org/apache/spark/sql/DataFrame.scala | 2 +- .../apache/spark/sql/DataFrameReader.scala | 6 +- .../apache/spark/sql/DataFrameWriter.scala | 6 +- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../spark/sql/execution/SQLExecution.scala | 2 +- .../spark/sql/execution/SparkPlan.scala | 8 +- .../spark/sql/execution/basicOperators.scala | 2 +- .../sql/execution/datasources/DDLParser.scala | 185 +++++++++ .../execution/datasources/DefaultSource.scala | 64 ++++ .../datasources/InsertIntoDataSource.scala | 23 +- .../datasources/ResolvedDataSource.scala | 204 ++++++++++ .../spark/sql/execution/datasources/ddl.scala | 352 +----------------- .../datasources/jdbc/DefaultSource.scala | 62 +++ .../datasources/jdbc/DriverRegistry.scala | 60 +++ .../datasources/jdbc/DriverWrapper.scala | 48 +++ .../datasources}/jdbc/JDBCRDD.scala | 9 +- .../datasources}/jdbc/JDBCRelation.scala | 41 +- .../datasources/jdbc/JdbcUtils.scala | 219 +++++++++++ .../datasources}/json/InferSchema.scala | 4 +- .../datasources}/json/JSONRelation.scala | 7 +- .../datasources}/json/JacksonGenerator.scala | 2 +- .../datasources}/json/JacksonParser.scala | 4 +- .../datasources}/json/JacksonUtils.scala | 2 +- .../parquet/CatalystReadSupport.scala | 2 +- .../parquet/CatalystRecordMaterializer.scala | 2 +- .../parquet/CatalystRowConverter.scala | 2 +- .../parquet/CatalystSchemaConverter.scala | 4 +- .../DirectParquetOutputCommitter.scala | 2 +- .../parquet/ParquetConverter.scala | 2 +- .../datasources}/parquet/ParquetFilters.scala | 2 +- .../parquet/ParquetRelation.scala | 4 +- .../parquet/ParquetTableSupport.scala | 2 +- .../parquet/ParquetTypesConverter.scala | 2 +- .../{ => execution}/metric/SQLMetrics.scala | 36 +- .../apache/spark/sql/execution/package.scala | 8 +- .../ui/AllExecutionsPage.scala | 2 +- .../{ => execution}/ui/ExecutionPage.scala | 2 +- .../sql/{ => execution}/ui/SQLListener.scala | 7 +- .../spark/sql/{ => execution}/ui/SQLTab.scala | 6 +- .../{ => execution}/ui/SparkPlanGraph.scala | 4 +- .../org/apache/spark/sql/jdbc/JdbcUtils.scala | 52 --- .../org/apache/spark/sql/jdbc/jdbc.scala | 250 ------------- .../apache/spark/sql/sources/interfaces.scala | 15 +- .../parquet/test/avro/CompatibilityTest.java | 4 +- .../parquet/test/avro/Nested.java | 30 +- .../parquet/test/avro/ParquetAvroCompat.java | 106 +++--- .../org/apache/spark/sql/DataFrameSuite.scala | 4 +- .../datasources}/json/JsonSuite.scala | 4 +- .../datasources}/json/TestJsonData.scala | 2 +- .../ParquetAvroCompatibilitySuite.scala | 4 +- .../parquet/ParquetCompatibilityTest.scala | 2 +- .../parquet/ParquetFilterSuite.scala | 2 +- .../datasources}/parquet/ParquetIOSuite.scala | 4 +- .../ParquetPartitionDiscoverySuite.scala | 2 +- .../parquet/ParquetQuerySuite.scala | 2 +- .../parquet/ParquetSchemaSuite.scala | 2 +- .../datasources}/parquet/ParquetTest.scala | 2 +- .../ParquetThriftCompatibilitySuite.scala | 2 +- .../metric/SQLMetricsSuite.scala | 12 +- .../{ => execution}/ui/SQLListenerSuite.scala | 4 +- .../sources/CreateTableAsSelectSuite.scala | 20 +- .../sql/sources/DDLSourceLoadSuite.scala | 59 +-- .../sql/sources/ResolvedDataSourceSuite.scala | 39 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- .../spark/sql/hive/orc/OrcRelation.scala | 6 +- .../spark/sql/hive/HiveParquetSuite.scala | 2 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 2 +- .../hive/ParquetHiveCompatibilitySuite.scala | 2 +- .../sql/hive/execution/SQLQuerySuite.scala | 2 +- .../apache/spark/sql/hive/parquetSuites.scala | 2 +- .../ParquetHadoopFsRelationSuite.scala | 4 +- .../SimpleTextHadoopFsRelationSuite.scala | 2 +- 76 files changed, 1114 insertions(+), 966 deletions(-) rename sql/core/src/main/resources/org/apache/spark/sql/{ => execution}/ui/static/spark-sql-viz.css (100%) rename sql/core/src/main/resources/org/apache/spark/sql/{ => execution}/ui/static/spark-sql-viz.js (100%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DefaultSource.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverWrapper.scala rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/jdbc/JDBCRDD.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/jdbc/JDBCRelation.scala (71%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/json/InferSchema.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/json/JSONRelation.scala (97%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/json/JacksonGenerator.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/json/JacksonParser.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/json/JacksonUtils.scala (95%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/CatalystReadSupport.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/CatalystRecordMaterializer.scala (96%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/CatalystRowConverter.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/CatalystSchemaConverter.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/DirectParquetOutputCommitter.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetConverter.scala (96%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetFilters.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetRelation.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetTableSupport.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetTypesConverter.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/metric/SQLMetrics.scala (77%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/ui/AllExecutionsPage.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/ui/ExecutionPage.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/ui/SQLListener.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/ui/SQLTab.scala (90%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/ui/SparkPlanGraph.scala (97%) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala rename sql/core/src/test/gen-java/org/apache/spark/sql/{ => execution/datasources}/parquet/test/avro/CompatibilityTest.java (93%) rename sql/core/src/test/gen-java/org/apache/spark/sql/{ => execution/datasources}/parquet/test/avro/Nested.java (78%) rename sql/core/src/test/gen-java/org/apache/spark/sql/{ => execution/datasources}/parquet/test/avro/ParquetAvroCompat.java (83%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/json/JsonSuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/json/TestJsonData.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetAvroCompatibilitySuite.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetCompatibilityTest.scala (97%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetFilterSuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetIOSuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetPartitionDiscoverySuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetQuerySuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetSchemaSuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetTest.scala (98%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetThriftCompatibilitySuite.scala (98%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/metric/SQLMetricsSuite.scala (92%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/ui/SQLListenerSuite.scala (99%) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b60ae784c3798..90261ca3d61aa 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -62,8 +62,6 @@ object MimaExcludes { "org.apache.spark.ml.classification.LogisticCostFun.this"), // SQL execution is considered private. excludePackage("org.apache.spark.sql.execution"), - // Parquet support is considered private. - excludePackage("org.apache.spark.sql.parquet"), // The old JSON RDD is removed in favor of streaming Jackson ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD"), @@ -155,7 +153,27 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException") + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException"), + // SPARK-9763 Minimize exposure of internal SQL classes + excludePackage("org.apache.spark.sql.parquet"), + excludePackage("org.apache.spark.sql.json"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$DecimalConversion$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartition"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JdbcUtils$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$DecimalConversion"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartitioningInfo$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartition$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$JDBCConversion"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package$DriverWrapper"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartitioningInfo"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JdbcUtils"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DefaultSource"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRelation$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRelation") ) ++ Seq( // SPARK-4751 Dynamic allocation for standalone mode ProblemFilters.exclude[MissingMethodProblem]( diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index cc32d4b72748e..ca50000b4756e 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,3 +1,3 @@ -org.apache.spark.sql.jdbc.DefaultSource -org.apache.spark.sql.json.DefaultSource -org.apache.spark.sql.parquet.DefaultSource +org.apache.spark.sql.execution.datasources.jdbc.DefaultSource +org.apache.spark.sql.execution.datasources.json.DefaultSource +org.apache.spark.sql.execution.datasources.parquet.DefaultSource diff --git a/sql/core/src/main/resources/org/apache/spark/sql/ui/static/spark-sql-viz.css b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css similarity index 100% rename from sql/core/src/main/resources/org/apache/spark/sql/ui/static/spark-sql-viz.css rename to sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css diff --git a/sql/core/src/main/resources/org/apache/spark/sql/ui/static/spark-sql-viz.js b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.js similarity index 100% rename from sql/core/src/main/resources/org/apache/spark/sql/ui/static/spark-sql-viz.js rename to sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.js diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 570b8b2d5928d..27b994f1f0caf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} -import org.apache.spark.sql.json.JacksonGenerator +import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 85f33c5e99523..9ea955b010017 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -25,10 +25,10 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} +import org.apache.spark.sql.execution.datasources.json.JSONRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} -import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} -import org.apache.spark.sql.json.JSONRelation -import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.types.StructType import org.apache.spark.{Logging, Partition} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 2a4992db09bc2..5fa11da4c38cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -23,8 +23,8 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource} -import org.apache.spark.sql.jdbc.{JDBCWriteDetails, JdbcUtils} import org.apache.spark.sql.sources.HadoopFsRelation @@ -264,7 +264,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { // Create the table if the table didn't exist. if (!tableExists) { - val schema = JDBCWriteDetails.schemaString(df, url) + val schema = JdbcUtils.schemaString(df, url) val sql = s"CREATE TABLE $table ($schema)" conn.prepareStatement(sql).executeUpdate() } @@ -272,7 +272,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { conn.close() } - JDBCWriteDetails.saveTable(df, url, table, connectionProperties) + JdbcUtils.saveTable(df, url, table, connectionProperties) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 832572571cabd..f73bb0488c984 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.ui.{SQLListener, SQLTab} +import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 97f1323e97835..cee58218a885b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.ui.SparkPlanGraph +import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.util.Utils private[sql] object SQLExecution { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 1915496d16205..9ba5cf2d2b39e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.metric.{IntSQLMetric, LongSQLMetric, SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetric, SQLMetrics} import org.apache.spark.sql.types.DataType object SparkPlan { @@ -98,12 +98,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ */ private[sql] def metrics: Map[String, SQLMetric[_, _]] = defaultMetrics - /** - * Return a IntSQLMetric according to the name. - */ - private[sql] def intMetric(name: String): IntSQLMetric = - metrics(name).asInstanceOf[IntSQLMetric] - /** * Return a LongSQLMetric according to the name. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 24950f26061f7..bf2de244c8e4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.metric.SQLMetrics +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.StructType import org.apache.spark.util.collection.ExternalSorter import org.apache.spark.util.collection.unsafe.sort.PrefixComparator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala new file mode 100644 index 0000000000000..6c462fa30461b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala @@ -0,0 +1,185 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.datasources + +import scala.language.implicitConversions +import scala.util.matching.Regex + +import org.apache.spark.Logging +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.catalyst.{TableIdentifier, AbstractSparkSQLParser} +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types._ + + +/** + * A parser for foreign DDL commands. + */ +class DDLParser(parseQuery: String => LogicalPlan) + extends AbstractSparkSQLParser with DataTypeParser with Logging { + + def parse(input: String, exceptionOnError: Boolean): LogicalPlan = { + try { + parse(input) + } catch { + case ddlException: DDLException => throw ddlException + case _ if !exceptionOnError => parseQuery(input) + case x: Throwable => throw x + } + } + + // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` + // properties via reflection the class in runtime for constructing the SqlLexical object + protected val CREATE = Keyword("CREATE") + protected val TEMPORARY = Keyword("TEMPORARY") + protected val TABLE = Keyword("TABLE") + protected val IF = Keyword("IF") + protected val NOT = Keyword("NOT") + protected val EXISTS = Keyword("EXISTS") + protected val USING = Keyword("USING") + protected val OPTIONS = Keyword("OPTIONS") + protected val DESCRIBE = Keyword("DESCRIBE") + protected val EXTENDED = Keyword("EXTENDED") + protected val AS = Keyword("AS") + protected val COMMENT = Keyword("COMMENT") + protected val REFRESH = Keyword("REFRESH") + + protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable | refreshTable + + protected def start: Parser[LogicalPlan] = ddl + + /** + * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` + * or + * `CREATE [TEMPORARY] TABLE avroTable(intField int, stringField string...) [IF NOT EXISTS] + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` + * or + * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` + * AS SELECT ... + */ + protected lazy val createTable: Parser[LogicalPlan] = { + // TODO: Support database.table. + (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ ident ~ + tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ { + case temp ~ allowExisting ~ tableName ~ columns ~ provider ~ opts ~ query => + if (temp.isDefined && allowExisting.isDefined) { + throw new DDLException( + "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") + } + + val options = opts.getOrElse(Map.empty[String, String]) + if (query.isDefined) { + if (columns.isDefined) { + throw new DDLException( + "a CREATE TABLE AS SELECT statement does not allow column definitions.") + } + // When IF NOT EXISTS clause appears in the query, the save mode will be ignore. + val mode = if (allowExisting.isDefined) { + SaveMode.Ignore + } else if (temp.isDefined) { + SaveMode.Overwrite + } else { + SaveMode.ErrorIfExists + } + + val queryPlan = parseQuery(query.get) + CreateTableUsingAsSelect(tableName, + provider, + temp.isDefined, + Array.empty[String], + mode, + options, + queryPlan) + } else { + val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) + CreateTableUsing( + tableName, + userSpecifiedSchema, + provider, + temp.isDefined, + options, + allowExisting.isDefined, + managedIfNoPath = false) + } + } + } + + protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" + + /* + * describe [extended] table avroTable + * This will display all columns of table `avroTable` includes column_name,column_type,comment + */ + protected lazy val describeTable: Parser[LogicalPlan] = + (DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ { + case e ~ db ~ tbl => + val tblIdentifier = db match { + case Some(dbName) => + Seq(dbName, tbl) + case None => + Seq(tbl) + } + DescribeCommand(UnresolvedRelation(tblIdentifier, None), e.isDefined) + } + + protected lazy val refreshTable: Parser[LogicalPlan] = + REFRESH ~> TABLE ~> (ident <~ ".").? ~ ident ^^ { + case maybeDatabaseName ~ tableName => + RefreshTable(TableIdentifier(tableName, maybeDatabaseName)) + } + + protected lazy val options: Parser[Map[String, String]] = + "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } + + protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} + + override implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( + s"identifier matching regex $regex", { + case lexical.Identifier(str) if regex.unapplySeq(str).isDefined => str + case lexical.Keyword(str) if regex.unapplySeq(str).isDefined => str + } + ) + + protected lazy val optionPart: Parser[String] = "[_a-zA-Z][_a-zA-Z0-9]*".r ^^ { + case name => name + } + + protected lazy val optionName: Parser[String] = repsep(optionPart, ".") ^^ { + case parts => parts.mkString(".") + } + + protected lazy val pair: Parser[(String, String)] = + optionName ~ stringLit ^^ { case k ~ v => (k, v) } + + protected lazy val column: Parser[StructField] = + ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => + val meta = cm match { + case Some(comment) => + new MetadataBuilder().putString(COMMENT.str.toLowerCase, comment).build() + case None => Metadata.empty + } + + StructField(columnName, typ, nullable = true, meta) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DefaultSource.scala new file mode 100644 index 0000000000000..6e4cc4de7f651 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DefaultSource.scala @@ -0,0 +1,64 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.datasources + +import java.util.Properties + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCRelation, JDBCPartitioningInfo, DriverRegistry} +import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider} + + +class DefaultSource extends RelationProvider with DataSourceRegister { + + override def shortName(): String = "jdbc" + + /** Returns a new base relation with the given parameters. */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) + val driver = parameters.getOrElse("driver", null) + val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) + val partitionColumn = parameters.getOrElse("partitionColumn", null) + val lowerBound = parameters.getOrElse("lowerBound", null) + val upperBound = parameters.getOrElse("upperBound", null) + val numPartitions = parameters.getOrElse("numPartitions", null) + + if (driver != null) DriverRegistry.register(driver) + + if (partitionColumn != null + && (lowerBound == null || upperBound == null || numPartitions == null)) { + sys.error("Partitioning incompletely specified") + } + + val partitionInfo = if (partitionColumn == null) { + null + } else { + JDBCPartitioningInfo( + partitionColumn, + lowerBound.toLong, + upperBound.toLong, + numPartitions.toInt) + } + val parts = JDBCRelation.columnPartition(partitionInfo) + val properties = new Properties() // Additional properties that we will pass to getConnection + parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) + JDBCRelation(url, table, parts, properties)(sqlContext) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala index 6ccde7693bd34..3b7dc2e8d0210 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala @@ -17,27 +17,10 @@ package org.apache.spark.sql.execution.datasources -import java.io.IOException -import java.util.{Date, UUID} - -import scala.collection.JavaConversions.asScalaIterator - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter, FileOutputFormat} -import org.apache.spark._ -import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.execution.{RunnableCommand, SQLExecution} -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.StringType -import org.apache.spark.util.{Utils, SerializableConfiguration} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.sources.InsertableRelation /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala new file mode 100644 index 0000000000000..7770bbd712f04 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -0,0 +1,204 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.datasources + +import java.util.ServiceLoader + +import scala.collection.JavaConversions._ +import scala.language.{existentials, implicitConversions} +import scala.util.{Success, Failure, Try} + +import org.apache.hadoop.fs.Path + +import org.apache.spark.Logging +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.{DataFrame, SaveMode, AnalysisException, SQLContext} +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{CalendarIntervalType, StructType} +import org.apache.spark.util.Utils + + +case class ResolvedDataSource(provider: Class[_], relation: BaseRelation) + + +object ResolvedDataSource extends Logging { + + /** A map to maintain backward compatibility in case we move data sources around. */ + private val backwardCompatibilityMap = Map( + "org.apache.spark.sql.jdbc" -> classOf[jdbc.DefaultSource].getCanonicalName, + "org.apache.spark.sql.jdbc.DefaultSource" -> classOf[jdbc.DefaultSource].getCanonicalName, + "org.apache.spark.sql.json" -> classOf[json.DefaultSource].getCanonicalName, + "org.apache.spark.sql.json.DefaultSource" -> classOf[json.DefaultSource].getCanonicalName, + "org.apache.spark.sql.parquet" -> classOf[parquet.DefaultSource].getCanonicalName, + "org.apache.spark.sql.parquet.DefaultSource" -> classOf[parquet.DefaultSource].getCanonicalName + ) + + /** Given a provider name, look up the data source class definition. */ + def lookupDataSource(provider0: String): Class[_] = { + val provider = backwardCompatibilityMap.getOrElse(provider0, provider0) + val provider2 = s"$provider.DefaultSource" + val loader = Utils.getContextOrSparkClassLoader + val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) + + serviceLoader.iterator().filter(_.shortName().equalsIgnoreCase(provider)).toList match { + /** the provider format did not match any given registered aliases */ + case Nil => Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { + case Success(dataSource) => dataSource + case Failure(error) => + if (provider.startsWith("org.apache.spark.sql.hive.orc")) { + throw new ClassNotFoundException( + "The ORC data source must be used with Hive support enabled.", error) + } else { + throw new ClassNotFoundException( + s"Failed to load class for data source: $provider.", error) + } + } + /** there is exactly one registered alias */ + case head :: Nil => head.getClass + /** There are multiple registered aliases for the input */ + case sources => sys.error(s"Multiple sources found for $provider, " + + s"(${sources.map(_.getClass.getName).mkString(", ")}), " + + "please specify the fully qualified class name.") + } + } + + /** Create a [[ResolvedDataSource]] for reading data in. */ + def apply( + sqlContext: SQLContext, + userSpecifiedSchema: Option[StructType], + partitionColumns: Array[String], + provider: String, + options: Map[String, String]): ResolvedDataSource = { + val clazz: Class[_] = lookupDataSource(provider) + def className: String = clazz.getCanonicalName + val relation = userSpecifiedSchema match { + case Some(schema: StructType) => clazz.newInstance() match { + case dataSource: SchemaRelationProvider => + dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema) + case dataSource: HadoopFsRelationProvider => + val maybePartitionsSchema = if (partitionColumns.isEmpty) { + None + } else { + Some(partitionColumnsSchema(schema, partitionColumns)) + } + + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val paths = { + val patternPath = new Path(caseInsensitiveOptions("path")) + val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray + } + + val dataSchema = + StructType(schema.filterNot(f => partitionColumns.contains(f.name))).asNullable + + dataSource.createRelation( + sqlContext, + paths, + Some(dataSchema), + maybePartitionsSchema, + caseInsensitiveOptions) + case dataSource: org.apache.spark.sql.sources.RelationProvider => + throw new AnalysisException(s"$className does not allow user-specified schemas.") + case _ => + throw new AnalysisException(s"$className is not a RelationProvider.") + } + + case None => clazz.newInstance() match { + case dataSource: RelationProvider => + dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) + case dataSource: HadoopFsRelationProvider => + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val paths = { + val patternPath = new Path(caseInsensitiveOptions("path")) + val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray + } + dataSource.createRelation(sqlContext, paths, None, None, caseInsensitiveOptions) + case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => + throw new AnalysisException( + s"A schema needs to be specified when using $className.") + case _ => + throw new AnalysisException( + s"$className is neither a RelationProvider nor a FSBasedRelationProvider.") + } + } + new ResolvedDataSource(clazz, relation) + } + + private def partitionColumnsSchema( + schema: StructType, + partitionColumns: Array[String]): StructType = { + StructType(partitionColumns.map { col => + schema.find(_.name == col).getOrElse { + throw new RuntimeException(s"Partition column $col not found in schema $schema") + } + }).asNullable + } + + /** Create a [[ResolvedDataSource]] for saving the content of the given DataFrame. */ + def apply( + sqlContext: SQLContext, + provider: String, + partitionColumns: Array[String], + mode: SaveMode, + options: Map[String, String], + data: DataFrame): ResolvedDataSource = { + if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { + throw new AnalysisException("Cannot save interval data type into external storage.") + } + val clazz: Class[_] = lookupDataSource(provider) + val relation = clazz.newInstance() match { + case dataSource: CreatableRelationProvider => + dataSource.createRelation(sqlContext, mode, options, data) + case dataSource: HadoopFsRelationProvider => + // Don't glob path for the write path. The contracts here are: + // 1. Only one output path can be specified on the write path; + // 2. Output path must be a legal HDFS style file system path; + // 3. It's OK that the output path doesn't exist yet; + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val outputPath = { + val path = new Path(caseInsensitiveOptions("path")) + val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + path.makeQualified(fs.getUri, fs.getWorkingDirectory) + } + val dataSchema = StructType(data.schema.filterNot(f => partitionColumns.contains(f.name))) + val r = dataSource.createRelation( + sqlContext, + Array(outputPath.toString), + Some(dataSchema.asNullable), + Some(partitionColumnsSchema(data.schema, partitionColumns)), + caseInsensitiveOptions) + + // For partitioned relation r, r.schema's column ordering can be different from the column + // ordering of data.logicalPlan (partition columns are all moved after data column). This + // will be adjusted within InsertIntoHadoopFsRelation. + sqlContext.executePlan( + InsertIntoHadoopFsRelation( + r, + data.logicalPlan, + mode)).toRdd + r + case _ => + sys.error(s"${clazz.getCanonicalName} does not allow create table as select.") + } + ResolvedDataSource(clazz, relation) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 8c2f297e42458..ecd304c30cdee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -17,340 +17,12 @@ package org.apache.spark.sql.execution.datasources -import java.util.ServiceLoader - -import scala.collection.Iterator -import scala.collection.JavaConversions._ -import scala.language.{existentials, implicitConversions} -import scala.util.{Failure, Success, Try} -import scala.util.matching.Regex - -import org.apache.hadoop.fs.Path - -import org.apache.spark.Logging -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, TableIdentifier} import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SQLContext, SaveMode} -import org.apache.spark.util.Utils - -/** - * A parser for foreign DDL commands. - */ -private[sql] class DDLParser( - parseQuery: String => LogicalPlan) - extends AbstractSparkSQLParser with DataTypeParser with Logging { - - def parse(input: String, exceptionOnError: Boolean): LogicalPlan = { - try { - parse(input) - } catch { - case ddlException: DDLException => throw ddlException - case _ if !exceptionOnError => parseQuery(input) - case x: Throwable => throw x - } - } - - // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` - // properties via reflection the class in runtime for constructing the SqlLexical object - protected val CREATE = Keyword("CREATE") - protected val TEMPORARY = Keyword("TEMPORARY") - protected val TABLE = Keyword("TABLE") - protected val IF = Keyword("IF") - protected val NOT = Keyword("NOT") - protected val EXISTS = Keyword("EXISTS") - protected val USING = Keyword("USING") - protected val OPTIONS = Keyword("OPTIONS") - protected val DESCRIBE = Keyword("DESCRIBE") - protected val EXTENDED = Keyword("EXTENDED") - protected val AS = Keyword("AS") - protected val COMMENT = Keyword("COMMENT") - protected val REFRESH = Keyword("REFRESH") - - protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable | refreshTable - - protected def start: Parser[LogicalPlan] = ddl - - /** - * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * or - * `CREATE [TEMPORARY] TABLE avroTable(intField int, stringField string...) [IF NOT EXISTS] - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * or - * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * AS SELECT ... - */ - protected lazy val createTable: Parser[LogicalPlan] = - // TODO: Support database.table. - (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ ident ~ - tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ { - case temp ~ allowExisting ~ tableName ~ columns ~ provider ~ opts ~ query => - if (temp.isDefined && allowExisting.isDefined) { - throw new DDLException( - "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") - } - - val options = opts.getOrElse(Map.empty[String, String]) - if (query.isDefined) { - if (columns.isDefined) { - throw new DDLException( - "a CREATE TABLE AS SELECT statement does not allow column definitions.") - } - // When IF NOT EXISTS clause appears in the query, the save mode will be ignore. - val mode = if (allowExisting.isDefined) { - SaveMode.Ignore - } else if (temp.isDefined) { - SaveMode.Overwrite - } else { - SaveMode.ErrorIfExists - } - - val queryPlan = parseQuery(query.get) - CreateTableUsingAsSelect(tableName, - provider, - temp.isDefined, - Array.empty[String], - mode, - options, - queryPlan) - } else { - val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) - CreateTableUsing( - tableName, - userSpecifiedSchema, - provider, - temp.isDefined, - options, - allowExisting.isDefined, - managedIfNoPath = false) - } - } - - protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" - - /* - * describe [extended] table avroTable - * This will display all columns of table `avroTable` includes column_name,column_type,comment - */ - protected lazy val describeTable: Parser[LogicalPlan] = - (DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ { - case e ~ db ~ tbl => - val tblIdentifier = db match { - case Some(dbName) => - Seq(dbName, tbl) - case None => - Seq(tbl) - } - DescribeCommand(UnresolvedRelation(tblIdentifier, None), e.isDefined) - } - - protected lazy val refreshTable: Parser[LogicalPlan] = - REFRESH ~> TABLE ~> (ident <~ ".").? ~ ident ^^ { - case maybeDatabaseName ~ tableName => - RefreshTable(TableIdentifier(tableName, maybeDatabaseName)) - } - - protected lazy val options: Parser[Map[String, String]] = - "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } - - protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} - - override implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( - s"identifier matching regex $regex", { - case lexical.Identifier(str) if regex.unapplySeq(str).isDefined => str - case lexical.Keyword(str) if regex.unapplySeq(str).isDefined => str - } - ) - - protected lazy val optionPart: Parser[String] = "[_a-zA-Z][_a-zA-Z0-9]*".r ^^ { - case name => name - } - - protected lazy val optionName: Parser[String] = repsep(optionPart, ".") ^^ { - case parts => parts.mkString(".") - } - - protected lazy val pair: Parser[(String, String)] = - optionName ~ stringLit ^^ { case k ~ v => (k, v) } - - protected lazy val column: Parser[StructField] = - ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => - val meta = cm match { - case Some(comment) => - new MetadataBuilder().putString(COMMENT.str.toLowerCase, comment).build() - case None => Metadata.empty - } - - StructField(columnName, typ, nullable = true, meta) - } -} - -private[sql] object ResolvedDataSource extends Logging { - - /** Given a provider name, look up the data source class definition. */ - def lookupDataSource(provider: String): Class[_] = { - val provider2 = s"$provider.DefaultSource" - val loader = Utils.getContextOrSparkClassLoader - val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) - - serviceLoader.iterator().filter(_.format().equalsIgnoreCase(provider)).toList match { - /** the provider format did not match any given registered aliases */ - case Nil => Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { - case Success(dataSource) => dataSource - case Failure(error) => if (provider.startsWith("org.apache.spark.sql.hive.orc")) { - throw new ClassNotFoundException( - "The ORC data source must be used with Hive support enabled.", error) - } else { - throw new ClassNotFoundException( - s"Failed to load class for data source: $provider", error) - } - } - /** there is exactly one registered alias */ - case head :: Nil => head.getClass - /** There are multiple registered aliases for the input */ - case sources => sys.error(s"Multiple sources found for $provider, " + - s"(${sources.map(_.getClass.getName).mkString(", ")}), " + - "please specify the fully qualified class name") - } - } - - /** Create a [[ResolvedDataSource]] for reading data in. */ - def apply( - sqlContext: SQLContext, - userSpecifiedSchema: Option[StructType], - partitionColumns: Array[String], - provider: String, - options: Map[String, String]): ResolvedDataSource = { - val clazz: Class[_] = lookupDataSource(provider) - def className: String = clazz.getCanonicalName - val relation = userSpecifiedSchema match { - case Some(schema: StructType) => clazz.newInstance() match { - case dataSource: SchemaRelationProvider => - dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema) - case dataSource: HadoopFsRelationProvider => - val maybePartitionsSchema = if (partitionColumns.isEmpty) { - None - } else { - Some(partitionColumnsSchema(schema, partitionColumns)) - } - - val caseInsensitiveOptions = new CaseInsensitiveMap(options) - val paths = { - val patternPath = new Path(caseInsensitiveOptions("path")) - val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray - } - - val dataSchema = - StructType(schema.filterNot(f => partitionColumns.contains(f.name))).asNullable - - dataSource.createRelation( - sqlContext, - paths, - Some(dataSchema), - maybePartitionsSchema, - caseInsensitiveOptions) - case dataSource: org.apache.spark.sql.sources.RelationProvider => - throw new AnalysisException(s"$className does not allow user-specified schemas.") - case _ => - throw new AnalysisException(s"$className is not a RelationProvider.") - } - - case None => clazz.newInstance() match { - case dataSource: RelationProvider => - dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) - case dataSource: HadoopFsRelationProvider => - val caseInsensitiveOptions = new CaseInsensitiveMap(options) - val paths = { - val patternPath = new Path(caseInsensitiveOptions("path")) - val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray - } - dataSource.createRelation(sqlContext, paths, None, None, caseInsensitiveOptions) - case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => - throw new AnalysisException( - s"A schema needs to be specified when using $className.") - case _ => - throw new AnalysisException( - s"$className is neither a RelationProvider nor a FSBasedRelationProvider.") - } - } - new ResolvedDataSource(clazz, relation) - } - - private def partitionColumnsSchema( - schema: StructType, - partitionColumns: Array[String]): StructType = { - StructType(partitionColumns.map { col => - schema.find(_.name == col).getOrElse { - throw new RuntimeException(s"Partition column $col not found in schema $schema") - } - }).asNullable - } - - /** Create a [[ResolvedDataSource]] for saving the content of the given [[DataFrame]]. */ - def apply( - sqlContext: SQLContext, - provider: String, - partitionColumns: Array[String], - mode: SaveMode, - options: Map[String, String], - data: DataFrame): ResolvedDataSource = { - if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { - throw new AnalysisException("Cannot save interval data type into external storage.") - } - val clazz: Class[_] = lookupDataSource(provider) - val relation = clazz.newInstance() match { - case dataSource: CreatableRelationProvider => - dataSource.createRelation(sqlContext, mode, options, data) - case dataSource: HadoopFsRelationProvider => - // Don't glob path for the write path. The contracts here are: - // 1. Only one output path can be specified on the write path; - // 2. Output path must be a legal HDFS style file system path; - // 3. It's OK that the output path doesn't exist yet; - val caseInsensitiveOptions = new CaseInsensitiveMap(options) - val outputPath = { - val path = new Path(caseInsensitiveOptions("path")) - val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - path.makeQualified(fs.getUri, fs.getWorkingDirectory) - } - val dataSchema = StructType(data.schema.filterNot(f => partitionColumns.contains(f.name))) - val r = dataSource.createRelation( - sqlContext, - Array(outputPath.toString), - Some(dataSchema.asNullable), - Some(partitionColumnsSchema(data.schema, partitionColumns)), - caseInsensitiveOptions) - - // For partitioned relation r, r.schema's column ordering can be different from the column - // ordering of data.logicalPlan (partition columns are all moved after data column). This - // will be adjusted within InsertIntoHadoopFsRelation. - sqlContext.executePlan( - InsertIntoHadoopFsRelation( - r, - data.logicalPlan, - mode)).toRdd - r - case _ => - sys.error(s"${clazz.getCanonicalName} does not allow create table as select.") - } - new ResolvedDataSource(clazz, relation) - } -} - -private[sql] case class ResolvedDataSource(provider: Class[_], relation: BaseRelation) +import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} /** * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command. @@ -358,11 +30,12 @@ private[sql] case class ResolvedDataSource(provider: Class[_], relation: BaseRel * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false. * It is effective only when the table is a Hive table. */ -private[sql] case class DescribeCommand( +case class DescribeCommand( table: LogicalPlan, isExtended: Boolean) extends LogicalPlan with Command { override def children: Seq[LogicalPlan] = Seq.empty + override val output: Seq[Attribute] = Seq( // Column names are based on Hive. AttributeReference("col_name", StringType, nullable = false, @@ -370,7 +43,8 @@ private[sql] case class DescribeCommand( AttributeReference("data_type", StringType, nullable = false, new MetadataBuilder().putString("comment", "data type of the column").build())(), AttributeReference("comment", StringType, nullable = false, - new MetadataBuilder().putString("comment", "comment of the column").build())()) + new MetadataBuilder().putString("comment", "comment of the column").build())() + ) } /** @@ -378,7 +52,7 @@ private[sql] case class DescribeCommand( * @param allowExisting If it is true, we will do nothing when the table already exists. * If it is false, an exception will be thrown */ -private[sql] case class CreateTableUsing( +case class CreateTableUsing( tableName: String, userSpecifiedSchema: Option[StructType], provider: String, @@ -397,7 +71,7 @@ private[sql] case class CreateTableUsing( * can analyze the logical plan that will be used to populate the table. * So, [[PreWriteCheck]] can detect cases that are not allowed. */ -private[sql] case class CreateTableUsingAsSelect( +case class CreateTableUsingAsSelect( tableName: String, provider: String, temporary: Boolean, @@ -410,7 +84,7 @@ private[sql] case class CreateTableUsingAsSelect( // override lazy val resolved = databaseName != None && childrenResolved } -private[sql] case class CreateTempTableUsing( +case class CreateTempTableUsing( tableName: String, userSpecifiedSchema: Option[StructType], provider: String, @@ -425,7 +99,7 @@ private[sql] case class CreateTempTableUsing( } } -private[sql] case class CreateTempTableUsingAsSelect( +case class CreateTempTableUsingAsSelect( tableName: String, provider: String, partitionColumns: Array[String], @@ -443,7 +117,7 @@ private[sql] case class CreateTempTableUsingAsSelect( } } -private[sql] case class RefreshTable(tableIdent: TableIdentifier) +case class RefreshTable(tableIdent: TableIdentifier) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { @@ -472,7 +146,7 @@ private[sql] case class RefreshTable(tableIdent: TableIdentifier) /** * Builds a map in which keys are case insensitive */ -protected[sql] class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] +class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] with Serializable { val baseMap = map.map(kv => kv.copy(_1 = kv._1.toLowerCase)) @@ -490,4 +164,4 @@ protected[sql] class CaseInsensitiveMap(map: Map[String, String]) extends Map[St /** * The exception thrown from the DDL parser. */ -protected[sql] class DDLException(message: String) extends Exception(message) +class DDLException(message: String) extends RuntimeException(message) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala new file mode 100644 index 0000000000000..6773afc794f9c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala @@ -0,0 +1,62 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.datasources.jdbc + +import java.util.Properties + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.sources.{BaseRelation, RelationProvider, DataSourceRegister} + +class DefaultSource extends RelationProvider with DataSourceRegister { + + override def shortName(): String = "jdbc" + + /** Returns a new base relation with the given parameters. */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) + val driver = parameters.getOrElse("driver", null) + val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) + val partitionColumn = parameters.getOrElse("partitionColumn", null) + val lowerBound = parameters.getOrElse("lowerBound", null) + val upperBound = parameters.getOrElse("upperBound", null) + val numPartitions = parameters.getOrElse("numPartitions", null) + + if (driver != null) DriverRegistry.register(driver) + + if (partitionColumn != null + && (lowerBound == null || upperBound == null || numPartitions == null)) { + sys.error("Partitioning incompletely specified") + } + + val partitionInfo = if (partitionColumn == null) { + null + } else { + JDBCPartitioningInfo( + partitionColumn, + lowerBound.toLong, + upperBound.toLong, + numPartitions.toInt) + } + val parts = JDBCRelation.columnPartition(partitionInfo) + val properties = new Properties() // Additional properties that we will pass to getConnection + parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) + JDBCRelation(url, table, parts, properties)(sqlContext) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala new file mode 100644 index 0000000000000..7ccd61ed469e9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc + +import java.sql.{Driver, DriverManager} + +import scala.collection.mutable + +import org.apache.spark.Logging +import org.apache.spark.util.Utils + +/** + * java.sql.DriverManager is always loaded by bootstrap classloader, + * so it can't load JDBC drivers accessible by Spark ClassLoader. + * + * To solve the problem, drivers from user-supplied jars are wrapped into thin wrapper. + */ +object DriverRegistry extends Logging { + + private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty + + def register(className: String): Unit = { + val cls = Utils.getContextOrSparkClassLoader.loadClass(className) + if (cls.getClassLoader == null) { + logTrace(s"$className has been loaded with bootstrap ClassLoader, wrapper is not required") + } else if (wrapperMap.get(className).isDefined) { + logTrace(s"Wrapper for $className already exists") + } else { + synchronized { + if (wrapperMap.get(className).isEmpty) { + val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver]) + DriverManager.registerDriver(wrapper) + wrapperMap(className) = wrapper + logTrace(s"Wrapper for $className registered") + } + } + } + } + + def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { + case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName + case driver => driver.getClass.getCanonicalName + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverWrapper.scala new file mode 100644 index 0000000000000..18263fe227d04 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverWrapper.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc + +import java.sql.{Connection, Driver, DriverPropertyInfo, SQLFeatureNotSupportedException} +import java.util.Properties + +/** + * A wrapper for a JDBC Driver to work around SPARK-6913. + * + * The problem is in `java.sql.DriverManager` class that can't access drivers loaded by + * Spark ClassLoader. + */ +class DriverWrapper(val wrapped: Driver) extends Driver { + override def acceptsURL(url: String): Boolean = wrapped.acceptsURL(url) + + override def jdbcCompliant(): Boolean = wrapped.jdbcCompliant() + + override def getPropertyInfo(url: String, info: Properties): Array[DriverPropertyInfo] = { + wrapped.getPropertyInfo(url, info) + } + + override def getMinorVersion: Int = wrapped.getMinorVersion + + def getParentLogger: java.util.logging.Logger = { + throw new SQLFeatureNotSupportedException( + s"${this.getClass.getName}.getParentLogger is not yet implemented.") + } + + override def connect(url: String, info: Properties): Connection = wrapped.connect(url, info) + + override def getMajorVersion: Int = wrapped.getMajorVersion +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 3cf70db6b7b09..8eab6a0adccc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.jdbc +package org.apache.spark.sql.execution.datasources.jdbc import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException} import java.util.Properties @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -180,9 +181,8 @@ private[sql] object JDBCRDD extends Logging { try { if (driver != null) DriverRegistry.register(driver) } catch { - case e: ClassNotFoundException => { - logWarning(s"Couldn't find class $driver", e); - } + case e: ClassNotFoundException => + logWarning(s"Couldn't find class $driver", e) } DriverManager.getConnection(url, properties) } @@ -344,7 +344,6 @@ private[sql] class JDBCRDD( }).toArray } - /** * Runs the SQL query against the JDBC driver. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala similarity index 71% rename from sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 48d97ced9ca0a..f9300dc2cb529 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.jdbc +package org.apache.spark.sql.execution.datasources.jdbc import java.util.Properties @@ -77,45 +77,6 @@ private[sql] object JDBCRelation { } } -private[sql] class DefaultSource extends RelationProvider with DataSourceRegister { - - def format(): String = "jdbc" - - /** Returns a new base relation with the given parameters. */ - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { - val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) - val driver = parameters.getOrElse("driver", null) - val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) - val partitionColumn = parameters.getOrElse("partitionColumn", null) - val lowerBound = parameters.getOrElse("lowerBound", null) - val upperBound = parameters.getOrElse("upperBound", null) - val numPartitions = parameters.getOrElse("numPartitions", null) - - if (driver != null) DriverRegistry.register(driver) - - if (partitionColumn != null - && (lowerBound == null || upperBound == null || numPartitions == null)) { - sys.error("Partitioning incompletely specified") - } - - val partitionInfo = if (partitionColumn == null) { - null - } else { - JDBCPartitioningInfo( - partitionColumn, - lowerBound.toLong, - upperBound.toLong, - numPartitions.toInt) - } - val parts = JDBCRelation.columnPartition(partitionInfo) - val properties = new Properties() // Additional properties that we will pass to getConnection - parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) - JDBCRelation(url, table, parts, properties)(sqlContext) - } -} - private[sql] case class JDBCRelation( url: String, table: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala new file mode 100644 index 0000000000000..039c13bf163ca --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc + +import java.sql.{Connection, DriverManager, PreparedStatement} +import java.util.Properties + +import scala.util.Try + +import org.apache.spark.Logging +import org.apache.spark.sql.jdbc.JdbcDialects +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row} + +/** + * Util functions for JDBC tables. + */ +object JdbcUtils extends Logging { + + /** + * Establishes a JDBC connection. + */ + def createConnection(url: String, connectionProperties: Properties): Connection = { + DriverManager.getConnection(url, connectionProperties) + } + + /** + * Returns true if the table already exists in the JDBC database. + */ + def tableExists(conn: Connection, table: String): Boolean = { + // Somewhat hacky, but there isn't a good way to identify whether a table exists for all + // SQL database systems, considering "table" could also include the database name. + Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess + } + + /** + * Drops a table from the JDBC database. + */ + def dropTable(conn: Connection, table: String): Unit = { + conn.prepareStatement(s"DROP TABLE $table").executeUpdate() + } + + /** + * Returns a PreparedStatement that inserts a row into table via conn. + */ + def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = { + val sql = new StringBuilder(s"INSERT INTO $table VALUES (") + var fieldsLeft = rddSchema.fields.length + while (fieldsLeft > 0) { + sql.append("?") + if (fieldsLeft > 1) sql.append(", ") else sql.append(")") + fieldsLeft = fieldsLeft - 1 + } + conn.prepareStatement(sql.toString()) + } + + /** + * Saves a partition of a DataFrame to the JDBC database. This is done in + * a single database transaction in order to avoid repeatedly inserting + * data as much as possible. + * + * It is still theoretically possible for rows in a DataFrame to be + * inserted into the database more than once if a stage somehow fails after + * the commit occurs but before the stage can return successfully. + * + * This is not a closure inside saveTable() because apparently cosmetic + * implementation changes elsewhere might easily render such a closure + * non-Serializable. Instead, we explicitly close over all variables that + * are used. + */ + def savePartition( + getConnection: () => Connection, + table: String, + iterator: Iterator[Row], + rddSchema: StructType, + nullTypes: Array[Int]): Iterator[Byte] = { + val conn = getConnection() + var committed = false + try { + conn.setAutoCommit(false) // Everything in the same db transaction. + val stmt = insertStatement(conn, table, rddSchema) + try { + while (iterator.hasNext) { + val row = iterator.next() + val numFields = rddSchema.fields.length + var i = 0 + while (i < numFields) { + if (row.isNullAt(i)) { + stmt.setNull(i + 1, nullTypes(i)) + } else { + rddSchema.fields(i).dataType match { + case IntegerType => stmt.setInt(i + 1, row.getInt(i)) + case LongType => stmt.setLong(i + 1, row.getLong(i)) + case DoubleType => stmt.setDouble(i + 1, row.getDouble(i)) + case FloatType => stmt.setFloat(i + 1, row.getFloat(i)) + case ShortType => stmt.setInt(i + 1, row.getShort(i)) + case ByteType => stmt.setInt(i + 1, row.getByte(i)) + case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i)) + case StringType => stmt.setString(i + 1, row.getString(i)) + case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i)) + case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) + case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) + case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) + case _ => throw new IllegalArgumentException( + s"Can't translate non-null value for field $i") + } + } + i = i + 1 + } + stmt.executeUpdate() + } + } finally { + stmt.close() + } + conn.commit() + committed = true + } finally { + if (!committed) { + // The stage must fail. We got here through an exception path, so + // let the exception through unless rollback() or close() want to + // tell the user about another problem. + conn.rollback() + conn.close() + } else { + // The stage must succeed. We cannot propagate any exception close() might throw. + try { + conn.close() + } catch { + case e: Exception => logWarning("Transaction succeeded, but closing failed", e) + } + } + } + Array[Byte]().iterator + } + + /** + * Compute the schema string for this RDD. + */ + def schemaString(df: DataFrame, url: String): String = { + val sb = new StringBuilder() + val dialect = JdbcDialects.get(url) + df.schema.fields foreach { field => { + val name = field.name + val typ: String = + dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse( + field.dataType match { + case IntegerType => "INTEGER" + case LongType => "BIGINT" + case DoubleType => "DOUBLE PRECISION" + case FloatType => "REAL" + case ShortType => "INTEGER" + case ByteType => "BYTE" + case BooleanType => "BIT(1)" + case StringType => "TEXT" + case BinaryType => "BLOB" + case TimestampType => "TIMESTAMP" + case DateType => "DATE" + case t: DecimalType => s"DECIMAL(${t.precision}},${t.scale}})" + case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") + }) + val nullable = if (field.nullable) "" else "NOT NULL" + sb.append(s", $name $typ $nullable") + }} + if (sb.length < 2) "" else sb.substring(2) + } + + /** + * Saves the RDD to the database in a single transaction. + */ + def saveTable( + df: DataFrame, + url: String, + table: String, + properties: Properties = new Properties()) { + val dialect = JdbcDialects.get(url) + val nullTypes: Array[Int] = df.schema.fields.map { field => + dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse( + field.dataType match { + case IntegerType => java.sql.Types.INTEGER + case LongType => java.sql.Types.BIGINT + case DoubleType => java.sql.Types.DOUBLE + case FloatType => java.sql.Types.REAL + case ShortType => java.sql.Types.INTEGER + case ByteType => java.sql.Types.INTEGER + case BooleanType => java.sql.Types.BIT + case StringType => java.sql.Types.CLOB + case BinaryType => java.sql.Types.BLOB + case TimestampType => java.sql.Types.TIMESTAMP + case DateType => java.sql.Types.DATE + case t: DecimalType => java.sql.Types.DECIMAL + case _ => throw new IllegalArgumentException( + s"Can't translate null value for field $field") + }) + } + + val rddSchema = df.schema + val driver: String = DriverRegistry.getDriverClassName(url) + val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) + df.foreachPartition { iterator => + savePartition(getConnection, table, iterator, rddSchema, nullTypes) + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index ec5668c6b95a1..b6f3410bad690 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import com.fasterxml.jackson.core._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion -import org.apache.spark.sql.json.JacksonUtils.nextUntil +import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ private[sql] object InferSchema { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 5bb9e62310a50..114c8b211891e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import java.io.CharArrayWriter @@ -39,9 +39,10 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.util.SerializableConfiguration -private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { - def format(): String = "json" +class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + override def shortName(): String = "json" override def createRelation( sqlContext: SQLContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index d734e7e8904bd..37c2b5a296c15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import org.apache.spark.sql.catalyst.InternalRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index b8fd3b9cc150e..cd68bd667c5c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import java.io.ByteArrayOutputStream @@ -27,7 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.json.JacksonUtils.nextUntil +import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonUtils.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonUtils.scala index fde96852ce68e..005546f37dda0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonUtils.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import com.fasterxml.jackson.core.{JsonParser, JsonToken} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala index 975fec101d9c2..4049795ed3bad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.util.{Map => JMap} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala index 84f1dccfeb788..ed9e0aa65977b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import org.apache.parquet.io.api.{GroupConverter, RecordMaterializer} import org.apache.parquet.schema.MessageType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 4fe8a39f20abd..3542dfbae1292 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.math.{BigDecimal, BigInteger} import java.nio.ByteOrder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index b12149dcf1c92..a3fc74cf7929b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import scala.collection.JavaConversions._ @@ -25,7 +25,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.parquet.schema.Type.Repetition._ import org.apache.parquet.schema._ -import org.apache.spark.sql.parquet.CatalystSchemaConverter.{MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64, maxPrecisionForBytes} +import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64, maxPrecisionForBytes} import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLConf} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala index 1551afd7b7bf2..2c6b914328b60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetConverter.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetConverter.scala index 6ed3580af0729..ccd7ebf319af9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetConverter.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.{MapData, ArrayData} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index d57b789f5c1c7..9e2e232f50167 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.Serializable import java.nio.ByteBuffer diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index b6db71b5b8a62..4086a139bed72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.net.URI import java.util.logging.{Level, Logger => JLogger} @@ -51,7 +51,7 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { - def format(): String = "parquet" + override def shortName(): String = "parquet" override def createRelation( sqlContext: SQLContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTableSupport.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTableSupport.scala index 9cd0250f9c510..3191cf3d121bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTableSupport.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.math.BigInteger import java.nio.{ByteBuffer, ByteOrder} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypesConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypesConverter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala index 3854f5bd39fb1..019db34fc666d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypesConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.IOException diff --git a/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala similarity index 77% rename from sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 3b907e5da7897..1b51a5e5c8a8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.metric +package org.apache.spark.sql.execution.metric import org.apache.spark.{Accumulable, AccumulableParam, SparkContext} @@ -93,22 +93,6 @@ private[sql] class LongSQLMetric private[metric](name: String) } } -/** - * A specialized int Accumulable to avoid boxing and unboxing when using Accumulator's - * `+=` and `add`. - */ -private[sql] class IntSQLMetric private[metric](name: String) - extends SQLMetric[IntSQLMetricValue, Int](name, IntSQLMetricParam) { - - override def +=(term: Int): Unit = { - localValue.add(term) - } - - override def add(term: Int): Unit = { - localValue.add(term) - } -} - private object LongSQLMetricParam extends SQLMetricParam[LongSQLMetricValue, Long] { override def addAccumulator(r: LongSQLMetricValue, t: Long): LongSQLMetricValue = r.add(t) @@ -121,26 +105,8 @@ private object LongSQLMetricParam extends SQLMetricParam[LongSQLMetricValue, Lon override def zero: LongSQLMetricValue = new LongSQLMetricValue(0L) } -private object IntSQLMetricParam extends SQLMetricParam[IntSQLMetricValue, Int] { - - override def addAccumulator(r: IntSQLMetricValue, t: Int): IntSQLMetricValue = r.add(t) - - override def addInPlace(r1: IntSQLMetricValue, r2: IntSQLMetricValue): IntSQLMetricValue = - r1.add(r2.value) - - override def zero(initialValue: IntSQLMetricValue): IntSQLMetricValue = zero - - override def zero: IntSQLMetricValue = new IntSQLMetricValue(0) -} - private[sql] object SQLMetrics { - def createIntMetric(sc: SparkContext, name: String): IntSQLMetric = { - val acc = new IntSQLMetric(name) - sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) - acc - } - def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = { val acc = new LongSQLMetric(name) sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala index 66237f8f1314b..28fa231e722d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala @@ -18,12 +18,6 @@ package org.apache.spark.sql /** - * :: DeveloperApi :: - * An execution engine for relational query plans that runs on top Spark and returns RDDs. - * - * Note that the operators in this package are created automatically by a query planner using a - * [[SQLContext]] and are not intended to be used directly by end users of Spark SQL. They are - * documented here in order to make it easier for others to understand the performance - * characteristics of query plans that are generated by Spark SQL. + * The physical execution component of Spark SQL. Note that this is a private package. */ package object execution diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/ui/AllExecutionsPage.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index cb7ca60b2fe48..49646a99d68c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.ui +package org.apache.spark.sql.execution.ui import javax.servlet.http.HttpServletRequest diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/ui/ExecutionPage.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index 52ddf99e9266a..f0b56c2eb7a53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.ui +package org.apache.spark.sql.execution.ui import javax.servlet.http.HttpServletRequest diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 2fd4fc658d068..0b9bad987c488 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.ui +package org.apache.spark.sql.execution.ui import scala.collection.mutable @@ -26,7 +26,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.metric.{SQLMetricParam, SQLMetricValue} +import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener with Logging { @@ -51,17 +51,14 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit private val completedExecutions = mutable.ListBuffer[SQLExecutionUIData]() - @VisibleForTesting def executionIdToData: Map[Long, SQLExecutionUIData] = synchronized { _executionIdToData.toMap } - @VisibleForTesting def jobIdToExecutionId: Map[Long, Long] = synchronized { _jobIdToExecutionId.toMap } - @VisibleForTesting def stageIdToStageMetrics: Map[Long, SQLStageMetrics] = synchronized { _stageIdToStageMetrics.toMap } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala similarity index 90% rename from sql/core/src/main/scala/org/apache/spark/sql/ui/SQLTab.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala index 3bba0afaf14eb..0b0867f67eb6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLTab.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.ui +package org.apache.spark.sql.execution.ui import java.util.concurrent.atomic.AtomicInteger @@ -38,12 +38,12 @@ private[sql] class SQLTab(sqlContext: SQLContext, sparkUI: SparkUI) private[sql] object SQLTab { - private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/ui/static" + private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/execution/ui/static" private val nextTabId = new AtomicInteger(0) private def nextTabName: String = { val nextId = nextTabId.getAndIncrement() - if (nextId == 0) "SQL" else s"SQL${nextId}" + if (nextId == 0) "SQL" else s"SQL$nextId" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 1ba50b95becc1..ae3d752dde348 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.ui +package org.apache.spark.sql.execution.ui import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.metric.{SQLMetricParam, SQLMetricValue} +import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} /** * A graph used for storing information of an executionPlan of DataFrame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala deleted file mode 100644 index cc918c237192b..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc - -import java.sql.{Connection, DriverManager} -import java.util.Properties - -import scala.util.Try - -/** - * Util functions for JDBC tables. - */ -private[sql] object JdbcUtils { - - /** - * Establishes a JDBC connection. - */ - def createConnection(url: String, connectionProperties: Properties): Connection = { - DriverManager.getConnection(url, connectionProperties) - } - - /** - * Returns true if the table already exists in the JDBC database. - */ - def tableExists(conn: Connection, table: String): Boolean = { - // Somewhat hacky, but there isn't a good way to identify whether a table exists for all - // SQL database systems, considering "table" could also include the database name. - Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess - } - - /** - * Drops a table from the JDBC database. - */ - def dropTable(conn: Connection, table: String): Unit = { - conn.prepareStatement(s"DROP TABLE $table").executeUpdate() - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala deleted file mode 100644 index 035e0510080ff..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ /dev/null @@ -1,250 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.sql.{Connection, Driver, DriverManager, DriverPropertyInfo, PreparedStatement, SQLFeatureNotSupportedException} -import java.util.Properties - -import scala.collection.mutable - -import org.apache.spark.Logging -import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils - -package object jdbc { - private[sql] object JDBCWriteDetails extends Logging { - /** - * Returns a PreparedStatement that inserts a row into table via conn. - */ - def insertStatement(conn: Connection, table: String, rddSchema: StructType): - PreparedStatement = { - val sql = new StringBuilder(s"INSERT INTO $table VALUES (") - var fieldsLeft = rddSchema.fields.length - while (fieldsLeft > 0) { - sql.append("?") - if (fieldsLeft > 1) sql.append(", ") else sql.append(")") - fieldsLeft = fieldsLeft - 1 - } - conn.prepareStatement(sql.toString) - } - - /** - * Saves a partition of a DataFrame to the JDBC database. This is done in - * a single database transaction in order to avoid repeatedly inserting - * data as much as possible. - * - * It is still theoretically possible for rows in a DataFrame to be - * inserted into the database more than once if a stage somehow fails after - * the commit occurs but before the stage can return successfully. - * - * This is not a closure inside saveTable() because apparently cosmetic - * implementation changes elsewhere might easily render such a closure - * non-Serializable. Instead, we explicitly close over all variables that - * are used. - */ - def savePartition( - getConnection: () => Connection, - table: String, - iterator: Iterator[Row], - rddSchema: StructType, - nullTypes: Array[Int]): Iterator[Byte] = { - val conn = getConnection() - var committed = false - try { - conn.setAutoCommit(false) // Everything in the same db transaction. - val stmt = insertStatement(conn, table, rddSchema) - try { - while (iterator.hasNext) { - val row = iterator.next() - val numFields = rddSchema.fields.length - var i = 0 - while (i < numFields) { - if (row.isNullAt(i)) { - stmt.setNull(i + 1, nullTypes(i)) - } else { - rddSchema.fields(i).dataType match { - case IntegerType => stmt.setInt(i + 1, row.getInt(i)) - case LongType => stmt.setLong(i + 1, row.getLong(i)) - case DoubleType => stmt.setDouble(i + 1, row.getDouble(i)) - case FloatType => stmt.setFloat(i + 1, row.getFloat(i)) - case ShortType => stmt.setInt(i + 1, row.getShort(i)) - case ByteType => stmt.setInt(i + 1, row.getByte(i)) - case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i)) - case StringType => stmt.setString(i + 1, row.getString(i)) - case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i)) - case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) - case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) - case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) - case _ => throw new IllegalArgumentException( - s"Can't translate non-null value for field $i") - } - } - i = i + 1 - } - stmt.executeUpdate() - } - } finally { - stmt.close() - } - conn.commit() - committed = true - } finally { - if (!committed) { - // The stage must fail. We got here through an exception path, so - // let the exception through unless rollback() or close() want to - // tell the user about another problem. - conn.rollback() - conn.close() - } else { - // The stage must succeed. We cannot propagate any exception close() might throw. - try { - conn.close() - } catch { - case e: Exception => logWarning("Transaction succeeded, but closing failed", e) - } - } - } - Array[Byte]().iterator - } - - /** - * Compute the schema string for this RDD. - */ - def schemaString(df: DataFrame, url: String): String = { - val sb = new StringBuilder() - val dialect = JdbcDialects.get(url) - df.schema.fields foreach { field => { - val name = field.name - val typ: String = - dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse( - field.dataType match { - case IntegerType => "INTEGER" - case LongType => "BIGINT" - case DoubleType => "DOUBLE PRECISION" - case FloatType => "REAL" - case ShortType => "INTEGER" - case ByteType => "BYTE" - case BooleanType => "BIT(1)" - case StringType => "TEXT" - case BinaryType => "BLOB" - case TimestampType => "TIMESTAMP" - case DateType => "DATE" - case t: DecimalType => s"DECIMAL(${t.precision}},${t.scale}})" - case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") - }) - val nullable = if (field.nullable) "" else "NOT NULL" - sb.append(s", $name $typ $nullable") - }} - if (sb.length < 2) "" else sb.substring(2) - } - - /** - * Saves the RDD to the database in a single transaction. - */ - def saveTable( - df: DataFrame, - url: String, - table: String, - properties: Properties = new Properties()) { - val dialect = JdbcDialects.get(url) - val nullTypes: Array[Int] = df.schema.fields.map { field => - dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse( - field.dataType match { - case IntegerType => java.sql.Types.INTEGER - case LongType => java.sql.Types.BIGINT - case DoubleType => java.sql.Types.DOUBLE - case FloatType => java.sql.Types.REAL - case ShortType => java.sql.Types.INTEGER - case ByteType => java.sql.Types.INTEGER - case BooleanType => java.sql.Types.BIT - case StringType => java.sql.Types.CLOB - case BinaryType => java.sql.Types.BLOB - case TimestampType => java.sql.Types.TIMESTAMP - case DateType => java.sql.Types.DATE - case t: DecimalType => java.sql.Types.DECIMAL - case _ => throw new IllegalArgumentException( - s"Can't translate null value for field $field") - }) - } - - val rddSchema = df.schema - val driver: String = DriverRegistry.getDriverClassName(url) - val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) - df.foreachPartition { iterator => - JDBCWriteDetails.savePartition(getConnection, table, iterator, rddSchema, nullTypes) - } - } - - } - - private [sql] class DriverWrapper(val wrapped: Driver) extends Driver { - override def acceptsURL(url: String): Boolean = wrapped.acceptsURL(url) - - override def jdbcCompliant(): Boolean = wrapped.jdbcCompliant() - - override def getPropertyInfo(url: String, info: Properties): Array[DriverPropertyInfo] = { - wrapped.getPropertyInfo(url, info) - } - - override def getMinorVersion: Int = wrapped.getMinorVersion - - def getParentLogger: java.util.logging.Logger = - throw new SQLFeatureNotSupportedException( - s"${this.getClass().getName}.getParentLogger is not yet implemented.") - - override def connect(url: String, info: Properties): Connection = wrapped.connect(url, info) - - override def getMajorVersion: Int = wrapped.getMajorVersion - } - - /** - * java.sql.DriverManager is always loaded by bootstrap classloader, - * so it can't load JDBC drivers accessible by Spark ClassLoader. - * - * To solve the problem, drivers from user-supplied jars are wrapped - * into thin wrapper. - */ - private [sql] object DriverRegistry extends Logging { - - private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty - - def register(className: String): Unit = { - val cls = Utils.getContextOrSparkClassLoader.loadClass(className) - if (cls.getClassLoader == null) { - logTrace(s"$className has been loaded with bootstrap ClassLoader, wrapper is not required") - } else if (wrapperMap.get(className).isDefined) { - logTrace(s"Wrapper for $className already exists") - } else { - synchronized { - if (wrapperMap.get(className).isEmpty) { - val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver]) - DriverManager.registerDriver(wrapper) - wrapperMap(className) = wrapper - logTrace(s"Wrapper for $className registered") - } - } - } - } - - def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { - case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName - case driver => driver.getClass.getCanonicalName - } - } - -} // package object jdbc diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 6bcabbab4f77b..2f8417a48d32e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -43,19 +43,24 @@ import org.apache.spark.util.SerializableConfiguration * This allows users to give the data source alias as the format type over the fully qualified * class name. * - * ex: parquet.DefaultSource.format = "parquet". - * * A new instance of this class with be instantiated each time a DDL call is made. + * + * @since 1.5.0 */ @DeveloperApi trait DataSourceRegister { /** * The string that represents the format that this data source provider uses. This is - * overridden by children to provide a nice alias for the data source, - * ex: override def format(): String = "parquet" + * overridden by children to provide a nice alias for the data source. For example: + * + * {{{ + * override def format(): String = "parquet" + * }}} + * + * @since 1.5.0 */ - def format(): String + def shortName(): String } /** diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/CompatibilityTest.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java similarity index 93% rename from sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/CompatibilityTest.java rename to sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java index daec65a5bbe57..70dec1a9d3c92 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/CompatibilityTest.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java @@ -3,7 +3,7 @@ * * DO NOT EDIT DIRECTLY */ -package org.apache.spark.sql.parquet.test.avro; +package org.apache.spark.sql.execution.datasources.parquet.test.avro; @SuppressWarnings("all") @org.apache.avro.specific.AvroGenerated @@ -12,6 +12,6 @@ public interface CompatibilityTest { @SuppressWarnings("all") public interface Callback extends CompatibilityTest { - public static final org.apache.avro.Protocol PROTOCOL = org.apache.spark.sql.parquet.test.avro.CompatibilityTest.PROTOCOL; + public static final org.apache.avro.Protocol PROTOCOL = org.apache.spark.sql.execution.datasources.parquet.test.avro.CompatibilityTest.PROTOCOL; } } \ No newline at end of file diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/Nested.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java similarity index 78% rename from sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/Nested.java rename to sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java index 051f1ee903863..a0a406bcd10c1 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/Nested.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java @@ -3,7 +3,7 @@ * * DO NOT EDIT DIRECTLY */ -package org.apache.spark.sql.parquet.test.avro; +package org.apache.spark.sql.execution.datasources.parquet.test.avro; @SuppressWarnings("all") @org.apache.avro.specific.AvroGenerated public class Nested extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { @@ -77,18 +77,18 @@ public void setNestedStringColumn(java.lang.String value) { } /** Creates a new Nested RecordBuilder */ - public static org.apache.spark.sql.parquet.test.avro.Nested.Builder newBuilder() { - return new org.apache.spark.sql.parquet.test.avro.Nested.Builder(); + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder(); } /** Creates a new Nested RecordBuilder by copying an existing Builder */ - public static org.apache.spark.sql.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.Nested.Builder other) { - return new org.apache.spark.sql.parquet.test.avro.Nested.Builder(other); + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder(other); } /** Creates a new Nested RecordBuilder by copying an existing Nested instance */ - public static org.apache.spark.sql.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.Nested other) { - return new org.apache.spark.sql.parquet.test.avro.Nested.Builder(other); + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder(other); } /** @@ -102,11 +102,11 @@ public static class Builder extends org.apache.avro.specific.SpecificRecordBuild /** Creates a new Builder */ private Builder() { - super(org.apache.spark.sql.parquet.test.avro.Nested.SCHEMA$); + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.SCHEMA$); } /** Creates a Builder by copying an existing Builder */ - private Builder(org.apache.spark.sql.parquet.test.avro.Nested.Builder other) { + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder other) { super(other); if (isValidValue(fields()[0], other.nested_ints_column)) { this.nested_ints_column = data().deepCopy(fields()[0].schema(), other.nested_ints_column); @@ -119,8 +119,8 @@ private Builder(org.apache.spark.sql.parquet.test.avro.Nested.Builder other) { } /** Creates a Builder by copying an existing Nested instance */ - private Builder(org.apache.spark.sql.parquet.test.avro.Nested other) { - super(org.apache.spark.sql.parquet.test.avro.Nested.SCHEMA$); + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.SCHEMA$); if (isValidValue(fields()[0], other.nested_ints_column)) { this.nested_ints_column = data().deepCopy(fields()[0].schema(), other.nested_ints_column); fieldSetFlags()[0] = true; @@ -137,7 +137,7 @@ public java.util.List getNestedIntsColumn() { } /** Sets the value of the 'nested_ints_column' field */ - public org.apache.spark.sql.parquet.test.avro.Nested.Builder setNestedIntsColumn(java.util.List value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder setNestedIntsColumn(java.util.List value) { validate(fields()[0], value); this.nested_ints_column = value; fieldSetFlags()[0] = true; @@ -150,7 +150,7 @@ public boolean hasNestedIntsColumn() { } /** Clears the value of the 'nested_ints_column' field */ - public org.apache.spark.sql.parquet.test.avro.Nested.Builder clearNestedIntsColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder clearNestedIntsColumn() { nested_ints_column = null; fieldSetFlags()[0] = false; return this; @@ -162,7 +162,7 @@ public java.lang.String getNestedStringColumn() { } /** Sets the value of the 'nested_string_column' field */ - public org.apache.spark.sql.parquet.test.avro.Nested.Builder setNestedStringColumn(java.lang.String value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder setNestedStringColumn(java.lang.String value) { validate(fields()[1], value); this.nested_string_column = value; fieldSetFlags()[1] = true; @@ -175,7 +175,7 @@ public boolean hasNestedStringColumn() { } /** Clears the value of the 'nested_string_column' field */ - public org.apache.spark.sql.parquet.test.avro.Nested.Builder clearNestedStringColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder clearNestedStringColumn() { nested_string_column = null; fieldSetFlags()[1] = false; return this; diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/ParquetAvroCompat.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java similarity index 83% rename from sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/ParquetAvroCompat.java rename to sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java index 354c9d73cca31..6198b00b1e3ca 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/ParquetAvroCompat.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java @@ -3,7 +3,7 @@ * * DO NOT EDIT DIRECTLY */ -package org.apache.spark.sql.parquet.test.avro; +package org.apache.spark.sql.execution.datasources.parquet.test.avro; @SuppressWarnings("all") @org.apache.avro.specific.AvroGenerated public class ParquetAvroCompat extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { @@ -25,7 +25,7 @@ public class ParquetAvroCompat extends org.apache.avro.specific.SpecificRecordBa @Deprecated public java.lang.String maybe_string_column; @Deprecated public java.util.List strings_column; @Deprecated public java.util.Map string_to_int_column; - @Deprecated public java.util.Map> complex_column; + @Deprecated public java.util.Map> complex_column; /** * Default constructor. Note that this does not initialize fields @@ -37,7 +37,7 @@ public ParquetAvroCompat() {} /** * All-args constructor. */ - public ParquetAvroCompat(java.lang.Boolean bool_column, java.lang.Integer int_column, java.lang.Long long_column, java.lang.Float float_column, java.lang.Double double_column, java.nio.ByteBuffer binary_column, java.lang.String string_column, java.lang.Boolean maybe_bool_column, java.lang.Integer maybe_int_column, java.lang.Long maybe_long_column, java.lang.Float maybe_float_column, java.lang.Double maybe_double_column, java.nio.ByteBuffer maybe_binary_column, java.lang.String maybe_string_column, java.util.List strings_column, java.util.Map string_to_int_column, java.util.Map> complex_column) { + public ParquetAvroCompat(java.lang.Boolean bool_column, java.lang.Integer int_column, java.lang.Long long_column, java.lang.Float float_column, java.lang.Double double_column, java.nio.ByteBuffer binary_column, java.lang.String string_column, java.lang.Boolean maybe_bool_column, java.lang.Integer maybe_int_column, java.lang.Long maybe_long_column, java.lang.Float maybe_float_column, java.lang.Double maybe_double_column, java.nio.ByteBuffer maybe_binary_column, java.lang.String maybe_string_column, java.util.List strings_column, java.util.Map string_to_int_column, java.util.Map> complex_column) { this.bool_column = bool_column; this.int_column = int_column; this.long_column = long_column; @@ -101,7 +101,7 @@ public void put(int field$, java.lang.Object value$) { case 13: maybe_string_column = (java.lang.String)value$; break; case 14: strings_column = (java.util.List)value$; break; case 15: string_to_int_column = (java.util.Map)value$; break; - case 16: complex_column = (java.util.Map>)value$; break; + case 16: complex_column = (java.util.Map>)value$; break; default: throw new org.apache.avro.AvroRuntimeException("Bad index"); } } @@ -349,7 +349,7 @@ public void setStringToIntColumn(java.util.Map> getComplexColumn() { + public java.util.Map> getComplexColumn() { return complex_column; } @@ -357,23 +357,23 @@ public java.util.Map> value) { + public void setComplexColumn(java.util.Map> value) { this.complex_column = value; } /** Creates a new ParquetAvroCompat RecordBuilder */ - public static org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder newBuilder() { - return new org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder(); + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder(); } /** Creates a new ParquetAvroCompat RecordBuilder by copying an existing Builder */ - public static org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder other) { - return new org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder(other); + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder(other); } /** Creates a new ParquetAvroCompat RecordBuilder by copying an existing ParquetAvroCompat instance */ - public static org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat other) { - return new org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder(other); + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder(other); } /** @@ -398,15 +398,15 @@ public static class Builder extends org.apache.avro.specific.SpecificRecordBuild private java.lang.String maybe_string_column; private java.util.List strings_column; private java.util.Map string_to_int_column; - private java.util.Map> complex_column; + private java.util.Map> complex_column; /** Creates a new Builder */ private Builder() { - super(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.SCHEMA$); + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.SCHEMA$); } /** Creates a Builder by copying an existing Builder */ - private Builder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder other) { + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder other) { super(other); if (isValidValue(fields()[0], other.bool_column)) { this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); @@ -479,8 +479,8 @@ private Builder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder } /** Creates a Builder by copying an existing ParquetAvroCompat instance */ - private Builder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat other) { - super(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.SCHEMA$); + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.SCHEMA$); if (isValidValue(fields()[0], other.bool_column)) { this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); fieldSetFlags()[0] = true; @@ -557,7 +557,7 @@ public java.lang.Boolean getBoolColumn() { } /** Sets the value of the 'bool_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setBoolColumn(boolean value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setBoolColumn(boolean value) { validate(fields()[0], value); this.bool_column = value; fieldSetFlags()[0] = true; @@ -570,7 +570,7 @@ public boolean hasBoolColumn() { } /** Clears the value of the 'bool_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearBoolColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearBoolColumn() { fieldSetFlags()[0] = false; return this; } @@ -581,7 +581,7 @@ public java.lang.Integer getIntColumn() { } /** Sets the value of the 'int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setIntColumn(int value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setIntColumn(int value) { validate(fields()[1], value); this.int_column = value; fieldSetFlags()[1] = true; @@ -594,7 +594,7 @@ public boolean hasIntColumn() { } /** Clears the value of the 'int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearIntColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearIntColumn() { fieldSetFlags()[1] = false; return this; } @@ -605,7 +605,7 @@ public java.lang.Long getLongColumn() { } /** Sets the value of the 'long_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setLongColumn(long value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setLongColumn(long value) { validate(fields()[2], value); this.long_column = value; fieldSetFlags()[2] = true; @@ -618,7 +618,7 @@ public boolean hasLongColumn() { } /** Clears the value of the 'long_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearLongColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearLongColumn() { fieldSetFlags()[2] = false; return this; } @@ -629,7 +629,7 @@ public java.lang.Float getFloatColumn() { } /** Sets the value of the 'float_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setFloatColumn(float value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setFloatColumn(float value) { validate(fields()[3], value); this.float_column = value; fieldSetFlags()[3] = true; @@ -642,7 +642,7 @@ public boolean hasFloatColumn() { } /** Clears the value of the 'float_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearFloatColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearFloatColumn() { fieldSetFlags()[3] = false; return this; } @@ -653,7 +653,7 @@ public java.lang.Double getDoubleColumn() { } /** Sets the value of the 'double_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setDoubleColumn(double value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setDoubleColumn(double value) { validate(fields()[4], value); this.double_column = value; fieldSetFlags()[4] = true; @@ -666,7 +666,7 @@ public boolean hasDoubleColumn() { } /** Clears the value of the 'double_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearDoubleColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearDoubleColumn() { fieldSetFlags()[4] = false; return this; } @@ -677,7 +677,7 @@ public java.nio.ByteBuffer getBinaryColumn() { } /** Sets the value of the 'binary_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setBinaryColumn(java.nio.ByteBuffer value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setBinaryColumn(java.nio.ByteBuffer value) { validate(fields()[5], value); this.binary_column = value; fieldSetFlags()[5] = true; @@ -690,7 +690,7 @@ public boolean hasBinaryColumn() { } /** Clears the value of the 'binary_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearBinaryColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearBinaryColumn() { binary_column = null; fieldSetFlags()[5] = false; return this; @@ -702,7 +702,7 @@ public java.lang.String getStringColumn() { } /** Sets the value of the 'string_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setStringColumn(java.lang.String value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setStringColumn(java.lang.String value) { validate(fields()[6], value); this.string_column = value; fieldSetFlags()[6] = true; @@ -715,7 +715,7 @@ public boolean hasStringColumn() { } /** Clears the value of the 'string_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearStringColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearStringColumn() { string_column = null; fieldSetFlags()[6] = false; return this; @@ -727,7 +727,7 @@ public java.lang.Boolean getMaybeBoolColumn() { } /** Sets the value of the 'maybe_bool_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeBoolColumn(java.lang.Boolean value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeBoolColumn(java.lang.Boolean value) { validate(fields()[7], value); this.maybe_bool_column = value; fieldSetFlags()[7] = true; @@ -740,7 +740,7 @@ public boolean hasMaybeBoolColumn() { } /** Clears the value of the 'maybe_bool_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeBoolColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeBoolColumn() { maybe_bool_column = null; fieldSetFlags()[7] = false; return this; @@ -752,7 +752,7 @@ public java.lang.Integer getMaybeIntColumn() { } /** Sets the value of the 'maybe_int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeIntColumn(java.lang.Integer value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeIntColumn(java.lang.Integer value) { validate(fields()[8], value); this.maybe_int_column = value; fieldSetFlags()[8] = true; @@ -765,7 +765,7 @@ public boolean hasMaybeIntColumn() { } /** Clears the value of the 'maybe_int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeIntColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeIntColumn() { maybe_int_column = null; fieldSetFlags()[8] = false; return this; @@ -777,7 +777,7 @@ public java.lang.Long getMaybeLongColumn() { } /** Sets the value of the 'maybe_long_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeLongColumn(java.lang.Long value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeLongColumn(java.lang.Long value) { validate(fields()[9], value); this.maybe_long_column = value; fieldSetFlags()[9] = true; @@ -790,7 +790,7 @@ public boolean hasMaybeLongColumn() { } /** Clears the value of the 'maybe_long_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeLongColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeLongColumn() { maybe_long_column = null; fieldSetFlags()[9] = false; return this; @@ -802,7 +802,7 @@ public java.lang.Float getMaybeFloatColumn() { } /** Sets the value of the 'maybe_float_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeFloatColumn(java.lang.Float value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeFloatColumn(java.lang.Float value) { validate(fields()[10], value); this.maybe_float_column = value; fieldSetFlags()[10] = true; @@ -815,7 +815,7 @@ public boolean hasMaybeFloatColumn() { } /** Clears the value of the 'maybe_float_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeFloatColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeFloatColumn() { maybe_float_column = null; fieldSetFlags()[10] = false; return this; @@ -827,7 +827,7 @@ public java.lang.Double getMaybeDoubleColumn() { } /** Sets the value of the 'maybe_double_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeDoubleColumn(java.lang.Double value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeDoubleColumn(java.lang.Double value) { validate(fields()[11], value); this.maybe_double_column = value; fieldSetFlags()[11] = true; @@ -840,7 +840,7 @@ public boolean hasMaybeDoubleColumn() { } /** Clears the value of the 'maybe_double_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeDoubleColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeDoubleColumn() { maybe_double_column = null; fieldSetFlags()[11] = false; return this; @@ -852,7 +852,7 @@ public java.nio.ByteBuffer getMaybeBinaryColumn() { } /** Sets the value of the 'maybe_binary_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeBinaryColumn(java.nio.ByteBuffer value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeBinaryColumn(java.nio.ByteBuffer value) { validate(fields()[12], value); this.maybe_binary_column = value; fieldSetFlags()[12] = true; @@ -865,7 +865,7 @@ public boolean hasMaybeBinaryColumn() { } /** Clears the value of the 'maybe_binary_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeBinaryColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeBinaryColumn() { maybe_binary_column = null; fieldSetFlags()[12] = false; return this; @@ -877,7 +877,7 @@ public java.lang.String getMaybeStringColumn() { } /** Sets the value of the 'maybe_string_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeStringColumn(java.lang.String value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeStringColumn(java.lang.String value) { validate(fields()[13], value); this.maybe_string_column = value; fieldSetFlags()[13] = true; @@ -890,7 +890,7 @@ public boolean hasMaybeStringColumn() { } /** Clears the value of the 'maybe_string_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeStringColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeStringColumn() { maybe_string_column = null; fieldSetFlags()[13] = false; return this; @@ -902,7 +902,7 @@ public java.util.List getStringsColumn() { } /** Sets the value of the 'strings_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setStringsColumn(java.util.List value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setStringsColumn(java.util.List value) { validate(fields()[14], value); this.strings_column = value; fieldSetFlags()[14] = true; @@ -915,7 +915,7 @@ public boolean hasStringsColumn() { } /** Clears the value of the 'strings_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearStringsColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearStringsColumn() { strings_column = null; fieldSetFlags()[14] = false; return this; @@ -927,7 +927,7 @@ public java.util.Map getStringToIntColumn() } /** Sets the value of the 'string_to_int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setStringToIntColumn(java.util.Map value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setStringToIntColumn(java.util.Map value) { validate(fields()[15], value); this.string_to_int_column = value; fieldSetFlags()[15] = true; @@ -940,19 +940,19 @@ public boolean hasStringToIntColumn() { } /** Clears the value of the 'string_to_int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearStringToIntColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearStringToIntColumn() { string_to_int_column = null; fieldSetFlags()[15] = false; return this; } /** Gets the value of the 'complex_column' field */ - public java.util.Map> getComplexColumn() { + public java.util.Map> getComplexColumn() { return complex_column; } /** Sets the value of the 'complex_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setComplexColumn(java.util.Map> value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setComplexColumn(java.util.Map> value) { validate(fields()[16], value); this.complex_column = value; fieldSetFlags()[16] = true; @@ -965,7 +965,7 @@ public boolean hasComplexColumn() { } /** Clears the value of the 'complex_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearComplexColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearComplexColumn() { complex_column = null; fieldSetFlags()[16] = false; return this; @@ -991,7 +991,7 @@ public ParquetAvroCompat build() { record.maybe_string_column = fieldSetFlags()[13] ? this.maybe_string_column : (java.lang.String) defaultValue(fields()[13]); record.strings_column = fieldSetFlags()[14] ? this.strings_column : (java.util.List) defaultValue(fields()[14]); record.string_to_int_column = fieldSetFlags()[15] ? this.string_to_int_column : (java.util.Map) defaultValue(fields()[15]); - record.complex_column = fieldSetFlags()[16] ? this.complex_column : (java.util.Map>) defaultValue(fields()[16]); + record.complex_column = fieldSetFlags()[16] ? this.complex_column : (java.util.Map>) defaultValue(fields()[16]); return record; } catch (Exception e) { throw new org.apache.avro.AvroRuntimeException(e); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c49f256be5501..adbd95197d7ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -25,8 +25,8 @@ import scala.util.Random import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.functions._ -import org.apache.spark.sql.json.JSONRelation -import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.execution.datasources.json.JSONRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 92022ff23d2c3..73d5621897819 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.sql.{Date, Timestamp} @@ -28,7 +28,7 @@ import org.apache.spark.sql.{SQLContext, QueryTest, Row, SQLConf} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} -import org.apache.spark.sql.json.InferSchema.compatibleType +import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType import org.apache.spark.sql.types._ import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index 369df5653060b..6b62c9a003df6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala index bfa427349ff6a..4d9c07bb7a570 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.nio.ByteBuffer import java.util.{List => JList, Map => JMap} @@ -25,7 +25,7 @@ import scala.collection.JavaConversions._ import org.apache.hadoop.fs.Path import org.apache.parquet.avro.AvroParquetWriter -import org.apache.spark.sql.parquet.test.avro.{Nested, ParquetAvroCompat} +import org.apache.spark.sql.execution.datasources.parquet.test.avro.{Nested, ParquetAvroCompat} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.{Row, SQLContext} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala similarity index 97% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index 57478931cd509..68f35b1f3aa83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.File import scala.collection.JavaConversions._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index b6a7c4fbddbdc..7dd9680d8cd65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import org.apache.parquet.filter2.predicate.Operators._ import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index b415da5b8c136..ee925afe08508 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import scala.collection.JavaConversions._ import scala.reflect.ClassTag @@ -373,7 +373,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest { // _temporary should be missing if direct output committer works. try { configuration.set("spark.sql.parquet.output.committer.class", - "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") + classOf[DirectParquetOutputCommitter].getCanonicalName) sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => intercept[org.apache.spark.SparkException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 2eef10189f11c..73152de244759 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.File import java.math.BigInteger diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 5c65a8ec57f00..5e6d9c1cd44a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.File diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 4a0b3b60f419d..8f06de7ce7c4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 64e94056f209a..3c6e54db4bca7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.File diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala index 1c532d78790d2..92b1d822172d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.{Row, SQLContext} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala similarity index 92% rename from sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index d22160f5384f4..953284c98b208 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.metric +package org.apache.spark.sql.execution.metric import java.io.{ByteArrayInputStream, ByteArrayOutputStream} @@ -41,16 +41,6 @@ class SQLMetricsSuite extends SparkFunSuite { } } - test("IntSQLMetric should not box Int") { - val l = SQLMetrics.createIntMetric(TestSQLContext.sparkContext, "Int") - val f = () => { l += 1 } - BoxingFinder.getClassReader(f.getClass).foreach { cl => - val boxingFinder = new BoxingFinder() - cl.accept(boxingFinder, 0) - assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}") - } - } - test("Normal accumulator should do boxing") { // We need this test to make sure BoxingFinder works. val l = TestSQLContext.sparkContext.accumulator(0L) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 69a561e16aa17..41dd1896c15df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.ui +package org.apache.spark.sql.execution.ui import java.util.Properties import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite} import org.apache.spark.executor.TaskMetrics -import org.apache.spark.sql.metric.LongSQLMetricValue +import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.execution.SQLExecution diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 1907e643c85dd..562c279067048 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -51,7 +51,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -75,7 +75,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -92,7 +92,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -107,7 +107,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -122,7 +122,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -139,7 +139,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -158,7 +158,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -175,7 +175,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable (a int, b string) - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -188,7 +188,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -199,7 +199,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala index 1a4d41b02ca68..392da0b0826b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala @@ -20,9 +20,37 @@ package org.apache.spark.sql.sources import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types.{StringType, StructField, StructType} + +// please note that the META-INF/services had to be modified for the test directory for this to work +class DDLSourceLoadSuite extends DataSourceTest { + + test("data sources with the same name") { + intercept[RuntimeException] { + caseInsensitiveContext.read.format("Fluet da Bomb").load() + } + } + + test("load data source from format alias") { + caseInsensitiveContext.read.format("gathering quorum").load().schema == + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } + + test("specify full classname with duplicate formats") { + caseInsensitiveContext.read.format("org.apache.spark.sql.sources.FakeSourceOne") + .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false))) + } + + test("should fail to load ORC without HiveContext") { + intercept[ClassNotFoundException] { + caseInsensitiveContext.read.format("orc").load() + } + } +} + + class FakeSourceOne extends RelationProvider with DataSourceRegister { - def format(): String = "Fluet da Bomb" + def shortName(): String = "Fluet da Bomb" override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = new BaseRelation { @@ -35,7 +63,7 @@ class FakeSourceOne extends RelationProvider with DataSourceRegister { class FakeSourceTwo extends RelationProvider with DataSourceRegister { - def format(): String = "Fluet da Bomb" + def shortName(): String = "Fluet da Bomb" override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = new BaseRelation { @@ -48,7 +76,7 @@ class FakeSourceTwo extends RelationProvider with DataSourceRegister { class FakeSourceThree extends RelationProvider with DataSourceRegister { - def format(): String = "gathering quorum" + def shortName(): String = "gathering quorum" override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = new BaseRelation { @@ -58,28 +86,3 @@ class FakeSourceThree extends RelationProvider with DataSourceRegister { StructType(Seq(StructField("stringType", StringType, nullable = false))) } } -// please note that the META-INF/services had to be modified for the test directory for this to work -class DDLSourceLoadSuite extends DataSourceTest { - - test("data sources with the same name") { - intercept[RuntimeException] { - caseInsensitiveContext.read.format("Fluet da Bomb").load() - } - } - - test("load data source from format alias") { - caseInsensitiveContext.read.format("gathering quorum").load().schema == - StructType(Seq(StructField("stringType", StringType, nullable = false))) - } - - test("specify full classname with duplicate formats") { - caseInsensitiveContext.read.format("org.apache.spark.sql.sources.FakeSourceOne") - .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false))) - } - - test("Loading Orc") { - intercept[ClassNotFoundException] { - caseInsensitiveContext.read.format("orc").load() - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 3cbf5467b253a..27d1cd92fca1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -22,14 +22,39 @@ import org.apache.spark.sql.execution.datasources.ResolvedDataSource class ResolvedDataSourceSuite extends SparkFunSuite { - test("builtin sources") { - assert(ResolvedDataSource.lookupDataSource("jdbc") === - classOf[org.apache.spark.sql.jdbc.DefaultSource]) + test("jdbc") { + assert( + ResolvedDataSource.lookupDataSource("jdbc") === + classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.jdbc") === + classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.jdbc") === + classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + } - assert(ResolvedDataSource.lookupDataSource("json") === - classOf[org.apache.spark.sql.json.DefaultSource]) + test("json") { + assert( + ResolvedDataSource.lookupDataSource("json") === + classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.json") === + classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.json") === + classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + } - assert(ResolvedDataSource.lookupDataSource("parquet") === - classOf[org.apache.spark.sql.parquet.DefaultSource]) + test("parquet") { + assert( + ResolvedDataSource.lookupDataSource("parquet") === + classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.parquet") === + classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.parquet") === + classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 7198a32df4a02..ac9aaed19d566 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} import org.apache.spark.sql.execution.datasources import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 0c344c63fde3f..9f4f8b5789afe 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -32,7 +32,6 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.Logging -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow @@ -49,9 +48,9 @@ import scala.collection.JavaConversions._ private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { - def format(): String = "orc" + override def shortName(): String = "orc" - def createRelation( + override def createRelation( sqlContext: SQLContext, paths: Array[String], dataSchema: Option[StructType], @@ -144,7 +143,6 @@ private[orc] class OrcOutputWriter( } } -@DeveloperApi private[sql] class OrcRelation( override val paths: Array[String], maybeDataSchema: Option[StructType], diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index a45c2d957278f..1fa005d5f9a15 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.parquet.ParquetTest +import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.{QueryTest, Row} case class Cases(lower: String, UPPER: String) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index b73d6665755d0..7f36a483a3965 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index f00d3754c364a..80eb9f122ad90 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.conf.HiveConf import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.parquet.ParquetCompatibilityTest +import org.apache.spark.sql.execution.datasources.parquet.ParquetCompatibilityTest import org.apache.spark.sql.{Row, SQLConf, SQLContext} class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 2fa7ae3fa2e12..79a136ae6f619 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} -import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index c4bc60086f6e1..50f02432dacce 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index d280543a071d9..cb4cedddbfddd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -23,12 +23,12 @@ import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.{AnalysisException, SaveMode, parquet} +import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { - override val dataSourceName: String = classOf[parquet.DefaultSource].getCanonicalName + override val dataSourceName: String = "parquet" import sqlContext._ import sqlContext.implicits._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index 1813cc33226d1..48c37a1fa1022 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -53,7 +53,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = - classOf[org.apache.spark.sql.json.DefaultSource].getCanonicalName + classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource].getCanonicalName import sqlContext._ From fe2fb7fb7189d183a4273ad27514af4b6b461f26 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 10 Aug 2015 13:52:18 -0700 Subject: [PATCH 252/340] [SPARK-9620] [SQL] generated UnsafeProjection should support many columns or large exressions Currently, generated UnsafeProjection can reach 64k byte code limit of Java. This patch will split the generated expressions into multiple functions, to avoid the limitation. After this patch, we can work well with table that have up to 64k columns (hit max number of constants limit in Java), it should be enough in practice. cc rxin Author: Davies Liu Closes #8044 from davies/wider_table and squashes the following commits: 9192e6c [Davies Liu] fix generated safe projection d1ef81a [Davies Liu] fix failed tests 737b3d3 [Davies Liu] Merge branch 'master' of github.com:apache/spark into wider_table ffcd132 [Davies Liu] address comments 1b95be4 [Davies Liu] put the generated class into sql package 77ed72d [Davies Liu] address comments 4518e17 [Davies Liu] Merge branch 'master' of github.com:apache/spark into wider_table 75ccd01 [Davies Liu] Merge branch 'master' of github.com:apache/spark into wider_table 495e932 [Davies Liu] support wider table with more than 1k columns for generated projections --- .../expressions/codegen/CodeGenerator.scala | 48 ++++++- .../codegen/GenerateMutableProjection.scala | 43 +----- .../codegen/GenerateSafeProjection.scala | 52 ++------ .../codegen/GenerateUnsafeProjection.scala | 122 +++++++++--------- .../codegen/GenerateUnsafeRowJoiner.scala | 2 +- .../codegen/GeneratedProjectionSuite.scala | 82 ++++++++++++ 6 files changed, 207 insertions(+), 142 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 7b41c9a3f3b8e..c21f4d626a74e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import scala.language.existentials import com.google.common.cache.{CacheBuilder, CacheLoader} @@ -265,6 +266,45 @@ class CodeGenContext { def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt) def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) + + /** + * Splits the generated code of expressions into multiple functions, because function has + * 64kb code size limit in JVM + * + * @param row the variable name of row that is used by expressions + */ + def splitExpressions(row: String, expressions: Seq[String]): String = { + val blocks = new ArrayBuffer[String]() + val blockBuilder = new StringBuilder() + for (code <- expressions) { + // We can't know how many byte code will be generated, so use the number of bytes as limit + if (blockBuilder.length > 64 * 1000) { + blocks.append(blockBuilder.toString()) + blockBuilder.clear() + } + blockBuilder.append(code) + } + blocks.append(blockBuilder.toString()) + + if (blocks.length == 1) { + // inline execution if only one block + blocks.head + } else { + val apply = freshName("apply") + val functions = blocks.zipWithIndex.map { case (body, i) => + val name = s"${apply}_$i" + val code = s""" + |private void $name(InternalRow $row) { + | $body + |} + """.stripMargin + addNewFunction(name, code) + name + } + + functions.map(name => s"$name($row);").mkString("\n") + } + } } /** @@ -289,15 +329,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected def declareMutableStates(ctx: CodeGenContext): String = { ctx.mutableStates.map { case (javaType, variableName, _) => s"private $javaType $variableName;" - }.mkString + }.mkString("\n") } protected def initMutableStates(ctx: CodeGenContext): String = { - ctx.mutableStates.map(_._3).mkString + ctx.mutableStates.map(_._3).mkString("\n") } protected def declareAddedFunctions(ctx: CodeGenContext): String = { - ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString + ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString("\n") } /** @@ -328,6 +368,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin private[this] def doCompile(code: String): GeneratedClass = { val evaluator = new ClassBodyEvaluator() evaluator.setParentClassLoader(getClass.getClassLoader) + // Cannot be under package codegen, or fail with java.lang.InstantiationException + evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass") evaluator.setDefaultImports(Array( classOf[PlatformDependent].getName, classOf[InternalRow].getName, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index ac58423cd884d..b4d4df8934bd4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -40,7 +40,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { val ctx = newCodeGenContext() - val projectionCode = expressions.zipWithIndex.map { + val projectionCodes = expressions.zipWithIndex.map { case (NoOp, _) => "" case (e, i) => val evaluationCode = e.gen(ctx) @@ -65,49 +65,21 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu """ } } - // collect projections into blocks as function has 64kb codesize limit in JVM - val projectionBlocks = new ArrayBuffer[String]() - val blockBuilder = new StringBuilder() - for (projection <- projectionCode) { - if (blockBuilder.length > 16 * 1000) { - projectionBlocks.append(blockBuilder.toString()) - blockBuilder.clear() - } - blockBuilder.append(projection) - } - projectionBlocks.append(blockBuilder.toString()) - - val (projectionFuns, projectionCalls) = { - // inline execution if codesize limit was not broken - if (projectionBlocks.length == 1) { - ("", projectionBlocks.head) - } else { - ( - projectionBlocks.zipWithIndex.map { case (body, i) => - s""" - |private void apply$i(InternalRow i) { - | $body - |} - """.stripMargin - }.mkString, - projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n") - ) - } - } + val allProjections = ctx.splitExpressions("i", projectionCodes) val code = s""" public Object generate($exprType[] expr) { - return new SpecificProjection(expr); + return new SpecificMutableProjection(expr); } - class SpecificProjection extends ${classOf[BaseMutableProjection].getName} { + class SpecificMutableProjection extends ${classOf[BaseMutableProjection].getName} { private $exprType[] expressions; private $mutableRowType mutableRow; ${declareMutableStates(ctx)} ${declareAddedFunctions(ctx)} - public SpecificProjection($exprType[] expr) { + public SpecificMutableProjection($exprType[] expr) { expressions = expr; mutableRow = new $genericMutableRowType(${expressions.size}); ${initMutableStates(ctx)} @@ -123,12 +95,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu return (InternalRow) mutableRow; } - $projectionFuns - public Object apply(Object _i) { InternalRow i = (InternalRow) _i; - $projectionCalls - + $allProjections return mutableRow; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index ef08ddf041afc..7ad352d7ce3e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.types._ @@ -43,6 +41,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val tmp = ctx.freshName("tmp") val output = ctx.freshName("safeRow") val values = ctx.freshName("values") + // These expressions could be splitted into multiple functions + ctx.addMutableState("Object[]", values, s"this.$values = null;") + val rowClass = classOf[GenericInternalRow].getName val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => @@ -53,12 +54,12 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] $values[$i] = ${converter.primitive}; } """ - }.mkString("\n") - + } + val allFields = ctx.splitExpressions(tmp, fieldWriters) val code = s""" final InternalRow $tmp = $input; - final Object[] $values = new Object[${schema.length}]; - $fieldWriters + this.$values = new Object[${schema.length}]; + $allFields final InternalRow $output = new $rowClass($values); """ @@ -128,7 +129,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] protected def create(expressions: Seq[Expression]): Projection = { val ctx = newCodeGenContext() - val projectionCode = expressions.zipWithIndex.map { + val expressionCodes = expressions.zipWithIndex.map { case (NoOp, _) => "" case (e, i) => val evaluationCode = e.gen(ctx) @@ -143,36 +144,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] } """ } - // collect projections into blocks as function has 64kb codesize limit in JVM - val projectionBlocks = new ArrayBuffer[String]() - val blockBuilder = new StringBuilder() - for (projection <- projectionCode) { - if (blockBuilder.length > 16 * 1000) { - projectionBlocks.append(blockBuilder.toString()) - blockBuilder.clear() - } - blockBuilder.append(projection) - } - projectionBlocks.append(blockBuilder.toString()) - - val (projectionFuns, projectionCalls) = { - // inline it if we have only one block - if (projectionBlocks.length == 1) { - ("", projectionBlocks.head) - } else { - ( - projectionBlocks.zipWithIndex.map { case (body, i) => - s""" - |private void apply$i(InternalRow i) { - | $body - |} - """.stripMargin - }.mkString, - projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n") - ) - } - } - + val allExpressions = ctx.splitExpressions("i", expressionCodes) val code = s""" public Object generate($exprType[] expr) { return new SpecificSafeProjection(expr); @@ -183,6 +155,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] private $exprType[] expressions; private $mutableRowType mutableRow; ${declareMutableStates(ctx)} + ${declareAddedFunctions(ctx)} public SpecificSafeProjection($exprType[] expr) { expressions = expr; @@ -190,12 +163,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ${initMutableStates(ctx)} } - $projectionFuns - public Object apply(Object _i) { InternalRow i = (InternalRow) _i; - $projectionCalls - + $allExpressions return mutableRow; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index d8912df694a10..29f6a7b981752 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent /** * Generates a [[Projection]] that returns an [[UnsafeRow]]. @@ -41,8 +40,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private val ArrayWriter = classOf[UnsafeRowWriters.ArrayWriter].getName private val MapWriter = classOf[UnsafeRowWriters.MapWriter].getName - private val PlatformDependent = classOf[PlatformDependent].getName - /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = dataType match { case NullType => true @@ -56,19 +53,19 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match { case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => - s" + $DecimalWriter.getSize(${ev.primitive})" + s"$DecimalWriter.getSize(${ev.primitive})" case StringType => - s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))" + s"${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive})" case BinaryType => - s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))" + s"${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive})" case CalendarIntervalType => - s" + (${ev.isNull} ? 0 : 16)" + s"${ev.isNull} ? 0 : 16" case _: StructType => - s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))" + s"${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive})" case _: ArrayType => - s" + (${ev.isNull} ? 0 : $ArrayWriter.getSize(${ev.primitive}))" + s"${ev.isNull} ? 0 : $ArrayWriter.getSize(${ev.primitive})" case _: MapType => - s" + (${ev.isNull} ? 0 : $MapWriter.getSize(${ev.primitive}))" + s"${ev.isNull} ? 0 : $MapWriter.getSize(${ev.primitive})" case _ => "" } @@ -125,64 +122,69 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro */ private def createCodeForStruct( ctx: CodeGenContext, + row: String, inputs: Seq[GeneratedExpressionCode], inputTypes: Seq[DataType]): GeneratedExpressionCode = { + val fixedSize = 8 * inputTypes.length + UnsafeRow.calculateBitSetWidthInBytes(inputTypes.length) + val output = ctx.freshName("convertedStruct") - ctx.addMutableState("UnsafeRow", output, s"$output = new UnsafeRow();") + ctx.addMutableState("UnsafeRow", output, s"this.$output = new UnsafeRow();") val buffer = ctx.freshName("buffer") - ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") - val numBytes = ctx.freshName("numBytes") + ctx.addMutableState("byte[]", buffer, s"this.$buffer = new byte[$fixedSize];") val cursor = ctx.freshName("cursor") + ctx.addMutableState("int", cursor, s"this.$cursor = 0;") + val tmp = ctx.freshName("tmpBuffer") - val convertedFields = inputTypes.zip(inputs).map { case (dt, input) => - createConvertCode(ctx, input, dt) - } - - val fixedSize = 8 * inputTypes.length + UnsafeRow.calculateBitSetWidthInBytes(inputTypes.length) - val additionalSize = inputTypes.zip(convertedFields).map { case (dt, ev) => - genAdditionalSize(dt, ev) - }.mkString("") - - val fieldWriters = inputTypes.zip(convertedFields).zipWithIndex.map { case ((dt, ev), i) => - val update = genFieldWriter(ctx, dt, ev, output, i, cursor) - if (dt.isInstanceOf[DecimalType]) { - // Can't call setNullAt() for DecimalType + val convertedFields = inputTypes.zip(inputs).zipWithIndex.map { case ((dt, input), i) => + val ev = createConvertCode(ctx, input, dt) + val growBuffer = if (!UnsafeRow.isFixedLength(dt)) { + val numBytes = ctx.freshName("numBytes") s""" + int $numBytes = $cursor + (${genAdditionalSize(dt, ev)}); + if ($buffer.length < $numBytes) { + // This will not happen frequently, because the buffer is re-used. + byte[] $tmp = new byte[$numBytes * 2]; + PlatformDependent.copyMemory($buffer, PlatformDependent.BYTE_ARRAY_OFFSET, + $tmp, PlatformDependent.BYTE_ARRAY_OFFSET, $buffer.length); + $buffer = $tmp; + } + $output.pointTo($buffer, PlatformDependent.BYTE_ARRAY_OFFSET, + ${inputTypes.length}, $numBytes); + """ + } else { + "" + } + val update = dt match { + case dt: DecimalType if dt.precision > Decimal.MAX_LONG_DIGITS => + // Can't call setNullAt() for DecimalType + s""" if (${ev.isNull}) { - $cursor += $DecimalWriter.write($output, $i, $cursor, null); + $cursor += $DecimalWriter.write($output, $i, $cursor, null); } else { - $update; + ${genFieldWriter(ctx, dt, ev, output, i, cursor)}; } """ - } else { - s""" + case _ => + s""" if (${ev.isNull}) { $output.setNullAt($i); } else { - $update; + ${genFieldWriter(ctx, dt, ev, output, i, cursor)}; } """ } - }.mkString("\n") + s""" + ${ev.code} + $growBuffer + $update + """ + } val code = s""" - ${convertedFields.map(_.code).mkString("\n")} - - final int $numBytes = $fixedSize $additionalSize; - if ($numBytes > $buffer.length) { - $buffer = new byte[$numBytes]; - } - - $output.pointTo( - $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET, - ${inputTypes.length}, - $numBytes); - - int $cursor = $fixedSize; - - $fieldWriters + $cursor = $fixedSize; + $output.pointTo($buffer, PlatformDependent.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $cursor); + ${ctx.splitExpressions(row, convertedFields)} """ GeneratedExpressionCode(code, "false", output) } @@ -265,17 +267,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Should we do word align? val elementSize = elementType.defaultSize s""" - $PlatformDependent.UNSAFE.put${ctx.primitiveTypeName(elementType)}( + PlatformDependent.UNSAFE.put${ctx.primitiveTypeName(elementType)}( $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, ${convertedElement.primitive}); $cursor += $elementSize; """ case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => s""" - $PlatformDependent.UNSAFE.putLong( + PlatformDependent.UNSAFE.putLong( $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, ${convertedElement.primitive}.toUnscaledLong()); $cursor += 8; """ @@ -284,7 +286,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" $cursor += $writer.write( $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, $elements[$index]); """ } @@ -318,14 +320,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro for (int $index = 0; $index < $numElements; $index++) { if ($checkNull) { // If element is null, write the negative value address into offset region. - $PlatformDependent.UNSAFE.putInt( + PlatformDependent.UNSAFE.putInt( $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, + PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, -$cursor); } else { - $PlatformDependent.UNSAFE.putInt( + PlatformDependent.UNSAFE.putInt( $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, + PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, $cursor); $writeElement @@ -334,7 +336,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $output.pointTo( $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET, + PlatformDependent.BYTE_ARRAY_OFFSET, $numElements, $numBytes); } @@ -400,7 +402,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val fieldIsNull = s"$tmp.isNullAt($i)" GeneratedExpressionCode("", fieldIsNull, getFieldCode) } - val converter = createCodeForStruct(ctx, fieldEvals, fieldTypes) + val converter = createCodeForStruct(ctx, tmp, fieldEvals, fieldTypes) val code = s""" ${input.code} UnsafeRow $output = null; @@ -427,7 +429,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = { val exprEvals = expressions.map(e => e.gen(ctx)) val exprTypes = expressions.map(_.dataType) - createCodeForStruct(ctx, exprEvals, exprTypes) + createCodeForStruct(ctx, "i", exprEvals, exprTypes) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index 30b51dd83fa9a..8aaa5b4300044 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -155,7 +155,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U |$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32)); """.stripMargin } - }.mkString + }.mkString("\n") // ------------------------ Finally, put everything together --------------------------- // val code = s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala new file mode 100644 index 0000000000000..8c7ee8720f7bb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{StringType, IntegerType, StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * A test suite for generated projections + */ +class GeneratedProjectionSuite extends SparkFunSuite { + + test("generated projections on wider table") { + val N = 1000 + val wideRow1 = new GenericInternalRow((1 to N).toArray[Any]) + val schema1 = StructType((1 to N).map(i => StructField("", IntegerType))) + val wideRow2 = new GenericInternalRow( + (1 to N).map(i => UTF8String.fromString(i.toString)).toArray[Any]) + val schema2 = StructType((1 to N).map(i => StructField("", StringType))) + val joined = new JoinedRow(wideRow1, wideRow2) + val joinedSchema = StructType(schema1 ++ schema2) + val nested = new JoinedRow(InternalRow(joined, joined), joined) + val nestedSchema = StructType( + Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema) + + // test generated UnsafeProjection + val unsafeProj = UnsafeProjection.create(nestedSchema) + val unsafe: UnsafeRow = unsafeProj(nested) + (0 until N).foreach { i => + val s = UTF8String.fromString((i + 1).toString) + assert(i + 1 === unsafe.getInt(i + 2)) + assert(s === unsafe.getUTF8String(i + 2 + N)) + assert(i + 1 === unsafe.getStruct(0, N * 2).getInt(i)) + assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N)) + assert(i + 1 === unsafe.getStruct(1, N * 2).getInt(i)) + assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N)) + } + + // test generated SafeProjection + val safeProj = FromUnsafeProjection(nestedSchema) + val result = safeProj(unsafe) + // Can't compare GenericInternalRow with JoinedRow directly + (0 until N).foreach { i => + val r = i + 1 + val s = UTF8String.fromString((i + 1).toString) + assert(r === result.getInt(i + 2)) + assert(s === result.getUTF8String(i + 2 + N)) + assert(r === result.getStruct(0, N * 2).getInt(i)) + assert(s === result.getStruct(0, N * 2).getUTF8String(i + N)) + assert(r === result.getStruct(1, N * 2).getInt(i)) + assert(s === result.getStruct(1, N * 2).getUTF8String(i + N)) + } + + // test generated MutableProjection + val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, true) + } + val mutableProj = GenerateMutableProjection.generate(exprs)() + val row1 = mutableProj(result) + assert(result === row1) + val row2 = mutableProj(result) + assert(result === row2) + } +} From c4fd2a242228ee101904770446e3f37d49e39b76 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 10 Aug 2015 13:55:11 -0700 Subject: [PATCH 253/340] [SPARK-9759] [SQL] improve decimal.times() and cast(int, decimalType) This patch optimize two things: 1. passing MathContext to JavaBigDecimal.multiply/divide/reminder to do right rounding, because java.math.BigDecimal.apply(MathContext) is expensive 2. Cast integer/short/byte to decimal directly (without double) This two optimizations could speed up the end-to-end time of a aggregation (SUM(short * decimal(5, 2)) 75% (from 19s -> 10.8s) Author: Davies Liu Closes #8052 from davies/optimize_decimal and squashes the following commits: 225efad [Davies Liu] improve decimal.times() and cast(int, decimalType) --- .../spark/sql/catalyst/expressions/Cast.scala | 42 +++++++------------ .../org/apache/spark/sql/types/Decimal.scala | 12 +++--- 2 files changed, 22 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 946c5a9c04f14..616b9e0e65b78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -155,7 +155,7 @@ case class Cast(child: Expression, dataType: DataType) case ByteType => buildCast[Byte](_, _ != 0) case DecimalType() => - buildCast[Decimal](_, _ != Decimal.ZERO) + buildCast[Decimal](_, !_.isZero) case DoubleType => buildCast[Double](_, _ != 0) case FloatType => @@ -315,13 +315,13 @@ case class Cast(child: Expression, dataType: DataType) case TimestampType => // Note that we lose precision here. buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) - case DecimalType() => + case dt: DecimalType => b => changePrecision(b.asInstanceOf[Decimal].clone(), target) - case LongType => - b => changePrecision(Decimal(b.asInstanceOf[Long]), target) - case x: NumericType => // All other numeric types can be represented precisely as Doubles + case t: IntegralType => + b => changePrecision(Decimal(t.integral.asInstanceOf[Integral[Any]].toLong(b)), target) + case x: FractionalType => b => try { - changePrecision(Decimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)), target) + changePrecision(Decimal(x.fractional.asInstanceOf[Fractional[Any]].toDouble(b)), target) } catch { case _: NumberFormatException => null } @@ -534,10 +534,7 @@ case class Cast(child: Expression, dataType: DataType) (c, evPrim, evNull) => s""" try { - org.apache.spark.sql.types.Decimal tmpDecimal = - new org.apache.spark.sql.types.Decimal().set( - new scala.math.BigDecimal( - new java.math.BigDecimal($c.toString()))); + Decimal tmpDecimal = Decimal.apply(new java.math.BigDecimal($c.toString())); ${changePrecision("tmpDecimal", target, evPrim, evNull)} } catch (java.lang.NumberFormatException e) { $evNull = true; @@ -546,12 +543,7 @@ case class Cast(child: Expression, dataType: DataType) case BooleanType => (c, evPrim, evNull) => s""" - org.apache.spark.sql.types.Decimal tmpDecimal = null; - if ($c) { - tmpDecimal = new org.apache.spark.sql.types.Decimal().set(1); - } else { - tmpDecimal = new org.apache.spark.sql.types.Decimal().set(0); - } + Decimal tmpDecimal = $c ? Decimal.apply(1) : Decimal.apply(0); ${changePrecision("tmpDecimal", target, evPrim, evNull)} """ case DateType => @@ -561,32 +553,28 @@ case class Cast(child: Expression, dataType: DataType) // Note that we lose precision here. (c, evPrim, evNull) => s""" - org.apache.spark.sql.types.Decimal tmpDecimal = - new org.apache.spark.sql.types.Decimal().set( - scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)})); + Decimal tmpDecimal = Decimal.apply( + scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)})); ${changePrecision("tmpDecimal", target, evPrim, evNull)} """ case DecimalType() => (c, evPrim, evNull) => s""" - org.apache.spark.sql.types.Decimal tmpDecimal = $c.clone(); + Decimal tmpDecimal = $c.clone(); ${changePrecision("tmpDecimal", target, evPrim, evNull)} """ - case LongType => + case x: IntegralType => (c, evPrim, evNull) => s""" - org.apache.spark.sql.types.Decimal tmpDecimal = - new org.apache.spark.sql.types.Decimal().set($c); + Decimal tmpDecimal = Decimal.apply((long) $c); ${changePrecision("tmpDecimal", target, evPrim, evNull)} """ - case x: NumericType => + case x: FractionalType => // All other numeric types can be represented precisely as Doubles (c, evPrim, evNull) => s""" try { - org.apache.spark.sql.types.Decimal tmpDecimal = - new org.apache.spark.sql.types.Decimal().set( - scala.math.BigDecimal.valueOf((double) $c)); + Decimal tmpDecimal = Decimal.apply(scala.math.BigDecimal.valueOf((double) $c)); ${changePrecision("tmpDecimal", target, evPrim, evNull)} } catch (java.lang.NumberFormatException e) { $evNull = true; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 624c3f3d7fa97..d95805c24521c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -139,9 +139,9 @@ final class Decimal extends Ordered[Decimal] with Serializable { def toBigDecimal: BigDecimal = { if (decimalVal.ne(null)) { - decimalVal(MATH_CONTEXT) + decimalVal } else { - BigDecimal(longVal, _scale)(MATH_CONTEXT) + BigDecimal(longVal, _scale) } } @@ -280,13 +280,15 @@ final class Decimal extends Ordered[Decimal] with Serializable { } // HiveTypeCoercion will take care of the precision, scale of result - def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal) + def * (that: Decimal): Decimal = + Decimal(toJavaBigDecimal.multiply(that.toJavaBigDecimal, MATH_CONTEXT)) def / (that: Decimal): Decimal = - if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal) + if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, MATH_CONTEXT)) def % (that: Decimal): Decimal = - if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal) + if (that.isZero) null + else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, MATH_CONTEXT)) def remainder(that: Decimal): Decimal = this % that From 853809e948e7c5092643587a30738115b6591a59 Mon Sep 17 00:00:00 2001 From: Prabeesh K Date: Mon, 10 Aug 2015 16:33:23 -0700 Subject: [PATCH 254/340] [SPARK-5155] [PYSPARK] [STREAMING] Mqtt streaming support in Python This PR is based on #4229, thanks prabeesh. Closes #4229 Author: Prabeesh K Author: zsxwing Author: prabs Author: Prabeesh K Closes #7833 from zsxwing/pr4229 and squashes the following commits: 9570bec [zsxwing] Fix the variable name and check null in finally 4a9c79e [zsxwing] Fix pom.xml indentation abf5f18 [zsxwing] Merge branch 'master' into pr4229 935615c [zsxwing] Fix the flaky MQTT tests 47278c5 [zsxwing] Include the project class files 478f844 [zsxwing] Add unpack 5f8a1d4 [zsxwing] Make the maven build generate the test jar for Python MQTT tests 734db99 [zsxwing] Merge branch 'master' into pr4229 126608a [Prabeesh K] address the comments b90b709 [Prabeesh K] Merge pull request #1 from zsxwing/pr4229 d07f454 [zsxwing] Register StreamingListerner before starting StreamingContext; Revert unncessary changes; fix the python unit test a6747cb [Prabeesh K] wait for starting the receiver before publishing data 87fc677 [Prabeesh K] address the comments: 97244ec [zsxwing] Make sbt build the assembly test jar for streaming mqtt 80474d1 [Prabeesh K] fix 1f0cfe9 [Prabeesh K] python style fix e1ee016 [Prabeesh K] scala style fix a5a8f9f [Prabeesh K] added Python test 9767d82 [Prabeesh K] implemented Python-friendly class a11968b [Prabeesh K] fixed python style 795ec27 [Prabeesh K] address comments ee387ae [Prabeesh K] Fix assembly jar location of mqtt-assembly 3f4df12 [Prabeesh K] updated version b34c3c1 [prabs] adress comments 3aa7fff [prabs] Added Python streaming mqtt word count example b7d42ff [prabs] Mqtt streaming support in Python --- dev/run-tests.py | 2 + dev/sparktestsupport/modules.py | 2 + docs/streaming-programming-guide.md | 2 +- .../main/python/streaming/mqtt_wordcount.py | 58 +++++++++ external/mqtt-assembly/pom.xml | 102 +++++++++++++++ external/mqtt/pom.xml | 28 +++++ external/mqtt/src/main/assembly/assembly.xml | 44 +++++++ .../spark/streaming/mqtt/MQTTUtils.scala | 16 +++ .../streaming/mqtt/MQTTStreamSuite.scala | 118 +++--------------- .../spark/streaming/mqtt/MQTTTestUtils.scala | 111 ++++++++++++++++ pom.xml | 1 + project/SparkBuild.scala | 12 +- python/pyspark/streaming/mqtt.py | 72 +++++++++++ python/pyspark/streaming/tests.py | 106 +++++++++++++++- 14 files changed, 565 insertions(+), 109 deletions(-) create mode 100644 examples/src/main/python/streaming/mqtt_wordcount.py create mode 100644 external/mqtt-assembly/pom.xml create mode 100644 external/mqtt/src/main/assembly/assembly.xml create mode 100644 external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala create mode 100644 python/pyspark/streaming/mqtt.py diff --git a/dev/run-tests.py b/dev/run-tests.py index d1852b95bb292..f689425ee40b6 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -303,6 +303,8 @@ def build_spark_sbt(hadoop_version): "assembly/assembly", "streaming-kafka-assembly/assembly", "streaming-flume-assembly/assembly", + "streaming-mqtt-assembly/assembly", + "streaming-mqtt/test:assembly", "streaming-kinesis-asl-assembly/assembly"] profiles_and_goals = build_profiles + sbt_goals diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index a9717ff9569c7..d82c0cca37bc6 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -181,6 +181,7 @@ def contains_file(self, filename): dependencies=[streaming], source_file_regexes=[ "external/mqtt", + "external/mqtt-assembly", ], sbt_test_goals=[ "streaming-mqtt/test", @@ -306,6 +307,7 @@ def contains_file(self, filename): streaming, streaming_kafka, streaming_flume_assembly, + streaming_mqtt, streaming_kinesis_asl ], source_file_regexes=[ diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index dbfdb619f89e2..c59d936b43c88 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -683,7 +683,7 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea {:.no_toc} Python API As of Spark {{site.SPARK_VERSION_SHORT}}, -out of these sources, *only* Kafka and Flume are available in the Python API. We will add more advanced sources in the Python API in future. +out of these sources, *only* Kafka, Flume and MQTT are available in the Python API. We will add more advanced sources in the Python API in future. This category of sources require interfacing with external non-Spark libraries, some of them with complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts diff --git a/examples/src/main/python/streaming/mqtt_wordcount.py b/examples/src/main/python/streaming/mqtt_wordcount.py new file mode 100644 index 0000000000000..617ce5ea6775e --- /dev/null +++ b/examples/src/main/python/streaming/mqtt_wordcount.py @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + A sample wordcount with MqttStream stream + Usage: mqtt_wordcount.py + + To run this in your local machine, you need to setup a MQTT broker and publisher first, + Mosquitto is one of the open source MQTT Brokers, see + http://mosquitto.org/ + Eclipse paho project provides number of clients and utilities for working with MQTT, see + http://www.eclipse.org/paho/#getting-started + + and then run the example + `$ bin/spark-submit --jars external/mqtt-assembly/target/scala-*/\ + spark-streaming-mqtt-assembly-*.jar examples/src/main/python/streaming/mqtt_wordcount.py \ + tcp://localhost:1883 foo` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.mqtt import MQTTUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: mqtt_wordcount.py " + exit(-1) + + sc = SparkContext(appName="PythonStreamingMQTTWordCount") + ssc = StreamingContext(sc, 1) + + brokerUrl = sys.argv[1] + topic = sys.argv[2] + + lines = MQTTUtils.createStream(ssc, brokerUrl, topic) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/external/mqtt-assembly/pom.xml b/external/mqtt-assembly/pom.xml new file mode 100644 index 0000000000000..9c94473053d96 --- /dev/null +++ b/external/mqtt-assembly/pom.xml @@ -0,0 +1,102 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.5.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-mqtt-assembly_2.10 + jar + Spark Project External MQTT Assembly + http://spark.apache.org/ + + + streaming-mqtt-assembly + + + + + org.apache.spark + spark-streaming-mqtt_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-mqtt-assembly-${project.version}.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 0e41e5781784b..69b309876a0db 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -78,5 +78,33 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + + + + org.apache.maven.plugins + maven-assembly-plugin + + + test-jar-with-dependencies + package + + single + + + + spark-streaming-mqtt-test-${project.version} + ${project.build.directory}/scala-${scala.binary.version}/ + false + + false + + src/main/assembly/assembly.xml + + + + + + diff --git a/external/mqtt/src/main/assembly/assembly.xml b/external/mqtt/src/main/assembly/assembly.xml new file mode 100644 index 0000000000000..ecab5b360eb3e --- /dev/null +++ b/external/mqtt/src/main/assembly/assembly.xml @@ -0,0 +1,44 @@ + + + test-jar-with-dependencies + + jar + + false + + + + ${project.build.directory}/scala-${scala.binary.version}/test-classes + / + + + + + + true + test + true + + org.apache.hadoop:*:jar + org.apache.zookeeper:*:jar + org.apache.avro:*:jar + + + + + diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala index 1142d0f56ba34..38a1114863d15 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala @@ -74,3 +74,19 @@ object MQTTUtils { createStream(jssc.ssc, brokerUrl, topic, storageLevel) } } + +/** + * This is a helper class that wraps the methods in MQTTUtils into more Python-friendly class and + * function so that it can be easily instantiated and called from Python's MQTTUtils. + */ +private class MQTTUtilsPythonHelper { + + def createStream( + jssc: JavaStreamingContext, + brokerUrl: String, + topic: String, + storageLevel: StorageLevel + ): JavaDStream[String] = { + MQTTUtils.createStream(jssc, brokerUrl, topic, storageLevel) + } +} diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index c4bf5aa7869bb..a6a9249db8ed7 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -17,46 +17,30 @@ package org.apache.spark.streaming.mqtt -import java.net.{URI, ServerSocket} -import java.util.concurrent.CountDownLatch -import java.util.concurrent.TimeUnit - import scala.concurrent.duration._ import scala.language.postfixOps -import org.apache.activemq.broker.{TransportConnector, BrokerService} -import org.apache.commons.lang3.RandomUtils -import org.eclipse.paho.client.mqttv3._ -import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence - import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually -import org.apache.spark.streaming.{Milliseconds, StreamingContext} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.scheduler.StreamingListener -import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.util.Utils +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Milliseconds, StreamingContext} class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter { private val batchDuration = Milliseconds(500) private val master = "local[2]" private val framework = this.getClass.getSimpleName - private val freePort = findFreePort() - private val brokerUri = "//localhost:" + freePort private val topic = "def" - private val persistenceDir = Utils.createTempDir() private var ssc: StreamingContext = _ - private var broker: BrokerService = _ - private var connector: TransportConnector = _ + private var mqttTestUtils: MQTTTestUtils = _ before { ssc = new StreamingContext(master, framework, batchDuration) - setupMQTT() + mqttTestUtils = new MQTTTestUtils + mqttTestUtils.setup() } after { @@ -64,14 +48,17 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter ssc.stop() ssc = null } - Utils.deleteRecursively(persistenceDir) - tearDownMQTT() + if (mqttTestUtils != null) { + mqttTestUtils.teardown() + mqttTestUtils = null + } } test("mqtt input stream") { val sendMessage = "MQTT demo for spark streaming" - val receiveStream = - MQTTUtils.createStream(ssc, "tcp:" + brokerUri, topic, StorageLevel.MEMORY_ONLY) + val receiveStream = MQTTUtils.createStream(ssc, "tcp://" + mqttTestUtils.brokerUri, topic, + StorageLevel.MEMORY_ONLY) + @volatile var receiveMessage: List[String] = List() receiveStream.foreachRDD { rdd => if (rdd.collect.length > 0) { @@ -79,89 +66,14 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter receiveMessage } } - ssc.start() - // wait for the receiver to start before publishing data, or we risk failing - // the test nondeterministically. See SPARK-4631 - waitForReceiverToStart() + ssc.start() - publishData(sendMessage) + // Retry it because we don't know when the receiver will start. eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { + mqttTestUtils.publishData(topic, sendMessage) assert(sendMessage.equals(receiveMessage(0))) } ssc.stop() } - - private def setupMQTT() { - broker = new BrokerService() - broker.setDataDirectoryFile(Utils.createTempDir()) - connector = new TransportConnector() - connector.setName("mqtt") - connector.setUri(new URI("mqtt:" + brokerUri)) - broker.addConnector(connector) - broker.start() - } - - private def tearDownMQTT() { - if (broker != null) { - broker.stop() - broker = null - } - if (connector != null) { - connector.stop() - connector = null - } - } - - private def findFreePort(): Int = { - val candidatePort = RandomUtils.nextInt(1024, 65536) - Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { - val socket = new ServerSocket(trialPort) - socket.close() - (null, trialPort) - }, new SparkConf())._2 - } - - def publishData(data: String): Unit = { - var client: MqttClient = null - try { - val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) - client = new MqttClient("tcp:" + brokerUri, MqttClient.generateClientId(), persistence) - client.connect() - if (client.isConnected) { - val msgTopic = client.getTopic(topic) - val message = new MqttMessage(data.getBytes("utf-8")) - message.setQos(1) - message.setRetained(true) - - for (i <- 0 to 10) { - try { - msgTopic.publish(message) - } catch { - case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => - // wait for Spark streaming to consume something from the message queue - Thread.sleep(50) - } - } - } - } finally { - client.disconnect() - client.close() - client = null - } - } - - /** - * Block until at least one receiver has started or timeout occurs. - */ - private def waitForReceiverToStart() = { - val latch = new CountDownLatch(1) - ssc.addStreamingListener(new StreamingListener { - override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { - latch.countDown() - } - }) - - assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.") - } } diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala new file mode 100644 index 0000000000000..1a371b7008824 --- /dev/null +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.mqtt + +import java.net.{ServerSocket, URI} + +import scala.language.postfixOps + +import com.google.common.base.Charsets.UTF_8 +import org.apache.activemq.broker.{BrokerService, TransportConnector} +import org.apache.commons.lang3.RandomUtils +import org.eclipse.paho.client.mqttv3._ +import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence + +import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf} + +/** + * Share codes for Scala and Python unit tests + */ +private class MQTTTestUtils extends Logging { + + private val persistenceDir = Utils.createTempDir() + private val brokerHost = "localhost" + private val brokerPort = findFreePort() + + private var broker: BrokerService = _ + private var connector: TransportConnector = _ + + def brokerUri: String = { + s"$brokerHost:$brokerPort" + } + + def setup(): Unit = { + broker = new BrokerService() + broker.setDataDirectoryFile(Utils.createTempDir()) + connector = new TransportConnector() + connector.setName("mqtt") + connector.setUri(new URI("mqtt://" + brokerUri)) + broker.addConnector(connector) + broker.start() + } + + def teardown(): Unit = { + if (broker != null) { + broker.stop() + broker = null + } + if (connector != null) { + connector.stop() + connector = null + } + Utils.deleteRecursively(persistenceDir) + } + + private def findFreePort(): Int = { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + }, new SparkConf())._2 + } + + def publishData(topic: String, data: String): Unit = { + var client: MqttClient = null + try { + val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) + client = new MqttClient("tcp://" + brokerUri, MqttClient.generateClientId(), persistence) + client.connect() + if (client.isConnected) { + val msgTopic = client.getTopic(topic) + val message = new MqttMessage(data.getBytes(UTF_8)) + message.setQos(1) + message.setRetained(true) + + for (i <- 0 to 10) { + try { + msgTopic.publish(message) + } catch { + case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => + // wait for Spark streaming to consume something from the message queue + Thread.sleep(50) + } + } + } + } finally { + if (client != null) { + client.disconnect() + client.close() + client = null + } + } + } + +} diff --git a/pom.xml b/pom.xml index 2bcc55b040a26..8942836a7da16 100644 --- a/pom.xml +++ b/pom.xml @@ -104,6 +104,7 @@ external/flume-sink external/flume-assembly external/mqtt + external/mqtt-assembly external/zeromq examples repl diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 9a33baa7c6ce1..41a85fa9de778 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -45,8 +45,8 @@ object BuildCommons { sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", "kinesis-asl").map(ProjectRef(buildLocation, _)) - val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKinesisAslAssembly) = - Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-kinesis-asl-assembly") + val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingMqttAssembly, streamingKinesisAslAssembly) = + Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-mqtt-assembly", "streaming-kinesis-asl-assembly") .map(ProjectRef(buildLocation, _)) val tools = ProjectRef(buildLocation, "tools") @@ -212,6 +212,9 @@ object SparkBuild extends PomBuild { /* Enable Assembly for all assembly projects */ assemblyProjects.foreach(enable(Assembly.settings)) + /* Enable Assembly for streamingMqtt test */ + enable(inConfig(Test)(Assembly.settings))(streamingMqtt) + /* Package pyspark artifacts in a separate zip file for YARN. */ enable(PySparkAssembly.settings)(assembly) @@ -382,13 +385,16 @@ object Assembly { .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String]) }, jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => - if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly") || mName.contains("streaming-kinesis-asl-assembly")) { + if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly") || mName.contains("streaming-mqtt-assembly") || mName.contains("streaming-kinesis-asl-assembly")) { // This must match the same name used in maven (see external/kafka-assembly/pom.xml) s"${mName}-${v}.jar" } else { s"${mName}-${v}-hadoop${hv}.jar" } }, + jarName in (Test, assembly) <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => + s"${mName}-test-${v}.jar" + }, mergeStrategy in assembly := { case PathList("org", "datanucleus", xs @ _*) => MergeStrategy.discard case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py new file mode 100644 index 0000000000000..f06598971c548 --- /dev/null +++ b/python/pyspark/streaming/mqtt.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from py4j.java_gateway import Py4JJavaError + +from pyspark.storagelevel import StorageLevel +from pyspark.serializers import UTF8Deserializer +from pyspark.streaming import DStream + +__all__ = ['MQTTUtils'] + + +class MQTTUtils(object): + + @staticmethod + def createStream(ssc, brokerUrl, topic, + storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): + """ + Create an input stream that pulls messages from a Mqtt Broker. + :param ssc: StreamingContext object + :param brokerUrl: Url of remote mqtt publisher + :param topic: topic name to subscribe to + :param storageLevel: RDD storage level. + :return: A DStream object + """ + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + + try: + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.mqtt.MQTTUtilsPythonHelper") + helper = helperClass.newInstance() + jstream = helper.createStream(ssc._jssc, brokerUrl, topic, jlevel) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + MQTTUtils._printErrorMsg(ssc.sparkContext) + raise e + + return DStream(jstream, ssc, UTF8Deserializer()) + + @staticmethod + def _printErrorMsg(sc): + print(""" +________________________________________________________________________________________________ + + Spark Streaming's MQTT libraries not found in class path. Try one of the following. + + 1. Include the MQTT library and its dependencies with in the + spark-submit command as + + $ bin/spark-submit --packages org.apache.spark:spark-streaming-mqtt:%s ... + + 2. Download the JAR of the artifact from Maven Central http://search.maven.org/, + Group Id = org.apache.spark, Artifact Id = spark-streaming-mqtt-assembly, Version = %s. + Then, include the jar in the spark-submit command as + + $ bin/spark-submit --jars ... +________________________________________________________________________________________________ +""" % (sc.version, sc.version)) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 5cd544b2144ef..66ae3345f468f 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -40,6 +40,7 @@ from pyspark.streaming.context import StreamingContext from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition from pyspark.streaming.flume import FlumeUtils +from pyspark.streaming.mqtt import MQTTUtils from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream @@ -893,6 +894,68 @@ def test_flume_polling_multiple_hosts(self): self._testMultipleTimes(self._testFlumePollingMultipleHosts) +class MQTTStreamTests(PySparkStreamingTestCase): + timeout = 20 # seconds + duration = 1 + + def setUp(self): + super(MQTTStreamTests, self).setUp() + + MQTTTestUtilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.mqtt.MQTTTestUtils") + self._MQTTTestUtils = MQTTTestUtilsClz.newInstance() + self._MQTTTestUtils.setup() + + def tearDown(self): + if self._MQTTTestUtils is not None: + self._MQTTTestUtils.teardown() + self._MQTTTestUtils = None + + super(MQTTStreamTests, self).tearDown() + + def _randomTopic(self): + return "topic-%d" % random.randint(0, 10000) + + def _startContext(self, topic): + # Start the StreamingContext and also collect the result + stream = MQTTUtils.createStream(self.ssc, "tcp://" + self._MQTTTestUtils.brokerUri(), topic) + result = [] + + def getOutput(_, rdd): + for data in rdd.collect(): + result.append(data) + + stream.foreachRDD(getOutput) + self.ssc.start() + return result + + def test_mqtt_stream(self): + """Test the Python MQTT stream API.""" + sendData = "MQTT demo for spark streaming" + topic = self._randomTopic() + result = self._startContext(topic) + + def retry(): + self._MQTTTestUtils.publishData(topic, sendData) + # Because "publishData" sends duplicate messages, here we should use > 0 + self.assertTrue(len(result) > 0) + self.assertEqual(sendData, result[0]) + + # Retry it because we don't know when the receiver will start. + self._retry_or_timeout(retry) + + def _retry_or_timeout(self, test_func): + start_time = time.time() + while True: + try: + test_func() + break + except: + if time.time() - start_time > self.timeout: + raise + time.sleep(0.01) + + class KinesisStreamTests(PySparkStreamingTestCase): def test_kinesis_stream_api(self): @@ -985,7 +1048,42 @@ def search_flume_assembly_jar(): "'build/mvn package' before running this test") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Flume assembly JARs in %s; please " - "remove all but one") % flume_assembly_dir) + "remove all but one") % flume_assembly_dir) + else: + return jars[0] + + +def search_mqtt_assembly_jar(): + SPARK_HOME = os.environ["SPARK_HOME"] + mqtt_assembly_dir = os.path.join(SPARK_HOME, "external/mqtt-assembly") + jars = glob.glob( + os.path.join(mqtt_assembly_dir, "target/scala-*/spark-streaming-mqtt-assembly-*.jar")) + if not jars: + raise Exception( + ("Failed to find Spark Streaming MQTT assembly jar in %s. " % mqtt_assembly_dir) + + "You need to build Spark with " + "'build/sbt assembly/assembly streaming-mqtt-assembly/assembly' or " + "'build/mvn package' before running this test") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming MQTT assembly JARs in %s; please " + "remove all but one") % mqtt_assembly_dir) + else: + return jars[0] + + +def search_mqtt_test_jar(): + SPARK_HOME = os.environ["SPARK_HOME"] + mqtt_test_dir = os.path.join(SPARK_HOME, "external/mqtt") + jars = glob.glob( + os.path.join(mqtt_test_dir, "target/scala-*/spark-streaming-mqtt-test-*.jar")) + if not jars: + raise Exception( + ("Failed to find Spark Streaming MQTT test jar in %s. " % mqtt_test_dir) + + "You need to build Spark with " + "'build/sbt assembly/assembly streaming-mqtt/test:assembly'") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming MQTT test JARs in %s; please " + "remove all but one") % mqtt_test_dir) else: return jars[0] @@ -1012,8 +1110,12 @@ def search_kinesis_asl_assembly_jar(): if __name__ == "__main__": kafka_assembly_jar = search_kafka_assembly_jar() flume_assembly_jar = search_flume_assembly_jar() + mqtt_assembly_jar = search_mqtt_assembly_jar() + mqtt_test_jar = search_mqtt_test_jar() kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar() - jars = "%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar) + + jars = "%s,%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar, + mqtt_assembly_jar, mqtt_test_jar) os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars unittest.main() From 3c9802d9400bea802984456683b2736a450ee17e Mon Sep 17 00:00:00 2001 From: Hao Zhu Date: Mon, 10 Aug 2015 17:17:22 -0700 Subject: [PATCH 255/340] [SPARK-9801] [STREAMING] Check if file exists before deleting temporary files. Spark streaming deletes the temp file and backup files without checking if they exist or not Author: Hao Zhu Closes #8082 from viadea/master and squashes the following commits: 242d05f [Hao Zhu] [SPARK-9801][Streaming]No need to check the existence of those files fd143f2 [Hao Zhu] [SPARK-9801][Streaming]Check if backupFile exists before deleting backupFile files. 087daf0 [Hao Zhu] SPARK-9801 --- .../scala/org/apache/spark/streaming/Checkpoint.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 2780d5b6adbcf..6f6b449accc3c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -192,7 +192,9 @@ class CheckpointWriter( + "'") // Write checkpoint to temp file - fs.delete(tempFile, true) // just in case it exists + if (fs.exists(tempFile)) { + fs.delete(tempFile, true) // just in case it exists + } val fos = fs.create(tempFile) Utils.tryWithSafeFinally { fos.write(bytes) @@ -203,7 +205,9 @@ class CheckpointWriter( // If the checkpoint file exists, back it up // If the backup exists as well, just delete it, otherwise rename will fail if (fs.exists(checkpointFile)) { - fs.delete(backupFile, true) // just in case it exists + if (fs.exists(backupFile)){ + fs.delete(backupFile, true) // just in case it exists + } if (!fs.rename(checkpointFile, backupFile)) { logWarning("Could not rename " + checkpointFile + " to " + backupFile) } From 071bbad5db1096a548c886762b611a8484a52753 Mon Sep 17 00:00:00 2001 From: Damian Guy Date: Tue, 11 Aug 2015 12:46:33 +0800 Subject: [PATCH 256/340] [SPARK-9340] [SQL] Fixes converting unannotated Parquet lists This PR is inspired by #8063 authored by dguy. Especially, testing Parquet files added here are all taken from that PR. **Committer who merges this PR should attribute it to "Damian Guy ".** ---- SPARK-6776 and SPARK-6777 followed `parquet-avro` to implement backwards-compatibility rules defined in `parquet-format` spec. However, both Spark SQL and `parquet-avro` neglected the following statement in `parquet-format`: > This does not affect repeated fields that are not annotated: A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor annotated by `LIST` or `MAP` should be interpreted as a required list of required elements where the element type is the type of the field. One of the consequences is that, Parquet files generated by `parquet-protobuf` containing unannotated repeated fields are not correctly converted to Catalyst arrays. This PR fixes this issue by 1. Handling unannotated repeated fields in `CatalystSchemaConverter`. 2. Converting this kind of special repeated fields to Catalyst arrays in `CatalystRowConverter`. Two special converters, `RepeatedPrimitiveConverter` and `RepeatedGroupConverter`, are added. They delegate actual conversion work to a child `elementConverter` and accumulates elements in an `ArrayBuffer`. Two extra methods, `start()` and `end()`, are added to `ParentContainerUpdater`. So that they can be used to initialize new `ArrayBuffer`s for unannotated repeated fields, and propagate converted array values to upstream. Author: Cheng Lian Closes #8070 from liancheng/spark-9340/unannotated-parquet-list and squashes the following commits: ace6df7 [Cheng Lian] Moves ParquetProtobufCompatibilitySuite f1c7bfd [Cheng Lian] Updates .rat-excludes 420ad2b [Cheng Lian] Fixes converting unannotated Parquet lists --- .rat-excludes | 1 + .../parquet/CatalystRowConverter.scala | 151 ++++++++++++++---- .../parquet/CatalystSchemaConverter.scala | 7 +- .../resources/nested-array-struct.parquet | Bin 0 -> 775 bytes .../test/resources/old-repeated-int.parquet | Bin 0 -> 389 bytes .../resources/old-repeated-message.parquet | Bin 0 -> 600 bytes .../src/test/resources/old-repeated.parquet | Bin 0 -> 432 bytes .../parquet-thrift-compat.snappy.parquet | Bin .../resources/proto-repeated-string.parquet | Bin 0 -> 411 bytes .../resources/proto-repeated-struct.parquet | Bin 0 -> 608 bytes .../proto-struct-with-array-many.parquet | Bin 0 -> 802 bytes .../resources/proto-struct-with-array.parquet | Bin 0 -> 1576 bytes .../ParquetProtobufCompatibilitySuite.scala | 91 +++++++++++ .../parquet/ParquetSchemaSuite.scala | 30 ++++ 14 files changed, 247 insertions(+), 33 deletions(-) create mode 100644 sql/core/src/test/resources/nested-array-struct.parquet create mode 100644 sql/core/src/test/resources/old-repeated-int.parquet create mode 100644 sql/core/src/test/resources/old-repeated-message.parquet create mode 100644 sql/core/src/test/resources/old-repeated.parquet mode change 100755 => 100644 sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet create mode 100644 sql/core/src/test/resources/proto-repeated-string.parquet create mode 100644 sql/core/src/test/resources/proto-repeated-struct.parquet create mode 100644 sql/core/src/test/resources/proto-struct-with-array-many.parquet create mode 100644 sql/core/src/test/resources/proto-struct-with-array.parquet create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala diff --git a/.rat-excludes b/.rat-excludes index 72771465846b8..9165872b9fb27 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -94,3 +94,4 @@ INDEX gen-java.* .*avpr org.apache.spark.sql.sources.DataSourceRegister +.*parquet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 3542dfbae1292..ab5a6ddd41cfc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -21,11 +21,11 @@ import java.math.{BigDecimal, BigInteger} import java.nio.ByteOrder import scala.collection.JavaConversions._ -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.parquet.column.Dictionary import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} +import org.apache.parquet.schema.OriginalType.LIST import org.apache.parquet.schema.Type.Repetition import org.apache.parquet.schema.{GroupType, PrimitiveType, Type} @@ -42,6 +42,12 @@ import org.apache.spark.unsafe.types.UTF8String * values to an [[ArrayBuffer]]. */ private[parquet] trait ParentContainerUpdater { + /** Called before a record field is being converted */ + def start(): Unit = () + + /** Called after a record field is being converted */ + def end(): Unit = () + def set(value: Any): Unit = () def setBoolean(value: Boolean): Unit = set(value) def setByte(value: Byte): Unit = set(value) @@ -55,6 +61,32 @@ private[parquet] trait ParentContainerUpdater { /** A no-op updater used for root converter (who doesn't have a parent). */ private[parquet] object NoopUpdater extends ParentContainerUpdater +private[parquet] trait HasParentContainerUpdater { + def updater: ParentContainerUpdater +} + +/** + * A convenient converter class for Parquet group types with an [[HasParentContainerUpdater]]. + */ +private[parquet] abstract class CatalystGroupConverter(val updater: ParentContainerUpdater) + extends GroupConverter with HasParentContainerUpdater + +/** + * Parquet converter for Parquet primitive types. Note that not all Spark SQL atomic types + * are handled by this converter. Parquet primitive types are only a subset of those of Spark + * SQL. For example, BYTE, SHORT, and INT in Spark SQL are all covered by INT32 in Parquet. + */ +private[parquet] class CatalystPrimitiveConverter(val updater: ParentContainerUpdater) + extends PrimitiveConverter with HasParentContainerUpdater { + + override def addBoolean(value: Boolean): Unit = updater.setBoolean(value) + override def addInt(value: Int): Unit = updater.setInt(value) + override def addLong(value: Long): Unit = updater.setLong(value) + override def addFloat(value: Float): Unit = updater.setFloat(value) + override def addDouble(value: Double): Unit = updater.setDouble(value) + override def addBinary(value: Binary): Unit = updater.set(value.getBytes) +} + /** * A [[CatalystRowConverter]] is used to convert Parquet "structs" into Spark SQL [[InternalRow]]s. * Since any Parquet record is also a struct, this converter can also be used as root converter. @@ -70,7 +102,7 @@ private[parquet] class CatalystRowConverter( parquetType: GroupType, catalystType: StructType, updater: ParentContainerUpdater) - extends GroupConverter { + extends CatalystGroupConverter(updater) { /** * Updater used together with field converters within a [[CatalystRowConverter]]. It propagates @@ -89,13 +121,11 @@ private[parquet] class CatalystRowConverter( /** * Represents the converted row object once an entire Parquet record is converted. - * - * @todo Uses [[UnsafeRow]] for better performance. */ val currentRow = new SpecificMutableRow(catalystType.map(_.dataType)) // Converters for each field. - private val fieldConverters: Array[Converter] = { + private val fieldConverters: Array[Converter with HasParentContainerUpdater] = { parquetType.getFields.zip(catalystType).zipWithIndex.map { case ((parquetFieldType, catalystField), ordinal) => // Converted field value should be set to the `ordinal`-th cell of `currentRow` @@ -105,11 +135,19 @@ private[parquet] class CatalystRowConverter( override def getConverter(fieldIndex: Int): Converter = fieldConverters(fieldIndex) - override def end(): Unit = updater.set(currentRow) + override def end(): Unit = { + var i = 0 + while (i < currentRow.numFields) { + fieldConverters(i).updater.end() + i += 1 + } + updater.set(currentRow) + } override def start(): Unit = { var i = 0 while (i < currentRow.numFields) { + fieldConverters(i).updater.start() currentRow.setNullAt(i) i += 1 } @@ -122,20 +160,20 @@ private[parquet] class CatalystRowConverter( private def newConverter( parquetType: Type, catalystType: DataType, - updater: ParentContainerUpdater): Converter = { + updater: ParentContainerUpdater): Converter with HasParentContainerUpdater = { catalystType match { case BooleanType | IntegerType | LongType | FloatType | DoubleType | BinaryType => new CatalystPrimitiveConverter(updater) case ByteType => - new PrimitiveConverter { + new CatalystPrimitiveConverter(updater) { override def addInt(value: Int): Unit = updater.setByte(value.asInstanceOf[ByteType#InternalType]) } case ShortType => - new PrimitiveConverter { + new CatalystPrimitiveConverter(updater) { override def addInt(value: Int): Unit = updater.setShort(value.asInstanceOf[ShortType#InternalType]) } @@ -148,7 +186,7 @@ private[parquet] class CatalystRowConverter( case TimestampType => // TODO Implements `TIMESTAMP_MICROS` once parquet-mr has that. - new PrimitiveConverter { + new CatalystPrimitiveConverter(updater) { // Converts nanosecond timestamps stored as INT96 override def addBinary(value: Binary): Unit = { assert( @@ -164,13 +202,23 @@ private[parquet] class CatalystRowConverter( } case DateType => - new PrimitiveConverter { + new CatalystPrimitiveConverter(updater) { override def addInt(value: Int): Unit = { // DateType is not specialized in `SpecificMutableRow`, have to box it here. updater.set(value.asInstanceOf[DateType#InternalType]) } } + // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor + // annotated by `LIST` or `MAP` should be interpreted as a required list of required + // elements where the element type is the type of the field. + case t: ArrayType if parquetType.getOriginalType != LIST => + if (parquetType.isPrimitive) { + new RepeatedPrimitiveConverter(parquetType, t.elementType, updater) + } else { + new RepeatedGroupConverter(parquetType, t.elementType, updater) + } + case t: ArrayType => new CatalystArrayConverter(parquetType.asGroupType(), t, updater) @@ -195,27 +243,11 @@ private[parquet] class CatalystRowConverter( } } - /** - * Parquet converter for Parquet primitive types. Note that not all Spark SQL atomic types - * are handled by this converter. Parquet primitive types are only a subset of those of Spark - * SQL. For example, BYTE, SHORT, and INT in Spark SQL are all covered by INT32 in Parquet. - */ - private final class CatalystPrimitiveConverter(updater: ParentContainerUpdater) - extends PrimitiveConverter { - - override def addBoolean(value: Boolean): Unit = updater.setBoolean(value) - override def addInt(value: Int): Unit = updater.setInt(value) - override def addLong(value: Long): Unit = updater.setLong(value) - override def addFloat(value: Float): Unit = updater.setFloat(value) - override def addDouble(value: Double): Unit = updater.setDouble(value) - override def addBinary(value: Binary): Unit = updater.set(value.getBytes) - } - /** * Parquet converter for strings. A dictionary is used to minimize string decoding cost. */ private final class CatalystStringConverter(updater: ParentContainerUpdater) - extends PrimitiveConverter { + extends CatalystPrimitiveConverter(updater) { private var expandedDictionary: Array[UTF8String] = null @@ -242,7 +274,7 @@ private[parquet] class CatalystRowConverter( private final class CatalystDecimalConverter( decimalType: DecimalType, updater: ParentContainerUpdater) - extends PrimitiveConverter { + extends CatalystPrimitiveConverter(updater) { // Converts decimals stored as INT32 override def addInt(value: Int): Unit = { @@ -306,7 +338,7 @@ private[parquet] class CatalystRowConverter( parquetSchema: GroupType, catalystSchema: ArrayType, updater: ParentContainerUpdater) - extends GroupConverter { + extends CatalystGroupConverter(updater) { private var currentArray: ArrayBuffer[Any] = _ @@ -383,7 +415,7 @@ private[parquet] class CatalystRowConverter( parquetType: GroupType, catalystType: MapType, updater: ParentContainerUpdater) - extends GroupConverter { + extends CatalystGroupConverter(updater) { private var currentKeys: ArrayBuffer[Any] = _ private var currentValues: ArrayBuffer[Any] = _ @@ -446,4 +478,61 @@ private[parquet] class CatalystRowConverter( } } } + + private trait RepeatedConverter { + private var currentArray: ArrayBuffer[Any] = _ + + protected def newArrayUpdater(updater: ParentContainerUpdater) = new ParentContainerUpdater { + override def start(): Unit = currentArray = ArrayBuffer.empty[Any] + override def end(): Unit = updater.set(new GenericArrayData(currentArray.toArray)) + override def set(value: Any): Unit = currentArray += value + } + } + + /** + * A primitive converter for converting unannotated repeated primitive values to required arrays + * of required primitives values. + */ + private final class RepeatedPrimitiveConverter( + parquetType: Type, + catalystType: DataType, + parentUpdater: ParentContainerUpdater) + extends PrimitiveConverter with RepeatedConverter with HasParentContainerUpdater { + + val updater: ParentContainerUpdater = newArrayUpdater(parentUpdater) + + private val elementConverter: PrimitiveConverter = + newConverter(parquetType, catalystType, updater).asPrimitiveConverter() + + override def addBoolean(value: Boolean): Unit = elementConverter.addBoolean(value) + override def addInt(value: Int): Unit = elementConverter.addInt(value) + override def addLong(value: Long): Unit = elementConverter.addLong(value) + override def addFloat(value: Float): Unit = elementConverter.addFloat(value) + override def addDouble(value: Double): Unit = elementConverter.addDouble(value) + override def addBinary(value: Binary): Unit = elementConverter.addBinary(value) + + override def setDictionary(dict: Dictionary): Unit = elementConverter.setDictionary(dict) + override def hasDictionarySupport: Boolean = elementConverter.hasDictionarySupport + override def addValueFromDictionary(id: Int): Unit = elementConverter.addValueFromDictionary(id) + } + + /** + * A group converter for converting unannotated repeated group values to required arrays of + * required struct values. + */ + private final class RepeatedGroupConverter( + parquetType: Type, + catalystType: DataType, + parentUpdater: ParentContainerUpdater) + extends GroupConverter with HasParentContainerUpdater with RepeatedConverter { + + val updater: ParentContainerUpdater = newArrayUpdater(parentUpdater) + + private val elementConverter: GroupConverter = + newConverter(parquetType, catalystType, updater).asGroupConverter() + + override def getConverter(field: Int): Converter = elementConverter.getConverter(field) + override def end(): Unit = elementConverter.end() + override def start(): Unit = elementConverter.start() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index a3fc74cf7929b..275646e8181ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -100,8 +100,11 @@ private[parquet] class CatalystSchemaConverter( StructField(field.getName, convertField(field), nullable = false) case REPEATED => - throw new AnalysisException( - s"REPEATED not supported outside LIST or MAP. Type: $field") + // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor + // annotated by `LIST` or `MAP` should be interpreted as a required list of required + // elements where the element type is the type of the field. + val arrayType = ArrayType(convertField(field), containsNull = false) + StructField(field.getName, arrayType, nullable = false) } } diff --git a/sql/core/src/test/resources/nested-array-struct.parquet b/sql/core/src/test/resources/nested-array-struct.parquet new file mode 100644 index 0000000000000000000000000000000000000000..41a43fa35d39685e56ba4849a16cba4bb1aa86ae GIT binary patch literal 775 zcmaKr-%G+!6vvO#=8u;i;*JSE3{j~lAyWtuV%0Q3*U(Y)B(q&>u(@?NBZ;2+Px=dc z?6I>s%N6u+cQ4=bIp1@3cBjdsBLbvCDhGte15a`Q8~~)V;d2WY3V@LYX{-@GMj#!6 z`;fvdgDblto20oxMhs#hdKzt*4*3w}iuUE6PW?b*Zs1NAv%1WfvAnT@2NhLn_L#fy zHFWX={q7*H z93@^0*O&w#e53@lJ`hFEV2=wL)V**Pb(8vc%<=-4iJz&t;n22J{%<(t!px$!DZLaV zDaOCYR1UR;Go`F89pTwFrqpgr1NlrDOs+J&f2GO;)PtpmW%OH3neOEccXHoWyO`6Bl5({+ea14&qL7DtETw`(h_426$BxCYApN S1!5siKXe$p;U(Ab7x)6M)yB5~ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/old-repeated-int.parquet b/sql/core/src/test/resources/old-repeated-int.parquet new file mode 100644 index 0000000000000000000000000000000000000000..520922f73ebb75950c4e65dec6689af817cb7a33 GIT binary patch literal 389 zcmZWmO-sW-5FMkeB_3tN1_Ca@_7p=~Z@EPbSg5juTs)Ocvz0>9#NEw7ivQhdDcI1< z%;U}1booNUUY~PehCwzvumZho_zD!@TVtKoQHivd-PzgS{O4oWwfk)ZsDnB!ByvMUB7gt@Rk50{GMw}6DWn-w!#F0CGZwP` zh81XVct9peJU!6=O6yx`ZywS;EGVOwOOIsCr3p)d)y(vgv`4bce)5D7UiKlH0l6Ga0+m-EijTt zM!X)Rd#pSj~EkCao0il-K=62)db~64ZC7r8d1AwIhzFkoG;e&AbpZW#pJ&{5H literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/old-repeated.parquet b/sql/core/src/test/resources/old-repeated.parquet new file mode 100644 index 0000000000000000000000000000000000000000..213f1a90291b30a8a3161b51c38f008f3ae9f6e5 GIT binary patch literal 432 zcmZWm!D@p*5ZxNF!5+)X3PMGioUA16Ei?y9g$B|h;-#ms>Ld--Xm{5`DgF13A<&42 znSH!BJNtMWhsm50I-@h68VC$(I7}ZALYRJm-NGUo*2p;a%Z@xEJgH{;FE=Sj6^mNc zS-TAqXn-pyRtNP8Qt};84d*60yAuBru{7JUo$1)Y6%&KlJ(Uv6u!JS1zOP)cw zaM$5ewB9699EEB0jJ*18aDDn7N1N4K`fzXlnuJ~V?c^nwk}Yeo3wXox4+#3Y!pMU2 V+-`?%2{TWZ?kYh(G4~k%>JK8=aDe~- literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet b/sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet old mode 100755 new mode 100644 diff --git a/sql/core/src/test/resources/proto-repeated-string.parquet b/sql/core/src/test/resources/proto-repeated-string.parquet new file mode 100644 index 0000000000000000000000000000000000000000..8a7eea601d0164b1177ab90efb7116c6e4c900da GIT binary patch literal 411 zcmY*W!D_-l5S^$E62wc{umKMtR4+{f-b!vM4Q)Y6&|G?w#H^aKansF;gi?C#*Yq1Z zv6e;{_C4Mk<_)t^FrN}2UmBK6hDddy19SkO`+9soFOY8;=b|A8A$itAvJoQdBBnKK zKy>)#|`4$W^3YtqL)WN5pTmWh1ZGv$>{f|s#sCG%1VN%LJ&FyD4siH@<(8PDu@ z!?sWEU$)ao`yyr1x2MQ?k}~ewv*0eAE$3kr261?gx~fYY8oxy0auLs;o*#@41L)=X h7Au}q6}>(e6`sLs-{PvZ8BpWYeN#xd)c_*=mLHLcZu|fM literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/proto-repeated-struct.parquet b/sql/core/src/test/resources/proto-repeated-struct.parquet new file mode 100644 index 0000000000000000000000000000000000000000..c29eee35c350e84efa2818febb2528309a8ac3ea GIT binary patch literal 608 zcma))-D<)x6vvP811a8(bSZdI%9Js*tjca=2puciK%r=F1_P}cr%>B2jf^q&4ts<> z%r1PaC9^8^%A4e$li&GFTzg<)z+K#J;DQh(TmnDm&bFDCfsEak0$H6=|yp$CW-$_F@l={DK5j1GEpn8)DX!>A+15G`Fpg} zMZREE-l#~cYPa=r6<4%c3AD?tzx2bP7Sypiu9pGoi(^0p+XD*$Y;s4$HpQOVL*dnD3MsZWzL<}ml5ZY`6p``6p3uzOR6cOQizQ3hI@;TGm z8hy+Rm>Hl!&aiC8oEI4i7`u<#dC`r5%q<5r!9eDg1IFy*c2RU=AalzBO)!wT<$y7e zS0C?=T^uJ)6ePiDIW^oM?BO`}o-pLWHOS&h@*H8x zD56?ZuNu`Fl-0Tj)YHv=x(@J>C=j)+#mj%Z_4kgdp}D_<3b zc(o7;z363$6CCr9}+-E j<-Z;KUL2!lIhl~NDb+dIHUN;6ire!D{E~S&5gg5S?rsql%Icnq5|qgD{N<#x?oCY2!n{ZAEKv64iDOJsH{F!~)4uB-v0( z@BJM;_owufA5^-l4@Z{ekVAVh>zOxi-n`LDMyq>_0q^0x8b74l?^SjJrp%q&06h8-y4iMdSJ@MDH4c~HjX3j($=&sN1 zW|q&!OYxG3d&~^8@dlzhDa$1b0`rz(6tkBD*J153G=T1;gzF$B0g1WSKnPOy6M$~KQ|W5 z5A#DO!$$Zca>TK`;67WBib&?m76{4rqTmNgwH)UCc)*uPhjcg;fc!=Tfl{N?GyS_6 z3+tZPeSOS=k#BjS>(f75Q`2EhwX*hMsK_@Kv&ZT;SydBky3mDR6_J}cL*_TtV}7>H zA+wumr}b9v46coS`}(TY;qmaR$9wg^82X@n)jvIvzps*~J`|Fl Date: Mon, 10 Aug 2015 22:04:41 -0700 Subject: [PATCH 257/340] [SPARK-9729] [SPARK-9363] [SQL] Use sort merge join for left and right outer join This patch adds a new `SortMergeOuterJoin` operator that performs left and right outer joins using sort merge join. It also refactors `SortMergeJoin` in order to improve performance and code clarity. Along the way, I also performed a couple pieces of minor cleanup and optimization: - Rename the `HashJoin` physical planner rule to `EquiJoinSelection`, since it's also used for non-hash joins. - Rewrite the comment at the top of `HashJoin` to better explain the precedence for choosing join operators. - Update `JoinSuite` to use `SqlTestUtils.withConf` for changing SQLConf settings. This patch incorporates several ideas from adrian-wang's patch, #5717. Closes #5717. [Review on Reviewable](https://reviewable.io/reviews/apache/spark/7904) Author: Josh Rosen Author: Daoyuan Wang Closes #7904 from JoshRosen/outer-join-smj and squashes 1 commits. --- .../sql/catalyst/expressions/JoinedRow.scala | 6 +- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../spark/sql/execution/RowIterator.scala | 93 +++++ .../spark/sql/execution/SparkStrategies.scala | 45 ++- .../joins/BroadcastNestedLoopJoin.scala | 5 +- .../sql/execution/joins/SortMergeJoin.scala | 331 +++++++++++++----- .../execution/joins/SortMergeOuterJoin.scala | 251 +++++++++++++ .../org/apache/spark/sql/JoinSuite.scala | 132 ++++--- .../sql/execution/joins/InnerJoinSuite.scala | 180 ++++++++++ .../sql/execution/joins/OuterJoinSuite.scala | 310 +++++++++++----- .../sql/execution/joins/SemiJoinSuite.scala | 125 ++++--- .../apache/spark/sql/test/SQLTestUtils.scala | 2 +- .../apache/spark/sql/hive/HiveContext.scala | 2 +- 13 files changed, 1165 insertions(+), 319 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala index b76757c93523d..d3560df0792eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala @@ -37,20 +37,20 @@ class JoinedRow extends InternalRow { } /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: InternalRow, r2: InternalRow): InternalRow = { + def apply(r1: InternalRow, r2: InternalRow): JoinedRow = { row1 = r1 row2 = r2 this } /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: InternalRow): InternalRow = { + def withLeft(newLeft: InternalRow): JoinedRow = { row1 = newLeft this } /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: InternalRow): InternalRow = { + def withRight(newRight: InternalRow): JoinedRow = { row2 = newRight this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f73bb0488c984..4bf00b3399e7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -873,7 +873,7 @@ class SQLContext(@transient val sparkContext: SparkContext) HashAggregation :: Aggregation :: LeftSemiJoin :: - HashJoin :: + EquiJoinSelection :: InMemoryScans :: BasicOperators :: CartesianProduct :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala new file mode 100644 index 0000000000000..7462dbc4eba3a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.NoSuchElementException + +import org.apache.spark.sql.catalyst.InternalRow + +/** + * An internal iterator interface which presents a more restrictive API than + * [[scala.collection.Iterator]]. + * + * One major departure from the Scala iterator API is the fusing of the `hasNext()` and `next()` + * calls: Scala's iterator allows users to call `hasNext()` without immediately advancing the + * iterator to consume the next row, whereas RowIterator combines these calls into a single + * [[advanceNext()]] method. + */ +private[sql] abstract class RowIterator { + /** + * Advance this iterator by a single row. Returns `false` if this iterator has no more rows + * and `true` otherwise. If this returns `true`, then the new row can be retrieved by calling + * [[getRow]]. + */ + def advanceNext(): Boolean + + /** + * Retrieve the row from this iterator. This method is idempotent. It is illegal to call this + * method after [[advanceNext()]] has returned `false`. + */ + def getRow: InternalRow + + /** + * Convert this RowIterator into a [[scala.collection.Iterator]]. + */ + def toScala: Iterator[InternalRow] = new RowIteratorToScala(this) +} + +object RowIterator { + def fromScala(scalaIter: Iterator[InternalRow]): RowIterator = { + scalaIter match { + case wrappedRowIter: RowIteratorToScala => wrappedRowIter.rowIter + case _ => new RowIteratorFromScala(scalaIter) + } + } +} + +private final class RowIteratorToScala(val rowIter: RowIterator) extends Iterator[InternalRow] { + private [this] var hasNextWasCalled: Boolean = false + private [this] var _hasNext: Boolean = false + override def hasNext: Boolean = { + // Idempotency: + if (!hasNextWasCalled) { + _hasNext = rowIter.advanceNext() + hasNextWasCalled = true + } + _hasNext + } + override def next(): InternalRow = { + if (!hasNext) throw new NoSuchElementException + hasNextWasCalled = false + rowIter.getRow + } +} + +private final class RowIteratorFromScala(scalaIter: Iterator[InternalRow]) extends RowIterator { + private[this] var _next: InternalRow = null + override def advanceNext(): Boolean = { + if (scalaIter.hasNext) { + _next = scalaIter.next() + true + } else { + _next = null + false + } + } + override def getRow: InternalRow = _next + override def toScala: Iterator[InternalRow] = scalaIter +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c4b9b5acea4de..1fc870d44b578 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -63,19 +63,23 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } /** - * Uses the ExtractEquiJoinKeys pattern to find joins where at least some of the predicates can be - * evaluated by matching hash keys. + * Uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least some of the predicates + * can be evaluated by matching join keys. * - * This strategy applies a simple optimization based on the estimates of the physical sizes of - * the two join sides. When planning a [[joins.BroadcastHashJoin]], if one side has an - * estimated physical size smaller than the user-settable threshold - * [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the - * ''build'' relation and mark the other relation as the ''stream'' side. The build table will be - * ''broadcasted'' to all of the executors involved in the join, as a - * [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they - * will instead be used to decide the build side in a [[joins.ShuffledHashJoin]]. + * Join implementations are chosen with the following precedence: + * + * - Broadcast: if one side of the join has an estimated physical size that is smaller than the + * user-configurable [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold + * or if that side has an explicit broadcast hint (e.g. the user applied the + * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side + * of the join will be broadcasted and the other side will be streamed, with no shuffling + * performed. If both sides of the join are eligible to be broadcasted then the + * - Sort merge: if the matching join keys are sortable and + * [[org.apache.spark.sql.SQLConf.SORTMERGE_JOIN]] is enabled (default), then sort merge join + * will be used. + * - Hash: will be chosen if neither of the above optimizations apply to this join. */ - object HashJoin extends Strategy with PredicateHelper { + object EquiJoinSelection extends Strategy with PredicateHelper { private[this] def makeBroadcastHashJoin( leftKeys: Seq[Expression], @@ -90,14 +94,15 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + + // --- Inner joins -------------------------------------------------------------------------- + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) - // If the sort merge join option is set, we want to use sort merge join prior to hashjoin - // for now let's support inner join first, then add outer join case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => val mergeJoin = @@ -115,6 +120,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil + // --- Outer joins -------------------------------------------------------------------------- + case ExtractEquiJoinKeys( LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => joins.BroadcastHashOuterJoin( @@ -125,10 +132,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { joins.BroadcastHashOuterJoin( leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil + case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) + if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => + joins.SortMergeOuterJoin( + leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil + + case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) + if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => + joins.SortMergeOuterJoin( + leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => joins.ShuffledHashOuterJoin( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + // --- Cases where this strategy does not apply --------------------------------------------- + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 23aebf4b068b4..017a44b9ca863 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -65,8 +65,9 @@ case class BroadcastNestedLoopJoin( left.output.map(_.withNullability(true)) ++ right.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case _ => - left.output ++ right.output + case x => + throw new IllegalArgumentException( + s"BroadcastNestedLoopJoin should not take $x as the JoinType") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 4ae23c186cf7b..6d656ea2849a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.execution.joins -import java.util.NoSuchElementException +import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} -import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} /** * :: DeveloperApi :: @@ -38,8 +37,6 @@ case class SortMergeJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode { - override protected[sql] val trackNumOfRowsEnabled = true - override def output: Seq[Attribute] = left.output ++ right.output override def outputPartitioning: Partitioning = @@ -56,117 +53,265 @@ case class SortMergeJoin( @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) + protected[this] def isUnsafeMode: Boolean = { + (codegenEnabled && unsafeEnabled + && UnsafeProjection.canSupport(leftKeys) + && UnsafeProjection.canSupport(rightKeys) + && UnsafeProjection.canSupport(schema)) + } + + override def outputsUnsafeRows: Boolean = isUnsafeMode + override def canProcessUnsafeRows: Boolean = isUnsafeMode + override def canProcessSafeRows: Boolean = !isUnsafeMode + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. keys.map(SortOrder(_, Ascending)) } protected override def doExecute(): RDD[InternalRow] = { - val leftResults = left.execute().map(_.copy()) - val rightResults = right.execute().map(_.copy()) - - leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => - new Iterator[InternalRow] { + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + new RowIterator { // An ordering that can be used to compare keys from both sides. private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) - // Mutable per row objects. + private[this] var currentLeftRow: InternalRow = _ + private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _ + private[this] var currentMatchIdx: Int = -1 + private[this] val smjScanner = new SortMergeJoinScanner( + leftKeyGenerator, + rightKeyGenerator, + keyOrdering, + RowIterator.fromScala(leftIter), + RowIterator.fromScala(rightIter) + ) private[this] val joinRow = new JoinedRow - private[this] var leftElement: InternalRow = _ - private[this] var rightElement: InternalRow = _ - private[this] var leftKey: InternalRow = _ - private[this] var rightKey: InternalRow = _ - private[this] var rightMatches: CompactBuffer[InternalRow] = _ - private[this] var rightPosition: Int = -1 - private[this] var stop: Boolean = false - private[this] var matchKey: InternalRow = _ - - // initialize iterator - initialize() - - override final def hasNext: Boolean = nextMatchingPair() - - override final def next(): InternalRow = { - if (hasNext) { - // we are using the buffered right rows and run down left iterator - val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) - rightPosition += 1 - if (rightPosition >= rightMatches.size) { - rightPosition = 0 - fetchLeft() - if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { - stop = false - rightMatches = null - } - } - joinedRow + private[this] val resultProjection: (InternalRow) => InternalRow = { + if (isUnsafeMode) { + UnsafeProjection.create(schema) } else { - // no more result - throw new NoSuchElementException + identity[InternalRow] } } - private def fetchLeft() = { - if (leftIter.hasNext) { - leftElement = leftIter.next() - leftKey = leftKeyGenerator(leftElement) - } else { - leftElement = null + override def advanceNext(): Boolean = { + if (currentMatchIdx == -1 || currentMatchIdx == currentRightMatches.length) { + if (smjScanner.findNextInnerJoinRows()) { + currentRightMatches = smjScanner.getBufferedMatches + currentLeftRow = smjScanner.getStreamedRow + currentMatchIdx = 0 + } else { + currentRightMatches = null + currentLeftRow = null + currentMatchIdx = -1 + } } - } - - private def fetchRight() = { - if (rightIter.hasNext) { - rightElement = rightIter.next() - rightKey = rightKeyGenerator(rightElement) + if (currentLeftRow != null) { + joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) + currentMatchIdx += 1 + true } else { - rightElement = null + false } } - private def initialize() = { - fetchLeft() - fetchRight() + override def getRow: InternalRow = resultProjection(joinRow) + }.toScala + } + } +} + +/** + * Helper class that is used to implement [[SortMergeJoin]] and [[SortMergeOuterJoin]]. + * + * To perform an inner (outer) join, users of this class call [[findNextInnerJoinRows()]] + * ([[findNextOuterJoinRows()]]), which returns `true` if a result has been produced and `false` + * otherwise. If a result has been produced, then the caller may call [[getStreamedRow]] to return + * the matching row from the streamed input and may call [[getBufferedMatches]] to return the + * sequence of matching rows from the buffered input (in the case of an outer join, this will return + * an empty sequence if there are no matches from the buffered input). For efficiency, both of these + * methods return mutable objects which are re-used across calls to the `findNext*JoinRows()` + * methods. + * + * @param streamedKeyGenerator a projection that produces join keys from the streamed input. + * @param bufferedKeyGenerator a projection that produces join keys from the buffered input. + * @param keyOrdering an ordering which can be used to compare join keys. + * @param streamedIter an input whose rows will be streamed. + * @param bufferedIter an input whose rows will be buffered to construct sequences of rows that + * have the same join key. + */ +private[joins] class SortMergeJoinScanner( + streamedKeyGenerator: Projection, + bufferedKeyGenerator: Projection, + keyOrdering: Ordering[InternalRow], + streamedIter: RowIterator, + bufferedIter: RowIterator) { + private[this] var streamedRow: InternalRow = _ + private[this] var streamedRowKey: InternalRow = _ + private[this] var bufferedRow: InternalRow = _ + // Note: this is guaranteed to never have any null columns: + private[this] var bufferedRowKey: InternalRow = _ + /** + * The join key for the rows buffered in `bufferedMatches`, or null if `bufferedMatches` is empty + */ + private[this] var matchJoinKey: InternalRow = _ + /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ + private[this] val bufferedMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] + + // Initialization (note: do _not_ want to advance streamed here). + advancedBufferedToRowWithNullFreeJoinKey() + + // --- Public methods --------------------------------------------------------------------------- + + def getStreamedRow: InternalRow = streamedRow + + def getBufferedMatches: ArrayBuffer[InternalRow] = bufferedMatches + + /** + * Advances both input iterators, stopping when we have found rows with matching join keys. + * @return true if matching rows have been found and false otherwise. If this returns true, then + * [[getStreamedRow]] and [[getBufferedMatches]] can be called to construct the join + * results. + */ + final def findNextInnerJoinRows(): Boolean = { + while (advancedStreamed() && streamedRowKey.anyNull) { + // Advance the streamed side of the join until we find the next row whose join key contains + // no nulls or we hit the end of the streamed iterator. + } + if (streamedRow == null) { + // We have consumed the entire streamed iterator, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { + // The new streamed row has the same join key as the previous row, so return the same matches. + true + } else if (bufferedRow == null) { + // The streamed row's join key does not match the current batch of buffered rows and there are + // no more rows to read from the buffered iterator, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else { + // Advance both the streamed and buffered iterators to find the next pair of matching rows. + var comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) + do { + if (streamedRowKey.anyNull) { + advancedStreamed() + } else { + assert(!bufferedRowKey.anyNull) + comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) + if (comp > 0) advancedBufferedToRowWithNullFreeJoinKey() + else if (comp < 0) advancedStreamed() } + } while (streamedRow != null && bufferedRow != null && comp != 0) + if (streamedRow == null || bufferedRow == null) { + // We have either hit the end of one of the iterators, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else { + // The streamed row's join key matches the current buffered row's join, so walk through the + // buffered iterator to buffer the rest of the matching rows. + assert(comp == 0) + bufferMatchingRows() + true + } + } + } - /** - * Searches the right iterator for the next rows that have matches in left side, and store - * them in a buffer. - * - * @return true if the search is successful, and false if the right iterator runs out of - * tuples. - */ - private def nextMatchingPair(): Boolean = { - if (!stop && rightElement != null) { - // run both side to get the first match pair - while (!stop && leftElement != null && rightElement != null) { - val comparing = keyOrdering.compare(leftKey, rightKey) - // for inner join, we need to filter those null keys - stop = comparing == 0 && !leftKey.anyNull - if (comparing > 0 || rightKey.anyNull) { - fetchRight() - } else if (comparing < 0 || leftKey.anyNull) { - fetchLeft() - } - } - rightMatches = new CompactBuffer[InternalRow]() - if (stop) { - stop = false - // iterate the right side to buffer all rows that matches - // as the records should be ordered, exit when we meet the first that not match - while (!stop && rightElement != null) { - rightMatches += rightElement - fetchRight() - stop = keyOrdering.compare(leftKey, rightKey) != 0 - } - if (rightMatches.size > 0) { - rightPosition = 0 - matchKey = leftKey - } - } + /** + * Advances the streamed input iterator and buffers all rows from the buffered input that + * have matching keys. + * @return true if the streamed iterator returned a row, false otherwise. If this returns true, + * then [getStreamedRow and [[getBufferedMatches]] can be called to produce the outer + * join results. + */ + final def findNextOuterJoinRows(): Boolean = { + if (!advancedStreamed()) { + // We have consumed the entire streamed iterator, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else { + if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { + // Matches the current group, so do nothing. + } else { + // The streamed row does not match the current group. + matchJoinKey = null + bufferedMatches.clear() + if (bufferedRow != null && !streamedRowKey.anyNull) { + // The buffered iterator could still contain matching rows, so we'll need to walk through + // it until we either find matches or pass where they would be found. + var comp = 1 + do { + comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) + } while (comp > 0 && advancedBufferedToRowWithNullFreeJoinKey()) + if (comp == 0) { + // We have found matches, so buffer them (this updates matchJoinKey) + bufferMatchingRows() + } else { + // We have overshot the position where the row would be found, hence no matches. } - rightMatches != null && rightMatches.size > 0 } } + // If there is a streamed input then we always return true + true } } + + // --- Private methods -------------------------------------------------------------------------- + + /** + * Advance the streamed iterator and compute the new row's join key. + * @return true if the streamed iterator returned a row and false otherwise. + */ + private def advancedStreamed(): Boolean = { + if (streamedIter.advanceNext()) { + streamedRow = streamedIter.getRow + streamedRowKey = streamedKeyGenerator(streamedRow) + true + } else { + streamedRow = null + streamedRowKey = null + false + } + } + + /** + * Advance the buffered iterator until we find a row with join key that does not contain nulls. + * @return true if the buffered iterator returned a row and false otherwise. + */ + private def advancedBufferedToRowWithNullFreeJoinKey(): Boolean = { + var foundRow: Boolean = false + while (!foundRow && bufferedIter.advanceNext()) { + bufferedRow = bufferedIter.getRow + bufferedRowKey = bufferedKeyGenerator(bufferedRow) + foundRow = !bufferedRowKey.anyNull + } + if (!foundRow) { + bufferedRow = null + bufferedRowKey = null + false + } else { + true + } + } + + /** + * Called when the streamed and buffered join keys match in order to buffer the matching rows. + */ + private def bufferMatchingRows(): Unit = { + assert(streamedRowKey != null) + assert(!streamedRowKey.anyNull) + assert(bufferedRowKey != null) + assert(!bufferedRowKey.anyNull) + assert(keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) + // This join key may have been produced by a mutable projection, so we need to make a copy: + matchJoinKey = streamedRowKey.copy() + bufferedMatches.clear() + do { + bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them + advancedBufferedToRowWithNullFreeJoinKey() + } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala new file mode 100644 index 0000000000000..5326966b07a66 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} + +/** + * :: DeveloperApi :: + * Performs an sort merge outer join of two child relations. + * + * Note: this does not support full outer join yet; see SPARK-9730 for progress on this. + */ +@DeveloperApi +case class SortMergeOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def output: Seq[Attribute] = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case x => + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $x as the JoinType") + } + } + + override def outputPartitioning: Partitioning = joinType match { + // For left and right outer joins, the output is partitioned by the streamed input's join keys. + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case x => + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $x as the JoinType") + } + + override def outputOrdering: Seq[SortOrder] = joinType match { + // For left and right outer joins, the output is ordered by the streamed input's join keys. + case LeftOuter => requiredOrders(leftKeys) + case RightOuter => requiredOrders(rightKeys) + case x => throw new IllegalArgumentException( + s"SortMergeOuterJoin should not take $x as the JoinType") + } + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil + + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { + // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. + keys.map(SortOrder(_, Ascending)) + } + + private def isUnsafeMode: Boolean = { + (codegenEnabled && unsafeEnabled + && UnsafeProjection.canSupport(leftKeys) + && UnsafeProjection.canSupport(rightKeys) + && UnsafeProjection.canSupport(schema)) + } + + override def outputsUnsafeRows: Boolean = isUnsafeMode + override def canProcessUnsafeRows: Boolean = isUnsafeMode + override def canProcessSafeRows: Boolean = !isUnsafeMode + + private def createLeftKeyGenerator(): Projection = { + if (isUnsafeMode) { + UnsafeProjection.create(leftKeys, left.output) + } else { + newProjection(leftKeys, left.output) + } + } + + private def createRightKeyGenerator(): Projection = { + if (isUnsafeMode) { + UnsafeProjection.create(rightKeys, right.output) + } else { + newProjection(rightKeys, right.output) + } + } + + override def doExecute(): RDD[InternalRow] = { + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + // An ordering that can be used to compare keys from both sides. + val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) + val boundCondition: (InternalRow) => Boolean = { + condition.map { cond => + newPredicate(cond, left.output ++ right.output) + }.getOrElse { + (r: InternalRow) => true + } + } + val resultProj: InternalRow => InternalRow = { + if (isUnsafeMode) { + UnsafeProjection.create(schema) + } else { + identity[InternalRow] + } + } + + joinType match { + case LeftOuter => + val smjScanner = new SortMergeJoinScanner( + streamedKeyGenerator = createLeftKeyGenerator(), + bufferedKeyGenerator = createRightKeyGenerator(), + keyOrdering, + streamedIter = RowIterator.fromScala(leftIter), + bufferedIter = RowIterator.fromScala(rightIter) + ) + val rightNullRow = new GenericInternalRow(right.output.length) + new LeftOuterIterator(smjScanner, rightNullRow, boundCondition, resultProj).toScala + + case RightOuter => + val smjScanner = new SortMergeJoinScanner( + streamedKeyGenerator = createRightKeyGenerator(), + bufferedKeyGenerator = createLeftKeyGenerator(), + keyOrdering, + streamedIter = RowIterator.fromScala(rightIter), + bufferedIter = RowIterator.fromScala(leftIter) + ) + val leftNullRow = new GenericInternalRow(left.output.length) + new RightOuterIterator(smjScanner, leftNullRow, boundCondition, resultProj).toScala + + case x => + throw new IllegalArgumentException( + s"SortMergeOuterJoin should not take $x as the JoinType") + } + } + } +} + + +private class LeftOuterIterator( + smjScanner: SortMergeJoinScanner, + rightNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow + ) extends RowIterator { + private[this] val joinedRow: JoinedRow = new JoinedRow() + private[this] var rightIdx: Int = 0 + assert(smjScanner.getBufferedMatches.length == 0) + + private def advanceLeft(): Boolean = { + rightIdx = 0 + if (smjScanner.findNextOuterJoinRows()) { + joinedRow.withLeft(smjScanner.getStreamedRow) + if (smjScanner.getBufferedMatches.isEmpty) { + // There are no matching right rows, so return nulls for the right row + joinedRow.withRight(rightNullRow) + } else { + // Find the next row from the right input that satisfied the bound condition + if (!advanceRightUntilBoundConditionSatisfied()) { + joinedRow.withRight(rightNullRow) + } + } + true + } else { + // Left input has been exhausted + false + } + } + + private def advanceRightUntilBoundConditionSatisfied(): Boolean = { + var foundMatch: Boolean = false + while (!foundMatch && rightIdx < smjScanner.getBufferedMatches.length) { + foundMatch = boundCondition(joinedRow.withRight(smjScanner.getBufferedMatches(rightIdx))) + rightIdx += 1 + } + foundMatch + } + + override def advanceNext(): Boolean = { + advanceRightUntilBoundConditionSatisfied() || advanceLeft() + } + + override def getRow: InternalRow = resultProj(joinedRow) +} + +private class RightOuterIterator( + smjScanner: SortMergeJoinScanner, + leftNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow + ) extends RowIterator { + private[this] val joinedRow: JoinedRow = new JoinedRow() + private[this] var leftIdx: Int = 0 + assert(smjScanner.getBufferedMatches.length == 0) + + private def advanceRight(): Boolean = { + leftIdx = 0 + if (smjScanner.findNextOuterJoinRows()) { + joinedRow.withRight(smjScanner.getStreamedRow) + if (smjScanner.getBufferedMatches.isEmpty) { + // There are no matching left rows, so return nulls for the left row + joinedRow.withLeft(leftNullRow) + } else { + // Find the next row from the left input that satisfied the bound condition + if (!advanceLeftUntilBoundConditionSatisfied()) { + joinedRow.withLeft(leftNullRow) + } + } + true + } else { + // Right input has been exhausted + false + } + } + + private def advanceLeftUntilBoundConditionSatisfied(): Boolean = { + var foundMatch: Boolean = false + while (!foundMatch && leftIdx < smjScanner.getBufferedMatches.length) { + foundMatch = boundCondition(joinedRow.withLeft(smjScanner.getBufferedMatches(leftIdx))) + leftIdx += 1 + } + foundMatch + } + + override def advanceNext(): Boolean = { + advanceLeftUntilBoundConditionSatisfied() || advanceRight() + } + + override def getRow: InternalRow = resultProj(joinedRow) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 5bef1d8966031..ae07eaf91c872 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -22,13 +22,14 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.types.BinaryType +import org.apache.spark.sql.test.SQLTestUtils -class JoinSuite extends QueryTest with BeforeAndAfterEach { +class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { // Ensures tables are loaded. TestData + override def sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ import ctx.logicalPlanToSparkQuery @@ -37,7 +38,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = ctx.planner.HashJoin(join) + val planned = ctx.planner.EquiJoinSelection(join) assert(planned.size === 1) } @@ -55,6 +56,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: BroadcastNestedLoopJoin => j case j: BroadcastLeftSemiJoinHash => j case j: SortMergeJoin => j + case j: SortMergeOuterJoin => j } assert(operators.size === 1) @@ -66,7 +68,6 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("join operator selection") { ctx.cacheManager.clearCache() - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), @@ -83,11 +84,11 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[ShuffledHashOuterJoin]), + classOf[SortMergeOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[ShuffledHashOuterJoin]), + classOf[SortMergeOuterJoin]), ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", @@ -97,82 +98,75 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { Seq( - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", + classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", + classOf[ShuffledHashJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", + classOf[ShuffledHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } } test("SortMergeJoin shouldn't work on unsortable columns") { - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { Seq( ("SELECT * FROM arrayData JOIN complexData ON data = a", classOf[ShuffledHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } } test("broadcasted hash join operator selection") { ctx.cacheManager.clearCache() ctx.sql("CACHE TABLE testData") - - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled - Seq( - ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key = 2", - classOf[BroadcastHashJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) - Seq( - ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a and key = 2", - classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key = 2", - classOf[BroadcastHashJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) + for (sortMergeJoinEnabled <- Seq(true, false)) { + withClue(s"sortMergeJoinEnabled=$sortMergeJoinEnabled") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> s"$sortMergeJoinEnabled") { + Seq( + ("SELECT * FROM testData join testData2 ON key = a", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key = 2", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key = 2", + classOf[BroadcastHashJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } + } } - ctx.sql("UNCACHE TABLE testData") } test("broadcasted hash outer join operator selection") { ctx.cacheManager.clearCache() ctx.sql("CACHE TABLE testData") - - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled - Seq( - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), - ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[BroadcastHashOuterJoin]), - ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[BroadcastHashOuterJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { Seq( - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", + classOf[SortMergeOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[BroadcastHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[BroadcastHashOuterJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + Seq( + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", + classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", classOf[BroadcastHashOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[BroadcastHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } - ctx.sql("UNCACHE TABLE testData") } @@ -180,7 +174,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = ctx.planner.HashJoin(join) + val planned = ctx.planner.EquiJoinSelection(join) assert(planned.size === 1) } @@ -457,25 +451,24 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("broadcasted left semi join operator selection") { ctx.cacheManager.clearCache() ctx.sql("CACHE TABLE testData") - val tmp = ctx.conf.autoBroadcastJoinThreshold - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=1000000000") - Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", - classOf[BroadcastLeftSemiJoinHash]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", + classOf[BroadcastLeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } } - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") - - Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } } - ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp) ctx.sql("UNCACHE TABLE testData") } @@ -488,6 +481,5 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala new file mode 100644 index 0000000000000..ddff7cebcc17d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} +import org.apache.spark.sql.{SQLConf, execution, Row, DataFrame} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution._ + +class InnerJoinSuite extends SparkPlanTest with SQLTestUtils { + + private def testInnerJoin( + testName: String, + leftRows: DataFrame, + rightRows: DataFrame, + condition: Expression, + expectedAnswer: Seq[Product]): Unit = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join).foreach { + case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => + + def makeBroadcastHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = { + val broadcastHashJoin = + execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, left, right) + boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) + } + + def makeShuffledHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = { + val shuffledHashJoin = + execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, left, right) + val filteredJoin = + boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) + EnsureRequirements(sqlContext).apply(filteredJoin) + } + + def makeSortMergeJoin(left: SparkPlan, right: SparkPlan) = { + val sortMergeJoin = + execution.joins.SortMergeJoin(leftKeys, rightKeys, left, right) + val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin) + EnsureRequirements(sqlContext).apply(filteredJoin) + } + + test(s"$testName using BroadcastHashJoin (build=left)") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + makeBroadcastHashJoin(left, right, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using BroadcastHashJoin (build=right)") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + makeBroadcastHashJoin(left, right, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using ShuffledHashJoin (build=left)") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + makeShuffledHashJoin(left, right, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using ShuffledHashJoin (build=right)") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + makeShuffledHashJoin(left, right, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using SortMergeJoin") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + makeSortMergeJoin(left, right), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + } + + { + val upperCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(1, "A"), + Row(2, "B"), + Row(3, "C"), + Row(4, "D"), + Row(5, "E"), + Row(6, "F"), + Row(null, "G") + )), new StructType().add("N", IntegerType).add("L", StringType)) + + val lowerCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(1, "a"), + Row(2, "b"), + Row(3, "c"), + Row(4, "d"), + Row(null, "e") + )), new StructType().add("n", IntegerType).add("l", StringType)) + + testInnerJoin( + "inner join, one match per row", + upperCaseData, + lowerCaseData, + (upperCaseData.col("N") === lowerCaseData.col("n")).expr, + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + ) + ) + } + + private val testData2 = Seq( + (1, 1), + (1, 2), + (2, 1), + (2, 2), + (3, 1), + (3, 2) + ).toDF("a", "b") + + { + val left = testData2.where("a = 1") + val right = testData2.where("a = 1") + testInnerJoin( + "inner join, multiple matches", + left, + right, + (left.col("a") === right.col("a")).expr, + Seq( + (1, 1, 1, 1), + (1, 1, 1, 2), + (1, 2, 1, 1), + (1, 2, 1, 2) + ) + ) + } + + { + val left = testData2.where("a = 1") + val right = testData2.where("a = 2") + testInnerJoin( + "inner join, no matches", + left, + right, + (left.col("a") === right.col("a")).expr, + Seq.empty + ) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 2c27da596bc4f..e16f5e39aa2f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -1,89 +1,221 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Expression, LessThan} -import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} - -class OuterJoinSuite extends SparkPlanTest { - - val left = Seq( - (1, 2.0), - (2, 1.0), - (3, 3.0) - ).toDF("a", "b") - - val right = Seq( - (2, 3.0), - (3, 2.0), - (4, 1.0) - ).toDF("c", "d") - - val leftKeys: List[Expression] = 'a :: Nil - val rightKeys: List[Expression] = 'c :: Nil - val condition = Some(LessThan('b, 'd)) - - test("shuffled hash outer join") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - ShuffledHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), - Seq( - (1, 2.0, null, null), - (2, 1.0, 2, 3.0), - (3, 3.0, null, null) - ).map(Row.fromTuple)) - - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), - Seq( - (2, 1.0, 2, 3.0), - (null, null, 3, 2.0), - (null, null, 4, 1.0) - ).map(Row.fromTuple)) - - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, right), - Seq( - (1, 2.0, null, null), - (2, 1.0, 2, 3.0), - (3, 3.0, null, null), - (null, null, 3, 2.0), - (null, null, 4, 1.0) - ).map(Row.fromTuple)) - } - - test("broadcast hash outer join") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - BroadcastHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), - Seq( - (1, 2.0, null, null), - (2, 1.0, 2, 3.0), - (3, 3.0, null, null) - ).map(Row.fromTuple)) - - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), - Seq( - (2, 1.0, 2, 3.0), - (null, null, 3, 2.0), - (null, null, 4, 1.0) - ).map(Row.fromTuple)) - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType} +import org.apache.spark.sql.{SQLConf, DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.execution.{EnsureRequirements, joins, SparkPlan, SparkPlanTest} + +class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { + + private def testOuterJoin( + testName: String, + leftRows: DataFrame, + rightRows: DataFrame, + joinType: JoinType, + condition: Expression, + expectedAnswer: Seq[Product]): Unit = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join).foreach { + case (_, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => + test(s"$testName using ShuffledHashOuterJoin") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext).apply( + ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + if (joinType != FullOuter) { + test(s"$testName using BroadcastHashOuterJoin") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using SortMergeOuterJoin") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext).apply( + SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = false) + } + } + } + } + + test(s"$testName using BroadcastNestedLoopJoin (build=left)") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + joins.BroadcastNestedLoopJoin(left, right, joins.BuildLeft, joinType, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using BroadcastNestedLoopJoin (build=right)") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + joins.BroadcastNestedLoopJoin(left, right, joins.BuildRight, joinType, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(2, 100.0), + Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches + Row(2, 1.0), + Row(3, 3.0), + Row(5, 1.0), + Row(6, 6.0), + Row(null, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) + + val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(0, 0.0), + Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches + Row(2, -1.0), + Row(2, -1.0), + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(5, 3.0), + Row(7, 7.0), + Row(null, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + val condition = { + And( + (left.col("a") === right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) + } + + // --- Basic outer joins ------------------------------------------------------------------------ + + testOuterJoin( + "basic left outer join", + left, + right, + LeftOuter, + condition, + Seq( + (null, null, null, null), + (1, 2.0, null, null), + (2, 100.0, null, null), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null), + (5, 1.0, 5, 3.0), + (6, 6.0, null, null) + ) + ) + + testOuterJoin( + "basic right outer join", + left, + right, + RightOuter, + condition, + Seq( + (null, null, null, null), + (null, null, 0, 0.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (null, null, 2, -1.0), + (null, null, 2, -1.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0), + (5, 1.0, 5, 3.0), + (null, null, 7, 7.0) + ) + ) + + testOuterJoin( + "basic full outer join", + left, + right, + FullOuter, + condition, + Seq( + (1, 2.0, null, null), + (null, null, 2, -1.0), + (null, null, 2, -1.0), + (2, 100.0, null, null), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null), + (5, 1.0, 5, 3.0), + (6, 6.0, null, null), + (null, null, 0, 0.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0), + (null, null, 7, 7.0), + (null, null, null, null), + (null, null, null, null) + ) + ) + + // --- Both inputs empty ------------------------------------------------------------------------ + + testOuterJoin( + "left outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + LeftOuter, + condition, + Seq.empty + ) + + testOuterJoin( + "right outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + RightOuter, + condition, + Seq.empty + ) + + testOuterJoin( + "full outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + FullOuter, + condition, + Seq.empty + ) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index 927e85a7db3dc..4503ed251fcb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -17,58 +17,91 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{LessThan, Expression} -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} +import org.apache.spark.sql.{SQLConf, DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression} +import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +class SemiJoinSuite extends SparkPlanTest with SQLTestUtils { -class SemiJoinSuite extends SparkPlanTest{ - val left = Seq( - (1, 2.0), - (1, 2.0), - (2, 1.0), - (2, 1.0), - (3, 3.0) - ).toDF("a", "b") + private def testLeftSemiJoin( + testName: String, + leftRows: DataFrame, + rightRows: DataFrame, + condition: Expression, + expectedAnswer: Seq[Product]): Unit = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join).foreach { + case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => + test(s"$testName using LeftSemiJoinHash") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext).apply( + LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } - val right = Seq( - (2, 3.0), - (2, 3.0), - (3, 2.0), - (4, 1.0) - ).toDF("c", "d") + test(s"$testName using BroadcastLeftSemiJoinHash") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } - val leftKeys: List[Expression] = 'a :: Nil - val rightKeys: List[Expression] = 'c :: Nil - val condition = Some(LessThan('b, 'd)) - - test("left semi join hash") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - LeftSemiJoinHash(leftKeys, rightKeys, left, right, condition), - Seq( - (2, 1.0), - (2, 1.0) - ).map(Row.fromTuple)) + test(s"$testName using LeftSemiJoinBNL") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + LeftSemiJoinBNL(left, right, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } } - test("left semi join BNL") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - LeftSemiJoinBNL(left, right, condition), - Seq( - (1, 2.0), - (1, 2.0), - (2, 1.0), - (2, 1.0) - ).map(Row.fromTuple)) - } + val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(1, 2.0), + Row(2, 1.0), + Row(2, 1.0), + Row(3, 3.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) - test("broadcast left semi join hash") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, condition), - Seq( - (2, 1.0), - (2, 1.0) - ).map(Row.fromTuple)) + val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(2, 3.0), + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + val condition = { + And( + (left.col("a") === right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) } + + testLeftSemiJoin( + "basic test", + left, + right, + condition, + Seq( + (2, 1.0), + (2, 1.0) + ) + ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 4c11acdab9ec0..1066695589778 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils trait SQLTestUtils { this: SparkFunSuite => - def sqlContext: SQLContext + protected def sqlContext: SQLContext protected def configuration = sqlContext.sparkContext.hadoopConfiguration diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 567d7fa12ff14..f17177a771c3b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -531,7 +531,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { HashAggregation, Aggregation, LeftSemiJoin, - HashJoin, + EquiJoinSelection, BasicOperators, CartesianProduct, BroadcastNestedLoopJoin From 0f90d6055e5bea9ceb1d454db84f4aa1d59b284d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 10 Aug 2015 23:41:53 -0700 Subject: [PATCH 258/340] [SPARK-9640] [STREAMING] [TEST] Do not run Python Kinesis tests when the Kinesis assembly JAR has not been generated Author: Tathagata Das Closes #7961 from tdas/SPARK-9640 and squashes the following commits: 974ce19 [Tathagata Das] Undo changes related to SPARK-9727 004ae26 [Tathagata Das] style fixes 9bbb97d [Tathagata Das] Minor style fies e6a677e [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into SPARK-9640 ca90719 [Tathagata Das] Removed extra line ba9cfc7 [Tathagata Das] Improved kinesis test selection logic 88d59bd [Tathagata Das] updated test modules 871fcc8 [Tathagata Das] Fixed SparkBuild 94be631 [Tathagata Das] Fixed style b858196 [Tathagata Das] Fixed conditions and few other things based on PR comments. e292e64 [Tathagata Das] Added filters for Kinesis python tests --- python/pyspark/streaming/tests.py | 56 ++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 12 deletions(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 66ae3345f468f..f0ed415f97120 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -971,8 +971,10 @@ def test_kinesis_stream_api(self): "awsAccessKey", "awsSecretKey") def test_kinesis_stream(self): - if os.environ.get('ENABLE_KINESIS_TESTS') != '1': - print("Skip test_kinesis_stream") + if not are_kinesis_tests_enabled: + sys.stderr.write( + "Skipped test_kinesis_stream (enable by setting environment variable %s=1" + % kinesis_test_environ_var) return import random @@ -1013,6 +1015,7 @@ def get_output(_, rdd): traceback.print_exc() raise finally: + self.ssc.stop(False) kinesisTestUtils.deleteStream() kinesisTestUtils.deleteDynamoDBTable(kinesisAppName) @@ -1027,7 +1030,7 @@ def search_kafka_assembly_jar(): ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or " - "'build/mvn package' before running this test") + "'build/mvn package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Kafka assembly JARs in %s; please " "remove all but one") % kafka_assembly_dir) @@ -1045,7 +1048,7 @@ def search_flume_assembly_jar(): ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " - "'build/mvn package' before running this test") + "'build/mvn package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Flume assembly JARs in %s; please " "remove all but one") % flume_assembly_dir) @@ -1095,11 +1098,7 @@ def search_kinesis_asl_assembly_jar(): os.path.join(kinesis_asl_assembly_dir, "target/scala-*/spark-streaming-kinesis-asl-assembly-*.jar")) if not jars: - raise Exception( - ("Failed to find Spark Streaming Kinesis ASL assembly jar in %s. " % - kinesis_asl_assembly_dir) + "You need to build Spark with " - "'build/sbt -Pkinesis-asl assembly/assembly streaming-kinesis-asl-assembly/assembly' " - "or 'build/mvn -Pkinesis-asl package' before running this test") + return None elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Kinesis ASL assembly JARs in %s; please " "remove all but one") % kinesis_asl_assembly_dir) @@ -1107,6 +1106,10 @@ def search_kinesis_asl_assembly_jar(): return jars[0] +# Must be same as the variable and condition defined in KinesisTestUtils.scala +kinesis_test_environ_var = "ENABLE_KINESIS_TESTS" +are_kinesis_tests_enabled = os.environ.get(kinesis_test_environ_var) == '1' + if __name__ == "__main__": kafka_assembly_jar = search_kafka_assembly_jar() flume_assembly_jar = search_flume_assembly_jar() @@ -1114,8 +1117,37 @@ def search_kinesis_asl_assembly_jar(): mqtt_test_jar = search_mqtt_test_jar() kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar() - jars = "%s,%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar, - mqtt_assembly_jar, mqtt_test_jar) + if kinesis_asl_assembly_jar is None: + kinesis_jar_present = False + jars = "%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar, + mqtt_test_jar) + else: + kinesis_jar_present = True + jars = "%s,%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar, + mqtt_test_jar, kinesis_asl_assembly_jar) os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars - unittest.main() + testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, + CheckpointTests, KafkaStreamTests, FlumeStreamTests, FlumePollingStreamTests] + + if kinesis_jar_present is True: + testcases.append(KinesisStreamTests) + elif are_kinesis_tests_enabled is False: + sys.stderr.write("Skipping all Kinesis Python tests as the optional Kinesis project was " + "not compiled with -Pkinesis-asl profile. To run these tests, " + "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/assembly " + "streaming-kinesis-asl-assembly/assembly' or " + "'build/mvn -Pkinesis-asl package' before running this test.") + else: + raise Exception( + ("Failed to find Spark Streaming Kinesis assembly jar in %s. " + % kinesis_asl_assembly_dir) + + "You need to build Spark with 'build/sbt -Pkinesis-asl " + "assembly/assembly streaming-kinesis-asl-assembly/assembly'" + "or 'build/mvn -Pkinesis-asl package' before running this test.") + + sys.stderr.write("Running tests: %s \n" % (str(testcases))) + for testcase in testcases: + sys.stderr.write("[Running %s]\n" % (testcase)) + tests = unittest.TestLoader().loadTestsFromTestCase(testcase) + unittest.TextTestRunner(verbosity=2).run(tests) From 55752d88321925da815823f968128832de6fdbbb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 11 Aug 2015 01:08:30 -0700 Subject: [PATCH 259/340] [SPARK-9810] [BUILD] Remove individual commit messages from the squash commit message For more information, please see the JIRA ticket and the associated dev list discussion. https://issues.apache.org/jira/browse/SPARK-9810 http://apache-spark-developers-list.1001551.n3.nabble.com/discuss-Removing-individual-commit-messages-from-the-squash-commit-message-td13295.html Author: Reynold Xin Closes #8091 from rxin/SPARK-9810. --- dev/merge_spark_pr.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index ad4b76695c9ff..b9bdec3d70864 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -159,11 +159,7 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): merge_message_flags += ["-m", message] # The string "Closes #%s" string is required for GitHub to correctly close the PR - merge_message_flags += [ - "-m", - "Closes #%s from %s and squashes the following commits:" % (pr_num, pr_repo_desc)] - for c in commits: - merge_message_flags += ["-m", c] + merge_message_flags += ["-m", "Closes #%s from %s." % (pr_num, pr_repo_desc)] run_cmd(['git', 'commit', '--author="%s"' % primary_author] + merge_message_flags) From 600031ebe27473d8fffe6ea436c2149223b82896 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 11 Aug 2015 02:41:03 -0700 Subject: [PATCH 260/340] [SPARK-9727] [STREAMING] [BUILD] Updated streaming kinesis SBT project name to be more consistent Author: Tathagata Das Closes #8092 from tdas/SPARK-9727 and squashes the following commits: b1b01fd [Tathagata Das] Updated streaming kinesis project name --- dev/sparktestsupport/modules.py | 4 ++-- extras/kinesis-asl/pom.xml | 2 +- project/SparkBuild.scala | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index d82c0cca37bc6..346452f3174e4 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -134,7 +134,7 @@ def contains_file(self, filename): # files in streaming_kinesis_asl are changed, so that if Kinesis experiences an outage, we don't # fail other PRs. streaming_kinesis_asl = Module( - name="kinesis-asl", + name="streaming-kinesis-asl", dependencies=[], source_file_regexes=[ "extras/kinesis-asl/", @@ -147,7 +147,7 @@ def contains_file(self, filename): "ENABLE_KINESIS_TESTS": "1" }, sbt_test_goals=[ - "kinesis-asl/test", + "streaming-kinesis-asl/test", ] ) diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index c242e7a57b9ab..521b53e230c4a 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -31,7 +31,7 @@ Spark Kinesis Integration - kinesis-asl + streaming-kinesis-asl diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 41a85fa9de778..cad7067ade8c1 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -42,8 +42,8 @@ object BuildCommons { "streaming-zeromq", "launcher", "unsafe").map(ProjectRef(buildLocation, _)) val optionallyEnabledProjects@Seq(yarn, yarnStable, java8Tests, sparkGangliaLgpl, - sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", - "kinesis-asl").map(ProjectRef(buildLocation, _)) + streamingKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", + "streaming-kinesis-asl").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingMqttAssembly, streamingKinesisAslAssembly) = Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-mqtt-assembly", "streaming-kinesis-asl-assembly") From d378396f86f625f006738d87fe5dbc2ff8fd913d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 11 Aug 2015 08:41:06 -0700 Subject: [PATCH 261/340] [SPARK-9815] Rename PlatformDependent.UNSAFE -> Platform. PlatformDependent.UNSAFE is way too verbose. Author: Reynold Xin Closes #8094 from rxin/SPARK-9815 and squashes the following commits: 229b603 [Reynold Xin] [SPARK-9815] Rename PlatformDependent.UNSAFE -> Platform. --- .../serializer/DummySerializerInstance.java | 6 +- .../unsafe/UnsafeShuffleExternalSorter.java | 22 +-- .../shuffle/unsafe/UnsafeShuffleWriter.java | 4 +- .../spark/unsafe/map/BytesToBytesMap.java | 20 +-- .../unsafe/sort/PrefixComparators.java | 5 +- .../unsafe/sort/UnsafeExternalSorter.java | 22 +-- .../unsafe/sort/UnsafeInMemorySorter.java | 4 +- .../unsafe/sort/UnsafeSorterSpillReader.java | 4 +- .../unsafe/sort/UnsafeSorterSpillWriter.java | 6 +- .../UnsafeShuffleInMemorySorterSuite.java | 20 +-- .../map/AbstractBytesToBytesMapSuite.java | 94 +++++----- .../sort/UnsafeExternalSorterSuite.java | 20 +-- .../sort/UnsafeInMemorySorterSuite.java | 20 +-- .../catalyst/expressions/UnsafeArrayData.java | 51 ++---- .../catalyst/expressions/UnsafeReaders.java | 8 +- .../sql/catalyst/expressions/UnsafeRow.java | 108 +++++------ .../expressions/UnsafeRowWriters.java | 41 ++--- .../catalyst/expressions/UnsafeWriters.java | 43 ++--- .../execution/UnsafeExternalRowSorter.java | 4 +- .../expressions/codegen/CodeGenerator.scala | 4 +- .../codegen/GenerateUnsafeProjection.scala | 32 ++-- .../codegen/GenerateUnsafeRowJoiner.scala | 16 +- .../expressions/stringOperations.scala | 4 +- .../GenerateUnsafeRowJoinerBitsetSuite.scala | 4 +- .../UnsafeFixedWidthAggregationMap.java | 4 +- .../sql/execution/UnsafeKVExternalSorter.java | 4 +- .../sql/execution/UnsafeRowSerializer.scala | 6 +- .../sql/execution/joins/HashedRelation.scala | 13 +- .../org/apache/spark/sql/UnsafeRowSuite.scala | 4 +- .../{PlatformDependent.java => Platform.java} | 170 ++++++++---------- .../spark/unsafe/array/ByteArrayMethods.java | 14 +- .../apache/spark/unsafe/array/LongArray.java | 6 +- .../spark/unsafe/bitset/BitSetMethods.java | 19 +- .../spark/unsafe/hash/Murmur3_x86_32.java | 4 +- .../spark/unsafe/memory/MemoryBlock.java | 4 +- .../unsafe/memory/UnsafeMemoryAllocator.java | 6 +- .../apache/spark/unsafe/types/ByteArray.java | 10 +- .../apache/spark/unsafe/types/UTF8String.java | 30 ++-- .../unsafe/hash/Murmur3_x86_32Suite.java | 14 +- 39 files changed, 371 insertions(+), 499 deletions(-) rename unsafe/src/main/java/org/apache/spark/unsafe/{PlatformDependent.java => Platform.java} (55%) diff --git a/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java index 0399abc63c235..0e58bb4f7101c 100644 --- a/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java +++ b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java @@ -25,7 +25,7 @@ import scala.reflect.ClassTag; import org.apache.spark.annotation.Private; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. @@ -49,7 +49,7 @@ public void flush() { try { s.flush(); } catch (IOException e) { - PlatformDependent.throwException(e); + Platform.throwException(e); } } @@ -64,7 +64,7 @@ public void close() { try { s.close(); } catch (IOException e) { - PlatformDependent.throwException(e); + Platform.throwException(e); } } }; diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 925b60a145886..3d1ef0c48adc5 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -37,7 +37,7 @@ import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.TempShuffleBlockId; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -211,16 +211,12 @@ private void writeSortedFile(boolean isLastFile) throws IOException { final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); final Object recordPage = taskMemoryManager.getPage(recordPointer); final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer); - int dataRemaining = PlatformDependent.UNSAFE.getInt(recordPage, recordOffsetInPage); + int dataRemaining = Platform.getInt(recordPage, recordOffsetInPage); long recordReadPosition = recordOffsetInPage + 4; // skip over record length while (dataRemaining > 0) { final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining); - PlatformDependent.copyMemory( - recordPage, - recordReadPosition, - writeBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET, - toTransfer); + Platform.copyMemory( + recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer); writer.write(writeBuffer, 0, toTransfer); recordReadPosition += toTransfer; dataRemaining -= toTransfer; @@ -447,14 +443,10 @@ public void insertRecord( final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); + Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); dataPagePosition += 4; - PlatformDependent.copyMemory( - recordBaseObject, - recordBaseOffset, - dataPageBaseObject, - dataPagePosition, - lengthInBytes); + Platform.copyMemory( + recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes); assert(inMemSorter != null); inMemSorter.insertRecord(recordAddress, partitionId); } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 02084f9122e00..2389c28b28395 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -53,7 +53,7 @@ import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.TimeTrackingOutputStream; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.TaskMemoryManager; @Private @@ -244,7 +244,7 @@ void insertRecordIntoSorter(Product2 record) throws IOException { assert (serializedRecordSize > 0); sorter.insertRecord( - serBuffer.getBuf(), PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); + serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); } @VisibleForTesting diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 7f79cd13aab43..85b46ec8bfae3 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -270,10 +270,10 @@ public boolean hasNext() { @Override public Location next() { - int totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, offsetInPage); + int totalLength = Platform.getInt(pageBaseObject, offsetInPage); if (totalLength == END_OF_PAGE_MARKER) { advanceToNextPage(); - totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, offsetInPage); + totalLength = Platform.getInt(pageBaseObject, offsetInPage); } loc.with(currentPage, offsetInPage); offsetInPage += 4 + totalLength; @@ -402,9 +402,9 @@ private void updateAddressesAndSizes(long fullKeyAddress) { private void updateAddressesAndSizes(final Object page, final long offsetInPage) { long position = offsetInPage; - final int totalLength = PlatformDependent.UNSAFE.getInt(page, position); + final int totalLength = Platform.getInt(page, position); position += 4; - keyLength = PlatformDependent.UNSAFE.getInt(page, position); + keyLength = Platform.getInt(page, position); position += 4; valueLength = totalLength - keyLength - 4; @@ -572,7 +572,7 @@ public boolean putNewKey( // There wasn't enough space in the current page, so write an end-of-page marker: final Object pageBaseObject = currentDataPage.getBaseObject(); final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor; - PlatformDependent.UNSAFE.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); + Platform.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); } final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes); if (memoryGranted != pageSizeBytes) { @@ -608,21 +608,21 @@ public boolean putNewKey( final long valueDataOffsetInPage = insertCursor; insertCursor += valueLengthBytes; // word used to store the value size - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, recordOffset, + Platform.putInt(dataPageBaseObject, recordOffset, keyLengthBytes + valueLengthBytes + 4); - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes); + Platform.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes); // Copy the key - PlatformDependent.copyMemory( + Platform.copyMemory( keyBaseObject, keyBaseOffset, dataPageBaseObject, keyDataOffsetInPage, keyLengthBytes); // Copy the value - PlatformDependent.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject, + Platform.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject, valueDataOffsetInPage, valueLengthBytes); // --- Update bookeeping data structures ----------------------------------------------------- if (useOverflowPage) { // Store the end-of-page marker at the end of the data page - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER); + Platform.putInt(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER); } else { pageCursor += requiredSize; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index 5e002ae1b7568..71b76d5ddfaa7 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -20,10 +20,9 @@ import com.google.common.primitives.UnsignedLongs; import org.apache.spark.annotation.Private; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.UTF8String; import org.apache.spark.util.Utils; -import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET; @Private public class PrefixComparators { @@ -73,7 +72,7 @@ public static long computePrefix(byte[] bytes) { final int minLen = Math.min(bytes.length, 8); long p = 0; for (int i = 0; i < minLen; ++i) { - p |= (128L + PlatformDependent.UNSAFE.getByte(bytes, BYTE_ARRAY_OFFSET + i)) + p |= (128L + Platform.getByte(bytes, Platform.BYTE_ARRAY_OFFSET + i)) << (56 - 8 * i); } return p; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 5ebbf9b068fd6..9601aafe55464 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -35,7 +35,7 @@ import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.array.ByteArrayMethods; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; import org.apache.spark.util.Utils; @@ -427,14 +427,10 @@ public void insertRecord( final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); + Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); dataPagePosition += 4; - PlatformDependent.copyMemory( - recordBaseObject, - recordBaseOffset, - dataPageBaseObject, - dataPagePosition, - lengthInBytes); + Platform.copyMemory( + recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes); assert(inMemSorter != null); inMemSorter.insertRecord(recordAddress, prefix); } @@ -493,18 +489,16 @@ public void insertKVRecord( final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, keyLen + valueLen + 4); + Platform.putInt(dataPageBaseObject, dataPagePosition, keyLen + valueLen + 4); dataPagePosition += 4; - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, keyLen); + Platform.putInt(dataPageBaseObject, dataPagePosition, keyLen); dataPagePosition += 4; - PlatformDependent.copyMemory( - keyBaseObj, keyOffset, dataPageBaseObject, dataPagePosition, keyLen); + Platform.copyMemory(keyBaseObj, keyOffset, dataPageBaseObject, dataPagePosition, keyLen); dataPagePosition += keyLen; - PlatformDependent.copyMemory( - valueBaseObj, valueOffset, dataPageBaseObject, dataPagePosition, valueLen); + Platform.copyMemory(valueBaseObj, valueOffset, dataPageBaseObject, dataPagePosition, valueLen); assert(inMemSorter != null); inMemSorter.insertRecord(recordAddress, prefix); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 1e4b8a116e11a..f7787e1019c2b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -19,7 +19,7 @@ import java.util.Comparator; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.util.collection.Sorter; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -164,7 +164,7 @@ public void loadNext() { final long recordPointer = sortBuffer[position]; baseObject = memoryManager.getPage(recordPointer); baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length - recordLength = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset - 4); + recordLength = Platform.getInt(baseObject, baseOffset - 4); keyPrefix = sortBuffer[position + 1]; position += 2; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index ca1ccedc93c8e..4989b05d63e23 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -23,7 +23,7 @@ import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManager; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description @@ -42,7 +42,7 @@ final class UnsafeSorterSpillReader extends UnsafeSorterIterator { private byte[] arr = new byte[1024 * 1024]; private Object baseObject = arr; - private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET; + private final long baseOffset = Platform.BYTE_ARRAY_OFFSET; public UnsafeSorterSpillReader( BlockManager blockManager, diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index 44cf6c756d7c3..e59a84ff8d118 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -28,7 +28,7 @@ import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.TempLocalBlockId; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * Spills a list of sorted records to disk. Spill files have the following format: @@ -117,11 +117,11 @@ public void write( long recordReadPosition = baseOffset; while (dataRemaining > 0) { final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining); - PlatformDependent.copyMemory( + Platform.copyMemory( baseObject, recordReadPosition, writeBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer), + Platform.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer), toTransfer); writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer); recordReadPosition += toTransfer; diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java index 8fa72597db24d..40fefe2c9d140 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java @@ -24,7 +24,7 @@ import org.junit.Test; import org.apache.spark.HashPartitioner; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ExecutorMemoryManager; import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -34,11 +34,7 @@ public class UnsafeShuffleInMemorySorterSuite { private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) { final byte[] strBytes = new byte[strLength]; - PlatformDependent.copyMemory( - baseObject, - baseOffset, - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, strLength); + Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, strLength); return new String(strBytes); } @@ -74,14 +70,10 @@ public void testBasicSorting() throws Exception { for (String str : dataToSort) { final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position); final byte[] strBytes = str.getBytes("utf-8"); - PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length); + Platform.putInt(baseObject, position, strBytes.length); position += 4; - PlatformDependent.copyMemory( - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - position, - strBytes.length); + Platform.copyMemory( + strBytes, Platform.BYTE_ARRAY_OFFSET, baseObject, position, strBytes.length); position += strBytes.length; sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str)); } @@ -98,7 +90,7 @@ public void testBasicSorting() throws Exception { Assert.assertTrue("Partition id " + partitionId + " should be >= prev id " + prevPartitionId, partitionId >= prevPartitionId); final long recordAddress = iter.packedRecordPointer.getRecordPointer(); - final int recordLength = PlatformDependent.UNSAFE.getInt( + final int recordLength = Platform.getInt( memoryManager.getPage(recordAddress), memoryManager.getOffsetInPage(recordAddress)); final String str = getStringFromDataPage( memoryManager.getPage(recordAddress), diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index e56a3f0b6d12c..1a79c20c35246 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -32,9 +32,7 @@ import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.memory.*; -import org.apache.spark.unsafe.PlatformDependent; -import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET; -import static org.apache.spark.unsafe.PlatformDependent.LONG_ARRAY_OFFSET; +import org.apache.spark.unsafe.Platform; public abstract class AbstractBytesToBytesMapSuite { @@ -80,13 +78,8 @@ public void tearDown() { private static byte[] getByteArray(MemoryLocation loc, int size) { final byte[] arr = new byte[size]; - PlatformDependent.copyMemory( - loc.getBaseObject(), - loc.getBaseOffset(), - arr, - BYTE_ARRAY_OFFSET, - size - ); + Platform.copyMemory( + loc.getBaseObject(), loc.getBaseOffset(), arr, Platform.BYTE_ARRAY_OFFSET, size); return arr; } @@ -108,7 +101,7 @@ private static boolean arrayEquals( long actualLengthBytes) { return (actualLengthBytes == expected.length) && ByteArrayMethods.arrayEquals( expected, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, actualAddr.getBaseObject(), actualAddr.getBaseOffset(), expected.length @@ -124,7 +117,7 @@ public void emptyMap() { final int keyLengthInWords = 10; final int keyLengthInBytes = keyLengthInWords * 8; final byte[] key = getRandomByteArray(keyLengthInWords); - Assert.assertFalse(map.lookup(key, BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined()); + Assert.assertFalse(map.lookup(key, Platform.BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined()); Assert.assertFalse(map.iterator().hasNext()); } finally { map.free(); @@ -141,14 +134,14 @@ public void setAndRetrieveAKey() { final byte[] valueData = getRandomByteArray(recordLengthWords); try { final BytesToBytesMap.Location loc = - map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes); + map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes); Assert.assertFalse(loc.isDefined()); Assert.assertTrue(loc.putNewKey( keyData, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, recordLengthBytes, valueData, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, recordLengthBytes )); // After storing the key and value, the other location methods should return results that @@ -159,7 +152,8 @@ public void setAndRetrieveAKey() { Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); // After calling lookup() the location should still point to the correct data. - Assert.assertTrue(map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined()); + Assert.assertTrue( + map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined()); Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); Assert.assertEquals(recordLengthBytes, loc.getValueLength()); Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); @@ -168,10 +162,10 @@ public void setAndRetrieveAKey() { try { Assert.assertTrue(loc.putNewKey( keyData, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, recordLengthBytes, valueData, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, recordLengthBytes )); Assert.fail("Should not be able to set a new value for a key"); @@ -191,25 +185,25 @@ private void iteratorTestBase(boolean destructive) throws Exception { for (long i = 0; i < size; i++) { final long[] value = new long[] { i }; final BytesToBytesMap.Location loc = - map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8); + map.lookup(value, Platform.LONG_ARRAY_OFFSET, 8); Assert.assertFalse(loc.isDefined()); // Ensure that we store some zero-length keys if (i % 5 == 0) { Assert.assertTrue(loc.putNewKey( null, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 0, value, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 8 )); } else { Assert.assertTrue(loc.putNewKey( value, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 8, value, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 8 )); } @@ -228,14 +222,13 @@ private void iteratorTestBase(boolean destructive) throws Exception { Assert.assertTrue(loc.isDefined()); final MemoryLocation keyAddress = loc.getKeyAddress(); final MemoryLocation valueAddress = loc.getValueAddress(); - final long value = PlatformDependent.UNSAFE.getLong( + final long value = Platform.getLong( valueAddress.getBaseObject(), valueAddress.getBaseOffset()); final long keyLength = loc.getKeyLength(); if (keyLength == 0) { Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0); } else { - final long key = PlatformDependent.UNSAFE.getLong( - keyAddress.getBaseObject(), keyAddress.getBaseOffset()); + final long key = Platform.getLong(keyAddress.getBaseObject(), keyAddress.getBaseOffset()); Assert.assertEquals(value, key); } valuesSeen.set((int) value); @@ -284,16 +277,16 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { final long[] value = new long[] { i, i, i, i, i }; // 5 * 8 = 40 bytes final BytesToBytesMap.Location loc = map.lookup( key, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, KEY_LENGTH ); Assert.assertFalse(loc.isDefined()); Assert.assertTrue(loc.putNewKey( key, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, KEY_LENGTH, value, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, VALUE_LENGTH )); } @@ -308,18 +301,18 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { Assert.assertTrue(loc.isDefined()); Assert.assertEquals(KEY_LENGTH, loc.getKeyLength()); Assert.assertEquals(VALUE_LENGTH, loc.getValueLength()); - PlatformDependent.copyMemory( + Platform.copyMemory( loc.getKeyAddress().getBaseObject(), loc.getKeyAddress().getBaseOffset(), key, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, KEY_LENGTH ); - PlatformDependent.copyMemory( + Platform.copyMemory( loc.getValueAddress().getBaseObject(), loc.getValueAddress().getBaseOffset(), value, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, VALUE_LENGTH ); for (long j : key) { @@ -354,16 +347,16 @@ public void randomizedStressTest() { expected.put(ByteBuffer.wrap(key), value); final BytesToBytesMap.Location loc = map.lookup( key, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, key.length ); Assert.assertFalse(loc.isDefined()); Assert.assertTrue(loc.putNewKey( key, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, key.length, value, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, value.length )); // After calling putNewKey, the following should be true, even before calling @@ -379,7 +372,8 @@ public void randomizedStressTest() { for (Map.Entry entry : expected.entrySet()) { final byte[] key = entry.getKey().array(); final byte[] value = entry.getValue(); - final BytesToBytesMap.Location loc = map.lookup(key, BYTE_ARRAY_OFFSET, key.length); + final BytesToBytesMap.Location loc = + map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); Assert.assertTrue(loc.isDefined()); Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); @@ -405,16 +399,16 @@ public void randomizedTestWithRecordsLargerThanPageSize() { expected.put(ByteBuffer.wrap(key), value); final BytesToBytesMap.Location loc = map.lookup( key, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, key.length ); Assert.assertFalse(loc.isDefined()); Assert.assertTrue(loc.putNewKey( key, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, key.length, value, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, value.length )); // After calling putNewKey, the following should be true, even before calling @@ -429,7 +423,8 @@ public void randomizedTestWithRecordsLargerThanPageSize() { for (Map.Entry entry : expected.entrySet()) { final byte[] key = entry.getKey().array(); final byte[] value = entry.getValue(); - final BytesToBytesMap.Location loc = map.lookup(key, BYTE_ARRAY_OFFSET, key.length); + final BytesToBytesMap.Location loc = + map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); Assert.assertTrue(loc.isDefined()); Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); @@ -447,12 +442,10 @@ public void failureToAllocateFirstPage() { try { final long[] emptyArray = new long[0]; final BytesToBytesMap.Location loc = - map.lookup(emptyArray, PlatformDependent.LONG_ARRAY_OFFSET, 0); + map.lookup(emptyArray, Platform.LONG_ARRAY_OFFSET, 0); Assert.assertFalse(loc.isDefined()); Assert.assertFalse(loc.putNewKey( - emptyArray, LONG_ARRAY_OFFSET, 0, - emptyArray, LONG_ARRAY_OFFSET, 0 - )); + emptyArray, Platform.LONG_ARRAY_OFFSET, 0, emptyArray, Platform.LONG_ARRAY_OFFSET, 0)); } finally { map.free(); } @@ -468,8 +461,9 @@ public void failureToGrow() { int i; for (i = 0; i < 1024; i++) { final long[] arr = new long[]{i}; - final BytesToBytesMap.Location loc = map.lookup(arr, PlatformDependent.LONG_ARRAY_OFFSET, 8); - success = loc.putNewKey(arr, LONG_ARRAY_OFFSET, 8, arr, LONG_ARRAY_OFFSET, 8); + final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8); + success = + loc.putNewKey(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8); if (!success) { break; } @@ -541,12 +535,12 @@ public void testPeakMemoryUsed() { try { for (long i = 0; i < numRecordsPerPage * 10; i++) { final long[] value = new long[]{i}; - map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8).putNewKey( + map.lookup(value, Platform.LONG_ARRAY_OFFSET, 8).putNewKey( value, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 8, value, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 8); newPeakMemory = map.getPeakMemoryUsedBytes(); if (i % numRecordsPerPage == 0) { diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 83049b8a21fcf..445a37b83e98a 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -49,7 +49,7 @@ import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.*; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ExecutorMemoryManager; import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -166,14 +166,14 @@ private void assertSpillFilesWereCleanedUp() { private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception { final int[] arr = new int[]{ value }; - sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value); + sorter.insertRecord(arr, Platform.INT_ARRAY_OFFSET, 4, value); } private static void insertRecord( UnsafeExternalSorter sorter, int[] record, long prefix) throws IOException { - sorter.insertRecord(record, PlatformDependent.INT_ARRAY_OFFSET, record.length * 4, prefix); + sorter.insertRecord(record, Platform.INT_ARRAY_OFFSET, record.length * 4, prefix); } private UnsafeExternalSorter newSorter() throws IOException { @@ -205,7 +205,7 @@ public void testSortingOnlyByPrefix() throws Exception { iter.loadNext(); assertEquals(i, iter.getKeyPrefix()); assertEquals(4, iter.getRecordLength()); - assertEquals(i, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + assertEquals(i, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); } sorter.cleanupResources(); @@ -253,7 +253,7 @@ public void spillingOccursInResponseToMemoryPressure() throws Exception { iter.loadNext(); assertEquals(i, iter.getKeyPrefix()); assertEquals(4, iter.getRecordLength()); - assertEquals(i, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + assertEquals(i, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); i++; } sorter.cleanupResources(); @@ -265,7 +265,7 @@ public void testFillingPage() throws Exception { final UnsafeExternalSorter sorter = newSorter(); byte[] record = new byte[16]; while (sorter.getNumberOfAllocatedPages() < 2) { - sorter.insertRecord(record, PlatformDependent.BYTE_ARRAY_OFFSET, record.length, 0); + sorter.insertRecord(record, Platform.BYTE_ARRAY_OFFSET, record.length, 0); } sorter.cleanupResources(); assertSpillFilesWereCleanedUp(); @@ -292,25 +292,25 @@ public void sortingRecordsThatExceedPageSize() throws Exception { iter.loadNext(); assertEquals(123, iter.getKeyPrefix()); assertEquals(smallRecord.length * 4, iter.getRecordLength()); - assertEquals(123, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + assertEquals(123, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); // Small record assertTrue(iter.hasNext()); iter.loadNext(); assertEquals(123, iter.getKeyPrefix()); assertEquals(smallRecord.length * 4, iter.getRecordLength()); - assertEquals(123, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + assertEquals(123, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); // Large record assertTrue(iter.hasNext()); iter.loadNext(); assertEquals(456, iter.getKeyPrefix()); assertEquals(largeRecord.length * 4, iter.getRecordLength()); - assertEquals(456, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + assertEquals(456, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); // Large record assertTrue(iter.hasNext()); iter.loadNext(); assertEquals(456, iter.getKeyPrefix()); assertEquals(largeRecord.length * 4, iter.getRecordLength()); - assertEquals(456, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + assertEquals(456, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); assertFalse(iter.hasNext()); sorter.cleanupResources(); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 909500930539c..778e813df6b54 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -26,7 +26,7 @@ import static org.mockito.Mockito.mock; import org.apache.spark.HashPartitioner; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ExecutorMemoryManager; import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -36,11 +36,7 @@ public class UnsafeInMemorySorterSuite { private static String getStringFromDataPage(Object baseObject, long baseOffset, int length) { final byte[] strBytes = new byte[length]; - PlatformDependent.copyMemory( - baseObject, - baseOffset, - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, length); + Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, length); return new String(strBytes); } @@ -76,14 +72,10 @@ public void testSortingOnlyByIntegerPrefix() throws Exception { long position = dataPage.getBaseOffset(); for (String str : dataToSort) { final byte[] strBytes = str.getBytes("utf-8"); - PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length); + Platform.putInt(baseObject, position, strBytes.length); position += 4; - PlatformDependent.copyMemory( - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - position, - strBytes.length); + Platform.copyMemory( + strBytes, Platform.BYTE_ARRAY_OFFSET, baseObject, position, strBytes.length); position += strBytes.length; } // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so @@ -113,7 +105,7 @@ public int compare(long prefix1, long prefix2) { position = dataPage.getBaseOffset(); for (int i = 0; i < dataToSort.length; i++) { // position now points to the start of a record (which holds its length). - final int recordLength = PlatformDependent.UNSAFE.getInt(baseObject, position); + final int recordLength = Platform.getInt(baseObject, position); final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); final String str = getStringFromDataPage(baseObject, position + 4, recordLength); final int partitionId = hashPartitioner.getPartition(str); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 0374846d71674..501dff090313c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.*; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.types.CalendarInterval; @@ -59,7 +59,7 @@ public class UnsafeArrayData extends ArrayData { private int sizeInBytes; private int getElementOffset(int ordinal) { - return PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + ordinal * 4L); + return Platform.getInt(baseObject, baseOffset + ordinal * 4L); } private int getElementSize(int offset, int ordinal) { @@ -157,7 +157,7 @@ public boolean getBoolean(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return false; - return PlatformDependent.UNSAFE.getBoolean(baseObject, baseOffset + offset); + return Platform.getBoolean(baseObject, baseOffset + offset); } @Override @@ -165,7 +165,7 @@ public byte getByte(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + offset); + return Platform.getByte(baseObject, baseOffset + offset); } @Override @@ -173,7 +173,7 @@ public short getShort(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getShort(baseObject, baseOffset + offset); + return Platform.getShort(baseObject, baseOffset + offset); } @Override @@ -181,7 +181,7 @@ public int getInt(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset); + return Platform.getInt(baseObject, baseOffset + offset); } @Override @@ -189,7 +189,7 @@ public long getLong(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); + return Platform.getLong(baseObject, baseOffset + offset); } @Override @@ -197,7 +197,7 @@ public float getFloat(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getFloat(baseObject, baseOffset + offset); + return Platform.getFloat(baseObject, baseOffset + offset); } @Override @@ -205,7 +205,7 @@ public double getDouble(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getDouble(baseObject, baseOffset + offset); + return Platform.getDouble(baseObject, baseOffset + offset); } @Override @@ -215,7 +215,7 @@ public Decimal getDecimal(int ordinal, int precision, int scale) { if (offset < 0) return null; if (precision <= Decimal.MAX_LONG_DIGITS()) { - final long value = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); + final long value = Platform.getLong(baseObject, baseOffset + offset); return Decimal.apply(value, precision, scale); } else { final byte[] bytes = getBinary(ordinal); @@ -241,12 +241,7 @@ public byte[] getBinary(int ordinal) { if (offset < 0) return null; final int size = getElementSize(offset, ordinal); final byte[] bytes = new byte[size]; - PlatformDependent.copyMemory( - baseObject, - baseOffset + offset, - bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - size); + Platform.copyMemory(baseObject, baseOffset + offset, bytes, Platform.BYTE_ARRAY_OFFSET, size); return bytes; } @@ -255,9 +250,8 @@ public CalendarInterval getInterval(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return null; - final int months = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); - final long microseconds = - PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset + 8); + final int months = (int) Platform.getLong(baseObject, baseOffset + offset); + final long microseconds = Platform.getLong(baseObject, baseOffset + offset + 8); return new CalendarInterval(months, microseconds); } @@ -307,27 +301,16 @@ public boolean equals(Object other) { } public void writeToMemory(Object target, long targetOffset) { - PlatformDependent.copyMemory( - baseObject, - baseOffset, - target, - targetOffset, - sizeInBytes - ); + Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); } @Override public UnsafeArrayData copy() { UnsafeArrayData arrayCopy = new UnsafeArrayData(); final byte[] arrayDataCopy = new byte[sizeInBytes]; - PlatformDependent.copyMemory( - baseObject, - baseOffset, - arrayDataCopy, - PlatformDependent.BYTE_ARRAY_OFFSET, - sizeInBytes - ); - arrayCopy.pointTo(arrayDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numElements, sizeInBytes); + Platform.copyMemory( + baseObject, baseOffset, arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); + arrayCopy.pointTo(arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, numElements, sizeInBytes); return arrayCopy; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java index b521b703389d3..7b03185a30e3c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java @@ -17,13 +17,13 @@ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; public class UnsafeReaders { public static UnsafeArrayData readArray(Object baseObject, long baseOffset, int numBytes) { // Read the number of elements from first 4 bytes. - final int numElements = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset); + final int numElements = Platform.getInt(baseObject, baseOffset); final UnsafeArrayData array = new UnsafeArrayData(); // Skip the first 4 bytes. array.pointTo(baseObject, baseOffset + 4, numElements, numBytes - 4); @@ -32,9 +32,9 @@ public static UnsafeArrayData readArray(Object baseObject, long baseOffset, int public static UnsafeMapData readMap(Object baseObject, long baseOffset, int numBytes) { // Read the number of elements from first 4 bytes. - final int numElements = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset); + final int numElements = Platform.getInt(baseObject, baseOffset); // Read the numBytes of key array in second 4 bytes. - final int keyArraySize = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + 4); + final int keyArraySize = Platform.getInt(baseObject, baseOffset + 4); final int valueArraySize = numBytes - 8 - keyArraySize; final UnsafeArrayData keyArray = new UnsafeArrayData(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index e829acb6285f1..7fd94772090df 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -27,7 +27,7 @@ import java.util.Set; import org.apache.spark.sql.types.*; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; @@ -169,7 +169,7 @@ public void pointTo(Object baseObject, long baseOffset, int numFields, int sizeI * @param sizeInBytes the number of bytes valid in the byte array */ public void pointTo(byte[] buf, int numFields, int sizeInBytes) { - pointTo(buf, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); + pointTo(buf, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); } @Override @@ -179,7 +179,7 @@ public void setNullAt(int i) { // To preserve row equality, zero out the value when setting the column to null. // Since this row does does not currently support updates to variable-length values, we don't // have to worry about zeroing out that data. - PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(i), 0); + Platform.putLong(baseObject, getFieldOffset(i), 0); } @Override @@ -191,14 +191,14 @@ public void update(int ordinal, Object value) { public void setInt(int ordinal, int value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putInt(baseObject, getFieldOffset(ordinal), value); + Platform.putInt(baseObject, getFieldOffset(ordinal), value); } @Override public void setLong(int ordinal, long value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(ordinal), value); + Platform.putLong(baseObject, getFieldOffset(ordinal), value); } @Override @@ -208,28 +208,28 @@ public void setDouble(int ordinal, double value) { if (Double.isNaN(value)) { value = Double.NaN; } - PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value); + Platform.putDouble(baseObject, getFieldOffset(ordinal), value); } @Override public void setBoolean(int ordinal, boolean value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putBoolean(baseObject, getFieldOffset(ordinal), value); + Platform.putBoolean(baseObject, getFieldOffset(ordinal), value); } @Override public void setShort(int ordinal, short value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putShort(baseObject, getFieldOffset(ordinal), value); + Platform.putShort(baseObject, getFieldOffset(ordinal), value); } @Override public void setByte(int ordinal, byte value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putByte(baseObject, getFieldOffset(ordinal), value); + Platform.putByte(baseObject, getFieldOffset(ordinal), value); } @Override @@ -239,7 +239,7 @@ public void setFloat(int ordinal, float value) { if (Float.isNaN(value)) { value = Float.NaN; } - PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); + Platform.putFloat(baseObject, getFieldOffset(ordinal), value); } /** @@ -263,23 +263,23 @@ public void setDecimal(int ordinal, Decimal value, int precision) { long cursor = getLong(ordinal) >>> 32; assert cursor > 0 : "invalid cursor " + cursor; // zero-out the bytes - PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + cursor, 0L); - PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + cursor + 8, 0L); + Platform.putLong(baseObject, baseOffset + cursor, 0L); + Platform.putLong(baseObject, baseOffset + cursor + 8, 0L); if (value == null) { setNullAt(ordinal); // keep the offset for future update - PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(ordinal), cursor << 32); + Platform.putLong(baseObject, getFieldOffset(ordinal), cursor << 32); } else { final BigInteger integer = value.toJavaBigDecimal().unscaledValue(); - final int[] mag = (int[]) PlatformDependent.UNSAFE.getObjectVolatile(integer, - PlatformDependent.BIG_INTEGER_MAG_OFFSET); + final int[] mag = (int[]) Platform.getObjectVolatile(integer, + Platform.BIG_INTEGER_MAG_OFFSET); assert(mag.length <= 4); // Write the bytes to the variable length portion. - PlatformDependent.copyMemory(mag, PlatformDependent.INT_ARRAY_OFFSET, - baseObject, baseOffset + cursor, mag.length * 4); + Platform.copyMemory( + mag, Platform.INT_ARRAY_OFFSET, baseObject, baseOffset + cursor, mag.length * 4); setLong(ordinal, (cursor << 32) | ((long) (((integer.signum() + 1) << 8) + mag.length))); } } @@ -336,43 +336,43 @@ public boolean isNullAt(int ordinal) { @Override public boolean getBoolean(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getBoolean(baseObject, getFieldOffset(ordinal)); + return Platform.getBoolean(baseObject, getFieldOffset(ordinal)); } @Override public byte getByte(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getByte(baseObject, getFieldOffset(ordinal)); + return Platform.getByte(baseObject, getFieldOffset(ordinal)); } @Override public short getShort(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getShort(baseObject, getFieldOffset(ordinal)); + return Platform.getShort(baseObject, getFieldOffset(ordinal)); } @Override public int getInt(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getInt(baseObject, getFieldOffset(ordinal)); + return Platform.getInt(baseObject, getFieldOffset(ordinal)); } @Override public long getLong(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(ordinal)); + return Platform.getLong(baseObject, getFieldOffset(ordinal)); } @Override public float getFloat(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(ordinal)); + return Platform.getFloat(baseObject, getFieldOffset(ordinal)); } @Override public double getDouble(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal)); + return Platform.getDouble(baseObject, getFieldOffset(ordinal)); } private static byte[] EMPTY = new byte[0]; @@ -391,13 +391,13 @@ public Decimal getDecimal(int ordinal, int precision, int scale) { assert signum >=0 && signum <= 2 : "invalid signum " + signum; int size = (int) (offsetAndSize & 0xff); int[] mag = new int[size]; - PlatformDependent.copyMemory(baseObject, baseOffset + offset, - mag, PlatformDependent.INT_ARRAY_OFFSET, size * 4); + Platform.copyMemory( + baseObject, baseOffset + offset, mag, Platform.INT_ARRAY_OFFSET, size * 4); // create a BigInteger using signum and mag BigInteger v = new BigInteger(0, EMPTY); // create the initial object - PlatformDependent.UNSAFE.putInt(v, PlatformDependent.BIG_INTEGER_SIGNUM_OFFSET, signum - 1); - PlatformDependent.UNSAFE.putObjectVolatile(v, PlatformDependent.BIG_INTEGER_MAG_OFFSET, mag); + Platform.putInt(v, Platform.BIG_INTEGER_SIGNUM_OFFSET, signum - 1); + Platform.putObjectVolatile(v, Platform.BIG_INTEGER_MAG_OFFSET, mag); return Decimal.apply(new BigDecimal(v, scale), precision, scale); } } @@ -420,11 +420,11 @@ public byte[] getBinary(int ordinal) { final int offset = (int) (offsetAndSize >> 32); final int size = (int) (offsetAndSize & ((1L << 32) - 1)); final byte[] bytes = new byte[size]; - PlatformDependent.copyMemory( + Platform.copyMemory( baseObject, baseOffset + offset, bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, size ); return bytes; @@ -438,9 +438,8 @@ public CalendarInterval getInterval(int ordinal) { } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int months = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); - final long microseconds = - PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset + 8); + final int months = (int) Platform.getLong(baseObject, baseOffset + offset); + final long microseconds = Platform.getLong(baseObject, baseOffset + offset + 8); return new CalendarInterval(months, microseconds); } } @@ -491,14 +490,14 @@ public MapData getMap(int ordinal) { public UnsafeRow copy() { UnsafeRow rowCopy = new UnsafeRow(); final byte[] rowDataCopy = new byte[sizeInBytes]; - PlatformDependent.copyMemory( + Platform.copyMemory( baseObject, baseOffset, rowDataCopy, - PlatformDependent.BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, sizeInBytes ); - rowCopy.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); + rowCopy.pointTo(rowDataCopy, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); return rowCopy; } @@ -518,18 +517,13 @@ public static UnsafeRow createFromByteArray(int numBytes, int numFields) { */ public void copyFrom(UnsafeRow row) { // copyFrom is only available for UnsafeRow created from byte array. - assert (baseObject instanceof byte[]) && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET; + assert (baseObject instanceof byte[]) && baseOffset == Platform.BYTE_ARRAY_OFFSET; if (row.sizeInBytes > this.sizeInBytes) { // resize the underlying byte[] if it's not large enough. this.baseObject = new byte[row.sizeInBytes]; } - PlatformDependent.copyMemory( - row.baseObject, - row.baseOffset, - this.baseObject, - this.baseOffset, - row.sizeInBytes - ); + Platform.copyMemory( + row.baseObject, row.baseOffset, this.baseObject, this.baseOffset, row.sizeInBytes); // update the sizeInBytes. this.sizeInBytes = row.sizeInBytes; } @@ -544,19 +538,15 @@ public void copyFrom(UnsafeRow row) { */ public void writeToStream(OutputStream out, byte[] writeBuffer) throws IOException { if (baseObject instanceof byte[]) { - int offsetInByteArray = (int) (PlatformDependent.BYTE_ARRAY_OFFSET - baseOffset); + int offsetInByteArray = (int) (Platform.BYTE_ARRAY_OFFSET - baseOffset); out.write((byte[]) baseObject, offsetInByteArray, sizeInBytes); } else { int dataRemaining = sizeInBytes; long rowReadPosition = baseOffset; while (dataRemaining > 0) { int toTransfer = Math.min(writeBuffer.length, dataRemaining); - PlatformDependent.copyMemory( - baseObject, - rowReadPosition, - writeBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET, - toTransfer); + Platform.copyMemory( + baseObject, rowReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer); out.write(writeBuffer, 0, toTransfer); rowReadPosition += toTransfer; dataRemaining -= toTransfer; @@ -584,13 +574,12 @@ public boolean equals(Object other) { * Returns the underlying bytes for this UnsafeRow. */ public byte[] getBytes() { - if (baseObject instanceof byte[] && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET + if (baseObject instanceof byte[] && baseOffset == Platform.BYTE_ARRAY_OFFSET && (((byte[]) baseObject).length == sizeInBytes)) { return (byte[]) baseObject; } else { byte[] bytes = new byte[sizeInBytes]; - PlatformDependent.copyMemory(baseObject, baseOffset, bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, sizeInBytes); + Platform.copyMemory(baseObject, baseOffset, bytes, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); return bytes; } } @@ -600,8 +589,7 @@ public byte[] getBytes() { public String toString() { StringBuilder build = new StringBuilder("["); for (int i = 0; i < sizeInBytes; i += 8) { - build.append(java.lang.Long.toHexString( - PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i))); + build.append(java.lang.Long.toHexString(Platform.getLong(baseObject, baseOffset + i))); build.append(','); } build.append(']'); @@ -619,12 +607,6 @@ public boolean anyNull() { * bytes in this string. */ public void writeToMemory(Object target, long targetOffset) { - PlatformDependent.copyMemory( - baseObject, - baseOffset, - target, - targetOffset, - sizeInBytes - ); + Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java index 28e7ec0a0f120..005351f0883e5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.Decimal; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.ByteArray; import org.apache.spark.unsafe.types.CalendarInterval; @@ -58,27 +58,27 @@ public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input final Object base = target.getBaseObject(); final long offset = target.getBaseOffset() + cursor; // zero-out the bytes - PlatformDependent.UNSAFE.putLong(base, offset, 0L); - PlatformDependent.UNSAFE.putLong(base, offset + 8, 0L); + Platform.putLong(base, offset, 0L); + Platform.putLong(base, offset + 8, 0L); if (input == null) { target.setNullAt(ordinal); // keep the offset and length for update int fieldOffset = UnsafeRow.calculateBitSetWidthInBytes(target.numFields()) + ordinal * 8; - PlatformDependent.UNSAFE.putLong(base, target.getBaseOffset() + fieldOffset, + Platform.putLong(base, target.getBaseOffset() + fieldOffset, ((long) cursor) << 32); return SIZE; } final BigInteger integer = input.toJavaBigDecimal().unscaledValue(); int signum = integer.signum() + 1; - final int[] mag = (int[]) PlatformDependent.UNSAFE.getObjectVolatile(integer, - PlatformDependent.BIG_INTEGER_MAG_OFFSET); + final int[] mag = (int[]) Platform.getObjectVolatile( + integer, Platform.BIG_INTEGER_MAG_OFFSET); assert(mag.length <= 4); // Write the bytes to the variable length portion. - PlatformDependent.copyMemory(mag, PlatformDependent.INT_ARRAY_OFFSET, - base, target.getBaseOffset() + cursor, mag.length * 4); + Platform.copyMemory( + mag, Platform.INT_ARRAY_OFFSET, base, target.getBaseOffset() + cursor, mag.length * 4); // Set the fixed length portion. target.setLong(ordinal, (((long) cursor) << 32) | ((long) ((signum << 8) + mag.length))); @@ -99,8 +99,7 @@ public static int write(UnsafeRow target, int ordinal, int cursor, UTF8String in // zero-out the padding bytes if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } // Write the bytes to the variable length portion. @@ -125,8 +124,7 @@ public static int write(UnsafeRow target, int ordinal, int cursor, byte[] input) // zero-out the padding bytes if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } // Write the bytes to the variable length portion. @@ -167,8 +165,7 @@ public static int write(UnsafeRow target, int ordinal, int cursor, InternalRow i // zero-out the padding bytes if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } // Write the bytes to the variable length portion. @@ -191,8 +188,8 @@ public static int write(UnsafeRow target, int ordinal, int cursor, CalendarInter final long offset = target.getBaseOffset() + cursor; // Write the months and microseconds fields of Interval to the variable length portion. - PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, input.months); - PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, input.microseconds); + Platform.putLong(target.getBaseObject(), offset, input.months); + Platform.putLong(target.getBaseObject(), offset + 8, input.microseconds); // Set the fixed length portion. target.setLong(ordinal, ((long) cursor) << 32); @@ -212,12 +209,11 @@ public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeArrayDa final long offset = target.getBaseOffset() + cursor; // write the number of elements into first 4 bytes. - PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset, input.numElements()); + Platform.putInt(target.getBaseObject(), offset, input.numElements()); // zero-out the padding bytes if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } // Write the bytes to the variable length portion. @@ -247,14 +243,13 @@ public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeMapData final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes; // write the number of elements into first 4 bytes. - PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset, input.numElements()); + Platform.putInt(target.getBaseObject(), offset, input.numElements()); // write the numBytes of key array into second 4 bytes. - PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset + 4, keysNumBytes); + Platform.putInt(target.getBaseObject(), offset + 4, keysNumBytes); // zero-out the padding bytes if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } // Write the bytes of key array to the variable length portion. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java index 0e8e405d055de..cd83695fca033 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.sql.types.Decimal; -import org.apache.spark.unsafe.PlatformDependent; -import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -36,17 +35,11 @@ public static void writeToMemory( // zero-out the padding bytes // if ((numBytes & 0x07) > 0) { -// PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + ((numBytes >> 3) << 3), 0L); +// Platform.putLong(targetObject, targetOffset + ((numBytes >> 3) << 3), 0L); // } // Write the UnsafeData to the target memory. - PlatformDependent.copyMemory( - inputObject, - inputOffset, - targetObject, - targetOffset, - numBytes - ); + Platform.copyMemory(inputObject, inputOffset, targetObject, targetOffset, numBytes); } public static int getRoundedSize(int size) { @@ -68,16 +61,11 @@ public static int write(Object targetObject, long targetOffset, Decimal input) { assert(numBytes <= 16); // zero-out the bytes - PlatformDependent.UNSAFE.putLong(targetObject, targetOffset, 0L); - PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + 8, 0L); + Platform.putLong(targetObject, targetOffset, 0L); + Platform.putLong(targetObject, targetOffset + 8, 0L); // Write the bytes to the variable length portion. - PlatformDependent.copyMemory(bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - targetObject, - targetOffset, - numBytes); - + Platform.copyMemory(bytes, Platform.BYTE_ARRAY_OFFSET, targetObject, targetOffset, numBytes); return 16; } } @@ -111,8 +99,7 @@ public static int write(Object targetObject, long targetOffset, byte[] input) { final int numBytes = input.length; // Write the bytes to the variable length portion. - writeToMemory(input, PlatformDependent.BYTE_ARRAY_OFFSET, - targetObject, targetOffset, numBytes); + writeToMemory(input, Platform.BYTE_ARRAY_OFFSET, targetObject, targetOffset, numBytes); return getRoundedSize(numBytes); } @@ -144,11 +131,9 @@ public static int getSize(UnsafeRow input) { } public static int write(Object targetObject, long targetOffset, CalendarInterval input) { - // Write the months and microseconds fields of Interval to the variable length portion. - PlatformDependent.UNSAFE.putLong(targetObject, targetOffset, input.months); - PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + 8, input.microseconds); - + Platform.putLong(targetObject, targetOffset, input.months); + Platform.putLong(targetObject, targetOffset + 8, input.microseconds); return 16; } } @@ -165,11 +150,11 @@ public static int write(Object targetObject, long targetOffset, UnsafeArrayData final int numBytes = input.getSizeInBytes(); // write the number of elements into first 4 bytes. - PlatformDependent.UNSAFE.putInt(targetObject, targetOffset, input.numElements()); + Platform.putInt(targetObject, targetOffset, input.numElements()); // Write the bytes to the variable length portion. - writeToMemory(input.getBaseObject(), input.getBaseOffset(), - targetObject, targetOffset + 4, numBytes); + writeToMemory( + input.getBaseObject(), input.getBaseOffset(), targetObject, targetOffset + 4, numBytes); return getRoundedSize(numBytes + 4); } @@ -190,9 +175,9 @@ public static int write(Object targetObject, long targetOffset, UnsafeMapData in final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes; // write the number of elements into first 4 bytes. - PlatformDependent.UNSAFE.putInt(targetObject, targetOffset, input.numElements()); + Platform.putInt(targetObject, targetOffset, input.numElements()); // write the numBytes of key array into second 4 bytes. - PlatformDependent.UNSAFE.putInt(targetObject, targetOffset + 4, keysNumBytes); + Platform.putInt(targetObject, targetOffset + 4, keysNumBytes); // Write the bytes of key array to the variable length portion. writeToMemory(keyArray.getBaseObject(), keyArray.getBaseOffset(), diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index a5ae2b9736527..1d27182912c8a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.types.StructType; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; import org.apache.spark.util.collection.unsafe.sort.RecordComparator; import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter; @@ -157,7 +157,7 @@ public UnsafeRow next() { cleanupResources(); // Scala iterators don't declare any checked exceptions, so we need to use this hack // to re-throw the exception: - PlatformDependent.throwException(e); + Platform.throwException(e); } throw new RuntimeException("Exception should have been re-thrown in next()"); }; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index c21f4d626a74e..bf96248feaef7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -28,7 +28,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types._ @@ -371,7 +371,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin // Cannot be under package codegen, or fail with java.lang.InstantiationException evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass") evaluator.setDefaultImports(Array( - classOf[PlatformDependent].getName, + classOf[Platform].getName, classOf[InternalRow].getName, classOf[UnsafeRow].getName, classOf[UTF8String].getName, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 29f6a7b981752..b2fb913850794 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -145,12 +145,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro if ($buffer.length < $numBytes) { // This will not happen frequently, because the buffer is re-used. byte[] $tmp = new byte[$numBytes * 2]; - PlatformDependent.copyMemory($buffer, PlatformDependent.BYTE_ARRAY_OFFSET, - $tmp, PlatformDependent.BYTE_ARRAY_OFFSET, $buffer.length); + Platform.copyMemory($buffer, Platform.BYTE_ARRAY_OFFSET, + $tmp, Platform.BYTE_ARRAY_OFFSET, $buffer.length); $buffer = $tmp; } - $output.pointTo($buffer, PlatformDependent.BYTE_ARRAY_OFFSET, - ${inputTypes.length}, $numBytes); + $output.pointTo($buffer, Platform.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $numBytes); """ } else { "" @@ -183,7 +182,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val code = s""" $cursor = $fixedSize; - $output.pointTo($buffer, PlatformDependent.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $cursor); + $output.pointTo($buffer, Platform.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $cursor); ${ctx.splitExpressions(row, convertedFields)} """ GeneratedExpressionCode(code, "false", output) @@ -267,17 +266,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Should we do word align? val elementSize = elementType.defaultSize s""" - PlatformDependent.UNSAFE.put${ctx.primitiveTypeName(elementType)}( + Platform.put${ctx.primitiveTypeName(elementType)}( $buffer, - PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + Platform.BYTE_ARRAY_OFFSET + $cursor, ${convertedElement.primitive}); $cursor += $elementSize; """ case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => s""" - PlatformDependent.UNSAFE.putLong( + Platform.putLong( $buffer, - PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + Platform.BYTE_ARRAY_OFFSET + $cursor, ${convertedElement.primitive}.toUnscaledLong()); $cursor += 8; """ @@ -286,7 +285,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" $cursor += $writer.write( $buffer, - PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + Platform.BYTE_ARRAY_OFFSET + $cursor, $elements[$index]); """ } @@ -320,23 +319,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro for (int $index = 0; $index < $numElements; $index++) { if ($checkNull) { // If element is null, write the negative value address into offset region. - PlatformDependent.UNSAFE.putInt( - $buffer, - PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, - -$cursor); + Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, -$cursor); } else { - PlatformDependent.UNSAFE.putInt( - $buffer, - PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, - $cursor); - + Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, $cursor); $writeElement } } $output.pointTo( $buffer, - PlatformDependent.BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, $numElements, $numBytes); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index 8aaa5b4300044..da91ff29537b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, Attribute} import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform abstract class UnsafeRowJoiner { @@ -52,9 +52,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U } def create(schema1: StructType, schema2: StructType): UnsafeRowJoiner = { - val offset = PlatformDependent.BYTE_ARRAY_OFFSET - val getLong = "PlatformDependent.UNSAFE.getLong" - val putLong = "PlatformDependent.UNSAFE.putLong" + val offset = Platform.BYTE_ARRAY_OFFSET + val getLong = "Platform.getLong" + val putLong = "Platform.putLong" val bitset1Words = (schema1.size + 63) / 64 val bitset2Words = (schema2.size + 63) / 64 @@ -96,7 +96,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U var cursor = offset + outputBitsetWords * 8 val copyFixedLengthRow1 = s""" |// Copy fixed length data for row1 - |PlatformDependent.copyMemory( + |Platform.copyMemory( | obj1, offset1 + ${bitset1Words * 8}, | buf, $cursor, | ${schema1.size * 8}); @@ -106,7 +106,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U // --------------------- copy fixed length portion from row 2 ----------------------- // val copyFixedLengthRow2 = s""" |// Copy fixed length data for row2 - |PlatformDependent.copyMemory( + |Platform.copyMemory( | obj2, offset2 + ${bitset2Words * 8}, | buf, $cursor, | ${schema2.size * 8}); @@ -118,7 +118,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U val copyVariableLengthRow1 = s""" |// Copy variable length data for row1 |long numBytesVariableRow1 = row1.getSizeInBytes() - $numBytesBitsetAndFixedRow1; - |PlatformDependent.copyMemory( + |Platform.copyMemory( | obj1, offset1 + ${(bitset1Words + schema1.size) * 8}, | buf, $cursor, | numBytesVariableRow1); @@ -129,7 +129,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U val copyVariableLengthRow2 = s""" |// Copy variable length data for row2 |long numBytesVariableRow2 = row2.getSizeInBytes() - $numBytesBitsetAndFixedRow2; - |PlatformDependent.copyMemory( + |Platform.copyMemory( | obj2, offset2 + ${(bitset2Words + schema2.size) * 8}, | buf, $cursor + numBytesVariableRow1, | numBytesVariableRow2); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 76666bd6b3d27..134f1aa2af9a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -1013,7 +1013,7 @@ case class Decode(bin: Expression, charset: Expression) try { ${ev.primitive} = UTF8String.fromString(new String($bytes, $charset.toString())); } catch (java.io.UnsupportedEncodingException e) { - org.apache.spark.unsafe.PlatformDependent.throwException(e); + org.apache.spark.unsafe.Platform.throwException(e); } """) } @@ -1043,7 +1043,7 @@ case class Encode(value: Expression, charset: Expression) try { ${ev.primitive} = $string.toString().getBytes($charset.toString()); } catch (java.io.UnsupportedEncodingException e) { - org.apache.spark.unsafe.PlatformDependent.throwException(e); + org.apache.spark.unsafe.Platform.throwException(e); }""") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala index aff1bee99faad..796d60032e1a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala @@ -22,7 +22,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform /** * A test suite for the bitset portion of the row concatenation. @@ -96,7 +96,7 @@ class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite { // This way we can test the joiner when the input UnsafeRows are not the entire arrays. val offset = numFields * 8 val buf = new Array[Byte](sizeInBytes + offset) - row.pointTo(buf, PlatformDependent.BYTE_ARRAY_OFFSET + offset, numFields, sizeInBytes) + row.pointTo(buf, Platform.BYTE_ARRAY_OFFSET + offset, numFields, sizeInBytes) row } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 00218f213054b..5cce41d5a7569 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.KVIterator; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryLocation; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -138,7 +138,7 @@ public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRo unsafeGroupingKeyRow.getBaseOffset(), unsafeGroupingKeyRow.getSizeInBytes(), emptyAggregationBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, emptyAggregationBuffer.length ); if (!putSucceeded) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 69d6784713a24..7db6b7ff50f22 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -31,7 +31,7 @@ import org.apache.spark.sql.types.StructType; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.KVIterator; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -225,7 +225,7 @@ public boolean next() throws IOException { int recordLen = underlying.getRecordLength(); // Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself) - int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset); + int keyLen = Platform.getInt(baseObj, recordOffset); int valueLen = recordLen - keyLen - 4; key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen); value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 6c7e5cacc99e7..3860c4bba9a99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -26,7 +26,7 @@ import com.google.common.io.ByteStreams import org.apache.spark.serializer.{SerializationStream, DeserializationStream, SerializerInstance, Serializer} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform /** * Serializer for serializing [[UnsafeRow]]s during shuffle. Since UnsafeRows are already stored as @@ -116,7 +116,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst rowBuffer = new Array[Byte](rowSize) } ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) - row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize) + row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize) rowSize = dIn.readInt() // read the next row's size if (rowSize == EOF) { // We are returning the last row in this stream val _rowTuple = rowTuple @@ -150,7 +150,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst rowBuffer = new Array[Byte](rowSize) } ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) - row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize) + row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize) row.asInstanceOf[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 953abf409f220..63d35d0f02622 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -25,7 +25,7 @@ import org.apache.spark.shuffle.ShuffleMemoryManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.util.Utils @@ -218,8 +218,8 @@ private[joins] final class UnsafeHashedRelation( var offset = loc.getValueAddress.getBaseOffset val last = loc.getValueAddress.getBaseOffset + loc.getValueLength while (offset < last) { - val numFields = PlatformDependent.UNSAFE.getInt(base, offset) - val sizeInBytes = PlatformDependent.UNSAFE.getInt(base, offset + 4) + val numFields = Platform.getInt(base, offset) + val sizeInBytes = Platform.getInt(base, offset + 4) offset += 8 val row = new UnsafeRow @@ -314,10 +314,11 @@ private[joins] final class UnsafeHashedRelation( in.readFully(valuesBuffer, 0, valuesSize) // put it into binary map - val loc = binaryMap.lookup(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize) + val loc = binaryMap.lookup(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize) assert(!loc.isDefined, "Duplicated key found!") - val putSuceeded = loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize, - valuesBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, valuesSize) + val putSuceeded = loc.putNewKey( + keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize, + valuesBuffer, Platform.BYTE_ARRAY_OFFSET, valuesSize) if (!putSuceeded) { throw new IOException("Could not allocate memory to grow BytesToBytesMap") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index 89bad1bfdab0a..219435dff5bc8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.memory.MemoryAllocator import org.apache.spark.unsafe.types.UTF8String @@ -51,7 +51,7 @@ class UnsafeRowSuite extends SparkFunSuite { val bytesFromOffheapRow: Array[Byte] = { val offheapRowPage = MemoryAllocator.UNSAFE.allocate(arrayBackedUnsafeRow.getSizeInBytes) try { - PlatformDependent.copyMemory( + Platform.copyMemory( arrayBackedUnsafeRow.getBaseObject, arrayBackedUnsafeRow.getBaseOffset, offheapRowPage.getBaseObject, diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java similarity index 55% rename from unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java rename to unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index b2de2a2590f05..18343efdc3437 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -22,103 +22,111 @@ import sun.misc.Unsafe; -public final class PlatformDependent { +public final class Platform { - /** - * Facade in front of {@link sun.misc.Unsafe}, used to avoid directly exposing Unsafe outside of - * this package. This also lets us avoid accidental use of deprecated methods. - */ - public static final class UNSAFE { - - private UNSAFE() { } + private static final Unsafe _UNSAFE; - public static int getInt(Object object, long offset) { - return _UNSAFE.getInt(object, offset); - } + public static final int BYTE_ARRAY_OFFSET; - public static void putInt(Object object, long offset, int value) { - _UNSAFE.putInt(object, offset, value); - } + public static final int INT_ARRAY_OFFSET; - public static boolean getBoolean(Object object, long offset) { - return _UNSAFE.getBoolean(object, offset); - } + public static final int LONG_ARRAY_OFFSET; - public static void putBoolean(Object object, long offset, boolean value) { - _UNSAFE.putBoolean(object, offset, value); - } + public static final int DOUBLE_ARRAY_OFFSET; - public static byte getByte(Object object, long offset) { - return _UNSAFE.getByte(object, offset); - } + // Support for resetting final fields while deserializing + public static final long BIG_INTEGER_SIGNUM_OFFSET; + public static final long BIG_INTEGER_MAG_OFFSET; - public static void putByte(Object object, long offset, byte value) { - _UNSAFE.putByte(object, offset, value); - } + public static int getInt(Object object, long offset) { + return _UNSAFE.getInt(object, offset); + } - public static short getShort(Object object, long offset) { - return _UNSAFE.getShort(object, offset); - } + public static void putInt(Object object, long offset, int value) { + _UNSAFE.putInt(object, offset, value); + } - public static void putShort(Object object, long offset, short value) { - _UNSAFE.putShort(object, offset, value); - } + public static boolean getBoolean(Object object, long offset) { + return _UNSAFE.getBoolean(object, offset); + } - public static long getLong(Object object, long offset) { - return _UNSAFE.getLong(object, offset); - } + public static void putBoolean(Object object, long offset, boolean value) { + _UNSAFE.putBoolean(object, offset, value); + } - public static void putLong(Object object, long offset, long value) { - _UNSAFE.putLong(object, offset, value); - } + public static byte getByte(Object object, long offset) { + return _UNSAFE.getByte(object, offset); + } - public static float getFloat(Object object, long offset) { - return _UNSAFE.getFloat(object, offset); - } + public static void putByte(Object object, long offset, byte value) { + _UNSAFE.putByte(object, offset, value); + } - public static void putFloat(Object object, long offset, float value) { - _UNSAFE.putFloat(object, offset, value); - } + public static short getShort(Object object, long offset) { + return _UNSAFE.getShort(object, offset); + } - public static double getDouble(Object object, long offset) { - return _UNSAFE.getDouble(object, offset); - } + public static void putShort(Object object, long offset, short value) { + _UNSAFE.putShort(object, offset, value); + } - public static void putDouble(Object object, long offset, double value) { - _UNSAFE.putDouble(object, offset, value); - } + public static long getLong(Object object, long offset) { + return _UNSAFE.getLong(object, offset); + } - public static Object getObjectVolatile(Object object, long offset) { - return _UNSAFE.getObjectVolatile(object, offset); - } + public static void putLong(Object object, long offset, long value) { + _UNSAFE.putLong(object, offset, value); + } - public static void putObjectVolatile(Object object, long offset, Object value) { - _UNSAFE.putObjectVolatile(object, offset, value); - } + public static float getFloat(Object object, long offset) { + return _UNSAFE.getFloat(object, offset); + } - public static long allocateMemory(long size) { - return _UNSAFE.allocateMemory(size); - } + public static void putFloat(Object object, long offset, float value) { + _UNSAFE.putFloat(object, offset, value); + } - public static void freeMemory(long address) { - _UNSAFE.freeMemory(address); - } + public static double getDouble(Object object, long offset) { + return _UNSAFE.getDouble(object, offset); + } + public static void putDouble(Object object, long offset, double value) { + _UNSAFE.putDouble(object, offset, value); } - private static final Unsafe _UNSAFE; + public static Object getObjectVolatile(Object object, long offset) { + return _UNSAFE.getObjectVolatile(object, offset); + } - public static final int BYTE_ARRAY_OFFSET; + public static void putObjectVolatile(Object object, long offset, Object value) { + _UNSAFE.putObjectVolatile(object, offset, value); + } - public static final int INT_ARRAY_OFFSET; + public static long allocateMemory(long size) { + return _UNSAFE.allocateMemory(size); + } - public static final int LONG_ARRAY_OFFSET; + public static void freeMemory(long address) { + _UNSAFE.freeMemory(address); + } - public static final int DOUBLE_ARRAY_OFFSET; + public static void copyMemory( + Object src, long srcOffset, Object dst, long dstOffset, long length) { + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + srcOffset += size; + dstOffset += size; + } + } - // Support for resetting final fields while deserializing - public static final long BIG_INTEGER_SIGNUM_OFFSET; - public static final long BIG_INTEGER_MAG_OFFSET; + /** + * Raises an exception bypassing compiler checks for checked exceptions. + */ + public static void throwException(Throwable t) { + _UNSAFE.throwException(t); + } /** * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to @@ -162,26 +170,4 @@ public static void freeMemory(long address) { BIG_INTEGER_MAG_OFFSET = 0; } } - - static public void copyMemory( - Object src, - long srcOffset, - Object dst, - long dstOffset, - long length) { - while (length > 0) { - long size = Math.min(length, UNSAFE_COPY_THRESHOLD); - _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); - length -= size; - srcOffset += size; - dstOffset += size; - } - } - - /** - * Raises an exception bypassing compiler checks for checked exceptions. - */ - public static void throwException(Throwable t) { - _UNSAFE.throwException(t); - } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index 70b81ce015ddc..cf42877bf9fd4 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.array; -import static org.apache.spark.unsafe.PlatformDependent.*; +import org.apache.spark.unsafe.Platform; public class ByteArrayMethods { @@ -45,20 +45,18 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { * @return true if the arrays are equal, false otherwise */ public static boolean arrayEquals( - Object leftBase, - long leftOffset, - Object rightBase, - long rightOffset, - final long length) { + Object leftBase, long leftOffset, Object rightBase, long rightOffset, final long length) { int i = 0; while (i <= length - 8) { - if (UNSAFE.getLong(leftBase, leftOffset + i) != UNSAFE.getLong(rightBase, rightOffset + i)) { + if (Platform.getLong(leftBase, leftOffset + i) != + Platform.getLong(rightBase, rightOffset + i)) { return false; } i += 8; } while (i < length) { - if (UNSAFE.getByte(leftBase, leftOffset + i) != UNSAFE.getByte(rightBase, rightOffset + i)) { + if (Platform.getByte(leftBase, leftOffset + i) != + Platform.getByte(rightBase, rightOffset + i)) { return false; } i += 1; diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java index 18d1f0d2d7eb2..74105050e4191 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.array; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; /** @@ -64,7 +64,7 @@ public long size() { public void set(int index, long value) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - PlatformDependent.UNSAFE.putLong(baseObj, baseOffset + index * WIDTH, value); + Platform.putLong(baseObj, baseOffset + index * WIDTH, value); } /** @@ -73,6 +73,6 @@ public void set(int index, long value) { public long get(int index) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - return PlatformDependent.UNSAFE.getLong(baseObj, baseOffset + index * WIDTH); + return Platform.getLong(baseObj, baseOffset + index * WIDTH); } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java index 27462c7fa5e62..7857bf66a72ad 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.bitset; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * Methods for working with fixed-size uncompressed bitsets. @@ -41,8 +41,8 @@ public static void set(Object baseObject, long baseOffset, int index) { assert index >= 0 : "index (" + index + ") should >= 0"; final long mask = 1L << (index & 0x3f); // mod 64 and shift final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; - final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); - PlatformDependent.UNSAFE.putLong(baseObject, wordOffset, word | mask); + final long word = Platform.getLong(baseObject, wordOffset); + Platform.putLong(baseObject, wordOffset, word | mask); } /** @@ -52,8 +52,8 @@ public static void unset(Object baseObject, long baseOffset, int index) { assert index >= 0 : "index (" + index + ") should >= 0"; final long mask = 1L << (index & 0x3f); // mod 64 and shift final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; - final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); - PlatformDependent.UNSAFE.putLong(baseObject, wordOffset, word & ~mask); + final long word = Platform.getLong(baseObject, wordOffset); + Platform.putLong(baseObject, wordOffset, word & ~mask); } /** @@ -63,7 +63,7 @@ public static boolean isSet(Object baseObject, long baseOffset, int index) { assert index >= 0 : "index (" + index + ") should >= 0"; final long mask = 1L << (index & 0x3f); // mod 64 and shift final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; - final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); + final long word = Platform.getLong(baseObject, wordOffset); return (word & mask) != 0; } @@ -73,7 +73,7 @@ public static boolean isSet(Object baseObject, long baseOffset, int index) { public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidthInWords) { long addr = baseOffset; for (int i = 0; i < bitSetWidthInWords; i++, addr += WORD_SIZE) { - if (PlatformDependent.UNSAFE.getLong(baseObject, addr) != 0) { + if (Platform.getLong(baseObject, addr) != 0) { return true; } } @@ -109,8 +109,7 @@ public static int nextSetBit( // Try to find the next set bit in the current word final int subIndex = fromIndex & 0x3f; - long word = - PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + wi * WORD_SIZE) >> subIndex; + long word = Platform.getLong(baseObject, baseOffset + wi * WORD_SIZE) >> subIndex; if (word != 0) { return (wi << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word); } @@ -118,7 +117,7 @@ public static int nextSetBit( // Find the next set bit in the rest of the words wi += 1; while (wi < bitsetSizeInWords) { - word = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + wi * WORD_SIZE); + word = Platform.getLong(baseObject, baseOffset + wi * WORD_SIZE); if (word != 0) { return (wi << 6) + java.lang.Long.numberOfTrailingZeros(word); } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index 61f483ced3217..4276f25c2165b 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.hash; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. @@ -53,7 +53,7 @@ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, i assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; int h1 = seed; for (int i = 0; i < lengthInBytes; i += 4) { - int halfWord = PlatformDependent.UNSAFE.getInt(base, offset + i); + int halfWord = Platform.getInt(base, offset + i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index 91be46ba21ff8..dd75820834370 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -19,7 +19,7 @@ import javax.annotation.Nullable; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size. @@ -50,6 +50,6 @@ public long size() { * Creates a memory block pointing to the memory used by the long array. */ public static MemoryBlock fromLongArray(final long[] array) { - return new MemoryBlock(array, PlatformDependent.LONG_ARRAY_OFFSET, array.length * 8); + return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8); } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java index 62f4459696c28..cda7826c8c99b 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.memory; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * A simple {@link MemoryAllocator} that uses {@code Unsafe} to allocate off-heap memory. @@ -29,7 +29,7 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { if (size % 8 != 0) { throw new IllegalArgumentException("Size " + size + " was not a multiple of 8"); } - long address = PlatformDependent.UNSAFE.allocateMemory(size); + long address = Platform.allocateMemory(size); return new MemoryBlock(null, address, size); } @@ -37,6 +37,6 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { public void free(MemoryBlock memory) { assert (memory.obj == null) : "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?"; - PlatformDependent.UNSAFE.freeMemory(memory.offset); + Platform.freeMemory(memory.offset); } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index 69b0e206cef18..c08c9c73d2396 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.types; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; public class ByteArray { @@ -27,12 +27,6 @@ public class ByteArray { * hold all the bytes in this string. */ public static void writeToMemory(byte[] src, Object target, long targetOffset) { - PlatformDependent.copyMemory( - src, - PlatformDependent.BYTE_ARRAY_OFFSET, - target, - targetOffset, - src.length - ); + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET, target, targetOffset, src.length); } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index d1014426c0f49..667c00900f2c5 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -24,10 +24,10 @@ import java.util.Arrays; import java.util.Map; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; -import static org.apache.spark.unsafe.PlatformDependent.*; +import static org.apache.spark.unsafe.Platform.*; /** @@ -133,13 +133,7 @@ protected UTF8String(Object base, long offset, int numBytes) { * bytes in this string. */ public void writeToMemory(Object target, long targetOffset) { - PlatformDependent.copyMemory( - base, - offset, - target, - targetOffset, - numBytes - ); + Platform.copyMemory(base, offset, target, targetOffset, numBytes); } /** @@ -183,12 +177,12 @@ public long getPrefix() { long mask = 0; if (isLittleEndian) { if (numBytes >= 8) { - p = PlatformDependent.UNSAFE.getLong(base, offset); + p = Platform.getLong(base, offset); } else if (numBytes > 4) { - p = PlatformDependent.UNSAFE.getLong(base, offset); + p = Platform.getLong(base, offset); mask = (1L << (8 - numBytes) * 8) - 1; } else if (numBytes > 0) { - p = (long) PlatformDependent.UNSAFE.getInt(base, offset); + p = (long) Platform.getInt(base, offset); mask = (1L << (8 - numBytes) * 8) - 1; } else { p = 0; @@ -197,12 +191,12 @@ public long getPrefix() { } else { // byteOrder == ByteOrder.BIG_ENDIAN if (numBytes >= 8) { - p = PlatformDependent.UNSAFE.getLong(base, offset); + p = Platform.getLong(base, offset); } else if (numBytes > 4) { - p = PlatformDependent.UNSAFE.getLong(base, offset); + p = Platform.getLong(base, offset); mask = (1L << (8 - numBytes) * 8) - 1; } else if (numBytes > 0) { - p = ((long) PlatformDependent.UNSAFE.getInt(base, offset)) << 32; + p = ((long) Platform.getInt(base, offset)) << 32; mask = (1L << (8 - numBytes) * 8) - 1; } else { p = 0; @@ -293,7 +287,7 @@ public boolean contains(final UTF8String substring) { * Returns the byte at position `i`. */ private byte getByte(int i) { - return UNSAFE.getByte(base, offset + i); + return Platform.getByte(base, offset + i); } private boolean matchAt(final UTF8String s, int pos) { @@ -769,7 +763,7 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { int len = inputs[i].numBytes; copyMemory( inputs[i].base, inputs[i].offset, - result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + result, BYTE_ARRAY_OFFSET + offset, len); offset += len; @@ -778,7 +772,7 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { if (j < numInputs) { copyMemory( separator.base, separator.offset, - result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + result, BYTE_ARRAY_OFFSET + offset, separator.numBytes); offset += separator.numBytes; } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java index 3b9175835229c..2f8cb132ac8b4 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -22,7 +22,7 @@ import java.util.Set; import junit.framework.Assert; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.junit.Test; /** @@ -83,11 +83,11 @@ public void randomizedStressTestBytes() { rand.nextBytes(bytes); Assert.assertEquals( - hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize), - hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); hashcodes.add(hasher.hashUnsafeWords( - bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); } // A very loose bound. @@ -106,11 +106,11 @@ public void randomizedStressTestPaddedStrings() { System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); Assert.assertEquals( - hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize), - hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); hashcodes.add(hasher.hashUnsafeWords( - paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); } // A very loose bound. From dfe347d2cae3eb05d7539aaf72db3d309e711213 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 11 Aug 2015 08:52:15 -0700 Subject: [PATCH 262/340] [SPARK-9785] [SQL] HashPartitioning compatibility should consider expression ordering HashPartitioning compatibility is currently defined w.r.t the _set_ of expressions, but the ordering of those expressions matters when computing hash codes; this could lead to incorrect answers if we mistakenly avoided a shuffle based on the assumption that HashPartitionings with the same expressions in different orders will produce equivalent row hashcodes. The first commit adds a regression test which illustrates this problem. The fix for this is simple: make `HashPartitioning.compatibleWith` and `HashPartitioning.guarantees` sensitive to the expression ordering (i.e. do not perform set comparison). Author: Josh Rosen Closes #8074 from JoshRosen/hashpartitioning-compatiblewith-fixes and squashes the following commits: b61412f [Josh Rosen] Demonstrate that I haven't cheated in my fix 0b4d7d9 [Josh Rosen] Update so that clusteringSet is only used in satisfies(). dc9c9d7 [Josh Rosen] Add failing regression test for SPARK-9785 --- .../plans/physical/partitioning.scala | 15 ++--- .../sql/catalyst/PartitioningSuite.scala | 55 +++++++++++++++++++ 2 files changed, 60 insertions(+), 10 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 5a89a90b735a6..5ac3f1f5b0cac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -216,26 +216,23 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - lazy val clusteringSet = expressions.toSet - override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true case ClusteredDistribution(requiredClustering) => - clusteringSet.subsetOf(requiredClustering.toSet) + expressions.toSet.subsetOf(requiredClustering.toSet) case _ => false } override def compatibleWith(other: Partitioning): Boolean = other match { - case o: HashPartitioning => - this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions + case o: HashPartitioning => this == o case _ => false } override def guarantees(other: Partitioning): Boolean = other match { - case o: HashPartitioning => - this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions + case o: HashPartitioning => this == o case _ => false } + } /** @@ -257,15 +254,13 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - private[this] lazy val clusteringSet = ordering.map(_.child).toSet - override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true case OrderedDistribution(requiredOrdering) => val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) case ClusteredDistribution(requiredClustering) => - clusteringSet.subsetOf(requiredClustering.toSet) + ordering.map(_.child).toSet.subsetOf(requiredClustering.toSet) case _ => false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala new file mode 100644 index 0000000000000..5b802ccc637dd --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal} +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning} + +class PartitioningSuite extends SparkFunSuite { + test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") { + val expressions = Seq(Literal(2), Literal(3)) + // Consider two HashPartitionings that have the same _set_ of hash expressions but which are + // created with different orderings of those expressions: + val partitioningA = HashPartitioning(expressions, 100) + val partitioningB = HashPartitioning(expressions.reverse, 100) + // These partitionings are not considered equal: + assert(partitioningA != partitioningB) + // However, they both satisfy the same clustered distribution: + val distribution = ClusteredDistribution(expressions) + assert(partitioningA.satisfies(distribution)) + assert(partitioningB.satisfies(distribution)) + // These partitionings compute different hashcodes for the same input row: + def computeHashCode(partitioning: HashPartitioning): Int = { + val hashExprProj = new InterpretedMutableProjection(partitioning.expressions, Seq.empty) + hashExprProj.apply(InternalRow.empty).hashCode() + } + assert(computeHashCode(partitioningA) != computeHashCode(partitioningB)) + // Thus, these partitionings are incompatible: + assert(!partitioningA.compatibleWith(partitioningB)) + assert(!partitioningB.compatibleWith(partitioningA)) + assert(!partitioningA.guarantees(partitioningB)) + assert(!partitioningB.guarantees(partitioningA)) + + // Just to be sure that we haven't cheated by having these methods always return false, + // check that identical partitionings are still compatible with and guarantee each other: + assert(partitioningA === partitioningA) + assert(partitioningA.guarantees(partitioningA)) + assert(partitioningA.compatibleWith(partitioningA)) + } +} From bce72797f3499f14455722600b0d0898d4fd87c9 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Tue, 11 Aug 2015 10:42:17 -0700 Subject: [PATCH 263/340] Fix comment error API is updated but its doc comment is not updated. Author: Jeff Zhang Closes #8097 from zjffdu/dev. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 9ced44131b0d9..6aafb4c5644d7 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -866,7 +866,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * }}} * * Do - * `val rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * `val rdd = sparkContext.binaryFiles("hdfs://a-hdfs-path")`, * * then `rdd` contains * {{{ From 8cad854ef6a2066de5adffcca6b79a205ccfd5f3 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 11 Aug 2015 11:01:59 -0700 Subject: [PATCH 264/340] [SPARK-8345] [ML] Add an SQL node as a feature transformer Implements the transforms which are defined by SQL statement. Currently we only support SQL syntax like 'SELECT ... FROM __THIS__' where '__THIS__' represents the underlying table of the input dataset. Author: Yanbo Liang Closes #7465 from yanboliang/spark-8345 and squashes the following commits: b403fcb [Yanbo Liang] address comments 0d4bb15 [Yanbo Liang] a better transformSchema() implementation 51eb9e7 [Yanbo Liang] Add an SQL node as a feature transformer --- .../spark/ml/feature/SQLTransformer.scala | 72 +++++++++++++++++++ .../ml/feature/SQLTransformerSuite.scala | 44 ++++++++++++ 2 files changed, 116 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala new file mode 100644 index 0000000000000..95e4305638730 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.{ParamMap, Param} +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.{SQLContext, DataFrame, Row} +import org.apache.spark.sql.types.StructType + +/** + * :: Experimental :: + * Implements the transforms which are defined by SQL statement. + * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__' + * where '__THIS__' represents the underlying table of the input dataset. + */ +@Experimental +class SQLTransformer (override val uid: String) extends Transformer { + + def this() = this(Identifiable.randomUID("sql")) + + /** + * SQL statement parameter. The statement is provided in string form. + * @group param + */ + final val statement: Param[String] = new Param[String](this, "statement", "SQL statement") + + /** @group setParam */ + def setStatement(value: String): this.type = set(statement, value) + + /** @group getParam */ + def getStatement: String = $(statement) + + private val tableIdentifier: String = "__THIS__" + + override def transform(dataset: DataFrame): DataFrame = { + val tableName = Identifiable.randomUID(uid) + dataset.registerTempTable(tableName) + val realStatement = $(statement).replace(tableIdentifier, tableName) + val outputDF = dataset.sqlContext.sql(realStatement) + outputDF + } + + override def transformSchema(schema: StructType): StructType = { + val sc = SparkContext.getOrCreate() + val sqlContext = SQLContext.getOrCreate(sc) + val dummyRDD = sc.parallelize(Seq(Row.empty)) + val dummyDF = sqlContext.createDataFrame(dummyRDD, schema) + dummyDF.registerTempTable(tableIdentifier) + val outputSchema = sqlContext.sql($(statement)).schema + outputSchema + } + + override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala new file mode 100644 index 0000000000000..d19052881ae45 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new SQLTransformer()) + } + + test("transform numeric data") { + val original = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") + val sqlTrans = new SQLTransformer().setStatement( + "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") + val result = sqlTrans.transform(original) + val resultSchema = sqlTrans.transformSchema(original.schema) + val expected = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0))) + .toDF("id", "v1", "v2", "v3", "v4") + assert(result.schema.toString == resultSchema.toString) + assert(resultSchema == expected.schema) + assert(result.collect().toSeq == expected.collect().toSeq) + } +} From dbd778d84d094ca142bc08c351478595b280bc2a Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 11 Aug 2015 11:33:36 -0700 Subject: [PATCH 265/340] [SPARK-8764] [ML] string indexer should take option to handle unseen values As a precursor to adding a public constructor add an option to handle unseen values by skipping rather than throwing an exception (default remains throwing an exception), Author: Holden Karau Closes #7266 from holdenk/SPARK-8764-string-indexer-should-take-option-to-handle-unseen-values and squashes the following commits: 38a4de9 [Holden Karau] fix long line 045bf22 [Holden Karau] Add a second b entry so b gets 0 for sure 81dd312 [Holden Karau] Update the docs for handleInvalid param to be more descriptive 7f37f6e [Holden Karau] remove extra space (scala style) 414e249 [Holden Karau] And switch to using handleInvalid instead of skipInvalid 1e53f9b [Holden Karau] update the param (codegen side) 7a22215 [Holden Karau] fix typo 100a39b [Holden Karau] Merge in master aa5b093 [Holden Karau] Since we filter we should never go down this code path if getSkipInvalid is true 75ffa69 [Holden Karau] Remove extra newline d69ef5e [Holden Karau] Add a test b5734be [Holden Karau] Add support for unseen labels afecd4e [Holden Karau] Add a param to skip invalid entries. --- .../spark/ml/feature/StringIndexer.scala | 26 ++++++++++++--- .../ml/param/shared/SharedParamsCodeGen.scala | 4 +++ .../spark/ml/param/shared/sharedParams.scala | 15 +++++++++ .../spark/ml/feature/StringIndexerSuite.scala | 32 +++++++++++++++++++ 4 files changed, 73 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index ebfa972532358..e4485eb038409 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -33,7 +33,8 @@ import org.apache.spark.util.collection.OpenHashMap /** * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. */ -private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { +private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol + with HasHandleInvalid { /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { @@ -65,13 +66,16 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod def this() = this(Identifiable.randomUID("strIdx")) + /** @group setParam */ + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, "error") + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - // TODO: handle unseen labels override def fit(dataset: DataFrame): StringIndexerModel = { val counts = dataset.select(col($(inputCol)).cast(StringType)) @@ -111,6 +115,10 @@ class StringIndexerModel private[ml] ( map } + /** @group setParam */ + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, "error") + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -128,14 +136,24 @@ class StringIndexerModel private[ml] ( if (labelToIndex.contains(label)) { labelToIndex(label) } else { - // TODO: handle unseen labels throw new SparkException(s"Unseen label: $label.") } } + val outputColName = $(outputCol) val metadata = NominalAttribute.defaultAttr .withName(outputColName).withValues(labels).toMetadata() - dataset.select(col("*"), + // If we are skipping invalid records, filter them out. + val filteredDataset = (getHandleInvalid) match { + case "skip" => { + val filterer = udf { label: String => + labelToIndex.contains(label) + } + dataset.where(filterer(dataset($(inputCol)))) + } + case _ => dataset + } + filteredDataset.select(col("*"), indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, metadata)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index a97c8059b8d45..da4c076830391 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -59,6 +59,10 @@ private[shared] object SharedParamsCodeGen { ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)", isValid = "ParamValidators.gtEq(1)"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), + ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " + + "will filter out rows with bad values), or error (which will throw an errror). More " + + "options may be added later.", + isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"), ParamDesc[Boolean]("standardization", "whether to standardize the training features" + " before fitting the model.", Some("true")), ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index f332630c32f1b..23e2b6cc43996 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -247,6 +247,21 @@ private[ml] trait HasFitIntercept extends Params { final def getFitIntercept: Boolean = $(fitIntercept) } +/** + * Trait for shared param handleInvalid. + */ +private[ml] trait HasHandleInvalid extends Params { + + /** + * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.. + * @group param + */ + final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", ParamValidators.inArray(Array("skip", "error"))) + + /** @group getParam */ + final def getHandleInvalid: String = $(handleInvalid) +} + /** * Trait for shared param standardization (default: true). */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index d0295a0fe2fc1..b111036087e6a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.SparkException import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite @@ -62,6 +63,37 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet) } + test("StringIndexerUnseen") { + val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2) + val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val df2 = sqlContext.createDataFrame(data2).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + // Verify we throw by default with unseen values + intercept[SparkException] { + indexer.transform(df2).collect() + } + val indexerSkipInvalid = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .setHandleInvalid("skip") + .fit(df) + // Verify that we skip the c record + val transformed = indexerSkipInvalid.transform(df2) + val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("b", "a")) + val output = transformed.select("id", "labelIndex").map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0 + val expected = Set((0, 1.0), (1, 0.0)) + assert(output === expected) + } + test("StringIndexer with a numeric input column") { val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2) val df = sqlContext.createDataFrame(data).toDF("id", "label") From 5b8bb1b213b8738f563fcd00747604410fbb3087 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 11 Aug 2015 12:02:28 -0700 Subject: [PATCH 266/340] [SPARK-9572] [STREAMING] [PYSPARK] Added StreamingContext.getActiveOrCreate() in Python Author: Tathagata Das Closes #8080 from tdas/SPARK-9572 and squashes the following commits: 64a231d [Tathagata Das] Fix based on comments 741a0d0 [Tathagata Das] Fixed style f4f094c [Tathagata Das] Tweaked test 9afcdbe [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into SPARK-9572 e21488d [Tathagata Das] Minor update 1a371d9 [Tathagata Das] Addressed comments. 60479da [Tathagata Das] Fixed indent 9c2da9c [Tathagata Das] Fixed bugs b5bd32c [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into SPARK-9572 b55b348 [Tathagata Das] Removed prints 5781728 [Tathagata Das] Fix style issues b711214 [Tathagata Das] Reverted run-tests.py 643b59d [Tathagata Das] Revert unnecessary change 150e58c [Tathagata Das] Added StreamingContext.getActiveOrCreate() in Python --- python/pyspark/streaming/context.py | 57 +++++++++++- python/pyspark/streaming/tests.py | 133 +++++++++++++++++++++++++--- python/run-tests.py | 2 +- 3 files changed, 177 insertions(+), 15 deletions(-) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index ac5ba69e8dbbb..e3ba70e4e5e88 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -86,6 +86,9 @@ class StreamingContext(object): """ _transformerSerializer = None + # Reference to a currently active StreamingContext + _activeContext = None + def __init__(self, sparkContext, batchDuration=None, jssc=None): """ Create a new StreamingContext. @@ -142,10 +145,10 @@ def getOrCreate(cls, checkpointPath, setupFunc): Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be recreated from the checkpoint data. If the data does not exist, then the provided setupFunc - will be used to create a JavaStreamingContext. + will be used to create a new context. - @param checkpointPath: Checkpoint directory used in an earlier JavaStreamingContext program - @param setupFunc: Function to create a new JavaStreamingContext and setup DStreams + @param checkpointPath: Checkpoint directory used in an earlier streaming program + @param setupFunc: Function to create a new context and setup DStreams """ # TODO: support checkpoint in HDFS if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath): @@ -170,6 +173,52 @@ def getOrCreate(cls, checkpointPath, setupFunc): cls._transformerSerializer.ctx = sc return StreamingContext(sc, None, jssc) + @classmethod + def getActive(cls): + """ + Return either the currently active StreamingContext (i.e., if there is a context started + but not stopped) or None. + """ + activePythonContext = cls._activeContext + if activePythonContext is not None: + # Verify that the current running Java StreamingContext is active and is the same one + # backing the supposedly active Python context + activePythonContextJavaId = activePythonContext._jssc.ssc().hashCode() + activeJvmContextOption = activePythonContext._jvm.StreamingContext.getActive() + + if activeJvmContextOption.isEmpty(): + cls._activeContext = None + elif activeJvmContextOption.get().hashCode() != activePythonContextJavaId: + cls._activeContext = None + raise Exception("JVM's active JavaStreamingContext is not the JavaStreamingContext " + "backing the action Python StreamingContext. This is unexpected.") + return cls._activeContext + + @classmethod + def getActiveOrCreate(cls, checkpointPath, setupFunc): + """ + Either return the active StreamingContext (i.e. currently started but not stopped), + or recreate a StreamingContext from checkpoint data or create a new StreamingContext + using the provided setupFunc function. If the checkpointPath is None or does not contain + valid checkpoint data, then setupFunc will be called to create a new context and setup + DStreams. + + @param checkpointPath: Checkpoint directory used in an earlier streaming program. Can be + None if the intention is to always create a new context when there + is no active context. + @param setupFunc: Function to create a new JavaStreamingContext and setup DStreams + """ + + if setupFunc is None: + raise Exception("setupFunc cannot be None") + activeContext = cls.getActive() + if activeContext is not None: + return activeContext + elif checkpointPath is not None: + return cls.getOrCreate(checkpointPath, setupFunc) + else: + return setupFunc() + @property def sparkContext(self): """ @@ -182,6 +231,7 @@ def start(self): Start the execution of the streams. """ self._jssc.start() + StreamingContext._activeContext = self def awaitTermination(self, timeout=None): """ @@ -212,6 +262,7 @@ def stop(self, stopSparkContext=True, stopGraceFully=False): of all received data to be completed """ self._jssc.stop(stopSparkContext, stopGraceFully) + StreamingContext._activeContext = None if stopSparkContext: self._sc.stop() diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index f0ed415f97120..6108c845c1efe 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -24,6 +24,7 @@ import tempfile import random import struct +import shutil from functools import reduce if sys.version_info[:2] <= (2, 6): @@ -59,12 +60,21 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): cls.sc.stop() + # Clean up in the JVM just in case there has been some issues in Python API + jSparkContextOption = SparkContext._jvm.SparkContext.get() + if jSparkContextOption.nonEmpty(): + jSparkContextOption.get().stop() def setUp(self): self.ssc = StreamingContext(self.sc, self.duration) def tearDown(self): - self.ssc.stop(False) + if self.ssc is not None: + self.ssc.stop(False) + # Clean up in the JVM just in case there has been some issues in Python API + jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive() + if jStreamingContextOption.nonEmpty(): + jStreamingContextOption.get().stop(False) def wait_for(self, result, n): start_time = time.time() @@ -442,6 +452,7 @@ def test_reduce_by_invalid_window(self): class StreamingContextTests(PySparkStreamingTestCase): duration = 0.1 + setupCalled = False def _add_input_stream(self): inputs = [range(1, x) for x in range(101)] @@ -515,10 +526,85 @@ def func(rdds): self.assertEqual([2, 3, 1], self._take(dstream, 3)) + def test_get_active(self): + self.assertEqual(StreamingContext.getActive(), None) + + # Verify that getActive() returns the active context + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + + # Verify that getActive() returns None + self.ssc.stop(False) + self.assertEqual(StreamingContext.getActive(), None) + + # Verify that if the Java context is stopped, then getActive() returns None + self.ssc = StreamingContext(self.sc, self.duration) + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + self.ssc._jssc.stop(False) + self.assertEqual(StreamingContext.getActive(), None) + + def test_get_active_or_create(self): + # Test StreamingContext.getActiveOrCreate() without checkpoint data + # See CheckpointTests for tests with checkpoint data + self.ssc = None + self.assertEqual(StreamingContext.getActive(), None) + + def setupFunc(): + ssc = StreamingContext(self.sc, self.duration) + ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.setupCalled = True + return ssc + + # Verify that getActiveOrCreate() (w/o checkpoint) calls setupFunc when no context is active + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + + # Verify that getActiveOrCreate() retuns active context and does not call the setupFunc + self.ssc.start() + self.setupCalled = False + self.assertEqual(StreamingContext.getActiveOrCreate(None, setupFunc), self.ssc) + self.assertFalse(self.setupCalled) + + # Verify that getActiveOrCreate() calls setupFunc after active context is stopped + self.ssc.stop(False) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + + # Verify that if the Java context is stopped, then getActive() returns None + self.ssc = StreamingContext(self.sc, self.duration) + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + self.ssc._jssc.stop(False) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + class CheckpointTests(unittest.TestCase): - def test_get_or_create(self): + setupCalled = False + + @staticmethod + def tearDownClass(): + # Clean up in the JVM just in case there has been some issues in Python API + jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive() + if jStreamingContextOption.nonEmpty(): + jStreamingContextOption.get().stop() + jSparkContextOption = SparkContext._jvm.SparkContext.get() + if jSparkContextOption.nonEmpty(): + jSparkContextOption.get().stop() + + def tearDown(self): + if self.ssc is not None: + self.ssc.stop(True) + + def test_get_or_create_and_get_active_or_create(self): inputd = tempfile.mkdtemp() outputd = tempfile.mkdtemp() + "/" @@ -533,11 +619,12 @@ def setup(): wc = dstream.updateStateByKey(updater) wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test") wc.checkpoint(.5) + self.setupCalled = True return ssc cpd = tempfile.mkdtemp("test_streaming_cps") - ssc = StreamingContext.getOrCreate(cpd, setup) - ssc.start() + self.ssc = StreamingContext.getOrCreate(cpd, setup) + self.ssc.start() def check_output(n): while not os.listdir(outputd): @@ -552,7 +639,7 @@ def check_output(n): # not finished time.sleep(0.01) continue - ordd = ssc.sparkContext.textFile(p).map(lambda line: line.split(",")) + ordd = self.ssc.sparkContext.textFile(p).map(lambda line: line.split(",")) d = ordd.values().map(int).collect() if not d: time.sleep(0.01) @@ -568,13 +655,37 @@ def check_output(n): check_output(1) check_output(2) - ssc.stop(True, True) + # Verify the getOrCreate() recovers from checkpoint files + self.ssc.stop(True, True) time.sleep(1) - ssc = StreamingContext.getOrCreate(cpd, setup) - ssc.start() + self.setupCalled = False + self.ssc = StreamingContext.getOrCreate(cpd, setup) + self.assertFalse(self.setupCalled) + self.ssc.start() check_output(3) - ssc.stop(True, True) + + # Verify the getActiveOrCreate() recovers from checkpoint files + self.ssc.stop(True, True) + time.sleep(1) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(cpd, setup) + self.assertFalse(self.setupCalled) + self.ssc.start() + check_output(4) + + # Verify that getActiveOrCreate() returns active context + self.setupCalled = False + self.assertEquals(StreamingContext.getActiveOrCreate(cpd, setup), self.ssc) + self.assertFalse(self.setupCalled) + + # Verify that getActiveOrCreate() calls setup() in absence of checkpoint files + self.ssc.stop(True, True) + shutil.rmtree(cpd) # delete checkpoint directory + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(cpd, setup) + self.assertTrue(self.setupCalled) + self.ssc.stop(True, True) class KafkaStreamTests(PySparkStreamingTestCase): @@ -1134,7 +1245,7 @@ def search_kinesis_asl_assembly_jar(): testcases.append(KinesisStreamTests) elif are_kinesis_tests_enabled is False: sys.stderr.write("Skipping all Kinesis Python tests as the optional Kinesis project was " - "not compiled with -Pkinesis-asl profile. To run these tests, " + "not compiled into a JAR. To run these tests, " "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/assembly " "streaming-kinesis-asl-assembly/assembly' or " "'build/mvn -Pkinesis-asl package' before running this test.") @@ -1150,4 +1261,4 @@ def search_kinesis_asl_assembly_jar(): for testcase in testcases: sys.stderr.write("[Running %s]\n" % (testcase)) tests = unittest.TestLoader().loadTestsFromTestCase(testcase) - unittest.TextTestRunner(verbosity=2).run(tests) + unittest.TextTestRunner(verbosity=3).run(tests) diff --git a/python/run-tests.py b/python/run-tests.py index cc560779373b3..fd56c7ab6e0e2 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -158,7 +158,7 @@ def main(): else: log_level = logging.INFO logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s") - LOGGER.info("Running PySpark tests. Output is in python/%s", LOG_FILE) + LOGGER.info("Running PySpark tests. Output is in %s", LOG_FILE) if os.path.exists(LOG_FILE): os.remove(LOG_FILE) python_execs = opts.python_executables.split(',') From 5831294a7a8fa2524133c5d718cbc8187d2b0620 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 11 Aug 2015 12:39:13 -0700 Subject: [PATCH 267/340] [SPARK-9646] [SQL] Add metrics for all join and aggregate operators This PR added metrics for all join and aggregate operators. However, I found the metrics may be confusing in the following two case: 1. The iterator is not totally consumed and the metric values will be less. 2. Recreating the iterators will make metric values look bigger than the size of the input source, such as `CartesianProduct`. Author: zsxwing Closes #8060 from zsxwing/sql-metrics and squashes the following commits: 40f3fc1 [zsxwing] Mark LongSQLMetric private[metric] to avoid using incorrectly and leak memory b1b9071 [zsxwing] Merge branch 'master' into sql-metrics 4bef25a [zsxwing] Add metrics for SortMergeOuterJoin 95ccfc6 [zsxwing] Merge branch 'master' into sql-metrics 67cb4dd [zsxwing] Add metrics for Project and TungstenProject; remove metrics from PhysicalRDD and LocalTableScan 0eb47d4 [zsxwing] Merge branch 'master' into sql-metrics dd9d932 [zsxwing] Avoid creating new Iterators 589ea26 [zsxwing] Add metrics for all join and aggregate operators --- .../spark/sql/execution/Aggregate.scala | 11 + .../spark/sql/execution/ExistingRDD.scala | 2 - .../spark/sql/execution/LocalTableScan.scala | 2 - .../spark/sql/execution/SparkPlan.scala | 25 +- .../aggregate/SortBasedAggregate.scala | 12 +- .../SortBasedAggregationIterator.scala | 18 +- .../aggregate/TungstenAggregate.scala | 12 +- .../TungstenAggregationIterator.scala | 11 +- .../spark/sql/execution/basicOperators.scala | 36 +- .../execution/joins/BroadcastHashJoin.scala | 30 +- .../joins/BroadcastHashOuterJoin.scala | 40 +- .../joins/BroadcastLeftSemiJoinHash.scala | 24 +- .../joins/BroadcastNestedLoopJoin.scala | 27 +- .../execution/joins/CartesianProduct.scala | 25 +- .../spark/sql/execution/joins/HashJoin.scala | 7 +- .../sql/execution/joins/HashOuterJoin.scala | 30 +- .../sql/execution/joins/HashSemiJoin.scala | 23 +- .../sql/execution/joins/HashedRelation.scala | 8 +- .../sql/execution/joins/LeftSemiJoinBNL.scala | 19 +- .../execution/joins/LeftSemiJoinHash.scala | 18 +- .../execution/joins/ShuffledHashJoin.scala | 16 +- .../joins/ShuffledHashOuterJoin.scala | 29 +- .../sql/execution/joins/SortMergeJoin.scala | 21 +- .../execution/joins/SortMergeOuterJoin.scala | 38 +- .../sql/execution/metric/SQLMetrics.scala | 6 + .../execution/joins/HashedRelationSuite.scala | 14 +- .../execution/metric/SQLMetricsSuite.scala | 450 +++++++++++++++++- 27 files changed, 847 insertions(+), 107 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index e8c6a0f8f801d..f3b6a3a5f4a33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -45,6 +46,10 @@ case class Aggregate( child: SparkPlan) extends UnaryNode { + override private[sql] lazy val metrics = Map( + "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def requiredChildDistribution: List[Distribution] = { if (partial) { UnspecifiedDistribution :: Nil @@ -121,12 +126,15 @@ case class Aggregate( } protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + val numInputRows = longMetric("numInputRows") + val numOutputRows = longMetric("numOutputRows") if (groupingExpressions.isEmpty) { child.execute().mapPartitions { iter => val buffer = newAggregateBuffer() var currentRow: InternalRow = null while (iter.hasNext) { currentRow = iter.next() + numInputRows += 1 var i = 0 while (i < buffer.length) { buffer(i).update(currentRow) @@ -142,6 +150,7 @@ case class Aggregate( i += 1 } + numOutputRows += 1 Iterator(resultProjection(aggregateResults)) } } else { @@ -152,6 +161,7 @@ case class Aggregate( var currentRow: InternalRow = null while (iter.hasNext) { currentRow = iter.next() + numInputRows += 1 val currentGroup = groupingProjection(currentRow) var currentBuffer = hashTable.get(currentGroup) if (currentBuffer == null) { @@ -180,6 +190,7 @@ case class Aggregate( val currentEntry = hashTableIter.next() val currentGroup = currentEntry.getKey val currentBuffer = currentEntry.getValue + numOutputRows += 1 var i = 0 while (i < currentBuffer.length) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index cae7ca5cbdc88..abb60cf12e3a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -99,8 +99,6 @@ private[sql] case class PhysicalRDD( rdd: RDD[InternalRow], extraInformation: String) extends LeafNode { - override protected[sql] val trackNumOfRowsEnabled = true - protected override def doExecute(): RDD[InternalRow] = rdd override def simpleString: String = "Scan " + extraInformation + output.mkString("[", ",", "]") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index 858dd85fd1fa6..34e926e4582be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -30,8 +30,6 @@ private[sql] case class LocalTableScan( output: Seq[Attribute], rows: Seq[InternalRow]) extends LeafNode { - override protected[sql] val trackNumOfRowsEnabled = true - private lazy val rdd = sqlContext.sparkContext.parallelize(rows) protected override def doExecute(): RDD[InternalRow] = rdd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 9ba5cf2d2b39e..72f5450510a10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -80,23 +80,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ super.makeCopy(newArgs) } - /** - * Whether track the number of rows output by this SparkPlan - */ - protected[sql] def trackNumOfRowsEnabled: Boolean = false - - private lazy val defaultMetrics: Map[String, SQLMetric[_, _]] = - if (trackNumOfRowsEnabled) { - Map("numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) - } - else { - Map.empty - } - /** * Return all metrics containing metrics of this SparkPlan. */ - private[sql] def metrics: Map[String, SQLMetric[_, _]] = defaultMetrics + private[sql] def metrics: Map[String, SQLMetric[_, _]] = Map.empty /** * Return a LongSQLMetric according to the name. @@ -150,15 +137,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } RDDOperationScope.withScope(sparkContext, nodeName, false, true) { prepare() - if (trackNumOfRowsEnabled) { - val numRows = longMetric("numRows") - doExecute().map { row => - numRows += 1 - row - } - } else { - doExecute() - } + doExecute() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index ad428ad663f30..ab26f9c58aa2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.StructType case class SortBasedAggregate( @@ -38,6 +39,10 @@ case class SortBasedAggregate( child: SparkPlan) extends UnaryNode { + override private[sql] lazy val metrics = Map( + "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def outputsUnsafeRows: Boolean = false override def canProcessUnsafeRows: Boolean = false @@ -63,6 +68,8 @@ case class SortBasedAggregate( } protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + val numInputRows = longMetric("numInputRows") + val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitions { iter => // Because the constructor of an aggregation iterator will read at least the first row, // we need to get the value of iter.hasNext first. @@ -84,10 +91,13 @@ case class SortBasedAggregate( newProjection _, child.output, iter, - outputsUnsafeRows) + outputsUnsafeRows, + numInputRows, + numOutputRows) if (!hasInput && groupingExpressions.isEmpty) { // There is no input and there is no grouping expressions. // We need to output a single row as the output. + numOutputRows += 1 Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) } else { outputIter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 67ebafde25ad3..73d50e07cf0b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2} +import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.unsafe.KVIterator /** @@ -37,7 +38,9 @@ class SortBasedAggregationIterator( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean) + outputsUnsafeRows: Boolean, + numInputRows: LongSQLMetric, + numOutputRows: LongSQLMetric) extends AggregationIterator( groupingKeyAttributes, valueAttributes, @@ -103,6 +106,7 @@ class SortBasedAggregationIterator( // Get the grouping key. val groupingKey = inputKVIterator.getKey val currentRow = inputKVIterator.getValue + numInputRows += 1 // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { @@ -137,7 +141,7 @@ class SortBasedAggregationIterator( val outputRow = generateOutput(currentGroupingKey, sortBasedAggregationBuffer) // Initialize buffer values for the next group. initializeBuffer(sortBasedAggregationBuffer) - + numOutputRows += 1 outputRow } else { // no more result @@ -151,7 +155,7 @@ class SortBasedAggregationIterator( nextGroupingKey = inputKVIterator.getKey().copy() firstRowInNextGroup = inputKVIterator.getValue().copy() - + numInputRows += 1 sortedInputHasNewGroup = true } else { // This inputIter is empty. @@ -181,7 +185,9 @@ object SortBasedAggregationIterator { newProjection: (Seq[Expression], Seq[Attribute]) => Projection, inputAttributes: Seq[Attribute], inputIter: Iterator[InternalRow], - outputsUnsafeRows: Boolean): SortBasedAggregationIterator = { + outputsUnsafeRows: Boolean, + numInputRows: LongSQLMetric, + numOutputRows: LongSQLMetric): SortBasedAggregationIterator = { val kvIterator = if (UnsafeProjection.canSupport(groupingExprs)) { AggregationIterator.unsafeKVIterator( groupingExprs, @@ -202,7 +208,9 @@ object SortBasedAggregationIterator { initialInputBufferOffset, resultExpressions, newMutableProjection, - outputsUnsafeRows) + outputsUnsafeRows, + numInputRows, + numOutputRows) } // scalastyle:on } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 1694794a53d9f..6b5935a7ce296 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics case class TungstenAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], @@ -35,6 +36,10 @@ case class TungstenAggregate( child: SparkPlan) extends UnaryNode { + override private[sql] lazy val metrics = Map( + "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def outputsUnsafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = true @@ -61,6 +66,8 @@ case class TungstenAggregate( } protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + val numInputRows = longMetric("numInputRows") + val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitions { iter => val hasInput = iter.hasNext if (!hasInput && groupingExpressions.nonEmpty) { @@ -78,9 +85,12 @@ case class TungstenAggregate( newMutableProjection, child.output, iter, - testFallbackStartsAt) + testFallbackStartsAt, + numInputRows, + numOutputRows) if (!hasInput && groupingExpressions.isEmpty) { + numOutputRows += 1 Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) } else { aggregationIterator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 32160906c3bc8..1f383dd04482f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap} +import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.sql.types.StructType /** @@ -83,7 +84,9 @@ class TungstenAggregationIterator( newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), originalInputAttributes: Seq[Attribute], inputIter: Iterator[InternalRow], - testFallbackStartsAt: Option[Int]) + testFallbackStartsAt: Option[Int], + numInputRows: LongSQLMetric, + numOutputRows: LongSQLMetric) extends Iterator[UnsafeRow] with Logging { /////////////////////////////////////////////////////////////////////////// @@ -352,6 +355,7 @@ class TungstenAggregationIterator( private def processInputs(): Unit = { while (!sortBased && inputIter.hasNext) { val newInput = inputIter.next() + numInputRows += 1 val groupingKey = groupProjection.apply(newInput) val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) if (buffer == null) { @@ -371,6 +375,7 @@ class TungstenAggregationIterator( var i = 0 while (!sortBased && inputIter.hasNext) { val newInput = inputIter.next() + numInputRows += 1 val groupingKey = groupProjection.apply(newInput) val buffer: UnsafeRow = if (i < fallbackStartsAt) { hashMap.getAggregationBufferFromUnsafeRow(groupingKey) @@ -439,6 +444,7 @@ class TungstenAggregationIterator( // Process the rest of input rows. while (inputIter.hasNext) { val newInput = inputIter.next() + numInputRows += 1 val groupingKey = groupProjection.apply(newInput) buffer.copyFrom(initialAggregationBuffer) processRow(buffer, newInput) @@ -462,6 +468,7 @@ class TungstenAggregationIterator( // Insert the rest of input rows. while (inputIter.hasNext) { val newInput = inputIter.next() + numInputRows += 1 val groupingKey = groupProjection.apply(newInput) bufferExtractor(newInput) externalSorter.insertKV(groupingKey, buffer) @@ -657,7 +664,7 @@ class TungstenAggregationIterator( TaskContext.get().internalMetricsToAccumulators( InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemory) } - + numOutputRows += 1 res } else { // no more result diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index bf2de244c8e4a..247c900baae9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -41,11 +41,20 @@ import org.apache.spark.{HashPartitioner, SparkEnv} case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) + override private[sql] lazy val metrics = Map( + "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) + @transient lazy val buildProjection = newMutableProjection(projectList, child.output) - protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => - val reusableProjection = buildProjection() - iter.map(reusableProjection) + protected override def doExecute(): RDD[InternalRow] = { + val numRows = longMetric("numRows") + child.execute().mapPartitions { iter => + val reusableProjection = buildProjection() + iter.map { row => + numRows += 1 + reusableProjection(row) + } + } } override def outputOrdering: Seq[SortOrder] = child.outputOrdering @@ -57,19 +66,28 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends */ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { + override private[sql] lazy val metrics = Map( + "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) + override def outputsUnsafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = true override def canProcessSafeRows: Boolean = true override def output: Seq[Attribute] = projectList.map(_.toAttribute) - protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => - this.transformAllExpressions { - case CreateStruct(children) => CreateStructUnsafe(children) - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + protected override def doExecute(): RDD[InternalRow] = { + val numRows = longMetric("numRows") + child.execute().mapPartitions { iter => + this.transformAllExpressions { + case CreateStruct(children) => CreateStructUnsafe(children) + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + } + val project = UnsafeProjection.create(projectList, child.output) + iter.map { row => + numRows += 1 + project(row) + } } - val project = UnsafeProjection.create(projectList, child.output) - iter.map(project) } override def outputOrdering: Seq[SortOrder] = child.outputOrdering diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index f7a68e4f5d445..2e108cb814516 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.ThreadUtils import org.apache.spark.{InternalAccumulator, TaskContext} @@ -45,7 +46,10 @@ case class BroadcastHashJoin( right: SparkPlan) extends BinaryNode with HashJoin { - override protected[sql] val trackNumOfRowsEnabled = true + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) val timeout: Duration = { val timeoutValue = sqlContext.conf.broadcastTimeout @@ -65,6 +69,11 @@ case class BroadcastHashJoin( // for the same query. @transient private lazy val broadcastFuture = { + val numBuildRows = buildSide match { + case BuildLeft => longMetric("numLeftRows") + case BuildRight => longMetric("numRightRows") + } + // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) future { @@ -73,8 +82,15 @@ case class BroadcastHashJoin( SQLExecution.withExecutionId(sparkContext, executionId) { // Note that we use .execute().collect() because we don't want to convert data to Scala // types - val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.size) + val input: Array[InternalRow] = buildPlan.execute().map { row => + numBuildRows += 1 + row.copy() + }.collect() + // The following line doesn't run in a job so we cannot track the metric value. However, we + // have already tracked it in the above lines. So here we can use + // `SQLMetrics.nullLongMetric` to ignore it. + val hashed = HashedRelation( + input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size) sparkContext.broadcast(hashed) } }(BroadcastHashJoin.broadcastHashJoinExecutionContext) @@ -85,6 +101,12 @@ case class BroadcastHashJoin( } protected override def doExecute(): RDD[InternalRow] = { + val numStreamedRows = buildSide match { + case BuildLeft => longMetric("numRightRows") + case BuildRight => longMetric("numLeftRows") + } + val numOutputRows = longMetric("numOutputRows") + val broadcastRelation = Await.result(broadcastFuture, timeout) streamedPlan.execute().mapPartitions { streamedIter => @@ -95,7 +117,7 @@ case class BroadcastHashJoin( InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) case _ => } - hashJoin(streamedIter, hashedRelation) + hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index a3626de49aeab..69a8b95eaa7ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.{InternalAccumulator, TaskContext} /** @@ -45,6 +46,11 @@ case class BroadcastHashOuterJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode with HashOuterJoin { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + val timeout = { val timeoutValue = sqlContext.conf.broadcastTimeout if (timeoutValue < 0) { @@ -63,6 +69,14 @@ case class BroadcastHashOuterJoin( // for the same query. @transient private lazy val broadcastFuture = { + val numBuildRows = joinType match { + case RightOuter => longMetric("numLeftRows") + case LeftOuter => longMetric("numRightRows") + case x => + throw new IllegalArgumentException( + s"HashOuterJoin should not take $x as the JoinType") + } + // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) future { @@ -71,8 +85,15 @@ case class BroadcastHashOuterJoin( SQLExecution.withExecutionId(sparkContext, executionId) { // Note that we use .execute().collect() because we don't want to convert data to Scala // types - val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size) + val input: Array[InternalRow] = buildPlan.execute().map { row => + numBuildRows += 1 + row.copy() + }.collect() + // The following line doesn't run in a job so we cannot track the metric value. However, we + // have already tracked it in the above lines. So here we can use + // `SQLMetrics.nullLongMetric` to ignore it. + val hashed = HashedRelation( + input.iterator, SQLMetrics.nullLongMetric, buildKeyGenerator, input.size) sparkContext.broadcast(hashed) } }(BroadcastHashJoin.broadcastHashJoinExecutionContext) @@ -83,6 +104,15 @@ case class BroadcastHashOuterJoin( } override def doExecute(): RDD[InternalRow] = { + val numStreamedRows = joinType match { + case RightOuter => longMetric("numRightRows") + case LeftOuter => longMetric("numLeftRows") + case x => + throw new IllegalArgumentException( + s"HashOuterJoin should not take $x as the JoinType") + } + val numOutputRows = longMetric("numOutputRows") + val broadcastRelation = Await.result(broadcastFuture, timeout) streamedPlan.execute().mapPartitions { streamedIter => @@ -101,16 +131,18 @@ case class BroadcastHashOuterJoin( joinType match { case LeftOuter => streamedIter.flatMap(currentRow => { + numStreamedRows += 1 val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj) + leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows) }) case RightOuter => streamedIter.flatMap(currentRow => { + numStreamedRows += 1 val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj) + rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows) }) case x => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 5bd06fbdca605..78a8c16c62bca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -37,18 +38,31 @@ case class BroadcastLeftSemiJoinHash( right: SparkPlan, condition: Option[Expression]) extends BinaryNode with HashSemiJoin { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + protected override def doExecute(): RDD[InternalRow] = { - val input = right.execute().map(_.copy()).collect() + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + + val input = right.execute().map { row => + numRightRows += 1 + row.copy() + }.collect() if (condition.isEmpty) { - val hashSet = buildKeyHashSet(input.toIterator) + val hashSet = buildKeyHashSet(input.toIterator, SQLMetrics.nullLongMetric) val broadcastedRelation = sparkContext.broadcast(hashSet) left.execute().mapPartitions { streamIter => - hashSemiJoin(streamIter, broadcastedRelation.value) + hashSemiJoin(streamIter, numLeftRows, broadcastedRelation.value, numOutputRows) } } else { - val hashRelation = HashedRelation(input.toIterator, rightKeyGenerator, input.size) + val hashRelation = + HashedRelation(input.toIterator, SQLMetrics.nullLongMetric, rightKeyGenerator, input.size) val broadcastedRelation = sparkContext.broadcast(hashRelation) left.execute().mapPartitions { streamIter => @@ -59,7 +73,7 @@ case class BroadcastLeftSemiJoinHash( InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) case _ => } - hashSemiJoin(streamIter, hashedRelation) + hashSemiJoin(streamIter, numLeftRows, hashedRelation, numOutputRows) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 017a44b9ca863..28c88b1b03d02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.collection.CompactBuffer /** @@ -38,6 +39,11 @@ case class BroadcastNestedLoopJoin( condition: Option[Expression]) extends BinaryNode { // TODO: Override requiredChildDistribution. + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + /** BuildRight means the right relation <=> the broadcast relation. */ private val (streamed, broadcast) = buildSide match { case BuildRight => (left, right) @@ -75,9 +81,17 @@ case class BroadcastNestedLoopJoin( newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) protected override def doExecute(): RDD[InternalRow] = { + val (numStreamedRows, numBuildRows) = buildSide match { + case BuildRight => (longMetric("numLeftRows"), longMetric("numRightRows")) + case BuildLeft => (longMetric("numRightRows"), longMetric("numLeftRows")) + } + val numOutputRows = longMetric("numOutputRows") + val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map(_.copy()) - .collect().toIndexedSeq) + sparkContext.broadcast(broadcast.execute().map { row => + numBuildRows += 1 + row.copy() + }.collect().toIndexedSeq) /** All rows that either match both-way, or rows from streamed joined with nulls. */ val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => @@ -94,6 +108,7 @@ case class BroadcastNestedLoopJoin( streamedIter.foreach { streamedRow => var i = 0 var streamRowMatched = false + numStreamedRows += 1 while (i < broadcastedRelation.value.size) { val broadcastedRow = broadcastedRelation.value(i) @@ -162,6 +177,12 @@ case class BroadcastNestedLoopJoin( // TODO: Breaks lineage. sparkContext.union( - matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls)) + matchesOrStreamedRowsWithNulls.flatMap(_._1), + sparkContext.makeRDD(broadcastRowsWithNulls) + ).map { row => + // `broadcastRowsWithNulls` doesn't run in a job so that we have to track numOutputRows here. + numOutputRows += 1 + row + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index 261b4724159fb..2115f40702286 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -22,6 +22,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -30,13 +31,31 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + protected override def doExecute(): RDD[InternalRow] = { - val leftResults = left.execute().map(_.copy()) - val rightResults = right.execute().map(_.copy()) + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + + val leftResults = left.execute().map { row => + numLeftRows += 1 + row.copy() + } + val rightResults = right.execute().map { row => + numRightRows += 1 + row.copy() + } leftResults.cartesian(rightResults).mapPartitions { iter => val joinedRow = new JoinedRow - iter.map(r => joinedRow(r._1, r._2)) + iter.map { r => + numOutputRows += 1 + joinedRow(r._1, r._2) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 22d46d1c3e3b7..7ce4a517838cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.LongSQLMetric trait HashJoin { @@ -69,7 +70,9 @@ trait HashJoin { protected def hashJoin( streamIter: Iterator[InternalRow], - hashedRelation: HashedRelation): Iterator[InternalRow] = + numStreamRows: LongSQLMetric, + hashedRelation: HashedRelation, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { new Iterator[InternalRow] { private[this] var currentStreamedRow: InternalRow = _ @@ -98,6 +101,7 @@ trait HashJoin { case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) } currentMatchPosition += 1 + numOutputRows += 1 resultProjection(ret) } @@ -113,6 +117,7 @@ trait HashJoin { while (currentHashMatches == null && streamIter.hasNext) { currentStreamedRow = streamIter.next() + numStreamRows += 1 val key = joinKeys(currentStreamedRow) if (!key.anyNull) { currentHashMatches = hashedRelation.get(key) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 701bd3cd86372..66903347c88c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.util.collection.CompactBuffer @DeveloperApi @@ -114,22 +115,28 @@ trait HashOuterJoin { key: InternalRow, joinedRow: JoinedRow, rightIter: Iterable[InternalRow], - resultProjection: InternalRow => InternalRow): Iterator[InternalRow] = { + resultProjection: InternalRow => InternalRow, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { val temp = if (rightIter != null) { rightIter.collect { - case r if boundCondition(joinedRow.withRight(r)) => resultProjection(joinedRow).copy() + case r if boundCondition(joinedRow.withRight(r)) => { + numOutputRows += 1 + resultProjection(joinedRow).copy() + } } } else { List.empty } if (temp.isEmpty) { + numOutputRows += 1 resultProjection(joinedRow.withRight(rightNullRow)) :: Nil } else { temp } } else { + numOutputRows += 1 resultProjection(joinedRow.withRight(rightNullRow)) :: Nil } } @@ -140,22 +147,28 @@ trait HashOuterJoin { key: InternalRow, leftIter: Iterable[InternalRow], joinedRow: JoinedRow, - resultProjection: InternalRow => InternalRow): Iterator[InternalRow] = { + resultProjection: InternalRow => InternalRow, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { val temp = if (leftIter != null) { leftIter.collect { - case l if boundCondition(joinedRow.withLeft(l)) => resultProjection(joinedRow).copy() + case l if boundCondition(joinedRow.withLeft(l)) => { + numOutputRows += 1 + resultProjection(joinedRow).copy() + } } } else { List.empty } if (temp.isEmpty) { + numOutputRows += 1 resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil } else { temp } } else { + numOutputRows += 1 resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil } } @@ -164,7 +177,7 @@ trait HashOuterJoin { protected[this] def fullOuterIterator( key: InternalRow, leftIter: Iterable[InternalRow], rightIter: Iterable[InternalRow], - joinedRow: JoinedRow): Iterator[InternalRow] = { + joinedRow: JoinedRow, numOutputRows: LongSQLMetric): Iterator[InternalRow] = { if (!key.anyNull) { // Store the positions of records in right, if one of its associated row satisfy // the join condition. @@ -177,6 +190,7 @@ trait HashOuterJoin { // append them directly case (r, idx) if boundCondition(joinedRow.withRight(r)) => + numOutputRows += 1 matched = true // if the row satisfy the join condition, add its index into the matched set rightMatchedSet.add(idx) @@ -189,6 +203,7 @@ trait HashOuterJoin { // as we don't know whether we need to append it until finish iterating all // of the records in right side. // If we didn't get any proper row, then append a single row with empty right. + numOutputRows += 1 joinedRow.withRight(rightNullRow).copy() }) } ++ rightIter.zipWithIndex.collect { @@ -197,12 +212,15 @@ trait HashOuterJoin { // Re-visiting the records in right, and append additional row with empty left, if its not // in the matched set. case (r, idx) if !rightMatchedSet.contains(idx) => + numOutputRows += 1 joinedRow(leftNullRow, r).copy() } } else { leftIter.iterator.map[InternalRow] { l => + numOutputRows += 1 joinedRow(l, rightNullRow).copy() } ++ rightIter.iterator.map[InternalRow] { r => + numOutputRows += 1 joinedRow(leftNullRow, r).copy() } } @@ -211,10 +229,12 @@ trait HashOuterJoin { // This is only used by FullOuter protected[this] def buildHashTable( iter: Iterator[InternalRow], + numIterRows: LongSQLMetric, keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = { val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]]() while (iter.hasNext) { val currentRow = iter.next() + numIterRows += 1 val rowKey = keyGenerator(currentRow) var existingMatchList = hashTable.get(rowKey) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 82dd6eb7e7ed0..beb141ade616d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.LongSQLMetric trait HashSemiJoin { @@ -61,13 +62,15 @@ trait HashSemiJoin { @transient private lazy val boundCondition = newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - protected def buildKeyHashSet(buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = { + protected def buildKeyHashSet( + buildIter: Iterator[InternalRow], numBuildRows: LongSQLMetric): java.util.Set[InternalRow] = { val hashSet = new java.util.HashSet[InternalRow]() // Create a Hash set of buildKeys val rightKey = rightKeyGenerator while (buildIter.hasNext) { val currentRow = buildIter.next() + numBuildRows += 1 val rowKey = rightKey(currentRow) if (!rowKey.anyNull) { val keyExists = hashSet.contains(rowKey) @@ -82,25 +85,35 @@ trait HashSemiJoin { protected def hashSemiJoin( streamIter: Iterator[InternalRow], - hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = { + numStreamRows: LongSQLMetric, + hashSet: java.util.Set[InternalRow], + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { val joinKeys = leftKeyGenerator streamIter.filter(current => { + numStreamRows += 1 val key = joinKeys(current) - !key.anyNull && hashSet.contains(key) + val r = !key.anyNull && hashSet.contains(key) + if (r) numOutputRows += 1 + r }) } protected def hashSemiJoin( streamIter: Iterator[InternalRow], - hashedRelation: HashedRelation): Iterator[InternalRow] = { + numStreamRows: LongSQLMetric, + hashedRelation: HashedRelation, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { val joinKeys = leftKeyGenerator val joinedRow = new JoinedRow streamIter.filter { current => + numStreamRows += 1 val key = joinKeys(current) lazy val rowBuffer = hashedRelation.get(key) - !key.anyNull && rowBuffer != null && rowBuffer.exists { + val r = !key.anyNull && rowBuffer != null && rowBuffer.exists { (row: InternalRow) => boundCondition(joinedRow(current, row)) } + if (r) numOutputRows += 1 + r } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 63d35d0f02622..c1bc7947aa39c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -25,6 +25,7 @@ import org.apache.spark.shuffle.ShuffleMemoryManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} @@ -112,11 +113,13 @@ private[joins] object HashedRelation { def apply( input: Iterator[InternalRow], + numInputRows: LongSQLMetric, keyGenerator: Projection, sizeEstimate: Int = 64): HashedRelation = { if (keyGenerator.isInstanceOf[UnsafeProjection]) { - return UnsafeHashedRelation(input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) + return UnsafeHashedRelation( + input, numInputRows, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) } // TODO: Use Spark's HashMap implementation. @@ -130,6 +133,7 @@ private[joins] object HashedRelation { // Create a mapping of buildKeys -> rows while (input.hasNext) { currentRow = input.next() + numInputRows += 1 val rowKey = keyGenerator(currentRow) if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) @@ -331,6 +335,7 @@ private[joins] object UnsafeHashedRelation { def apply( input: Iterator[InternalRow], + numInputRows: LongSQLMetric, keyGenerator: UnsafeProjection, sizeEstimate: Int): HashedRelation = { @@ -340,6 +345,7 @@ private[joins] object UnsafeHashedRelation { // Create a mapping of buildKeys -> rows while (input.hasNext) { val unsafeRow = input.next().asInstanceOf[UnsafeRow] + numInputRows += 1 val rowKey = keyGenerator(unsafeRow) if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index 4443455ef11fe..ad6362542f2ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -35,6 +36,11 @@ case class LeftSemiJoinBNL( extends BinaryNode { // TODO: Override requiredChildDistribution. + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def outputPartitioning: Partitioning = streamed.outputPartitioning override def output: Seq[Attribute] = left.output @@ -52,13 +58,21 @@ case class LeftSemiJoinBNL( newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) protected override def doExecute(): RDD[InternalRow] = { + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + sparkContext.broadcast(broadcast.execute().map { row => + numRightRows += 1 + row.copy() + }.collect().toIndexedSeq) streamed.execute().mapPartitions { streamedIter => val joinedRow = new JoinedRow streamedIter.filter(streamedRow => { + numLeftRows += 1 var i = 0 var matched = false @@ -69,6 +83,9 @@ case class LeftSemiJoinBNL( } i += 1 } + if (matched) { + numOutputRows += 1 + } matched }) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 68ccd34d8ed9b..18808adaac63f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, ClusteredDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -37,19 +38,28 @@ case class LeftSemiJoinHash( right: SparkPlan, condition: Option[Expression]) extends BinaryNode with HashSemiJoin { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def outputPartitioning: Partitioning = left.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) => if (condition.isEmpty) { - val hashSet = buildKeyHashSet(buildIter) - hashSemiJoin(streamIter, hashSet) + val hashSet = buildKeyHashSet(buildIter, numRightRows) + hashSemiJoin(streamIter, numLeftRows, hashSet, numOutputRows) } else { - val hashRelation = HashedRelation(buildIter, rightKeyGenerator) - hashSemiJoin(streamIter, hashRelation) + val hashRelation = HashedRelation(buildIter, numRightRows, rightKeyGenerator) + hashSemiJoin(streamIter, numLeftRows, hashRelation, numOutputRows) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index c923dc837c449..fc8c9439a6f07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -38,7 +39,10 @@ case class ShuffledHashJoin( right: SparkPlan) extends BinaryNode with HashJoin { - override protected[sql] val trackNumOfRowsEnabled = true + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) override def outputPartitioning: Partitioning = PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) @@ -47,9 +51,15 @@ case class ShuffledHashJoin( ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { + val (numBuildRows, numStreamedRows) = buildSide match { + case BuildLeft => (longMetric("numLeftRows"), longMetric("numRightRows")) + case BuildRight => (longMetric("numRightRows"), longMetric("numLeftRows")) + } + val numOutputRows = longMetric("numOutputRows") + buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashed = HashedRelation(buildIter, buildSideKeyGenerator) - hashJoin(streamIter, hashed) + val hashed = HashedRelation(buildIter, numBuildRows, buildSideKeyGenerator) + hashJoin(streamIter, numStreamedRows, hashed, numOutputRows) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index 6a8c35efca8f4..ed282f98b7d71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -41,6 +42,11 @@ case class ShuffledHashOuterJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode with HashOuterJoin { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil @@ -53,39 +59,48 @@ case class ShuffledHashOuterJoin( } protected override def doExecute(): RDD[InternalRow] = { + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + val joinedRow = new JoinedRow() left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => // TODO this probably can be replaced by external sort (sort merged join?) joinType match { case LeftOuter => - val hashed = HashedRelation(rightIter, buildKeyGenerator) + val hashed = HashedRelation(rightIter, numRightRows, buildKeyGenerator) val keyGenerator = streamedKeyGenerator val resultProj = resultProjection leftIter.flatMap( currentRow => { + numLeftRows += 1 val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj) + leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj, numOutputRows) }) case RightOuter => - val hashed = HashedRelation(leftIter, buildKeyGenerator) + val hashed = HashedRelation(leftIter, numLeftRows, buildKeyGenerator) val keyGenerator = streamedKeyGenerator val resultProj = resultProjection rightIter.flatMap ( currentRow => { + numRightRows += 1 val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj) + rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj, numOutputRows) }) case FullOuter => // TODO(davies): use UnsafeRow - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + val leftHashTable = + buildHashTable(leftIter, numLeftRows, newProjection(leftKeys, left.output)) + val rightHashTable = + buildHashTable(rightIter, numRightRows, newProjection(rightKeys, right.output)) (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => fullOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), rightHashTable.getOrElse(key, EMPTY_LIST), - joinedRow) + joinedRow, + numOutputRows) } case x => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 6d656ea2849a9..6b7322671d6b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} +import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} /** * :: DeveloperApi :: @@ -37,6 +38,11 @@ case class SortMergeJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def output: Seq[Attribute] = left.output ++ right.output override def outputPartitioning: Partitioning = @@ -70,6 +76,10 @@ case class SortMergeJoin( } protected override def doExecute(): RDD[InternalRow] = { + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => new RowIterator { // An ordering that can be used to compare keys from both sides. @@ -82,7 +92,9 @@ case class SortMergeJoin( rightKeyGenerator, keyOrdering, RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter) + numLeftRows, + RowIterator.fromScala(rightIter), + numRightRows ) private[this] val joinRow = new JoinedRow private[this] val resultProjection: (InternalRow) => InternalRow = { @@ -108,6 +120,7 @@ case class SortMergeJoin( if (currentLeftRow != null) { joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) currentMatchIdx += 1 + numOutputRows += 1 true } else { false @@ -144,7 +157,9 @@ private[joins] class SortMergeJoinScanner( bufferedKeyGenerator: Projection, keyOrdering: Ordering[InternalRow], streamedIter: RowIterator, - bufferedIter: RowIterator) { + numStreamedRows: LongSQLMetric, + bufferedIter: RowIterator, + numBufferedRows: LongSQLMetric) { private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ private[this] var bufferedRow: InternalRow = _ @@ -269,6 +284,7 @@ private[joins] class SortMergeJoinScanner( if (streamedIter.advanceNext()) { streamedRow = streamedIter.getRow streamedRowKey = streamedKeyGenerator(streamedRow) + numStreamedRows += 1 true } else { streamedRow = null @@ -286,6 +302,7 @@ private[joins] class SortMergeJoinScanner( while (!foundRow && bufferedIter.advanceNext()) { bufferedRow = bufferedIter.getRow bufferedRowKey = bufferedKeyGenerator(bufferedRow) + numBufferedRows += 1 foundRow = !bufferedRowKey.anyNull } if (!foundRow) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index 5326966b07a66..dea9e5e580a1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} +import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} /** * :: DeveloperApi :: @@ -40,6 +41,11 @@ case class SortMergeOuterJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def output: Seq[Attribute] = { joinType match { case LeftOuter => @@ -108,6 +114,10 @@ case class SortMergeOuterJoin( } override def doExecute(): RDD[InternalRow] = { + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => // An ordering that can be used to compare keys from both sides. val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) @@ -133,10 +143,13 @@ case class SortMergeOuterJoin( bufferedKeyGenerator = createRightKeyGenerator(), keyOrdering, streamedIter = RowIterator.fromScala(leftIter), - bufferedIter = RowIterator.fromScala(rightIter) + numLeftRows, + bufferedIter = RowIterator.fromScala(rightIter), + numRightRows ) val rightNullRow = new GenericInternalRow(right.output.length) - new LeftOuterIterator(smjScanner, rightNullRow, boundCondition, resultProj).toScala + new LeftOuterIterator( + smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows).toScala case RightOuter => val smjScanner = new SortMergeJoinScanner( @@ -144,10 +157,13 @@ case class SortMergeOuterJoin( bufferedKeyGenerator = createLeftKeyGenerator(), keyOrdering, streamedIter = RowIterator.fromScala(rightIter), - bufferedIter = RowIterator.fromScala(leftIter) + numRightRows, + bufferedIter = RowIterator.fromScala(leftIter), + numLeftRows ) val leftNullRow = new GenericInternalRow(left.output.length) - new RightOuterIterator(smjScanner, leftNullRow, boundCondition, resultProj).toScala + new RightOuterIterator( + smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala case x => throw new IllegalArgumentException( @@ -162,7 +178,8 @@ private class LeftOuterIterator( smjScanner: SortMergeJoinScanner, rightNullRow: InternalRow, boundCondition: InternalRow => Boolean, - resultProj: InternalRow => InternalRow + resultProj: InternalRow => InternalRow, + numRows: LongSQLMetric ) extends RowIterator { private[this] val joinedRow: JoinedRow = new JoinedRow() private[this] var rightIdx: Int = 0 @@ -198,7 +215,9 @@ private class LeftOuterIterator( } override def advanceNext(): Boolean = { - advanceRightUntilBoundConditionSatisfied() || advanceLeft() + val r = advanceRightUntilBoundConditionSatisfied() || advanceLeft() + if (r) numRows += 1 + r } override def getRow: InternalRow = resultProj(joinedRow) @@ -208,7 +227,8 @@ private class RightOuterIterator( smjScanner: SortMergeJoinScanner, leftNullRow: InternalRow, boundCondition: InternalRow => Boolean, - resultProj: InternalRow => InternalRow + resultProj: InternalRow => InternalRow, + numRows: LongSQLMetric ) extends RowIterator { private[this] val joinedRow: JoinedRow = new JoinedRow() private[this] var leftIdx: Int = 0 @@ -244,7 +264,9 @@ private class RightOuterIterator( } override def advanceNext(): Boolean = { - advanceLeftUntilBoundConditionSatisfied() || advanceRight() + val r = advanceLeftUntilBoundConditionSatisfied() || advanceRight() + if (r) numRows += 1 + r } override def getRow: InternalRow = resultProj(joinedRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 1b51a5e5c8a8e..7a2a98ec18cb8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -112,4 +112,10 @@ private[sql] object SQLMetrics { sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc } + + /** + * A metric that its value will be ignored. Use this one when we need a metric parameter but don't + * care about the value. + */ + val nullLongMetric = new LongSQLMetric("null") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 8b1a9b21a96b9..a1fa2c3864bdb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -22,6 +22,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.collection.CompactBuffer @@ -35,7 +37,8 @@ class HashedRelationSuite extends SparkFunSuite { test("GeneralHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) - val hashed = HashedRelation(data.iterator, keyProjection) + val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data") + val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) @@ -45,11 +48,13 @@ class HashedRelationSuite extends SparkFunSuite { val data2 = CompactBuffer[InternalRow](data(2)) data2 += data(2) assert(hashed.get(data(2)) === data2) + assert(numDataRows.value.value === data.length) } test("UniqueKeyHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2)) - val hashed = HashedRelation(data.iterator, keyProjection) + val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data") + val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) @@ -62,17 +67,19 @@ class HashedRelationSuite extends SparkFunSuite { assert(uniqHashed.getValue(data(1)) === data(1)) assert(uniqHashed.getValue(data(2)) === data(2)) assert(uniqHashed.getValue(InternalRow(10)) === null) + assert(numDataRows.value.value === data.length) } test("UnsafeHashedRelation") { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) + val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data") val toUnsafe = UnsafeProjection.create(schema) val unsafeData = data.map(toUnsafe(_).copy()).toArray val buildKey = Seq(BoundReference(0, IntegerType, false)) val keyGenerator = UnsafeProjection.create(buildKey) - val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1) + val hashed = UnsafeHashedRelation(unsafeData.iterator, numDataRows, keyGenerator, 1) assert(hashed.isInstanceOf[UnsafeHashedRelation]) assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) @@ -94,5 +101,6 @@ class HashedRelationSuite extends SparkFunSuite { assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) assert(hashed2.get(toUnsafe(InternalRow(10))) === null) assert(hashed2.get(unsafeData(2)) === data2) + assert(numDataRows.value.value === data.length) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 953284c98b208..7383d3f8fe024 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -25,15 +25,24 @@ import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm._ import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.ui.SparkPlanGraph +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} import org.apache.spark.util.Utils +class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { -class SQLMetricsSuite extends SparkFunSuite { + override val sqlContext = TestSQLContext + + import sqlContext.implicits._ test("LongSQLMetric should not box Long") { val l = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "long") - val f = () => { l += 1L } + val f = () => { + l += 1L + l.add(1L) + } BoxingFinder.getClassReader(f.getClass).foreach { cl => val boxingFinder = new BoxingFinder() cl.accept(boxingFinder, 0) @@ -51,6 +60,441 @@ class SQLMetricsSuite extends SparkFunSuite { assert(boxingFinder.boxingInvokes.nonEmpty, "Found find boxing in this test") } } + + /** + * Call `df.collect()` and verify if the collected metrics are same as "expectedMetrics". + * + * @param df `DataFrame` to run + * @param expectedNumOfJobs number of jobs that will run + * @param expectedMetrics the expected metrics. The format is + * `nodeId -> (operatorName, metric name -> metric value)`. + */ + private def testSparkPlanMetrics( + df: DataFrame, + expectedNumOfJobs: Int, + expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { + val previousExecutionIds = TestSQLContext.listener.executionIdToData.keySet + df.collect() + TestSQLContext.sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = TestSQLContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = TestSQLContext.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change it to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= expectedNumOfJobs) + if (jobs.size == expectedNumOfJobs) { + // If we can track all jobs, check the metric values + val metricValues = TestSQLContext.listener.getExecutionMetrics(executionId) + val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node => + expectedMetrics.contains(node.id) + }.map { node => + val nodeMetrics = node.metrics.map { metric => + val metricValue = metricValues(metric.accumulatorId) + (metric.name, metricValue) + }.toMap + (node.id, node.name -> nodeMetrics) + }.toMap + assert(expectedMetrics === actualMetrics) + } else { + // TODO Remove this "else" once we fix the race condition that missing the JobStarted event. + // Since we cannot track all jobs, the metric values could be wrong and we should not check + // them. + logWarning("Due to a race condition, we miss some jobs and cannot verify the metric values") + } + } + + test("Project metrics") { + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "false", + SQLConf.CODEGEN_ENABLED.key -> "false", + SQLConf.TUNGSTEN_ENABLED.key -> "false") { + // Assume the execution plan is + // PhysicalRDD(nodeId = 1) -> Project(nodeId = 0) + val df = TestData.person.select('name) + testSparkPlanMetrics(df, 1, Map( + 0L ->("Project", Map( + "number of rows" -> 2L))) + ) + } + } + + test("TungstenProject metrics") { + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "true", + SQLConf.CODEGEN_ENABLED.key -> "true", + SQLConf.TUNGSTEN_ENABLED.key -> "true") { + // Assume the execution plan is + // PhysicalRDD(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = TestData.person.select('name) + testSparkPlanMetrics(df, 1, Map( + 0L ->("TungstenProject", Map( + "number of rows" -> 2L))) + ) + } + } + + test("Filter metrics") { + // Assume the execution plan is + // PhysicalRDD(nodeId = 1) -> Filter(nodeId = 0) + val df = TestData.person.filter('age < 25) + testSparkPlanMetrics(df, 1, Map( + 0L -> ("Filter", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + } + + test("Aggregate metrics") { + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "false", + SQLConf.CODEGEN_ENABLED.key -> "false", + SQLConf.TUNGSTEN_ENABLED.key -> "false") { + // Assume the execution plan is + // ... -> Aggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> Aggregate(nodeId = 0) + val df = TestData.testData2.groupBy().count() // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("Aggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 2L)), + 0L -> ("Aggregate", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + + // 2 partitions and each partition contains 2 keys + val df2 = TestData.testData2.groupBy('a).count() + testSparkPlanMetrics(df2, 1, Map( + 2L -> ("Aggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 4L)), + 0L -> ("Aggregate", Map( + "number of input rows" -> 4L, + "number of output rows" -> 3L))) + ) + } + } + + test("SortBasedAggregate metrics") { + // Because SortBasedAggregate may skip different rows if the number of partitions is different, + // this test should use the deterministic number of partitions. + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "false", + SQLConf.CODEGEN_ENABLED.key -> "true", + SQLConf.TUNGSTEN_ENABLED.key -> "true") { + // Assume the execution plan is + // ... -> SortBasedAggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> + // SortBasedAggregate(nodeId = 0) + val df = TestData.testData2.groupBy().count() // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("SortBasedAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 2L)), + 0L -> ("SortBasedAggregate", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + + // Assume the execution plan is + // ... -> SortBasedAggregate(nodeId = 3) -> TungstenExchange(nodeId = 2) + // -> ExternalSort(nodeId = 1)-> SortBasedAggregate(nodeId = 0) + // 2 partitions and each partition contains 2 keys + val df2 = TestData.testData2.groupBy('a).count() + testSparkPlanMetrics(df2, 1, Map( + 3L -> ("SortBasedAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 4L)), + 0L -> ("SortBasedAggregate", Map( + "number of input rows" -> 4L, + "number of output rows" -> 3L))) + ) + } + } + + test("TungstenAggregate metrics") { + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "true", + SQLConf.CODEGEN_ENABLED.key -> "true", + SQLConf.TUNGSTEN_ENABLED.key -> "true") { + // Assume the execution plan is + // ... -> TungstenAggregate(nodeId = 2) -> Exchange(nodeId = 1) + // -> TungstenAggregate(nodeId = 0) + val df = TestData.testData2.groupBy().count() // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("TungstenAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 2L)), + 0L -> ("TungstenAggregate", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + + // 2 partitions and each partition contains 2 keys + val df2 = TestData.testData2.groupBy('a).count() + testSparkPlanMetrics(df2, 1, Map( + 2L -> ("TungstenAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 4L)), + 0L -> ("TungstenAggregate", Map( + "number of input rows" -> 4L, + "number of output rows" -> 3L))) + ) + } + } + + test("SortMergeJoin metrics") { + // Because SortMergeJoin may skip different rows if the number of partitions is different, this + // test should use the deterministic number of partitions. + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { + val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("SortMergeJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of left rows" -> 4L, + "number of right rows" -> 2L, + "number of output rows" -> 4L))) + ) + } + } + } + + test("SortMergeOuterJoin metrics") { + // Because SortMergeOuterJoin may skip different rows if the number of partitions is different, + // this test should use the deterministic number of partitions. + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { + val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> SortMergeOuterJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 left JOIN testDataForJoin ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("SortMergeOuterJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of left rows" -> 6L, + "number of right rows" -> 2L, + "number of output rows" -> 8L))) + ) + + val df2 = sqlContext.sql( + "SELECT * FROM testDataForJoin right JOIN testData2 ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df2, 1, Map( + 1L -> ("SortMergeOuterJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of left rows" -> 2L, + "number of right rows" -> 6L, + "number of output rows" -> 8L))) + ) + } + } + } + + test("BroadcastHashJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key", "value") + // Assume the execution plan is + // ... -> BroadcastHashJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = df1.join(broadcast(df2), "key") + testSparkPlanMetrics(df, 2, Map( + 1L -> ("BroadcastHashJoin", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("ShuffledHashJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> ShuffledHashJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("ShuffledHashJoin", Map( + "number of left rows" -> 6L, + "number of right rows" -> 2L, + "number of output rows" -> 4L))) + ) + } + } + } + + test("ShuffledHashOuterJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") + val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") + // Assume the execution plan is + // ... -> ShuffledHashOuterJoin(nodeId = 0) + val df = df1.join(df2, $"key" === $"key2", "left_outer") + testSparkPlanMetrics(df, 1, Map( + 0L -> ("ShuffledHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 5L))) + ) + + val df3 = df1.join(df2, $"key" === $"key2", "right_outer") + testSparkPlanMetrics(df3, 1, Map( + 0L -> ("ShuffledHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 6L))) + ) + + val df4 = df1.join(df2, $"key" === $"key2", "outer") + testSparkPlanMetrics(df4, 1, Map( + 0L -> ("ShuffledHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 7L))) + ) + } + } + + test("BroadcastHashOuterJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") + val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") + // Assume the execution plan is + // ... -> BroadcastHashOuterJoin(nodeId = 0) + val df = df1.join(broadcast(df2), $"key" === $"key2", "left_outer") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("BroadcastHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 5L))) + ) + + val df3 = df1.join(broadcast(df2), $"key" === $"key2", "right_outer") + testSparkPlanMetrics(df3, 2, Map( + 0L -> ("BroadcastHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 6L))) + ) + } + } + + test("BroadcastNestedLoopJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { + val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 left JOIN testDataForJoin ON " + + "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a") + testSparkPlanMetrics(df, 3, Map( + 1L -> ("BroadcastNestedLoopJoin", Map( + "number of left rows" -> 12L, // left needs to be scanned twice + "number of right rows" -> 2L, + "number of output rows" -> 12L))) + ) + } + } + } + + test("BroadcastLeftSemiJoinHash metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + // Assume the execution plan is + // ... -> BroadcastLeftSemiJoinHash(nodeId = 0) + val df = df1.join(broadcast(df2), $"key" === $"key2", "leftsemi") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("BroadcastLeftSemiJoinHash", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("LeftSemiJoinHash metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + // Assume the execution plan is + // ... -> LeftSemiJoinHash(nodeId = 0) + val df = df1.join(df2, $"key" === $"key2", "leftsemi") + testSparkPlanMetrics(df, 1, Map( + 0L -> ("LeftSemiJoinHash", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("LeftSemiJoinBNL metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + // Assume the execution plan is + // ... -> LeftSemiJoinBNL(nodeId = 0) + val df = df1.join(df2, $"key" < $"key2", "leftsemi") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("LeftSemiJoinBNL", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("CartesianProduct metrics") { + val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 JOIN testDataForJoin") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("CartesianProduct", Map( + "number of left rows" -> 12L, // left needs to be scanned twice + "number of right rows" -> 12L, // right is read 6 times + "number of output rows" -> 12L))) + ) + } + } + + test("save metrics") { + withTempPath { file => + val previousExecutionIds = TestSQLContext.listener.executionIdToData.keySet + // Assume the execution plan is + // PhysicalRDD(nodeId = 0) + TestData.person.select('name).write.format("json").save(file.getAbsolutePath) + TestSQLContext.sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = TestSQLContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = TestSQLContext.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= 1) + val metricValues = TestSQLContext.listener.getExecutionMetrics(executionId) + // Because "save" will create a new DataFrame internally, we cannot get the real metric id. + // However, we still can check the value. + assert(metricValues.values.toSeq === Seq(2L)) + } + } + } private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String) From 520ad44b17f72e6465bf990f64b4e289f8a83447 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 11 Aug 2015 12:49:47 -0700 Subject: [PATCH 268/340] [SPARK-9750] [MLLIB] Improve equals on SparseMatrix and DenseMatrix Adds unit test for `equals` on `mllib.linalg.Matrix` class and `equals` to both `SparseMatrix` and `DenseMatrix`. Supports equality testing between `SparseMatrix` and `DenseMatrix`. mengxr Author: Feynman Liang Closes #8042 from feynmanliang/SPARK-9750 and squashes the following commits: bb70d5e [Feynman Liang] Breeze compare for dense matrices as well, in case other is sparse ab6f3c8 [Feynman Liang] Sparse matrix compare for equals 22782df [Feynman Liang] Add equality based on matrix semantics, not representation 78f9426 [Feynman Liang] Add casts 43d28fa [Feynman Liang] Fix failing test 6416fa0 [Feynman Liang] Add failing sparse matrix equals tests --- .../apache/spark/mllib/linalg/Matrices.scala | 8 ++++++-- .../spark/mllib/linalg/MatricesSuite.scala | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 1c858348bf20e..1139ce36d50b8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -257,8 +257,7 @@ class DenseMatrix( this(numRows, numCols, values, false) override def equals(o: Any): Boolean = o match { - case m: DenseMatrix => - m.numRows == numRows && m.numCols == numCols && Arrays.equals(toArray, m.toArray) + case m: Matrix => toBreeze == m.toBreeze case _ => false } @@ -519,6 +518,11 @@ class SparseMatrix( rowIndices: Array[Int], values: Array[Double]) = this(numRows, numCols, colPtrs, rowIndices, values, false) + override def equals(o: Any): Boolean = o match { + case m: Matrix => toBreeze == m.toBreeze + case _ => false + } + private[mllib] def toBreeze: BM[Double] = { if (!isTransposed) { new BSM[Double](values, numRows, numCols, colPtrs, rowIndices) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index a270ba2562db9..bfd6d5495f5e0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -74,6 +74,24 @@ class MatricesSuite extends SparkFunSuite { } } + test("equals") { + val dm1 = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)) + assert(dm1 === dm1) + assert(dm1 !== dm1.transpose) + + val dm2 = Matrices.dense(2, 2, Array(0.0, 2.0, 1.0, 3.0)) + assert(dm1 === dm2.transpose) + + val sm1 = dm1.asInstanceOf[DenseMatrix].toSparse + assert(sm1 === sm1) + assert(sm1 === dm1) + assert(sm1 !== sm1.transpose) + + val sm2 = dm2.asInstanceOf[DenseMatrix].toSparse + assert(sm1 === sm2.transpose) + assert(sm1 === dm2.transpose) + } + test("matrix copies are deep copies") { val m = 3 val n = 2 From 2a3be4ddf9d9527353f07ea0ab204ce17dbcba9a Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 11 Aug 2015 14:02:23 -0700 Subject: [PATCH 269/340] [SPARK-7726] Add import so Scaladoc doesn't fail. This is another import needed so Scala 2.11 doc generation doesn't fail. See SPARK-7726 for more detail. I tested this locally and the 2.11 install goes from failing to succeeding with this patch. Author: Patrick Wendell Closes #8095 from pwendell/scaladoc. --- .../spark/network/shuffle/protocol/mesos/RegisterDriver.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java index 1c28fc1dff246..94a61d6caadc4 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java @@ -23,6 +23,9 @@ import org.apache.spark.network.protocol.Encoders; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + /** * A message sent from the driver to register with the MesosExternalShuffleService. */ From 00c02728a6c6c4282c389ca90641dd78dd5e3d32 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 11 Aug 2015 14:04:09 -0700 Subject: [PATCH 270/340] [SPARK-9814] [SQL] EqualNotNull not passing to data sources MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Author: hyukjinkwon Author: 권혁진 Closes #8096 from HyukjinKwon/master. --- .../sql/execution/datasources/DataSourceStrategy.scala | 5 +++++ .../scala/org/apache/spark/sql/sources/filters.scala | 9 +++++++++ .../org/apache/spark/sql/sources/FilteredScanSuite.scala | 1 + 3 files changed, 15 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 78a4acdf4b1bf..2a4c40db8bb66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -349,6 +349,11 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { case expressions.EqualTo(Literal(v, _), a: Attribute) => Some(sources.EqualTo(a.name, v)) + case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) => + Some(sources.EqualNullSafe(a.name, v)) + case expressions.EqualNullSafe(Literal(v, _), a: Attribute) => + Some(sources.EqualNullSafe(a.name, v)) + case expressions.GreaterThan(a: Attribute, Literal(v, _)) => Some(sources.GreaterThan(a.name, v)) case expressions.GreaterThan(Literal(v, _), a: Attribute) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 4d942e4f9287a..3780cbbcc9631 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -36,6 +36,15 @@ abstract class Filter */ case class EqualTo(attribute: String, value: Any) extends Filter +/** + * Performs equality comparison, similar to [[EqualTo]]. However, this differs from [[EqualTo]] + * in that it returns `true` (rather than NULL) if both inputs are NULL, and `false` + * (rather than NULL) if one of the input is NULL and the other is not NULL. + * + * @since 1.5.0 + */ +case class EqualNullSafe(attribute: String, value: Any) extends Filter + /** * A filter that evaluates to `true` iff the attribute evaluates to a value * greater than `value`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 81b3a0f0c5b3a..5ef365797eace 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -56,6 +56,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL // Predicate test on integer column def translateFilterOnA(filter: Filter): Int => Boolean = filter match { case EqualTo("a", v) => (a: Int) => a == v + case EqualNullSafe("a", v) => (a: Int) => a == v case LessThan("a", v: Int) => (a: Int) => a < v case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v case GreaterThan("a", v: Int) => (a: Int) => a > v From f16bc68dfb25c7b746ae031a57840ace9bafa87f Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 11 Aug 2015 14:06:23 -0700 Subject: [PATCH 271/340] [SPARK-9824] [CORE] Fix the issue that InternalAccumulator leaks WeakReference `InternalAccumulator.create` doesn't call `registerAccumulatorForCleanup` to register itself with ContextCleaner, so `WeakReference`s for these accumulators in `Accumulators.originals` won't be removed. This PR added `registerAccumulatorForCleanup` for internal accumulators to avoid the memory leak. Author: zsxwing Closes #8108 from zsxwing/internal-accumulators-leak. --- .../scala/org/apache/spark/Accumulators.scala | 22 +++++++++++-------- .../org/apache/spark/scheduler/Stage.scala | 2 +- .../org/apache/spark/AccumulatorSuite.scala | 3 ++- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 064246dfa7fc3..c39c8667d013e 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -382,14 +382,18 @@ private[spark] object InternalAccumulator { * add to the same set of accumulators. We do this to report the distribution of accumulator * values across all tasks within each stage. */ - def create(): Seq[Accumulator[Long]] = { - Seq( - // Execution memory refers to the memory used by internal data structures created - // during shuffles, aggregations and joins. The value of this accumulator should be - // approximately the sum of the peak sizes across all such data structures created - // in this task. For SQL jobs, this only tracks all unsafe operators and ExternalSort. - new Accumulator( - 0L, AccumulatorParam.LongAccumulatorParam, Some(PEAK_EXECUTION_MEMORY), internal = true) - ) ++ maybeTestAccumulator.toSeq + def create(sc: SparkContext): Seq[Accumulator[Long]] = { + val internalAccumulators = Seq( + // Execution memory refers to the memory used by internal data structures created + // during shuffles, aggregations and joins. The value of this accumulator should be + // approximately the sum of the peak sizes across all such data structures created + // in this task. For SQL jobs, this only tracks all unsafe operators and ExternalSort. + new Accumulator( + 0L, AccumulatorParam.LongAccumulatorParam, Some(PEAK_EXECUTION_MEMORY), internal = true) + ) ++ maybeTestAccumulator.toSeq + internalAccumulators.foreach { accumulator => + sc.cleaner.foreach(_.registerAccumulatorForCleanup(accumulator)) + } + internalAccumulators } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index de05ee256dbfc..1cf06856ffbc2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -81,7 +81,7 @@ private[spark] abstract class Stage( * accumulators here again will override partial values from the finished tasks. */ def resetInternalAccumulators(): Unit = { - _internalAccumulators = InternalAccumulator.create() + _internalAccumulators = InternalAccumulator.create(rdd.sparkContext) } /** diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 48f549575f4d1..0eb2293a9d063 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -160,7 +160,8 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex } test("internal accumulators in TaskContext") { - val accums = InternalAccumulator.create() + sc = new SparkContext("local", "test") + val accums = InternalAccumulator.create(sc) val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null, accums) val internalMetricsToAccums = taskContext.internalMetricsToAccumulators val collectedInternalAccums = taskContext.collectInternalAccumulators() From 423cdfd83d7fd02a4f8cf3e714db913fd3f9ca09 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 11 Aug 2015 14:08:09 -0700 Subject: [PATCH 272/340] Closes #1290 Closes #4934 From be3e27164133025db860781bd5cdd3ca233edd21 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 11 Aug 2015 14:21:53 -0700 Subject: [PATCH 273/340] [SPARK-9788] [MLLIB] Fix LDA Binary Compatibility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Add “asymmetricDocConcentration” and revert docConcentration changes. If the (internal) doc concentration vector is a single value, “getDocConcentration" returns it. If it is a constant vector, getDocConcentration returns the first item, and fails otherwise. 2. Give `LDAModel.gammaShape` a default value in `LDAModel` concrete class constructors. jkbradley Author: Feynman Liang Closes #8077 from feynmanliang/SPARK-9788 and squashes the following commits: 6b07bc8 [Feynman Liang] Code review changes 9d6a71e [Feynman Liang] Add asymmetricAlpha alias bf4e685 [Feynman Liang] Asymmetric docConcentration 4cab972 [Feynman Liang] Default gammaShape --- .../apache/spark/mllib/clustering/LDA.scala | 27 ++++++++++++++++-- .../spark/mllib/clustering/LDAModel.scala | 11 ++++---- .../spark/mllib/clustering/LDAOptimizer.scala | 28 +++++++++---------- .../spark/mllib/clustering/LDASuite.scala | 4 +-- 4 files changed, 46 insertions(+), 24 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index ab124e6d77c5e..0fc9b1ac4d716 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -79,7 +79,24 @@ class LDA private ( * * This is the parameter to a Dirichlet distribution. */ - def getDocConcentration: Vector = this.docConcentration + def getAsymmetricDocConcentration: Vector = this.docConcentration + + /** + * Concentration parameter (commonly named "alpha") for the prior placed on documents' + * distributions over topics ("theta"). + * + * This method assumes the Dirichlet distribution is symmetric and can be described by a single + * [[Double]] parameter. It should fail if docConcentration is asymmetric. + */ + def getDocConcentration: Double = { + val parameter = docConcentration(0) + if (docConcentration.size == 1) { + parameter + } else { + require(docConcentration.toArray.forall(_ == parameter)) + parameter + } + } /** * Concentration parameter (commonly named "alpha") for the prior placed on documents' @@ -106,18 +123,22 @@ class LDA private ( * [[https://github.com/Blei-Lab/onlineldavb]]. */ def setDocConcentration(docConcentration: Vector): this.type = { + require(docConcentration.size > 0, "docConcentration must have > 0 elements") this.docConcentration = docConcentration this } - /** Replicates Double to create a symmetric prior */ + /** Replicates a [[Double]] docConcentration to create a symmetric prior. */ def setDocConcentration(docConcentration: Double): this.type = { this.docConcentration = Vectors.dense(docConcentration) this } + /** Alias for [[getAsymmetricDocConcentration]] */ + def getAsymmetricAlpha: Vector = getAsymmetricDocConcentration + /** Alias for [[getDocConcentration]] */ - def getAlpha: Vector = getDocConcentration + def getAlpha: Double = getDocConcentration /** Alias for [[setDocConcentration()]] */ def setAlpha(alpha: Vector): this.type = setDocConcentration(alpha) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 33babda69bbb9..5dc637ebdc133 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -27,7 +27,6 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaPairRDD -import org.apache.spark.broadcast.Broadcast import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId} import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} @@ -190,7 +189,8 @@ class LocalLDAModel private[clustering] ( val topics: Matrix, override val docConcentration: Vector, override val topicConcentration: Double, - override protected[clustering] val gammaShape: Double) extends LDAModel with Serializable { + override protected[clustering] val gammaShape: Double = 100) + extends LDAModel with Serializable { override def k: Int = topics.numCols @@ -455,8 +455,9 @@ class DistributedLDAModel private[clustering] ( val vocabSize: Int, override val docConcentration: Vector, override val topicConcentration: Double, - override protected[clustering] val gammaShape: Double, - private[spark] val iterationTimes: Array[Double]) extends LDAModel { + private[spark] val iterationTimes: Array[Double], + override protected[clustering] val gammaShape: Double = 100) + extends LDAModel { import LDA._ @@ -756,7 +757,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges) new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize, - docConcentration, topicConcentration, gammaShape, iterationTimes) + docConcentration, topicConcentration, iterationTimes, gammaShape) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index afba2866c7040..a0008f9c99ad7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -95,10 +95,8 @@ final class EMLDAOptimizer extends LDAOptimizer { * Compute bipartite term/doc graph. */ override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer = { - val docConcentration = lda.getDocConcentration(0) - require({ - lda.getDocConcentration.toArray.forall(_ == docConcentration) - }, "EMLDAOptimizer currently only supports symmetric document-topic priors") + // EMLDAOptimizer currently only supports symmetric document-topic priors + val docConcentration = lda.getDocConcentration val topicConcentration = lda.getTopicConcentration val k = lda.getK @@ -209,11 +207,11 @@ final class EMLDAOptimizer extends LDAOptimizer { override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { require(graph != null, "graph is null, EMLDAOptimizer not initialized.") this.graphCheckpointer.deleteAllCheckpoints() - // This assumes gammaShape = 100 in OnlineLDAOptimizer to ensure equivalence in LDAModel.toLocal - // conversion + // The constructor's default arguments assume gammaShape = 100 to ensure equivalence in + // LDAModel.toLocal conversion new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize, Vectors.dense(Array.fill(this.k)(this.docConcentration)), this.topicConcentration, - 100, iterationTimes) + iterationTimes) } } @@ -378,18 +376,20 @@ final class OnlineLDAOptimizer extends LDAOptimizer { this.k = lda.getK this.corpusSize = docs.count() this.vocabSize = docs.first()._2.size - this.alpha = if (lda.getDocConcentration.size == 1) { - if (lda.getDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k)) + this.alpha = if (lda.getAsymmetricDocConcentration.size == 1) { + if (lda.getAsymmetricDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k)) else { - require(lda.getDocConcentration(0) >= 0, s"all entries in alpha must be >=0, got: $alpha") - Vectors.dense(Array.fill(k)(lda.getDocConcentration(0))) + require(lda.getAsymmetricDocConcentration(0) >= 0, + s"all entries in alpha must be >=0, got: $alpha") + Vectors.dense(Array.fill(k)(lda.getAsymmetricDocConcentration(0))) } } else { - require(lda.getDocConcentration.size == k, s"alpha must have length k, got: $alpha") - lda.getDocConcentration.foreachActive { case (_, x) => + require(lda.getAsymmetricDocConcentration.size == k, + s"alpha must have length k, got: $alpha") + lda.getAsymmetricDocConcentration.foreachActive { case (_, x) => require(x >= 0, s"all entries in alpha must be >= 0, got: $alpha") } - lda.getDocConcentration + lda.getAsymmetricDocConcentration } this.eta = if (lda.getTopicConcentration == -1) 1.0 / k else lda.getTopicConcentration this.randomGenerator = new Random(lda.getSeed) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index fdc2554ab853e..ce6a8eb8e8c46 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -160,8 +160,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { test("setter alias") { val lda = new LDA().setAlpha(2.0).setBeta(3.0) - assert(lda.getAlpha.toArray.forall(_ === 2.0)) - assert(lda.getDocConcentration.toArray.forall(_ === 2.0)) + assert(lda.getAsymmetricAlpha.toArray.forall(_ === 2.0)) + assert(lda.getAsymmetricDocConcentration.toArray.forall(_ === 2.0)) assert(lda.getBeta === 3.0) assert(lda.getTopicConcentration === 3.0) } From 017b5de07ef6cff249e984a2ab781c520249ac76 Mon Sep 17 00:00:00 2001 From: Sudhakar Thota Date: Tue, 11 Aug 2015 14:31:51 -0700 Subject: [PATCH 274/340] [SPARK-8925] [MLLIB] Add @since tags to mllib.util Went thru the history of changes the file MLUtils.scala and picked up the version that the change went in. Author: Sudhakar Thota Author: Sudhakar Thota Closes #7436 from sthota2014/SPARK-8925_thotas. --- .../org/apache/spark/mllib/util/MLUtils.scala | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 7c5cfa7bd84ce..26eb84a8dc0b0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -64,6 +64,7 @@ object MLUtils { * feature dimensions. * @param minPartitions min number of partitions * @return labeled data stored as an RDD[LabeledPoint] + * @since 1.0.0 */ def loadLibSVMFile( sc: SparkContext, @@ -113,7 +114,10 @@ object MLUtils { } // Convenient methods for `loadLibSVMFile`. - + + /** + * @since 1.0.0 + */ @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, @@ -126,6 +130,7 @@ object MLUtils { /** * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the default number of * partitions. + * @since 1.0.0 */ def loadLibSVMFile( sc: SparkContext, @@ -133,6 +138,9 @@ object MLUtils { numFeatures: Int): RDD[LabeledPoint] = loadLibSVMFile(sc, path, numFeatures, sc.defaultMinPartitions) + /** + * @since 1.0.0 + */ @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, @@ -141,6 +149,9 @@ object MLUtils { numFeatures: Int): RDD[LabeledPoint] = loadLibSVMFile(sc, path, numFeatures) + /** + * @since 1.0.0 + */ @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, @@ -151,6 +162,7 @@ object MLUtils { /** * Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint], with number of * features determined automatically and the default number of partitions. + * @since 1.0.0 */ def loadLibSVMFile(sc: SparkContext, path: String): RDD[LabeledPoint] = loadLibSVMFile(sc, path, -1) @@ -181,12 +193,14 @@ object MLUtils { * @param path file or directory path in any Hadoop-supported file system URI * @param minPartitions min number of partitions * @return vectors stored as an RDD[Vector] + * @since 1.1.0 */ def loadVectors(sc: SparkContext, path: String, minPartitions: Int): RDD[Vector] = sc.textFile(path, minPartitions).map(Vectors.parse) /** * Loads vectors saved using `RDD[Vector].saveAsTextFile` with the default number of partitions. + * @since 1.1.0 */ def loadVectors(sc: SparkContext, path: String): RDD[Vector] = sc.textFile(path, sc.defaultMinPartitions).map(Vectors.parse) @@ -197,6 +211,7 @@ object MLUtils { * @param path file or directory path in any Hadoop-supported file system URI * @param minPartitions min number of partitions * @return labeled points stored as an RDD[LabeledPoint] + * @since 1.1.0 */ def loadLabeledPoints(sc: SparkContext, path: String, minPartitions: Int): RDD[LabeledPoint] = sc.textFile(path, minPartitions).map(LabeledPoint.parse) @@ -204,6 +219,7 @@ object MLUtils { /** * Loads labeled points saved using `RDD[LabeledPoint].saveAsTextFile` with the default number of * partitions. + * @since 1.1.0 */ def loadLabeledPoints(sc: SparkContext, dir: String): RDD[LabeledPoint] = loadLabeledPoints(sc, dir, sc.defaultMinPartitions) @@ -220,6 +236,7 @@ object MLUtils { * * @deprecated Should use [[org.apache.spark.rdd.RDD#saveAsTextFile]] for saving and * [[org.apache.spark.mllib.util.MLUtils#loadLabeledPoints]] for loading. + * @since 1.0.0 */ @deprecated("Should use MLUtils.loadLabeledPoints instead.", "1.0.1") def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { @@ -241,6 +258,7 @@ object MLUtils { * * @deprecated Should use [[org.apache.spark.rdd.RDD#saveAsTextFile]] for saving and * [[org.apache.spark.mllib.util.MLUtils#loadLabeledPoints]] for loading. + * @since 1.0.0 */ @deprecated("Should use RDD[LabeledPoint].saveAsTextFile instead.", "1.0.1") def saveLabeledData(data: RDD[LabeledPoint], dir: String) { @@ -253,6 +271,7 @@ object MLUtils { * Return a k element array of pairs of RDDs with the first element of each pair * containing the training data, a complement of the validation data and the second * element, the validation data, containing a unique 1/kth of the data. Where k=numFolds. + * @since 1.0.0 */ @Experimental def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = { @@ -268,6 +287,7 @@ object MLUtils { /** * Returns a new vector with `1.0` (bias) appended to the input vector. + * @since 1.0.0 */ def appendBias(vector: Vector): Vector = { vector match { From 736af95bd0c41723d455246b634a0fb68b38a7c7 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 11 Aug 2015 14:52:52 -0700 Subject: [PATCH 275/340] [HOTFIX] Fix style error caused by 017b5de --- mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 26eb84a8dc0b0..11ed23176fc12 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -114,7 +114,7 @@ object MLUtils { } // Convenient methods for `loadLibSVMFile`. - + /** * @since 1.0.0 */ From 5a5bbc29961630d649d4bd4acd5d19eb537b5fd0 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 11 Aug 2015 16:33:08 -0700 Subject: [PATCH 276/340] [SPARK-9074] [LAUNCHER] Allow arbitrary Spark args to be set. This change allows any Spark argument to be added to the app to be started using SparkLauncher. Known arguments are properly validated, while unknown arguments are allowed so that the library can launch newer Spark versions (in case SPARK_HOME points at one). Author: Marcelo Vanzin Closes #7975 from vanzin/SPARK-9074 and squashes the following commits: b5e451a [Marcelo Vanzin] [SPARK-9074] [launcher] Allow arbitrary Spark args to be set. --- .../apache/spark/launcher/SparkLauncher.java | 101 +++++++++++++++++- .../launcher/SparkSubmitCommandBuilder.java | 2 +- .../spark/launcher/SparkLauncherSuite.java | 50 +++++++++ 3 files changed, 150 insertions(+), 3 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index c0f89c9230692..03c9358bc865d 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -20,12 +20,13 @@ import java.io.File; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Map; import static org.apache.spark.launcher.CommandBuilderUtils.*; -/** +/** * Launcher for Spark applications. *

    * Use this class to start Spark applications programmatically. The class uses a builder pattern @@ -57,7 +58,8 @@ public class SparkLauncher { /** Configuration key for the number of executor CPU cores. */ public static final String EXECUTOR_CORES = "spark.executor.cores"; - private final SparkSubmitCommandBuilder builder; + // Visible for testing. + final SparkSubmitCommandBuilder builder; public SparkLauncher() { this(null); @@ -187,6 +189,73 @@ public SparkLauncher setMainClass(String mainClass) { return this; } + /** + * Adds a no-value argument to the Spark invocation. If the argument is known, this method + * validates whether the argument is indeed a no-value argument, and throws an exception + * otherwise. + *

    + * Use this method with caution. It is possible to create an invalid Spark command by passing + * unknown arguments to this method, since those are allowed for forward compatibility. + * + * @param arg Argument to add. + * @return This launcher. + */ + public SparkLauncher addSparkArg(String arg) { + SparkSubmitOptionParser validator = new ArgumentValidator(false); + validator.parse(Arrays.asList(arg)); + builder.sparkArgs.add(arg); + return this; + } + + /** + * Adds an argument with a value to the Spark invocation. If the argument name corresponds to + * a known argument, the code validates that the argument actually expects a value, and throws + * an exception otherwise. + *

    + * It is safe to add arguments modified by other methods in this class (such as + * {@link #setMaster(String)} - the last invocation will be the one to take effect. + *

    + * Use this method with caution. It is possible to create an invalid Spark command by passing + * unknown arguments to this method, since those are allowed for forward compatibility. + * + * @param name Name of argument to add. + * @param value Value of the argument. + * @return This launcher. + */ + public SparkLauncher addSparkArg(String name, String value) { + SparkSubmitOptionParser validator = new ArgumentValidator(true); + if (validator.MASTER.equals(name)) { + setMaster(value); + } else if (validator.PROPERTIES_FILE.equals(name)) { + setPropertiesFile(value); + } else if (validator.CONF.equals(name)) { + String[] vals = value.split("=", 2); + setConf(vals[0], vals[1]); + } else if (validator.CLASS.equals(name)) { + setMainClass(value); + } else if (validator.JARS.equals(name)) { + builder.jars.clear(); + for (String jar : value.split(",")) { + addJar(jar); + } + } else if (validator.FILES.equals(name)) { + builder.files.clear(); + for (String file : value.split(",")) { + addFile(file); + } + } else if (validator.PY_FILES.equals(name)) { + builder.pyFiles.clear(); + for (String file : value.split(",")) { + addPyFile(file); + } + } else { + validator.parse(Arrays.asList(name, value)); + builder.sparkArgs.add(name); + builder.sparkArgs.add(value); + } + return this; + } + /** * Adds command line arguments for the application. * @@ -277,4 +346,32 @@ public Process launch() throws IOException { return pb.start(); } + private static class ArgumentValidator extends SparkSubmitOptionParser { + + private final boolean hasValue; + + ArgumentValidator(boolean hasValue) { + this.hasValue = hasValue; + } + + @Override + protected boolean handle(String opt, String value) { + if (value == null && hasValue) { + throw new IllegalArgumentException(String.format("'%s' does not expect a value.", opt)); + } + return true; + } + + @Override + protected boolean handleUnknown(String opt) { + // Do not fail on unknown arguments, to support future arguments added to SparkSubmit. + return true; + } + + protected void handleExtraArgs(List extra) { + // No op. + } + + }; + } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 87c43aa9980e1..4f354cedee66f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -76,7 +76,7 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { "spark-internal"); } - private final List sparkArgs; + final List sparkArgs; private final boolean printHelp; /** diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 252d5abae1ca3..d0c26dd05679b 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -20,6 +20,7 @@ import java.io.BufferedReader; import java.io.InputStream; import java.io.InputStreamReader; +import java.util.Arrays; import java.util.HashMap; import java.util.Map; @@ -35,8 +36,54 @@ public class SparkLauncherSuite { private static final Logger LOG = LoggerFactory.getLogger(SparkLauncherSuite.class); + @Test + public void testSparkArgumentHandling() throws Exception { + SparkLauncher launcher = new SparkLauncher() + .setSparkHome(System.getProperty("spark.test.home")); + SparkSubmitOptionParser opts = new SparkSubmitOptionParser(); + + launcher.addSparkArg(opts.HELP); + try { + launcher.addSparkArg(opts.PROXY_USER); + fail("Expected IllegalArgumentException."); + } catch (IllegalArgumentException e) { + // Expected. + } + + launcher.addSparkArg(opts.PROXY_USER, "someUser"); + try { + launcher.addSparkArg(opts.HELP, "someValue"); + fail("Expected IllegalArgumentException."); + } catch (IllegalArgumentException e) { + // Expected. + } + + launcher.addSparkArg("--future-argument"); + launcher.addSparkArg("--future-argument", "someValue"); + + launcher.addSparkArg(opts.MASTER, "myMaster"); + assertEquals("myMaster", launcher.builder.master); + + launcher.addJar("foo"); + launcher.addSparkArg(opts.JARS, "bar"); + assertEquals(Arrays.asList("bar"), launcher.builder.jars); + + launcher.addFile("foo"); + launcher.addSparkArg(opts.FILES, "bar"); + assertEquals(Arrays.asList("bar"), launcher.builder.files); + + launcher.addPyFile("foo"); + launcher.addSparkArg(opts.PY_FILES, "bar"); + assertEquals(Arrays.asList("bar"), launcher.builder.pyFiles); + + launcher.setConf("spark.foo", "foo"); + launcher.addSparkArg(opts.CONF, "spark.foo=bar"); + assertEquals("bar", launcher.builder.conf.get("spark.foo")); + } + @Test public void testChildProcLauncher() throws Exception { + SparkSubmitOptionParser opts = new SparkSubmitOptionParser(); Map env = new HashMap(); env.put("SPARK_PRINT_LAUNCH_COMMAND", "1"); @@ -44,9 +91,12 @@ public void testChildProcLauncher() throws Exception { .setSparkHome(System.getProperty("spark.test.home")) .setMaster("local") .setAppResource("spark-internal") + .addSparkArg(opts.CONF, + String.format("%s=-Dfoo=ShouldBeOverriddenBelow", SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS)) .setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, "-Dfoo=bar -Dtest.name=-testChildProcLauncher") .setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path")) + .addSparkArg(opts.CLASS, "ShouldBeOverriddenBelow") .setMainClass(SparkLauncherTestApp.class.getName()) .addAppArgs("proc"); final Process app = launcher.launch(); From afa757c98c537965007cad4c61c436887f3ac6a6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 11 Aug 2015 18:08:49 -0700 Subject: [PATCH 277/340] [SPARK-9849] [SQL] DirectParquetOutputCommitter qualified name should be backward compatible DirectParquetOutputCommitter was moved in SPARK-9763. However, users can explicitly set the class as a config option, so we must be able to resolve the old committer qualified name. Author: Reynold Xin Closes #8114 from rxin/SPARK-9849. --- .../datasources/parquet/ParquetRelation.scala | 7 +++++ .../datasources/parquet/ParquetIOSuite.scala | 27 ++++++++++++++++++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 4086a139bed72..c71c69b6e80b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -209,6 +209,13 @@ private[sql] class ParquetRelation( override def prepareJobForWrite(job: Job): OutputWriterFactory = { val conf = ContextUtil.getConfiguration(job) + // SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible + val committerClassname = conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) + if (committerClassname == "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") { + conf.set(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, + classOf[DirectParquetOutputCommitter].getCanonicalName) + } + val committerClass = conf.getClass( SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index ee925afe08508..cb166349fdb26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -390,7 +390,32 @@ class ParquetIOSuite extends QueryTest with ParquetTest { } } - test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overriden") { + test("SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible") { + val clonedConf = new Configuration(configuration) + + // Write to a parquet file and let it fail. + // _temporary should be missing if direct output committer works. + try { + configuration.set("spark.sql.parquet.output.committer.class", + "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") + sqlContext.udf.register("div0", (x: Int) => x / 0) + withTempPath { dir => + intercept[org.apache.spark.SparkException] { + sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) + } + val path = new Path(dir.getCanonicalPath, "_temporary") + val fs = path.getFileSystem(configuration) + assert(!fs.exists(path)) + } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + configuration.clear() + clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + } + } + + + test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overridden") { withTempPath { dir => val clonedConf = new Configuration(configuration) From ca8f70e9d473d2c81866f3c330cc6545c33bdac7 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 11 Aug 2015 20:46:58 -0700 Subject: [PATCH 278/340] [SPARK-9649] Fix flaky test MasterSuite again - disable REST The REST server is not actually used in most tests and so we can disable it. It is a source of flakiness because it tries to bind to a specific port in vain. There was also some code that avoided the shuffle service in tests. This is actually not necessary because the shuffle service is already off by default. Author: Andrew Or Closes #8084 from andrewor14/fix-master-suite-again. --- pom.xml | 1 + project/SparkBuild.scala | 1 + 2 files changed, 2 insertions(+) diff --git a/pom.xml b/pom.xml index 8942836a7da16..cfd7d32563f2a 100644 --- a/pom.xml +++ b/pom.xml @@ -1895,6 +1895,7 @@ ${project.build.directory}/tmp ${spark.test.home} 1 + false false false true diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index cad7067ade8c1..74f815f941d5b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -546,6 +546,7 @@ object TestSettings { javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.port.maxRetries=100", + javaOptions in Test += "-Dspark.master.rest.enabled=false", javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dspark.ui.showConsoleProgress=false", javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", From 3ef0f32928fc383ad3edd5ad167212aeb9eba6e1 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 11 Aug 2015 21:16:48 -0700 Subject: [PATCH 279/340] [SPARK-1517] Refactor release scripts to facilitate nightly publishing This update contains some code changes to the release scripts that allow easier nightly publishing. I've been using these new scripts on Jenkins for cutting and publishing nightly snapshots for the last month or so, and it has been going well. I'd like to get them merged back upstream so this can be maintained by the community. The main changes are: 1. Separates the release tagging from various build possibilities for an already tagged release (`release-tag.sh` and `release-build.sh`). 2. Allow for injecting credentials through the environment, including GPG keys. This is then paired with secure key injection in Jenkins. 3. Support for copying build results to a remote directory, and also "rotating" results, e.g. the ability to keep the last N copies of binary or doc builds. I'm happy if anyone wants to take a look at this - it's not user facing but an internal utility used for generating releases. Author: Patrick Wendell Closes #7411 from pwendell/release-script-updates and squashes the following commits: 74f9beb [Patrick Wendell] Moving maven build command to a variable 233ce85 [Patrick Wendell] [SPARK-1517] Refactor release scripts to facilitate nightly publishing --- dev/create-release/create-release.sh | 267 ---------------------- dev/create-release/release-build.sh | 321 +++++++++++++++++++++++++++ dev/create-release/release-tag.sh | 79 +++++++ 3 files changed, 400 insertions(+), 267 deletions(-) delete mode 100755 dev/create-release/create-release.sh create mode 100755 dev/create-release/release-build.sh create mode 100755 dev/create-release/release-tag.sh diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh deleted file mode 100755 index 4311c8c9e4ca6..0000000000000 --- a/dev/create-release/create-release.sh +++ /dev/null @@ -1,267 +0,0 @@ -#!/usr/bin/env bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Quick-and-dirty automation of making maven and binary releases. Not robust at all. -# Publishes releases to Maven and packages/copies binary release artifacts. -# Expects to be run in a totally empty directory. -# -# Options: -# --skip-create-release Assume the desired release tag already exists -# --skip-publish Do not publish to Maven central -# --skip-package Do not package and upload binary artifacts -# Would be nice to add: -# - Send output to stderr and have useful logging in stdout - -# Note: The following variables must be set before use! -ASF_USERNAME=${ASF_USERNAME:-pwendell} -ASF_PASSWORD=${ASF_PASSWORD:-XXX} -GPG_PASSPHRASE=${GPG_PASSPHRASE:-XXX} -GIT_BRANCH=${GIT_BRANCH:-branch-1.0} -RELEASE_VERSION=${RELEASE_VERSION:-1.2.0} -# Allows publishing under a different version identifier than -# was present in the actual release sources (e.g. rc-X) -PUBLISH_VERSION=${PUBLISH_VERSION:-$RELEASE_VERSION} -NEXT_VERSION=${NEXT_VERSION:-1.2.1} -RC_NAME=${RC_NAME:-rc2} - -M2_REPO=~/.m2/repository -SPARK_REPO=$M2_REPO/org/apache/spark -NEXUS_ROOT=https://repository.apache.org/service/local/staging -NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads - -if [ -z "$JAVA_HOME" ]; then - echo "Error: JAVA_HOME is not set, cannot proceed." - exit -1 -fi -JAVA_7_HOME=${JAVA_7_HOME:-$JAVA_HOME} - -set -e - -GIT_TAG=v$RELEASE_VERSION-$RC_NAME - -if [[ ! "$@" =~ --skip-create-release ]]; then - echo "Creating release commit and publishing to Apache repository" - # Artifact publishing - git clone https://$ASF_USERNAME:$ASF_PASSWORD@git-wip-us.apache.org/repos/asf/spark.git \ - -b $GIT_BRANCH - pushd spark - export MAVEN_OPTS="-Xmx3g -XX:MaxPermSize=1g -XX:ReservedCodeCacheSize=1g" - - # Create release commits and push them to github - # NOTE: This is done "eagerly" i.e. we don't check if we can succesfully build - # or before we coin the release commit. This helps avoid races where - # other people add commits to this branch while we are in the middle of building. - cur_ver="${RELEASE_VERSION}-SNAPSHOT" - rel_ver="${RELEASE_VERSION}" - next_ver="${NEXT_VERSION}-SNAPSHOT" - - old="^\( \{2,4\}\)${cur_ver}<\/version>$" - new="\1${rel_ver}<\/version>" - find . -name pom.xml | grep -v dev | xargs -I {} sed -i \ - -e "s/${old}/${new}/" {} - find . -name package.scala | grep -v dev | xargs -I {} sed -i \ - -e "s/${old}/${new}/" {} - - git commit -a -m "Preparing Spark release $GIT_TAG" - echo "Creating tag $GIT_TAG at the head of $GIT_BRANCH" - git tag $GIT_TAG - - old="^\( \{2,4\}\)${rel_ver}<\/version>$" - new="\1${next_ver}<\/version>" - find . -name pom.xml | grep -v dev | xargs -I {} sed -i \ - -e "s/$old/$new/" {} - find . -name package.scala | grep -v dev | xargs -I {} sed -i \ - -e "s/${old}/${new}/" {} - git commit -a -m "Preparing development version $next_ver" - git push origin $GIT_TAG - git push origin HEAD:$GIT_BRANCH - popd - rm -rf spark -fi - -if [[ ! "$@" =~ --skip-publish ]]; then - git clone https://$ASF_USERNAME:$ASF_PASSWORD@git-wip-us.apache.org/repos/asf/spark.git - pushd spark - git checkout --force $GIT_TAG - - # Substitute in case published version is different than released - old="^\( \{2,4\}\)${RELEASE_VERSION}<\/version>$" - new="\1${PUBLISH_VERSION}<\/version>" - find . -name pom.xml | grep -v dev | xargs -I {} sed -i \ - -e "s/${old}/${new}/" {} - - # Using Nexus API documented here: - # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API - echo "Creating Nexus staging repository" - repo_request="Apache Spark $GIT_TAG (published as $PUBLISH_VERSION)" - out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ - -H "Content-Type:application/xml" -v \ - $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) - staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") - echo "Created Nexus staging repository: $staged_repo_id" - - rm -rf $SPARK_REPO - - build/mvn -DskipTests -Pyarn -Phive \ - -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ - clean install - - ./dev/change-scala-version.sh 2.11 - - build/mvn -DskipTests -Pyarn -Phive \ - -Dscala-2.11 -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ - clean install - - ./dev/change-scala-version.sh 2.10 - - pushd $SPARK_REPO - - # Remove any extra files generated during install - find . -type f |grep -v \.jar |grep -v \.pom | xargs rm - - echo "Creating hash and signature files" - for file in $(find . -type f) - do - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --output $file.asc --detach-sig --armour $file; - if [ $(command -v md5) ]; then - # Available on OS X; -q to keep only hash - md5 -q $file > $file.md5 - else - # Available on Linux; cut to keep only hash - md5sum $file | cut -f1 -d' ' > $file.md5 - fi - shasum -a 1 $file | cut -f1 -d' ' > $file.sha1 - done - - nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id - echo "Uplading files to $nexus_upload" - for file in $(find . -type f) - do - # strip leading ./ - file_short=$(echo $file | sed -e "s/\.\///") - dest_url="$nexus_upload/org/apache/spark/$file_short" - echo " Uploading $file_short" - curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url - done - - echo "Closing nexus staging repository" - repo_request="$staged_repo_idApache Spark $GIT_TAG (published as $PUBLISH_VERSION)" - out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ - -H "Content-Type:application/xml" -v \ - $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) - echo "Closed Nexus staging repository: $staged_repo_id" - - popd - popd - rm -rf spark -fi - -if [[ ! "$@" =~ --skip-package ]]; then - # Source and binary tarballs - echo "Packaging release tarballs" - git clone https://git-wip-us.apache.org/repos/asf/spark.git - cd spark - git checkout --force $GIT_TAG - release_hash=`git rev-parse HEAD` - - rm .gitignore - rm -rf .git - cd .. - - cp -r spark spark-$RELEASE_VERSION - tar cvzf spark-$RELEASE_VERSION.tgz spark-$RELEASE_VERSION - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour --output spark-$RELEASE_VERSION.tgz.asc \ - --detach-sig spark-$RELEASE_VERSION.tgz - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md MD5 spark-$RELEASE_VERSION.tgz > \ - spark-$RELEASE_VERSION.tgz.md5 - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md SHA512 spark-$RELEASE_VERSION.tgz > \ - spark-$RELEASE_VERSION.tgz.sha - rm -rf spark-$RELEASE_VERSION - - # Updated for each binary build - make_binary_release() { - NAME=$1 - FLAGS=$2 - ZINC_PORT=$3 - cp -r spark spark-$RELEASE_VERSION-bin-$NAME - - cd spark-$RELEASE_VERSION-bin-$NAME - - # TODO There should probably be a flag to make-distribution to allow 2.11 support - if [[ $FLAGS == *scala-2.11* ]]; then - ./dev/change-scala-version.sh 2.11 - fi - - export ZINC_PORT=$ZINC_PORT - echo "Creating distribution: $NAME ($FLAGS)" - ./make-distribution.sh --name $NAME --tgz $FLAGS -DzincPort=$ZINC_PORT 2>&1 > \ - ../binary-release-$NAME.log - cd .. - cp spark-$RELEASE_VERSION-bin-$NAME/spark-$RELEASE_VERSION-bin-$NAME.tgz . - - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour \ - --output spark-$RELEASE_VERSION-bin-$NAME.tgz.asc \ - --detach-sig spark-$RELEASE_VERSION-bin-$NAME.tgz - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md \ - MD5 spark-$RELEASE_VERSION-bin-$NAME.tgz > \ - spark-$RELEASE_VERSION-bin-$NAME.tgz.md5 - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md \ - SHA512 spark-$RELEASE_VERSION-bin-$NAME.tgz > \ - spark-$RELEASE_VERSION-bin-$NAME.tgz.sha - } - - # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds - # share the same Zinc server. - make_binary_release "hadoop1" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver" "3030" & - make_binary_release "hadoop1-scala2.11" "-Psparkr -Phadoop-1 -Phive -Dscala-2.11" "3031" & - make_binary_release "cdh4" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & - make_binary_release "hadoop2.3" "-Psparkr -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & - make_binary_release "hadoop2.4" "-Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & - make_binary_release "mapr3" "-Pmapr3 -Psparkr -Phive -Phive-thriftserver" "3035" & - make_binary_release "mapr4" "-Pmapr4 -Psparkr -Pyarn -Phive -Phive-thriftserver" "3036" & - make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn" "3037" & - wait - rm -rf spark-$RELEASE_VERSION-bin-*/ - - # Copy data - echo "Copying release tarballs" - rc_folder=spark-$RELEASE_VERSION-$RC_NAME - ssh $ASF_USERNAME@people.apache.org \ - mkdir /home/$ASF_USERNAME/public_html/$rc_folder - scp spark-* \ - $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_folder/ - - # Docs - cd spark - sbt/sbt clean - cd docs - # Compile docs with Java 7 to use nicer format - JAVA_HOME="$JAVA_7_HOME" PRODUCTION=1 RELEASE_VERSION="$RELEASE_VERSION" jekyll build - echo "Copying release documentation" - rc_docs_folder=${rc_folder}-docs - ssh $ASF_USERNAME@people.apache.org \ - mkdir /home/$ASF_USERNAME/public_html/$rc_docs_folder - rsync -r _site/* $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_docs_folder - - echo "Release $RELEASE_VERSION completed:" - echo "Git tag:\t $GIT_TAG" - echo "Release commit:\t $release_hash" - echo "Binary location:\t http://people.apache.org/~$ASF_USERNAME/$rc_folder" - echo "Doc location:\t http://people.apache.org/~$ASF_USERNAME/$rc_docs_folder" -fi diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh new file mode 100755 index 0000000000000..399c73e7bf6bc --- /dev/null +++ b/dev/create-release/release-build.sh @@ -0,0 +1,321 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +function exit_with_usage { + cat << EOF +usage: release-build.sh +Creates build deliverables from a Spark commit. + +Top level targets are + package: Create binary packages and copy them to people.apache + docs: Build docs and copy them to people.apache + publish-snapshot: Publish snapshot release to Apache snapshots + publish-release: Publish a release to Apache release repo + +All other inputs are environment variables + +GIT_REF - Release tag or commit to build from +SPARK_VERSION - Release identifier used when publishing +SPARK_PACKAGE_VERSION - Release identifier in top level package directory +REMOTE_PARENT_DIR - Parent in which to create doc or release builds. +REMOTE_PARENT_MAX_LENGTH - If set, parent directory will be cleaned to only + have this number of subdirectories (by deleting old ones). WARNING: This deletes data. + +ASF_USERNAME - Username of ASF committer account +ASF_PASSWORD - Password of ASF committer account +ASF_RSA_KEY - RSA private key file for ASF committer account + +GPG_KEY - GPG key used to sign release artifacts +GPG_PASSPHRASE - Passphrase for GPG key +EOF + exit 1 +} + +set -e + +if [ $# -eq 0 ]; then + exit_with_usage +fi + +if [[ $@ == *"help"* ]]; then + exit_with_usage +fi + +for env in ASF_USERNAME ASF_RSA_KEY GPG_PASSPHRASE GPG_KEY; do + if [ -z "${!env}" ]; then + echo "ERROR: $env must be set to run this script" + exit_with_usage + fi +done + +# Commit ref to checkout when building +GIT_REF=${GIT_REF:-master} + +# Destination directory parent on remote server +REMOTE_PARENT_DIR=${REMOTE_PARENT_DIR:-/home/$ASF_USERNAME/public_html} + +SSH="ssh -o StrictHostKeyChecking=no -i $ASF_RSA_KEY" +GPG="gpg --no-tty --batch" +NEXUS_ROOT=https://repository.apache.org/service/local/staging +NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads +BASE_DIR=$(pwd) + +MVN="build/mvn --force" +PUBLISH_PROFILES="-Pyarn -Phive -Phadoop-2.2" +PUBLISH_PROFILES="$PUBLISH_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" + +rm -rf spark +git clone https://git-wip-us.apache.org/repos/asf/spark.git +cd spark +git checkout $GIT_REF +git_hash=`git rev-parse --short HEAD` +echo "Checked out Spark git hash $git_hash" + +if [ -z "$SPARK_VERSION" ]; then + SPARK_VERSION=$($MVN help:evaluate -Dexpression=project.version \ + | grep -v INFO | grep -v WARNING | grep -v Download) +fi + +if [ -z "$SPARK_PACKAGE_VERSION" ]; then + SPARK_PACKAGE_VERSION="${SPARK_VERSION}-$(date +%Y_%m_%d_%H_%M)-${git_hash}" +fi + +DEST_DIR_NAME="spark-$SPARK_PACKAGE_VERSION" +USER_HOST="$ASF_USERNAME@people.apache.org" + +rm .gitignore +rm -rf .git +cd .. + +if [ -n "$REMOTE_PARENT_MAX_LENGTH" ]; then + old_dirs=$($SSH $USER_HOST ls -t $REMOTE_PARENT_DIR | tail -n +$REMOTE_PARENT_MAX_LENGTH) + for old_dir in $old_dirs; do + echo "Removing directory: $old_dir" + $SSH $USER_HOST rm -r $REMOTE_PARENT_DIR/$old_dir + done +fi + +if [[ "$1" == "package" ]]; then + # Source and binary tarballs + echo "Packaging release tarballs" + cp -r spark spark-$SPARK_VERSION + tar cvzf spark-$SPARK_VERSION.tgz spark-$SPARK_VERSION + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour --output spark-$SPARK_VERSION.tgz.asc \ + --detach-sig spark-$SPARK_VERSION.tgz + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md MD5 spark-$SPARK_VERSION.tgz > \ + spark-$SPARK_VERSION.tgz.md5 + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + SHA512 spark-$SPARK_VERSION.tgz > spark-$SPARK_VERSION.tgz.sha + rm -rf spark-$SPARK_VERSION + + # Updated for each binary build + make_binary_release() { + NAME=$1 + FLAGS=$2 + ZINC_PORT=$3 + cp -r spark spark-$SPARK_VERSION-bin-$NAME + + cd spark-$SPARK_VERSION-bin-$NAME + + # TODO There should probably be a flag to make-distribution to allow 2.11 support + if [[ $FLAGS == *scala-2.11* ]]; then + ./dev/change-scala-version.sh 2.11 + fi + + export ZINC_PORT=$ZINC_PORT + echo "Creating distribution: $NAME ($FLAGS)" + ./make-distribution.sh --name $NAME --tgz $FLAGS -DzincPort=$ZINC_PORT 2>&1 > \ + ../binary-release-$NAME.log + cd .. + cp spark-$SPARK_VERSION-bin-$NAME/spark-$SPARK_VERSION-bin-$NAME.tgz . + + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ + --output spark-$SPARK_VERSION-bin-$NAME.tgz.asc \ + --detach-sig spark-$SPARK_VERSION-bin-$NAME.tgz + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + MD5 spark-$SPARK_VERSION-bin-$NAME.tgz > \ + spark-$SPARK_VERSION-bin-$NAME.tgz.md5 + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + SHA512 spark-$SPARK_VERSION-bin-$NAME.tgz > \ + spark-$SPARK_VERSION-bin-$NAME.tgz.sha + } + + # TODO: Check exit codes of children here: + # http://stackoverflow.com/questions/1570262/shell-get-exit-code-of-background-process + + # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds + # share the same Zinc server. + make_binary_release "hadoop1" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver" "3030" & + make_binary_release "hadoop1-scala2.11" "-Psparkr -Phadoop-1 -Phive -Dscala-2.11" "3031" & + make_binary_release "cdh4" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & + make_binary_release "hadoop2.3" "-Psparkr -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & + make_binary_release "hadoop2.4" "-Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & + make_binary_release "hadoop2.6" "-Psparkr -Phadoop-2.6 -Phive -Phive-thriftserver -Pyarn" "3034" & + make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn" "3037" & + make_binary_release "without-hadoop" "-Psparkr -Phadoop-provided -Pyarn" "3038" & + wait + rm -rf spark-$SPARK_VERSION-bin-*/ + + # Copy data + dest_dir="$REMOTE_PARENT_DIR/${DEST_DIR_NAME}-bin" + echo "Copying release tarballs to $dest_dir" + $SSH $USER_HOST mkdir $dest_dir + rsync -e "$SSH" spark-* $USER_HOST:$dest_dir + echo "Linking /latest to $dest_dir" + $SSH $USER_HOST rm -f "$REMOTE_PARENT_DIR/latest" + $SSH $USER_HOST ln -s $dest_dir "$REMOTE_PARENT_DIR/latest" + exit 0 +fi + +if [[ "$1" == "docs" ]]; then + # Documentation + cd spark + echo "Building Spark docs" + dest_dir="$REMOTE_PARENT_DIR/${DEST_DIR_NAME}-docs" + cd docs + # Compile docs with Java 7 to use nicer format + # TODO: Make configurable to add this: PRODUCTION=1 + PRODUCTION=1 RELEASE_VERSION="$SPARK_VERSION" jekyll build + echo "Copying release documentation to $dest_dir" + $SSH $USER_HOST mkdir $dest_dir + echo "Linking /latest to $dest_dir" + $SSH $USER_HOST rm -f "$REMOTE_PARENT_DIR/latest" + $SSH $USER_HOST ln -s $dest_dir "$REMOTE_PARENT_DIR/latest" + rsync -e "$SSH" -r _site/* $USER_HOST:$dest_dir + cd .. + exit 0 +fi + +if [[ "$1" == "publish-snapshot" ]]; then + cd spark + # Publish Spark to Maven release repo + echo "Deploying Spark SNAPSHOT at '$GIT_REF' ($git_hash)" + echo "Publish version is $SPARK_VERSION" + if [[ ! $SPARK_VERSION == *"SNAPSHOT"* ]]; then + echo "ERROR: Snapshots must have a version containing SNAPSHOT" + echo "ERROR: You gave version '$SPARK_VERSION'" + exit 1 + fi + # Coerce the requested version + $MVN versions:set -DnewVersion=$SPARK_VERSION + tmp_settings="tmp-settings.xml" + echo "" > $tmp_settings + echo "apache.snapshots.https$ASF_USERNAME" >> $tmp_settings + echo "$ASF_PASSWORD" >> $tmp_settings + echo "" >> $tmp_settings + + # Generate random point for Zinc + export ZINC_PORT=$(python -S -c "import random; print random.randrange(3030,4030)") + + $MVN -DzincPort=$ZINC_PORT --settings $tmp_settings -DskipTests $PUBLISH_PROFILES \ + -Phive-thriftserver deploy + ./dev/change-scala-version.sh 2.10 + $MVN -DzincPort=$ZINC_PORT -Dscala-2.11 --settings $tmp_settings \ + -DskipTests $PUBLISH_PROFILES deploy + + # Clean-up Zinc nailgun process + /usr/sbin/lsof -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill + + rm $tmp_settings + cd .. + exit 0 +fi + +if [[ "$1" == "publish-release" ]]; then + cd spark + # Publish Spark to Maven release repo + echo "Publishing Spark checkout at '$GIT_REF' ($git_hash)" + echo "Publish version is $SPARK_VERSION" + # Coerce the requested version + $MVN versions:set -DnewVersion=$SPARK_VERSION + + # Using Nexus API documented here: + # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API + echo "Creating Nexus staging repository" + repo_request="Apache Spark $SPARK_VERSION (commit $git_hash)" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) + staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") + echo "Created Nexus staging repository: $staged_repo_id" + + tmp_repo=$(mktemp -d spark-repo-XXXXX) + + # Generate random point for Zinc + export ZINC_PORT=$(python -S -c "import random; print random.randrange(3030,4030)") + + $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests $PUBLISH_PROFILES \ + -Phive-thriftserver clean install + + ./dev/change-scala-version.sh 2.11 + + $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -Dscala-2.11 \ + -DskipTests $PUBLISH_PROFILES clean install + + # Clean-up Zinc nailgun process + /usr/sbin/lsof -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill + + ./dev/change-version-to-2.10.sh + + pushd $tmp_repo/org/apache/spark + + # Remove any extra files generated during install + find . -type f |grep -v \.jar |grep -v \.pom | xargs rm + + echo "Creating hash and signature files" + for file in $(find . -type f) + do + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --output $file.asc \ + --detach-sig --armour $file; + if [ $(command -v md5) ]; then + # Available on OS X; -q to keep only hash + md5 -q $file > $file.md5 + else + # Available on Linux; cut to keep only hash + md5sum $file | cut -f1 -d' ' > $file.md5 + fi + sha1sum $file | cut -f1 -d' ' > $file.sha1 + done + + nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id + echo "Uplading files to $nexus_upload" + for file in $(find . -type f) + do + # strip leading ./ + file_short=$(echo $file | sed -e "s/\.\///") + dest_url="$nexus_upload/org/apache/spark/$file_short" + echo " Uploading $file_short" + curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url + done + + echo "Closing nexus staging repository" + repo_request="$staged_repo_idApache Spark $SPARK_VERSION (commit $git_hash)" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) + echo "Closed Nexus staging repository: $staged_repo_id" + popd + rm -rf $tmp_repo + cd .. + exit 0 +fi + +cd .. +rm -rf spark +echo "ERROR: expects to be called with 'package', 'docs', 'publish-release' or 'publish-snapshot'" diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh new file mode 100755 index 0000000000000..b0a3374becc6a --- /dev/null +++ b/dev/create-release/release-tag.sh @@ -0,0 +1,79 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +function exit_with_usage { + cat << EOF +usage: tag-release.sh +Tags a Spark release on a particular branch. + +Inputs are specified with the following environment variables: +ASF_USERNAME - Apache Username +ASF_PASSWORD - Apache Password +GIT_NAME - Name to use with git +GIT_EMAIL - E-mail address to use with git +GIT_BRANCH - Git branch on which to make release +RELEASE_VERSION - Version used in pom files for release +RELEASE_TAG - Name of release tag +NEXT_VERSION - Development version after release +EOF + exit 1 +} + +set -e + +if [[ $@ == *"help"* ]]; then + exit_with_usage +fi + +for env in ASF_USERNAME ASF_PASSWORD RELEASE_VERSION RELEASE_TAG NEXT_VERSION GIT_EMAIL GIT_NAME GIT_BRANCH; do + if [ -z "${!env}" ]; then + echo "$env must be set to run this script" + exit 1 + fi +done + +ASF_SPARK_REPO="git-wip-us.apache.org/repos/asf/spark.git" +MVN="build/mvn --force" + +rm -rf spark +git clone https://$ASF_USERNAME:$ASF_PASSWORD@$ASF_SPARK_REPO -b $GIT_BRANCH +cd spark + +git config user.name "$GIT_NAME" +git config user.email $GIT_EMAIL + +# Create release version +$MVN versions:set -DnewVersion=$RELEASE_VERSION | grep -v "no value" # silence logs +git commit -a -m "Preparing Spark release $RELEASE_TAG" +echo "Creating tag $RELEASE_TAG at the head of $GIT_BRANCH" +git tag $RELEASE_TAG + +# TODO: It would be nice to do some verifications here +# i.e. check whether ec2 scripts have the new version + +# Create next version +$MVN versions:set -DnewVersion=$NEXT_VERSION | grep -v "no value" # silence logs +git commit -a -m "Preparing development version $NEXT_VERSION" + +# Push changes +git push origin $RELEASE_TAG +git push origin HEAD:$GIT_BRANCH + +cd .. +rm -rf spark From 74a293f4537c6982345166f8883538f81d850872 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 11 Aug 2015 21:26:03 -0700 Subject: [PATCH 280/340] [SPARK-9713] [ML] Document SparkR MLlib glm() integration in Spark 1.5 This documents the use of R model formulae in the SparkR guide. Also fixes some bugs in the R api doc. mengxr Author: Eric Liang Closes #8085 from ericl/docs. --- R/pkg/R/generics.R | 4 ++-- R/pkg/R/mllib.R | 8 ++++---- docs/sparkr.md | 37 ++++++++++++++++++++++++++++++++++++- 3 files changed, 42 insertions(+), 7 deletions(-) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index c43b947129e87..379a78b1d833e 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -535,8 +535,8 @@ setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) #' @export setGeneric("summarize", function(x,...) { standardGeneric("summarize") }) -##' rdname summary -##' @export +#' @rdname summary +#' @export setGeneric("summary", function(x, ...) { standardGeneric("summary") }) # @rdname tojson diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index b524d1fd87496..cea3d760d05fe 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -56,10 +56,10 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram #' #' Makes predictions from a model produced by glm(), similarly to R's predict(). #' -#' @param model A fitted MLlib model +#' @param object A fitted MLlib model #' @param newData DataFrame for testing #' @return DataFrame containing predicted values -#' @rdname glm +#' @rdname predict #' @export #' @examples #'\dontrun{ @@ -76,10 +76,10 @@ setMethod("predict", signature(object = "PipelineModel"), #' #' Returns the summary of a model produced by glm(), similarly to R's summary(). #' -#' @param model A fitted MLlib model +#' @param x A fitted MLlib model #' @return a list with a 'coefficient' component, which is the matrix of coefficients. See #' summary.glm for more information. -#' @rdname glm +#' @rdname summary #' @export #' @examples #'\dontrun{ diff --git a/docs/sparkr.md b/docs/sparkr.md index 4385a4eeacd5c..7139d16b4a068 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -11,7 +11,8 @@ title: SparkR (R on Spark) SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. In Spark {{site.SPARK_VERSION}}, SparkR provides a distributed data frame implementation that supports operations like selection, filtering, aggregation etc. (similar to R data frames, -[dplyr](https://github.com/hadley/dplyr)) but on large datasets. +[dplyr](https://github.com/hadley/dplyr)) but on large datasets. SparkR also supports distributed +machine learning using MLlib. # SparkR DataFrames @@ -230,3 +231,37 @@ head(teenagers) {% endhighlight %}

    + +# Machine Learning + +SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', '+', and '-'. The example below shows the use of building a gaussian GLM model using SparkR. + +
    +{% highlight r %} +# Create the DataFrame +df <- createDataFrame(sqlContext, iris) + +# Fit a linear model over the dataset. +model <- glm(Sepal_Length ~ Sepal_Width + Species, data = df, family = "gaussian") + +# Model coefficients are returned in a similar format to R's native glm(). +summary(model) +##$coefficients +## Estimate +##(Intercept) 2.2513930 +##Sepal_Width 0.8035609 +##Species_versicolor 1.4587432 +##Species_virginica 1.9468169 + +# Make predictions based on the model. +predictions <- predict(model, newData = df) +head(select(predictions, "Sepal_Length", "prediction")) +## Sepal_Length prediction +##1 5.1 5.063856 +##2 4.9 4.662076 +##3 4.7 4.822788 +##4 4.6 4.742432 +##5 5.0 5.144212 +##6 5.4 5.385281 +{% endhighlight %} +
    From c3e9a120e33159fb45cd99f3a55fc5cf16cd7c6c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 11 Aug 2015 22:45:18 -0700 Subject: [PATCH 281/340] [SPARK-9831] [SQL] fix serialization with empty broadcast Author: Davies Liu Closes #8117 from davies/fix_serialization and squashes the following commits: d21ac71 [Davies Liu] fix serialization with empty broadcast --- .../sql/execution/joins/HashedRelation.scala | 2 +- .../execution/joins/HashedRelationSuite.scala | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index c1bc7947aa39c..076afe6e4e960 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -299,7 +299,7 @@ private[joins] final class UnsafeHashedRelation( binaryMap = new BytesToBytesMap( taskMemoryManager, shuffleMemoryManager, - nKeys * 2, // reduce hash collision + (nKeys * 1.5 + 1).toInt, // reduce hash collision pageSizeBytes) var i = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index a1fa2c3864bdb..c635b2d51f464 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -103,4 +103,21 @@ class HashedRelationSuite extends SparkFunSuite { assert(hashed2.get(unsafeData(2)) === data2) assert(numDataRows.value.value === data.length) } + + test("test serialization empty hash map") { + val os = new ByteArrayOutputStream() + val out = new ObjectOutputStream(os) + val hashed = new UnsafeHashedRelation( + new java.util.HashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) + hashed.writeExternal(out) + out.flush() + val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) + val hashed2 = new UnsafeHashedRelation() + hashed2.readExternal(in) + + val schema = StructType(StructField("a", IntegerType, true) :: Nil) + val toUnsafe = UnsafeProjection.create(schema) + val row = toUnsafe(InternalRow(0)) + assert(hashed2.get(row) === null) + } } From b1581ac28840a4d2209ef8bb5c9f8700b4c1b286 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 11 Aug 2015 22:46:59 -0700 Subject: [PATCH 282/340] [SPARK-9854] [SQL] RuleExecutor.timeMap should be thread-safe `RuleExecutor.timeMap` is currently a non-thread-safe mutable HashMap; this can lead to infinite loops if multiple threads are concurrently modifying the map. I believe that this is responsible for some hangs that I've observed in HiveQuerySuite. This patch addresses this by using a Guava `AtomicLongMap`. Author: Josh Rosen Closes #8120 from JoshRosen/rule-executor-time-map-fix. --- .../spark/sql/catalyst/rules/RuleExecutor.scala | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 8b824511a79da..f80d2a93241d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -17,22 +17,25 @@ package org.apache.spark.sql.catalyst.rules +import scala.collection.JavaConverters._ + +import com.google.common.util.concurrent.AtomicLongMap + import org.apache.spark.Logging import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide -import scala.collection.mutable - object RuleExecutor { - protected val timeMap = new mutable.HashMap[String, Long].withDefault(_ => 0) + protected val timeMap = AtomicLongMap.create[String]() /** Resets statistics about time spent running specific rules */ def resetTime(): Unit = timeMap.clear() /** Dump statistics about time spent running specific rules. */ def dumpTimeSpent(): String = { - val maxSize = timeMap.keys.map(_.toString.length).max - timeMap.toSeq.sortBy(_._2).reverseMap { case (k, v) => + val map = timeMap.asMap().asScala + val maxSize = map.keys.map(_.toString.length).max + map.toSeq.sortBy(_._2).reverseMap { case (k, v) => s"${k.padTo(maxSize, " ").mkString} $v" }.mkString("\n") } @@ -79,7 +82,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { val startTime = System.nanoTime() val result = rule(plan) val runTime = System.nanoTime() - startTime - RuleExecutor.timeMap(rule.ruleName) = RuleExecutor.timeMap(rule.ruleName) + runTime + RuleExecutor.timeMap.addAndGet(rule.ruleName, runTime) if (!result.fastEquals(plan)) { logTrace( From b85f9a242a12e8096e331fa77d5ebd16e93c844d Mon Sep 17 00:00:00 2001 From: xutingjun Date: Tue, 11 Aug 2015 23:19:35 -0700 Subject: [PATCH 283/340] [SPARK-8366] maxNumExecutorsNeeded should properly handle failed tasks Author: xutingjun Author: meiyoula <1039320815@qq.com> Closes #6817 from XuTingjun/SPARK-8366. --- .../spark/ExecutorAllocationManager.scala | 22 ++++++++++++------- .../ExecutorAllocationManagerSuite.scala | 22 +++++++++++++++++-- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 1877aaf2cac55..b93536e6536e2 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -599,14 +599,8 @@ private[spark] class ExecutorAllocationManager( // If this is the last pending task, mark the scheduler queue as empty stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex - val numTasksScheduled = stageIdToTaskIndices(stageId).size - val numTasksTotal = stageIdToNumTasks.getOrElse(stageId, -1) - if (numTasksScheduled == numTasksTotal) { - // No more pending tasks for this stage - stageIdToNumTasks -= stageId - if (stageIdToNumTasks.isEmpty) { - allocationManager.onSchedulerQueueEmpty() - } + if (totalPendingTasks() == 0) { + allocationManager.onSchedulerQueueEmpty() } // Mark the executor on which this task is scheduled as busy @@ -618,6 +612,8 @@ private[spark] class ExecutorAllocationManager( override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { val executorId = taskEnd.taskInfo.executorId val taskId = taskEnd.taskInfo.taskId + val taskIndex = taskEnd.taskInfo.index + val stageId = taskEnd.stageId allocationManager.synchronized { numRunningTasks -= 1 // If the executor is no longer running any scheduled tasks, mark it as idle @@ -628,6 +624,16 @@ private[spark] class ExecutorAllocationManager( allocationManager.onExecutorIdle(executorId) } } + + // If the task failed, we expect it to be resubmitted later. To ensure we have + // enough resources to run the resubmitted task, we need to mark the scheduler + // as backlogged again if it's not already marked as such (SPARK-8366) + if (taskEnd.reason != Success) { + if (totalPendingTasks() == 0) { + allocationManager.onSchedulerBacklogged() + } + stageIdToTaskIndices.get(stageId).foreach { _.remove(taskIndex) } + } } } diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 34caca892891c..f374f97f87448 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -206,8 +206,8 @@ class ExecutorAllocationManagerSuite val task2Info = createTaskInfo(1, 0, "executor-1") sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, task2Info)) - sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, null, task1Info, null)) - sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, null, task2Info, null)) + sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, Success, task1Info, null)) + sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, Success, task2Info, null)) assert(adjustRequestedExecutors(manager) === -1) } @@ -787,6 +787,24 @@ class ExecutorAllocationManagerSuite Map("host2" -> 1, "host3" -> 2, "host4" -> 1, "host5" -> 2)) } + test("SPARK-8366: maxNumExecutorsNeeded should properly handle failed tasks") { + sc = createSparkContext() + val manager = sc.executorAllocationManager.get + assert(maxNumExecutorsNeeded(manager) === 0) + + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1))) + assert(maxNumExecutorsNeeded(manager) === 1) + + val taskInfo = createTaskInfo(1, 1, "executor-1") + sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, taskInfo)) + assert(maxNumExecutorsNeeded(manager) === 1) + + // If the task is failed, we expect it to be resubmitted later. + val taskEndReason = ExceptionFailure(null, null, null, null, null) + sc.listenerBus.postToAll(SparkListenerTaskEnd(0, 0, null, taskEndReason, taskInfo, null)) + assert(maxNumExecutorsNeeded(manager) === 1) + } + private def createSparkContext( minExecutors: Int = 1, maxExecutors: Int = 5, From a807fcbe50b2ce18751d80d39e9d21842f7da32a Mon Sep 17 00:00:00 2001 From: Rohit Agarwal Date: Tue, 11 Aug 2015 23:20:39 -0700 Subject: [PATCH 284/340] [SPARK-9806] [WEB UI] Don't share ReplayListenerBus between multiple applications Author: Rohit Agarwal Closes #8088 from mindprince/SPARK-9806. --- .../org/apache/spark/deploy/history/FsHistoryProvider.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index e3060ac3fa1a9..53c18ca3ff50c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -272,9 +272,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) * Replay the log files in the list and merge the list of old applications with new ones */ private def mergeApplicationListing(logs: Seq[FileStatus]): Unit = { - val bus = new ReplayListenerBus() val newAttempts = logs.flatMap { fileStatus => try { + val bus = new ReplayListenerBus() val res = replay(fileStatus, bus) res match { case Some(r) => logDebug(s"Application log ${r.logPath} loaded successfully.") From 4e3f4b934f74e8c7c06f4940d6381343f9fd4918 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 11 Aug 2015 23:23:17 -0700 Subject: [PATCH 285/340] [SPARK-9829] [WEBUI] Display the update value for peak execution memory The peak execution memory is not correct because it shows the sum of finished tasks' values when a task finishes. This PR fixes it by using the update value rather than the accumulator value. Author: zsxwing Closes #8121 from zsxwing/SPARK-9829. --- core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 0c94204df6530..fb4556b836859 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -860,7 +860,7 @@ private[ui] class TaskDataSource( } val peakExecutionMemoryUsed = taskInternalAccumulables .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY } - .map { acc => acc.value.toLong } + .map { acc => acc.update.getOrElse("0").toLong } .getOrElse(0L) val maybeInput = metrics.flatMap(_.inputMetrics) From bab89232854de7554e88f29cab76f1a1c349edc1 Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Tue, 11 Aug 2015 23:25:02 -0700 Subject: [PATCH 286/340] [SPARK-9426] [WEBUI] Job page DAG visualization is not shown To reproduce the issue, go to the stage page and click DAG Visualization once, then go to the job page to show the job DAG visualization. You will only see the first stage of the job. Root cause: the java script use local storage to remember your selection. Once you click the stage DAG visualization, the local storage set `expand-dag-viz-arrow-stage` to true. When you go to the job page, the js checks `expand-dag-viz-arrow-stage` in the local storage first and will try to show stage DAG visualization on the job page. To fix this, I set an id to the DAG span to differ job page and stage page. In the js code, we check the id and local storage together to make sure we show the correct DAG visualization. Author: Carson Wang Closes #8104 from carsonwang/SPARK-9426. --- .../resources/org/apache/spark/ui/static/spark-dag-viz.js | 8 ++++---- core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index 4a893bc0189aa..83dbea40b63f3 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -109,13 +109,13 @@ function toggleDagViz(forJob) { } $(function (){ - if (window.localStorage.getItem(expandDagVizArrowKey(false)) == "true") { + if ($("#stage-dag-viz").length && + window.localStorage.getItem(expandDagVizArrowKey(false)) == "true") { // Set it to false so that the click function can revert it window.localStorage.setItem(expandDagVizArrowKey(false), "false"); toggleDagViz(false); - } - - if (window.localStorage.getItem(expandDagVizArrowKey(true)) == "true") { + } else if ($("#job-dag-viz").length && + window.localStorage.getItem(expandDagVizArrowKey(true)) == "true") { // Set it to false so that the click function can revert it window.localStorage.setItem(expandDagVizArrowKey(true), "false"); toggleDagViz(true); diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 718aea7e1dc22..f2da417724104 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -352,7 +352,8 @@ private[spark] object UIUtils extends Logging { */ private def showDagViz(graphs: Seq[RDDOperationGraph], forJob: Boolean): Seq[Node] = {
    - + From 5c99d8bf98cbf7f568345d02a814fc318cbfca75 Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Tue, 11 Aug 2015 23:26:33 -0700 Subject: [PATCH 287/340] [SPARK-8798] [MESOS] Allow additional uris to be fetched with mesos Some users like to download additional files in their sandbox that they can refer to from their spark program, or even later mount these files to another directory. Author: Timothy Chen Closes #7195 from tnachen/mesos_files. --- .../cluster/mesos/CoarseMesosSchedulerBackend.scala | 5 +++++ .../scheduler/cluster/mesos/MesosClusterScheduler.scala | 3 +++ .../scheduler/cluster/mesos/MesosSchedulerBackend.scala | 5 +++++ .../scheduler/cluster/mesos/MesosSchedulerUtils.scala | 6 ++++++ docs/running-on-mesos.md | 8 ++++++++ 5 files changed, 27 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 15a0915708c7c..d6e1e9e5bebc2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -194,6 +194,11 @@ private[spark] class CoarseMesosSchedulerBackend( s" --app-id $appId") command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get)) } + + conf.getOption("spark.mesos.uris").map { uris => + setupUris(uris, command) + } + command.build() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index f078547e71352..64ec2b8e3db15 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -403,6 +403,9 @@ private[spark] class MesosClusterScheduler( } builder.setValue(s"$executable $cmdOptions $jar $appArguments") builder.setEnvironment(envBuilder.build()) + conf.getOption("spark.mesos.uris").map { uris => + setupUris(uris, builder) + } builder.build() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 3f63ec1c5832f..5c20606d58715 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -133,6 +133,11 @@ private[spark] class MesosSchedulerBackend( builder.addAllResources(usedCpuResources) builder.addAllResources(usedMemResources) + + sc.conf.getOption("spark.mesos.uris").map { uris => + setupUris(uris, command) + } + val executorInfo = builder .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) .setCommand(command) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index c04920e4f5873..5b854aa5c2754 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -331,4 +331,10 @@ private[mesos] trait MesosSchedulerUtils extends Logging { sc.executorMemory } + def setupUris(uris: String, builder: CommandInfo.Builder): Unit = { + uris.split(",").foreach { uri => + builder.addUris(CommandInfo.URI.newBuilder().setValue(uri.trim())) + } + } + } diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index debdd2adf22d6..55e6d4e83a725 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -306,6 +306,14 @@ See the [configuration page](configuration.html) for information on Spark config the final overhead will be this value. + + spark.mesos.uris + (none) + + A list of URIs to be downloaded to the sandbox when driver or executor is launched by Mesos. + This applies to both coarse-grain and fine-grain mode. + + spark.mesos.principal Framework principal to authenticate to Mesos From 741a29f98945538a475579ccc974cd42c1613be4 Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Tue, 11 Aug 2015 23:33:22 -0700 Subject: [PATCH 288/340] [SPARK-9575] [MESOS] Add docuemntation around Mesos shuffle service. andrewor14 Author: Timothy Chen Closes #7907 from tnachen/mesos_shuffle. --- docs/running-on-mesos.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 55e6d4e83a725..cfd219ab02e26 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -216,6 +216,20 @@ node. Please refer to [Hadoop on Mesos](https://github.com/mesos/hadoop). In either case, HDFS runs separately from Hadoop MapReduce, without being scheduled through Mesos. +# Dynamic Resource Allocation with Mesos + +Mesos supports dynamic allocation only with coarse grain mode, which can resize the number of executors based on statistics +of the application. While dynamic allocation supports both scaling up and scaling down the number of executors, the coarse grain scheduler only supports scaling down +since it is already designed to run one executor per slave with the configured amount of resources. However, after scaling down the number of executors the coarse grain scheduler +can scale back up to the same amount of executors when Spark signals more executors are needed. + +Users that like to utilize this feature should launch the Mesos Shuffle Service that +provides shuffle data cleanup functionality on top of the Shuffle Service since Mesos doesn't yet support notifying another framework's +termination. To launch/stop the Mesos Shuffle Service please use the provided sbin/start-mesos-shuffle-service.sh and sbin/stop-mesos-shuffle-service.sh +scripts accordingly. + +The Shuffle Service is expected to be running on each slave node that will run Spark executors. One way to easily achieve this with Mesos +is to launch the Shuffle Service with Marathon with a unique host constraint. # Configuration From 9d0822455ddc8d765440d58c463367a4d67ef456 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 12 Aug 2015 19:54:00 +0800 Subject: [PATCH 289/340] [SPARK-9182] [SQL] Filters are not passed through to jdbc source This PR fixes unable to push filter down to JDBC source caused by `Cast` during pattern matching. While we are comparing columns of different type, there's a big chance we need a cast on the column, therefore not match the pattern directly on Attribute and would fail to push down. Author: Yijie Shen Closes #8049 from yjshen/jdbc_pushdown. --- .../datasources/DataSourceStrategy.scala | 30 ++++++++++++++-- .../execution/datasources/jdbc/JDBCRDD.scala | 2 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 34 +++++++++++++++++++ 3 files changed, 63 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2a4c40db8bb66..9eea2b0382535 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.{Logging, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} -import org.apache.spark.sql.catalyst.{InternalRow, expressions} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types.{TimestampType, DateType, StringType, StructType} import org.apache.spark.sql.{SaveMode, Strategy, execution, sources, _} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -343,11 +343,17 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { * and convert them. */ protected[sql] def selectFilters(filters: Seq[Expression]) = { + import CatalystTypeConverters._ + def translate(predicate: Expression): Option[Filter] = predicate match { case expressions.EqualTo(a: Attribute, Literal(v, _)) => Some(sources.EqualTo(a.name, v)) case expressions.EqualTo(Literal(v, _), a: Attribute) => Some(sources.EqualTo(a.name, v)) + case expressions.EqualTo(Cast(a: Attribute, _), l: Literal) => + Some(sources.EqualTo(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) + case expressions.EqualTo(l: Literal, Cast(a: Attribute, _)) => + Some(sources.EqualTo(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) => Some(sources.EqualNullSafe(a.name, v)) @@ -358,21 +364,41 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { Some(sources.GreaterThan(a.name, v)) case expressions.GreaterThan(Literal(v, _), a: Attribute) => Some(sources.LessThan(a.name, v)) + case expressions.GreaterThan(Cast(a: Attribute, _), l: Literal) => + Some(sources.GreaterThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) + case expressions.GreaterThan(l: Literal, Cast(a: Attribute, _)) => + Some(sources.LessThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.LessThan(a: Attribute, Literal(v, _)) => Some(sources.LessThan(a.name, v)) case expressions.LessThan(Literal(v, _), a: Attribute) => Some(sources.GreaterThan(a.name, v)) + case expressions.LessThan(Cast(a: Attribute, _), l: Literal) => + Some(sources.LessThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) + case expressions.LessThan(l: Literal, Cast(a: Attribute, _)) => + Some(sources.GreaterThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => Some(sources.GreaterThanOrEqual(a.name, v)) case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => Some(sources.LessThanOrEqual(a.name, v)) + case expressions.GreaterThanOrEqual(Cast(a: Attribute, _), l: Literal) => + Some(sources.GreaterThanOrEqual(a.name, + convertToScala(Cast(l, a.dataType).eval(), a.dataType))) + case expressions.GreaterThanOrEqual(l: Literal, Cast(a: Attribute, _)) => + Some(sources.LessThanOrEqual(a.name, + convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => Some(sources.LessThanOrEqual(a.name, v)) case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => Some(sources.GreaterThanOrEqual(a.name, v)) + case expressions.LessThanOrEqual(Cast(a: Attribute, _), l: Literal) => + Some(sources.LessThanOrEqual(a.name, + convertToScala(Cast(l, a.dataType).eval(), a.dataType))) + case expressions.LessThanOrEqual(l: Literal, Cast(a: Attribute, _)) => + Some(sources.GreaterThanOrEqual(a.name, + convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.InSet(a: Attribute, set) => Some(sources.In(a.name, set.toArray)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 8eab6a0adccc4..281943e23fcff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -284,7 +284,7 @@ private[sql] class JDBCRDD( /** * `filters`, but as a WHERE clause suitable for injection into a SQL query. */ - private val filterWhereClause: String = { + val filterWhereClause: String = { val filterStrings = filters map compileFilter filter (_ != null) if (filterStrings.size > 0) { val sb = new StringBuilder("WHERE ") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 42f2449afb0f9..b9cfae51e809c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -25,6 +25,8 @@ import org.h2.jdbc.JdbcSQLException import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD +import org.apache.spark.sql.execution.PhysicalRDD import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -148,6 +150,18 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { |OPTIONS (url '$url', dbtable 'TEST.FLTTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement("create table test.decimals (a DECIMAL(7, 2), b DECIMAL(4, 0))"). + executeUpdate() + conn.prepareStatement("insert into test.decimals values (12345.67, 1234)").executeUpdate() + conn.prepareStatement("insert into test.decimals values (34567.89, 1428)").executeUpdate() + conn.commit() + sql( + s""" + |CREATE TEMPORARY TABLE decimals + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable 'TEST.DECIMALS', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement( s""" |create table test.nulltypes (a INT, b BOOLEAN, c TINYINT, d BINARY(20), e VARCHAR(20), @@ -445,4 +459,24 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) } + test("SPARK-9182: filters are not passed through to jdbc source") { + def checkPushedFilter(query: String, filterStr: String): Unit = { + val rddOpt = sql(query).queryExecution.executedPlan.collectFirst { + case PhysicalRDD(_, rdd: JDBCRDD, _) => rdd + } + assert(rddOpt.isDefined) + val pushedFilterStr = rddOpt.get.filterWhereClause + assert(pushedFilterStr.contains(filterStr), + s"Expected to push [$filterStr], actually we pushed [$pushedFilterStr]") + } + + checkPushedFilter("select * from foobar where NAME = 'fred'", "NAME = 'fred'") + checkPushedFilter("select * from inttypes where A > '15'", "A > 15") + checkPushedFilter("select * from inttypes where C <= 20", "C <= 20") + checkPushedFilter("select * from decimals where A > 1000", "A > 1000.00") + checkPushedFilter("select * from decimals where A > 1000 AND A < 2000", + "A > 1000.00 AND A < 2000.00") + checkPushedFilter("select * from decimals where A = 2000 AND B > 20", "A = 2000.00 AND B > 20") + checkPushedFilter("select * from timetypes where B > '1998-09-10'", "B > 1998-09-10") + } } From 3ecb3794302dc12d0989f8d725483b2cc37762cf Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 12 Aug 2015 20:01:34 +0800 Subject: [PATCH 290/340] [SPARK-9407] [SQL] Relaxes Parquet ValidTypeMap to allow ENUM predicates to be pushed down This PR adds a hacky workaround for PARQUET-201, and should be removed once we upgrade to parquet-mr 1.8.1 or higher versions. In Parquet, not all types of columns can be used for filter push-down optimization. The set of valid column types is controlled by `ValidTypeMap`. Unfortunately, in parquet-mr 1.7.0 and prior versions, this limitation is too strict, and doesn't allow `BINARY (ENUM)` columns to be pushed down. On the other hand, `BINARY (ENUM)` is commonly seen in Parquet files written by libraries like `parquet-avro`. This restriction is problematic for Spark SQL, because Spark SQL doesn't have a type that maps to Parquet `BINARY (ENUM)` directly, and always converts `BINARY (ENUM)` to Catalyst `StringType`. Thus, a predicate involving a `BINARY (ENUM)` is recognized as one involving a string field instead and can be pushed down by the query optimizer. Such predicates are actually perfectly legal except that it fails the `ValidTypeMap` check. The workaround added here is relaxing `ValidTypeMap` to include `BINARY (ENUM)`. I also took the chance to simplify `ParquetCompatibilityTest` a little bit when adding regression test. Author: Cheng Lian Closes #8107 from liancheng/spark-9407/parquet-enum-filter-push-down. --- .../datasources/parquet/ParquetFilters.scala | 38 ++++- .../datasources/parquet/ParquetRelation.scala | 2 +- sql/core/src/test/README.md | 16 +- sql/core/src/test/avro/parquet-compat.avdl | 13 +- sql/core/src/test/avro/parquet-compat.avpr | 13 +- .../parquet/test/avro/CompatibilityTest.java | 2 +- .../datasources/parquet/test/avro/Nested.java | 4 +- .../parquet/test/avro/ParquetAvroCompat.java | 4 +- .../parquet/test/avro/ParquetEnum.java | 142 ++++++++++++++++++ .../datasources/parquet/test/avro/Suit.java | 13 ++ .../ParquetAvroCompatibilitySuite.scala | 105 +++++++------ .../parquet/ParquetCompatibilityTest.scala | 33 +--- .../test/scripts/{gen-code.sh => gen-avro.sh} | 13 +- sql/core/src/test/scripts/gen-thrift.sh | 27 ++++ .../src/test/thrift/parquet-compat.thrift | 2 +- .../hive/ParquetHiveCompatibilitySuite.scala | 83 +++++----- 16 files changed, 374 insertions(+), 136 deletions(-) create mode 100644 sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetEnum.java create mode 100644 sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Suit.java rename sql/core/src/test/scripts/{gen-code.sh => gen-avro.sh} (76%) create mode 100755 sql/core/src/test/scripts/gen-thrift.sh diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 9e2e232f50167..63915e0a28655 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -25,9 +25,10 @@ import org.apache.hadoop.conf.Configuration import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.compat.FilterCompat._ import org.apache.parquet.filter2.predicate.FilterApi._ -import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Statistics} -import org.apache.parquet.filter2.predicate.UserDefinedPredicate +import org.apache.parquet.filter2.predicate._ import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.OriginalType +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.expressions._ @@ -197,6 +198,8 @@ private[sql] object ParquetFilters { def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = { val dataTypeOf = schema.map(f => f.name -> f.dataType).toMap + relaxParquetValidTypeMap + // NOTE: // // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, @@ -239,6 +242,37 @@ private[sql] object ParquetFilters { } } + // !! HACK ALERT !! + // + // This lazy val is a workaround for PARQUET-201, and should be removed once we upgrade to + // parquet-mr 1.8.1 or higher versions. + // + // In Parquet, not all types of columns can be used for filter push-down optimization. The set + // of valid column types is controlled by `ValidTypeMap`. Unfortunately, in parquet-mr 1.7.0 and + // prior versions, the limitation is too strict, and doesn't allow `BINARY (ENUM)` columns to be + // pushed down. + // + // This restriction is problematic for Spark SQL, because Spark SQL doesn't have a type that maps + // to Parquet original type `ENUM` directly, and always converts `ENUM` to `StringType`. Thus, + // a predicate involving a `ENUM` field can be pushed-down as a string column, which is perfectly + // legal except that it fails the `ValidTypeMap` check. + // + // Here we add `BINARY (ENUM)` into `ValidTypeMap` lazily via reflection to workaround this issue. + private lazy val relaxParquetValidTypeMap: Unit = { + val constructor = Class + .forName(classOf[ValidTypeMap].getCanonicalName + "$FullTypeDescriptor") + .getDeclaredConstructor(classOf[PrimitiveTypeName], classOf[OriginalType]) + + constructor.setAccessible(true) + val enumTypeDescriptor = constructor + .newInstance(PrimitiveTypeName.BINARY, OriginalType.ENUM) + .asInstanceOf[AnyRef] + + val addMethod = classOf[ValidTypeMap].getDeclaredMethods.find(_.getName == "add").get + addMethod.setAccessible(true) + addMethod.invoke(null, classOf[Binary], enumTypeDescriptor) + } + /** * Converts Catalyst predicate expressions to Parquet filter predicates. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index c71c69b6e80b1..52fac18ba187a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -678,7 +678,7 @@ private[sql] object ParquetRelation extends Logging { val followParquetFormatSpec = sqlContext.conf.followParquetFormatSpec val serializedConf = new SerializableConfiguration(sqlContext.sparkContext.hadoopConfiguration) - // HACK ALERT: + // !! HACK ALERT !! // // Parquet requires `FileStatus`es to read footers. Here we try to send cached `FileStatus`es // to executor side to avoid fetching them again. However, `FileStatus` is not `Serializable` diff --git a/sql/core/src/test/README.md b/sql/core/src/test/README.md index 3dd9861b4896d..421c2ea4f7aed 100644 --- a/sql/core/src/test/README.md +++ b/sql/core/src/test/README.md @@ -6,23 +6,19 @@ The following directories and files are used for Parquet compatibility tests: . ├── README.md # This file ├── avro -│   ├── parquet-compat.avdl # Testing Avro IDL -│   └── parquet-compat.avpr # !! NO TOUCH !! Protocol file generated from parquet-compat.avdl +│   ├── *.avdl # Testing Avro IDL(s) +│   └── *.avpr # !! NO TOUCH !! Protocol files generated from Avro IDL(s) ├── gen-java # !! NO TOUCH !! Generated Java code ├── scripts -│   └── gen-code.sh # Script used to generate Java code for Thrift and Avro +│   ├── gen-avro.sh # Script used to generate Java code for Avro +│   └── gen-thrift.sh # Script used to generate Java code for Thrift └── thrift - └── parquet-compat.thrift # Testing Thrift schema + └── *.thrift # Testing Thrift schema(s) ``` -Generated Java code are used in the following test suites: - -- `org.apache.spark.sql.parquet.ParquetAvroCompatibilitySuite` -- `org.apache.spark.sql.parquet.ParquetThriftCompatibilitySuite` - To avoid code generation during build time, Java code generated from testing Thrift schema and Avro IDL are also checked in. -When updating the testing Thrift schema and Avro IDL, please run `gen-code.sh` to update all the generated Java code. +When updating the testing Thrift schema and Avro IDL, please run `gen-avro.sh` and `gen-thrift.sh` accordingly to update generated Java code. ## Prerequisites diff --git a/sql/core/src/test/avro/parquet-compat.avdl b/sql/core/src/test/avro/parquet-compat.avdl index 24729f6143e6c..8070d0a9170a3 100644 --- a/sql/core/src/test/avro/parquet-compat.avdl +++ b/sql/core/src/test/avro/parquet-compat.avdl @@ -16,8 +16,19 @@ */ // This is a test protocol for testing parquet-avro compatibility. -@namespace("org.apache.spark.sql.parquet.test.avro") +@namespace("org.apache.spark.sql.execution.datasources.parquet.test.avro") protocol CompatibilityTest { + enum Suit { + SPADES, + HEARTS, + DIAMONDS, + CLUBS + } + + record ParquetEnum { + Suit suit; + } + record Nested { array nested_ints_column; string nested_string_column; diff --git a/sql/core/src/test/avro/parquet-compat.avpr b/sql/core/src/test/avro/parquet-compat.avpr index a83b7c990dd2e..060391765034b 100644 --- a/sql/core/src/test/avro/parquet-compat.avpr +++ b/sql/core/src/test/avro/parquet-compat.avpr @@ -1,7 +1,18 @@ { "protocol" : "CompatibilityTest", - "namespace" : "org.apache.spark.sql.parquet.test.avro", + "namespace" : "org.apache.spark.sql.execution.datasources.parquet.test.avro", "types" : [ { + "type" : "enum", + "name" : "Suit", + "symbols" : [ "SPADES", "HEARTS", "DIAMONDS", "CLUBS" ] + }, { + "type" : "record", + "name" : "ParquetEnum", + "fields" : [ { + "name" : "suit", + "type" : "Suit" + } ] + }, { "type" : "record", "name" : "Nested", "fields" : [ { diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java index 70dec1a9d3c92..2368323cb36b9 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java @@ -8,7 +8,7 @@ @SuppressWarnings("all") @org.apache.avro.specific.AvroGenerated public interface CompatibilityTest { - public static final org.apache.avro.Protocol PROTOCOL = org.apache.avro.Protocol.parse("{\"protocol\":\"CompatibilityTest\",\"namespace\":\"org.apache.spark.sql.parquet.test.avro\",\"types\":[{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}},{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]},{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"Nested\"},\"avro.java.string\":\"String\"}}]}],\"messages\":{}}"); + public static final org.apache.avro.Protocol PROTOCOL = org.apache.avro.Protocol.parse("{\"protocol\":\"CompatibilityTest\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"types\":[{\"type\":\"enum\",\"name\":\"Suit\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]},{\"type\":\"record\",\"name\":\"ParquetEnum\",\"fields\":[{\"name\":\"suit\",\"type\":\"Suit\"}]},{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}},{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]},{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"Nested\"},\"avro.java.string\":\"String\"}}]}],\"messages\":{}}"); @SuppressWarnings("all") public interface Callback extends CompatibilityTest { diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java index a0a406bcd10c1..a7bf4841919c5 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java @@ -3,11 +3,11 @@ * * DO NOT EDIT DIRECTLY */ -package org.apache.spark.sql.execution.datasources.parquet.test.avro; +package org.apache.spark.sql.execution.datasources.parquet.test.avro; @SuppressWarnings("all") @org.apache.avro.specific.AvroGenerated public class Nested extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { - public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"Nested\",\"namespace\":\"org.apache.spark.sql.parquet.test.avro\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}"); + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"Nested\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}"); public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } @Deprecated public java.util.List nested_ints_column; @Deprecated public java.lang.String nested_string_column; diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java index 6198b00b1e3ca..681cacbd12c7c 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java @@ -3,11 +3,11 @@ * * DO NOT EDIT DIRECTLY */ -package org.apache.spark.sql.execution.datasources.parquet.test.avro; +package org.apache.spark.sql.execution.datasources.parquet.test.avro; @SuppressWarnings("all") @org.apache.avro.specific.AvroGenerated public class ParquetAvroCompat extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { - public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"namespace\":\"org.apache.spark.sql.parquet.test.avro\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}},{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]},{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}},\"avro.java.string\":\"String\"}}]}"); + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}},{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]},{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}},\"avro.java.string\":\"String\"}}]}"); public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } @Deprecated public boolean bool_column; @Deprecated public int int_column; diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetEnum.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetEnum.java new file mode 100644 index 0000000000000..05fefe4cee754 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetEnum.java @@ -0,0 +1,142 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class ParquetEnum extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ParquetEnum\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"suit\",\"type\":{\"type\":\"enum\",\"name\":\"Suit\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit suit; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public ParquetEnum() {} + + /** + * All-args constructor. + */ + public ParquetEnum(org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit suit) { + this.suit = suit; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return suit; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: suit = (org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'suit' field. + */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit getSuit() { + return suit; + } + + /** + * Sets the value of the 'suit' field. + * @param value the value to set. + */ + public void setSuit(org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit value) { + this.suit = value; + } + + /** Creates a new ParquetEnum RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder(); + } + + /** Creates a new ParquetEnum RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder(other); + } + + /** Creates a new ParquetEnum RecordBuilder by copying an existing ParquetEnum instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder(other); + } + + /** + * RecordBuilder for ParquetEnum instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit suit; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder other) { + super(other); + if (isValidValue(fields()[0], other.suit)) { + this.suit = data().deepCopy(fields()[0].schema(), other.suit); + fieldSetFlags()[0] = true; + } + } + + /** Creates a Builder by copying an existing ParquetEnum instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.SCHEMA$); + if (isValidValue(fields()[0], other.suit)) { + this.suit = data().deepCopy(fields()[0].schema(), other.suit); + fieldSetFlags()[0] = true; + } + } + + /** Gets the value of the 'suit' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit getSuit() { + return suit; + } + + /** Sets the value of the 'suit' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder setSuit(org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit value) { + validate(fields()[0], value); + this.suit = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'suit' field has been set */ + public boolean hasSuit() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'suit' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder clearSuit() { + suit = null; + fieldSetFlags()[0] = false; + return this; + } + + @Override + public ParquetEnum build() { + try { + ParquetEnum record = new ParquetEnum(); + record.suit = fieldSetFlags()[0] ? this.suit : (org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit) defaultValue(fields()[0]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Suit.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Suit.java new file mode 100644 index 0000000000000..00711a0c2a267 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Suit.java @@ -0,0 +1,13 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public enum Suit { + SPADES, HEARTS, DIAMONDS, CLUBS ; + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"enum\",\"name\":\"Suit\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala index 4d9c07bb7a570..866a975ad5404 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -22,10 +22,12 @@ import java.util.{List => JList, Map => JMap} import scala.collection.JavaConversions._ +import org.apache.avro.Schema +import org.apache.avro.generic.IndexedRecord import org.apache.hadoop.fs.Path import org.apache.parquet.avro.AvroParquetWriter -import org.apache.spark.sql.execution.datasources.parquet.test.avro.{Nested, ParquetAvroCompat} +import org.apache.spark.sql.execution.datasources.parquet.test.avro.{Nested, ParquetAvroCompat, ParquetEnum, Suit} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.{Row, SQLContext} @@ -34,52 +36,55 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest { override val sqlContext: SQLContext = TestSQLContext - override protected def beforeAll(): Unit = { - super.beforeAll() - - val writer = - new AvroParquetWriter[ParquetAvroCompat]( - new Path(parquetStore.getCanonicalPath), - ParquetAvroCompat.getClassSchema) - - (0 until 10).foreach(i => writer.write(makeParquetAvroCompat(i))) - writer.close() + private def withWriter[T <: IndexedRecord] + (path: String, schema: Schema) + (f: AvroParquetWriter[T] => Unit) = { + val writer = new AvroParquetWriter[T](new Path(path), schema) + try f(writer) finally writer.close() } test("Read Parquet file generated by parquet-avro") { - logInfo( - s"""Schema of the Parquet file written by parquet-avro: - |${readParquetSchema(parquetStore.getCanonicalPath)} + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[ParquetAvroCompat](path, ParquetAvroCompat.getClassSchema) { writer => + (0 until 10).foreach(i => writer.write(makeParquetAvroCompat(i))) + } + + logInfo( + s"""Schema of the Parquet file written by parquet-avro: + |${readParquetSchema(path)} """.stripMargin) - checkAnswer(sqlContext.read.parquet(parquetStore.getCanonicalPath), (0 until 10).map { i => - def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i) - - Row( - i % 2 == 0, - i, - i.toLong * 10, - i.toFloat + 0.1f, - i.toDouble + 0.2d, - s"val_$i".getBytes, - s"val_$i", - - nullable(i % 2 == 0: java.lang.Boolean), - nullable(i: Integer), - nullable(i.toLong: java.lang.Long), - nullable(i.toFloat + 0.1f: java.lang.Float), - nullable(i.toDouble + 0.2d: java.lang.Double), - nullable(s"val_$i".getBytes), - nullable(s"val_$i"), - - Seq.tabulate(3)(n => s"arr_${i + n}"), - Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap, - Seq.tabulate(3) { n => - (i + n).toString -> Seq.tabulate(3) { m => - Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}") - } - }.toMap) - }) + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i) + + Row( + i % 2 == 0, + i, + i.toLong * 10, + i.toFloat + 0.1f, + i.toDouble + 0.2d, + s"val_$i".getBytes, + s"val_$i", + + nullable(i % 2 == 0: java.lang.Boolean), + nullable(i: Integer), + nullable(i.toLong: java.lang.Long), + nullable(i.toFloat + 0.1f: java.lang.Float), + nullable(i.toDouble + 0.2d: java.lang.Double), + nullable(s"val_$i".getBytes), + nullable(s"val_$i"), + + Seq.tabulate(3)(n => s"arr_${i + n}"), + Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap, + Seq.tabulate(3) { n => + (i + n).toString -> Seq.tabulate(3) { m => + Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}") + } + }.toMap) + }) + } } def makeParquetAvroCompat(i: Int): ParquetAvroCompat = { @@ -122,4 +127,20 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest { .build() } + + test("SPARK-9407 Don't push down predicates involving Parquet ENUM columns") { + import sqlContext.implicits._ + + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[ParquetEnum](path, ParquetEnum.getClassSchema) { writer => + (0 until 4).foreach { i => + writer.write(ParquetEnum.newBuilder().setSuit(Suit.values.apply(i)).build()) + } + } + + checkAnswer(sqlContext.read.parquet(path).filter('suit === "SPADES"), Row("SPADES")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index 68f35b1f3aa83..0ea64aa2a509b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -16,45 +16,28 @@ */ package org.apache.spark.sql.execution.datasources.parquet -import java.io.File import scala.collection.JavaConversions._ -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.parquet.hadoop.ParquetFileReader import org.apache.parquet.schema.MessageType import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.QueryTest -import org.apache.spark.util.Utils abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest with BeforeAndAfterAll { - protected var parquetStore: File = _ - - /** - * Optional path to a staging subdirectory which may be created during query processing - * (Hive does this). - * Parquet files under this directory will be ignored in [[readParquetSchema()]] - * @return an optional staging directory to ignore when scanning for parquet files. - */ - protected def stagingDir: Option[String] = None - - override protected def beforeAll(): Unit = { - parquetStore = Utils.createTempDir(namePrefix = "parquet-compat_") - parquetStore.delete() - } - - override protected def afterAll(): Unit = { - Utils.deleteRecursively(parquetStore) + def readParquetSchema(path: String): MessageType = { + readParquetSchema(path, { path => !path.getName.startsWith("_") }) } - def readParquetSchema(path: String): MessageType = { + def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = { val fsPath = new Path(path) val fs = fsPath.getFileSystem(configuration) - val parquetFiles = fs.listStatus(fsPath).toSeq.filterNot { status => - status.getPath.getName.startsWith("_") || - stagingDir.map(status.getPath.getName.startsWith).getOrElse(false) - } + val parquetFiles = fs.listStatus(fsPath, new PathFilter { + override def accept(path: Path): Boolean = pathFilter(path) + }).toSeq + val footers = ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles, true) footers.head.getParquetMetadata.getFileMetaData.getSchema } diff --git a/sql/core/src/test/scripts/gen-code.sh b/sql/core/src/test/scripts/gen-avro.sh similarity index 76% rename from sql/core/src/test/scripts/gen-code.sh rename to sql/core/src/test/scripts/gen-avro.sh index 5d8d8ad08555c..48174b287fd7c 100755 --- a/sql/core/src/test/scripts/gen-code.sh +++ b/sql/core/src/test/scripts/gen-avro.sh @@ -22,10 +22,9 @@ cd - rm -rf $BASEDIR/gen-java mkdir -p $BASEDIR/gen-java -thrift\ - --gen java\ - -out $BASEDIR/gen-java\ - $BASEDIR/thrift/parquet-compat.thrift - -avro-tools idl $BASEDIR/avro/parquet-compat.avdl > $BASEDIR/avro/parquet-compat.avpr -avro-tools compile -string protocol $BASEDIR/avro/parquet-compat.avpr $BASEDIR/gen-java +for input in `ls $BASEDIR/avro/*.avdl`; do + filename=$(basename "$input") + filename="${filename%.*}" + avro-tools idl $input> $BASEDIR/avro/${filename}.avpr + avro-tools compile -string protocol $BASEDIR/avro/${filename}.avpr $BASEDIR/gen-java +done diff --git a/sql/core/src/test/scripts/gen-thrift.sh b/sql/core/src/test/scripts/gen-thrift.sh new file mode 100755 index 0000000000000..ada432c68ab95 --- /dev/null +++ b/sql/core/src/test/scripts/gen-thrift.sh @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +cd $(dirname $0)/.. +BASEDIR=`pwd` +cd - + +rm -rf $BASEDIR/gen-java +mkdir -p $BASEDIR/gen-java + +for input in `ls $BASEDIR/thrift/*.thrift`; do + thrift --gen java -out $BASEDIR/gen-java $input +done diff --git a/sql/core/src/test/thrift/parquet-compat.thrift b/sql/core/src/test/thrift/parquet-compat.thrift index fa5ed8c62306a..98bf778aec5d6 100644 --- a/sql/core/src/test/thrift/parquet-compat.thrift +++ b/sql/core/src/test/thrift/parquet-compat.thrift @@ -15,7 +15,7 @@ * limitations under the License. */ -namespace java org.apache.spark.sql.parquet.test.thrift +namespace java org.apache.spark.sql.execution.datasources.parquet.test.thrift enum Suit { SPADES, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index 80eb9f122ad90..251e0324bfa5f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -32,53 +32,54 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest { * Set the staging directory (and hence path to ignore Parquet files under) * to that set by [[HiveConf.ConfVars.STAGINGDIR]]. */ - override val stagingDir: Option[String] = - Some(new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR)) + private val stagingDir = new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR) - override protected def beforeAll(): Unit = { - super.beforeAll() + test("Read Parquet file generated by parquet-hive") { + withTable("parquet_compat") { + withTempPath { dir => + val path = dir.getCanonicalPath - withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { - withTempTable("data") { - sqlContext.sql( - s"""CREATE TABLE parquet_compat( - | bool_column BOOLEAN, - | byte_column TINYINT, - | short_column SMALLINT, - | int_column INT, - | long_column BIGINT, - | float_column FLOAT, - | double_column DOUBLE, - | - | strings_column ARRAY, - | int_to_string_column MAP - |) - |STORED AS PARQUET - |LOCATION '${parquetStore.getCanonicalPath}' - """.stripMargin) + withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { + withTempTable("data") { + sqlContext.sql( + s"""CREATE TABLE parquet_compat( + | bool_column BOOLEAN, + | byte_column TINYINT, + | short_column SMALLINT, + | int_column INT, + | long_column BIGINT, + | float_column FLOAT, + | double_column DOUBLE, + | + | strings_column ARRAY, + | int_to_string_column MAP + |) + |STORED AS PARQUET + |LOCATION '$path' + """.stripMargin) - val schema = sqlContext.table("parquet_compat").schema - val rowRDD = sqlContext.sparkContext.parallelize(makeRows).coalesce(1) - sqlContext.createDataFrame(rowRDD, schema).registerTempTable("data") - sqlContext.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") - } - } - } + val schema = sqlContext.table("parquet_compat").schema + val rowRDD = sqlContext.sparkContext.parallelize(makeRows).coalesce(1) + sqlContext.createDataFrame(rowRDD, schema).registerTempTable("data") + sqlContext.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") + } + } - override protected def afterAll(): Unit = { - sqlContext.sql("DROP TABLE parquet_compat") - } + val schema = readParquetSchema(path, { path => + !path.getName.startsWith("_") && !path.getName.startsWith(stagingDir) + }) - test("Read Parquet file generated by parquet-hive") { - logInfo( - s"""Schema of the Parquet file written by parquet-hive: - |${readParquetSchema(parquetStore.getCanonicalPath)} - """.stripMargin) + logInfo( + s"""Schema of the Parquet file written by parquet-hive: + |$schema + """.stripMargin) - // Unfortunately parquet-hive doesn't add `UTF8` annotation to BINARY when writing strings. - // Have to assume all BINARY values are strings here. - withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { - checkAnswer(sqlContext.read.parquet(parquetStore.getCanonicalPath), makeRows) + // Unfortunately parquet-hive doesn't add `UTF8` annotation to BINARY when writing strings. + // Have to assume all BINARY values are strings here. + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { + checkAnswer(sqlContext.read.parquet(path), makeRows) + } + } } } From 2e680668f7b6fc158aa068aedd19c1878ecf759e Mon Sep 17 00:00:00 2001 From: Tom White Date: Wed, 12 Aug 2015 10:06:27 -0500 Subject: [PATCH 291/340] [SPARK-8625] [CORE] Propagate user exceptions in tasks back to driver This allows clients to retrieve the original exception from the cause field of the SparkException that is thrown by the driver. If the original exception is not in fact Serializable then it will not be returned, but the message and stacktrace will be. (All Java Throwables implement the Serializable interface, but this is no guarantee that a particular implementation can actually be serialized.) Author: Tom White Closes #7014 from tomwhite/propagate-user-exceptions. --- .../org/apache/spark/TaskEndReason.scala | 44 ++++++++++++- .../org/apache/spark/executor/Executor.scala | 14 +++- .../apache/spark/scheduler/DAGScheduler.scala | 44 ++++++++----- .../spark/scheduler/DAGSchedulerEvent.scala | 3 +- .../spark/scheduler/TaskSetManager.scala | 12 ++-- .../org/apache/spark/util/JsonProtocol.scala | 2 +- .../ExecutorAllocationManagerSuite.scala | 2 +- .../scala/org/apache/spark/FailureSuite.scala | 66 ++++++++++++++++++- .../spark/scheduler/DAGSchedulerSuite.scala | 2 +- .../spark/scheduler/TaskSetManagerSuite.scala | 5 +- .../ui/jobs/JobProgressListenerSuite.scala | 2 +- .../apache/spark/util/JsonProtocolSuite.scala | 3 +- 12 files changed, 165 insertions(+), 34 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 48fd3e7e23d52..934d00dc708b9 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -17,6 +17,8 @@ package org.apache.spark +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId @@ -90,6 +92,10 @@ case class FetchFailed( * * `fullStackTrace` is a better representation of the stack trace because it contains the whole * stack trace including the exception and its causes + * + * `exception` is the actual exception that caused the task to fail. It may be `None` in + * the case that the exception is not in fact serializable. If a task fails more than + * once (due to retries), `exception` is that one that caused the last failure. */ @DeveloperApi case class ExceptionFailure( @@ -97,11 +103,26 @@ case class ExceptionFailure( description: String, stackTrace: Array[StackTraceElement], fullStackTrace: String, - metrics: Option[TaskMetrics]) + metrics: Option[TaskMetrics], + private val exceptionWrapper: Option[ThrowableSerializationWrapper]) extends TaskFailedReason { + /** + * `preserveCause` is used to keep the exception itself so it is available to the + * driver. This may be set to `false` in the event that the exception is not in fact + * serializable. + */ + private[spark] def this(e: Throwable, metrics: Option[TaskMetrics], preserveCause: Boolean) { + this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics, + if (preserveCause) Some(new ThrowableSerializationWrapper(e)) else None) + } + private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) { - this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics) + this(e, metrics, preserveCause = true) + } + + def exception: Option[Throwable] = exceptionWrapper.flatMap { + (w: ThrowableSerializationWrapper) => Option(w.exception) } override def toErrorString: String = @@ -127,6 +148,25 @@ case class ExceptionFailure( } } +/** + * A class for recovering from exceptions when deserializing a Throwable that was + * thrown in user task code. If the Throwable cannot be deserialized it will be null, + * but the stacktrace and message will be preserved correctly in SparkException. + */ +private[spark] class ThrowableSerializationWrapper(var exception: Throwable) extends + Serializable with Logging { + private def writeObject(out: ObjectOutputStream): Unit = { + out.writeObject(exception) + } + private def readObject(in: ObjectInputStream): Unit = { + try { + exception = in.readObject().asInstanceOf[Throwable] + } catch { + case e : Exception => log.warn("Task exception could not be deserialized", e) + } + } +} + /** * :: DeveloperApi :: * The task finished successfully, but the result was lost from the executor's block manager before diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 5d78a9dc8885e..42a85e42ea2b6 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -17,7 +17,7 @@ package org.apache.spark.executor -import java.io.File +import java.io.{File, NotSerializableException} import java.lang.management.ManagementFactory import java.net.URL import java.nio.ByteBuffer @@ -305,8 +305,16 @@ private[spark] class Executor( m } } - val taskEndReason = new ExceptionFailure(t, metrics) - execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(taskEndReason)) + val serializedTaskEndReason = { + try { + ser.serialize(new ExceptionFailure(t, metrics)) + } catch { + case _: NotSerializableException => + // t is not serializable so just send the stacktrace + ser.serialize(new ExceptionFailure(t, metrics, false)) + } + } + execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason) // Don't forcibly exit unless the exception was inherently fatal, to avoid // stopping other tasks unnecessarily. diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index bb489c6b6e98f..7ab5ccf50adb7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -200,8 +200,8 @@ class DAGScheduler( // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or // cancellation of the job itself. - def taskSetFailed(taskSet: TaskSet, reason: String): Unit = { - eventProcessLoop.post(TaskSetFailed(taskSet, reason)) + def taskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]): Unit = { + eventProcessLoop.post(TaskSetFailed(taskSet, reason, exception)) } private[scheduler] @@ -677,8 +677,11 @@ class DAGScheduler( submitWaitingStages() } - private[scheduler] def handleTaskSetFailed(taskSet: TaskSet, reason: String) { - stageIdToStage.get(taskSet.stageId).foreach {abortStage(_, reason) } + private[scheduler] def handleTaskSetFailed( + taskSet: TaskSet, + reason: String, + exception: Option[Throwable]): Unit = { + stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason, exception) } submitWaitingStages() } @@ -762,7 +765,7 @@ class DAGScheduler( } } } else { - abortStage(stage, "No active job for stage " + stage.id) + abortStage(stage, "No active job for stage " + stage.id, None) } } @@ -816,7 +819,7 @@ class DAGScheduler( case NonFatal(e) => stage.makeNewStageAttempt(partitionsToCompute.size) listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) - abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}") + abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e)) runningStages -= stage return } @@ -845,13 +848,13 @@ class DAGScheduler( } catch { // In the case of a failure during serialization, abort the stage. case e: NotSerializableException => - abortStage(stage, "Task not serializable: " + e.toString) + abortStage(stage, "Task not serializable: " + e.toString, Some(e)) runningStages -= stage // Abort execution return case NonFatal(e) => - abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}") + abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}", Some(e)) runningStages -= stage return } @@ -878,7 +881,7 @@ class DAGScheduler( } } catch { case NonFatal(e) => - abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}") + abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e)) runningStages -= stage return } @@ -1098,7 +1101,8 @@ class DAGScheduler( } if (disallowStageRetryForTest) { - abortStage(failedStage, "Fetch failure will not retry stage due to testing config") + abortStage(failedStage, "Fetch failure will not retry stage due to testing config", + None) } else if (failedStages.isEmpty) { // Don't schedule an event to resubmit failed stages if failed isn't empty, because // in that case the event will already have been scheduled. @@ -1126,7 +1130,7 @@ class DAGScheduler( case commitDenied: TaskCommitDenied => // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits - case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) => + case exceptionFailure: ExceptionFailure => // Do nothing here, left up to the TaskScheduler to decide how to handle user failures case TaskResultLost => @@ -1235,7 +1239,10 @@ class DAGScheduler( * Aborts all jobs depending on a particular Stage. This is called in response to a task set * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. */ - private[scheduler] def abortStage(failedStage: Stage, reason: String) { + private[scheduler] def abortStage( + failedStage: Stage, + reason: String, + exception: Option[Throwable]): Unit = { if (!stageIdToStage.contains(failedStage.id)) { // Skip all the actions if the stage has been removed. return @@ -1244,7 +1251,7 @@ class DAGScheduler( activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq failedStage.latestInfo.completionTime = Some(clock.getTimeMillis()) for (job <- dependentJobs) { - failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason") + failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason", exception) } if (dependentJobs.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") @@ -1252,8 +1259,11 @@ class DAGScheduler( } /** Fails a job and all stages that are only used by that job, and cleans up relevant state. */ - private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) { - val error = new SparkException(failureReason) + private def failJobAndIndependentStages( + job: ActiveJob, + failureReason: String, + exception: Option[Throwable] = None): Unit = { + val error = new SparkException(failureReason, exception.getOrElse(null)) var ableToCancelStages = true val shouldInterruptThread = @@ -1462,8 +1472,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) => dagScheduler.handleTaskCompletion(completion) - case TaskSetFailed(taskSet, reason) => - dagScheduler.handleTaskSetFailed(taskSet, reason) + case TaskSetFailed(taskSet, reason, exception) => + dagScheduler.handleTaskSetFailed(taskSet, reason, exception) case ResubmitFailedStages => dagScheduler.resubmitFailedStages() diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index a213d419cf033..f72a52e85dc15 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -73,6 +73,7 @@ private[scheduler] case class ExecutorAdded(execId: String, host: String) extend private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent private[scheduler] -case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent +case class TaskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]) + extends DAGSchedulerEvent private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 82455b0426a5d..818b95d67f6be 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -662,7 +662,7 @@ private[spark] class TaskSetManager( val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " + reason.asInstanceOf[TaskFailedReason].toErrorString - reason match { + val failureException: Option[Throwable] = reason match { case fetchFailed: FetchFailed => logWarning(failureReason) if (!successful(index)) { @@ -671,6 +671,7 @@ private[spark] class TaskSetManager( } // Not adding to failed executors for FetchFailed. isZombie = true + None case ef: ExceptionFailure => taskMetrics = ef.metrics.orNull @@ -706,12 +707,15 @@ private[spark] class TaskSetManager( s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid) on executor ${info.host}: " + s"${ef.className} (${ef.description}) [duplicate $dupCount]") } + ef.exception case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others logWarning(failureReason) + None case e: TaskEndReason => logError("Unknown TaskEndReason: " + e) + None } // always add to failed executors failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()). @@ -728,16 +732,16 @@ private[spark] class TaskSetManager( logError("Task %d in stage %s failed %d times; aborting job".format( index, taskSet.id, maxTaskFailures)) abort("Task %d in stage %s failed %d times, most recent failure: %s\nDriver stacktrace:" - .format(index, taskSet.id, maxTaskFailures, failureReason)) + .format(index, taskSet.id, maxTaskFailures, failureReason), failureException) return } } maybeFinishTaskSet() } - def abort(message: String): Unit = sched.synchronized { + def abort(message: String, exception: Option[Throwable] = None): Unit = sched.synchronized { // TODO: Kill running tasks if we were not terminated due to a Mesos error - sched.dagScheduler.taskSetFailed(taskSet, message) + sched.dagScheduler.taskSetFailed(taskSet, message, exception) isZombie = true maybeFinishTaskSet() } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index c600319d9ddb4..cbc94fd6d54d9 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -790,7 +790,7 @@ private[spark] object JsonProtocol { val fullStackTrace = Utils.jsonOption(json \ "Full Stack Trace"). map(_.extract[String]).orNull val metrics = Utils.jsonOption(json \ "Metrics").map(taskMetricsFromJson) - ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) + ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics, None) case `taskResultLost` => TaskResultLost case `taskKilled` => TaskKilled case `executorLostFailure` => diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index f374f97f87448..116f027a0f987 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -800,7 +800,7 @@ class ExecutorAllocationManagerSuite assert(maxNumExecutorsNeeded(manager) === 1) // If the task is failed, we expect it to be resubmitted later. - val taskEndReason = ExceptionFailure(null, null, null, null, null) + val taskEndReason = ExceptionFailure(null, null, null, null, null, None) sc.listenerBus.postToAll(SparkListenerTaskEnd(0, 0, null, taskEndReason, taskInfo, null)) assert(maxNumExecutorsNeeded(manager) === 1) } diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index 69cb4b44cf7ef..aa50a49c50232 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import org.apache.spark.util.NonSerializable -import java.io.NotSerializableException +import java.io.{IOException, NotSerializableException, ObjectInputStream} // Common state shared by FailureSuite-launched tasks. We use a global object // for this because any local variables used in the task closures will rightfully @@ -166,5 +166,69 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { assert(thrownDueToMemoryLeak.getMessage.contains("memory leak")) } + // Run a 3-task map job in which task 1 always fails with a exception message that + // depends on the failure number, and check that we get the last failure. + test("last failure cause is sent back to driver") { + sc = new SparkContext("local[1,2]", "test") + val data = sc.makeRDD(1 to 3, 3).map { x => + FailureSuiteState.synchronized { + FailureSuiteState.tasksRun += 1 + if (x == 3) { + FailureSuiteState.tasksFailed += 1 + throw new UserException("oops", + new IllegalArgumentException("failed=" + FailureSuiteState.tasksFailed)) + } + } + x * x + } + val thrown = intercept[SparkException] { + data.collect() + } + FailureSuiteState.synchronized { + assert(FailureSuiteState.tasksRun === 4) + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getCause.getClass === classOf[UserException]) + assert(thrown.getCause.getMessage === "oops") + assert(thrown.getCause.getCause.getClass === classOf[IllegalArgumentException]) + assert(thrown.getCause.getCause.getMessage === "failed=2") + FailureSuiteState.clear() + } + + test("failure cause stacktrace is sent back to driver if exception is not serializable") { + sc = new SparkContext("local", "test") + val thrown = intercept[SparkException] { + sc.makeRDD(1 to 3).foreach { _ => throw new NonSerializableUserException } + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getCause === null) + assert(thrown.getMessage.contains("NonSerializableUserException")) + FailureSuiteState.clear() + } + + test("failure cause stacktrace is sent back to driver if exception is not deserializable") { + sc = new SparkContext("local", "test") + val thrown = intercept[SparkException] { + sc.makeRDD(1 to 3).foreach { _ => throw new NonDeserializableUserException } + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getCause === null) + assert(thrown.getMessage.contains("NonDeserializableUserException")) + FailureSuiteState.clear() + } + // TODO: Need to add tests with shuffle fetch failures. } + +class UserException(message: String, cause: Throwable) + extends RuntimeException(message, cause) + +class NonSerializableUserException extends RuntimeException { + val nonSerializableInstanceVariable = new NonSerializable +} + +class NonDeserializableUserException extends RuntimeException { + private def readObject(in: ObjectInputStream): Unit = { + throw new IOException("Intentional exception during deserialization.") + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 86dff8fb577d5..b0ca49cbea4f7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -242,7 +242,7 @@ class DAGSchedulerSuite /** Sends TaskSetFailed to the scheduler. */ private def failed(taskSet: TaskSet, message: String) { - runEvent(TaskSetFailed(taskSet, message)) + runEvent(TaskSetFailed(taskSet, message, None)) } /** Sends JobCancelled to the DAG scheduler. */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index f7cc4bb61d574..edbdb485c5ea4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -48,7 +48,10 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) override def executorLost(execId: String) {} - override def taskSetFailed(taskSet: TaskSet, reason: String) { + override def taskSetFailed( + taskSet: TaskSet, + reason: String, + exception: Option[Throwable]): Unit = { taskScheduler.taskSetsFailed += taskSet.id } } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 56f7b9cf1f358..b140387d309f3 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -240,7 +240,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with val taskFailedReasons = Seq( Resubmitted, new FetchFailed(null, 0, 0, 0, "ignored"), - ExceptionFailure("Exception", "description", null, null, None), + ExceptionFailure("Exception", "description", null, null, None, None), TaskResultLost, TaskKilled, ExecutorLostFailure("0"), diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index dde95f3778434..343a4139b0ca8 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -163,7 +163,8 @@ class JsonProtocolSuite extends SparkFunSuite { } test("ExceptionFailure backward compatibility") { - val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, None) + val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, + None, None) val oldEvent = JsonProtocol.taskEndReasonToJson(exceptionFailure) .removeField({ _._1 == "Full Stack Trace" }) assertEquals(exceptionFailure, JsonProtocol.taskEndReasonFromJson(oldEvent)) From be5d1912076c2ffd21ec88611e53d3b3c59b7ecc Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 12 Aug 2015 09:24:50 -0700 Subject: [PATCH 292/340] [SPARK-9795] Dynamic allocation: avoid double counting when killing same executor twice This is based on KaiXinXiaoLei's changes in #7716. The issue is that when someone calls `sc.killExecutor("1")` on the same executor twice quickly, then the executor target will be adjusted downwards by 2 instead of 1 even though we're only actually killing one executor. In certain cases where we don't adjust the target back upwards quickly, we'll end up with jobs hanging. This is a common danger because there are many places where this is called: - `HeartbeatReceiver` kills an executor that has not been sending heartbeats - `ExecutorAllocationManager` kills an executor that has been idle - The user code might call this, which may interfere with the previous callers While it's not clear whether this fixes SPARK-9745, fixing this potential race condition seems like a strict improvement. I've added a regression test to illustrate the issue. Author: Andrew Or Closes #8078 from andrewor14/da-double-kill. --- .../CoarseGrainedSchedulerBackend.scala | 11 ++++++---- .../StandaloneDynamicAllocationSuite.scala | 20 +++++++++++++++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 6acf8a9a5e9b4..5730a87f960a0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -422,16 +422,19 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp logWarning(s"Executor to kill $id does not exist!") } + // If an executor is already pending to be removed, do not kill it again (SPARK-9795) + val executorsToKill = knownExecutors.filter { id => !executorsPendingToRemove.contains(id) } + executorsPendingToRemove ++= executorsToKill + // If we do not wish to replace the executors we kill, sync the target number of executors // with the cluster manager to avoid allocating new ones. When computing the new target, // take into account executors that are pending to be added or removed. if (!replace) { - doRequestTotalExecutors(numExistingExecutors + numPendingExecutors - - executorsPendingToRemove.size - knownExecutors.size) + doRequestTotalExecutors( + numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) } - executorsPendingToRemove ++= knownExecutors - doKillExecutors(knownExecutors) + doKillExecutors(executorsToKill) } /** diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index 08c41a897a861..1f2a0f0d309ce 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -283,6 +283,26 @@ class StandaloneDynamicAllocationSuite assert(master.apps.head.getExecutorLimit === 1000) } + test("kill the same executor twice (SPARK-9795)") { + sc = new SparkContext(appConf) + val appId = sc.applicationId + assert(master.apps.size === 1) + assert(master.apps.head.id === appId) + assert(master.apps.head.executors.size === 2) + assert(master.apps.head.getExecutorLimit === Int.MaxValue) + // sync executors between the Master and the driver, needed because + // the driver refuses to kill executors it does not know about + syncExecutors(sc) + // kill the same executor twice + val executors = getExecutorIds(sc) + assert(executors.size === 2) + assert(sc.killExecutor(executors.head)) + assert(sc.killExecutor(executors.head)) + assert(master.apps.head.executors.size === 1) + // The limit should not be lowered twice + assert(master.apps.head.getExecutorLimit === 1) + } + // =============================== // | Utility methods for testing | // =============================== From 66d87c1d76bea2b81993156ac1fa7dad6c312ebf Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 12 Aug 2015 09:35:32 -0700 Subject: [PATCH 293/340] [SPARK-7583] [MLLIB] User guide update for RegexTokenizer jira: https://issues.apache.org/jira/browse/SPARK-7583 User guide update for RegexTokenizer Author: Yuhao Yang Closes #7828 from hhbyyh/regexTokenizerDoc. --- docs/ml-features.md | 41 ++++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index fa0ad1f00ab12..cec2cbe673407 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -217,21 +217,32 @@ for feature in result.select("result").take(3): [Tokenization](http://en.wikipedia.org/wiki/Lexical_analysis#Tokenization) is the process of taking text (such as a sentence) and breaking it into individual terms (usually words). A simple [Tokenizer](api/scala/index.html#org.apache.spark.ml.feature.Tokenizer) class provides this functionality. The example below shows how to split sentences into sequences of words. -Note: A more advanced tokenizer is provided via [RegexTokenizer](api/scala/index.html#org.apache.spark.ml.feature.RegexTokenizer). +[RegexTokenizer](api/scala/index.html#org.apache.spark.ml.feature.RegexTokenizer) allows more + advanced tokenization based on regular expression (regex) matching. + By default, the parameter "pattern" (regex, default: \\s+) is used as delimiters to split the input text. + Alternatively, users can set parameter "gaps" to false indicating the regex "pattern" denotes + "tokens" rather than splitting gaps, and find all matching occurrences as the tokenization result.
    {% highlight scala %} -import org.apache.spark.ml.feature.Tokenizer +import org.apache.spark.ml.feature.{Tokenizer, RegexTokenizer} val sentenceDataFrame = sqlContext.createDataFrame(Seq( (0, "Hi I heard about Spark"), - (0, "I wish Java could use case classes"), - (1, "Logistic regression models are neat") + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") )).toDF("label", "sentence") val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") -val wordsDataFrame = tokenizer.transform(sentenceDataFrame) -wordsDataFrame.select("words", "label").take(3).foreach(println) +val regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) + +val tokenized = tokenizer.transform(sentenceDataFrame) +tokenized.select("words", "label").take(3).foreach(println) +val regexTokenized = regexTokenizer.transform(sentenceDataFrame) +regexTokenized.select("words", "label").take(3).foreach(println) {% endhighlight %}
    @@ -240,6 +251,7 @@ wordsDataFrame.select("words", "label").take(3).foreach(println) import com.google.common.collect.Lists; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RegexTokenizer; import org.apache.spark.ml.feature.Tokenizer; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.sql.DataFrame; @@ -252,8 +264,8 @@ import org.apache.spark.sql.types.StructType; JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(0, "I wish Java could use case classes"), - RowFactory.create(1, "Logistic regression models are neat") + RowFactory.create(1, "I wish Java could use case classes"), + RowFactory.create(2, "Logistic,regression,models,are,neat") )); StructType schema = new StructType(new StructField[]{ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), @@ -267,22 +279,29 @@ for (Row r : wordsDataFrame.select("words", "label").take(3)) { for (String word : words) System.out.print(word + " "); System.out.println(); } + +RegexTokenizer regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); {% endhighlight %}
    {% highlight python %} -from pyspark.ml.feature import Tokenizer +from pyspark.ml.feature import Tokenizer, RegexTokenizer sentenceDataFrame = sqlContext.createDataFrame([ (0, "Hi I heard about Spark"), - (0, "I wish Java could use case classes"), - (1, "Logistic regression models are neat") + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") ], ["label", "sentence"]) tokenizer = Tokenizer(inputCol="sentence", outputCol="words") wordsDataFrame = tokenizer.transform(sentenceDataFrame) for words_label in wordsDataFrame.select("words", "label").take(3): print(words_label) +regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") +# alternatively, pattern="\\w+", gaps(False) {% endhighlight %}
    From e0110792ef71ebfd3727b970346a2e13695990a4 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 12 Aug 2015 10:08:35 -0700 Subject: [PATCH 294/340] [SPARK-9747] [SQL] Avoid starving an unsafe operator in aggregation This is the sister patch to #8011, but for aggregation. In a nutshell: create the `TungstenAggregationIterator` before computing the parent partition. Internally this creates a `BytesToBytesMap` which acquires a page in the constructor as of this patch. This ensures that the aggregation operator is not starved since we reserve at least 1 page in advance. rxin yhuai Author: Andrew Or Closes #8038 from andrewor14/unsafe-starve-memory-agg. --- .../spark/unsafe/map/BytesToBytesMap.java | 34 +++++-- .../unsafe/sort/UnsafeExternalSorter.java | 9 +- .../map/AbstractBytesToBytesMapSuite.java | 11 ++- .../UnsafeFixedWidthAggregationMap.java | 7 ++ .../aggregate/TungstenAggregate.scala | 72 +++++++++------ .../TungstenAggregationIterator.scala | 88 +++++++++++-------- .../TungstenAggregationIteratorSuite.scala | 56 ++++++++++++ 7 files changed, 201 insertions(+), 76 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 85b46ec8bfae3..87ed47e88c4ef 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -193,6 +193,11 @@ public BytesToBytesMap( TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES); } allocate(initialCapacity); + + // Acquire a new page as soon as we construct the map to ensure that we have at least + // one page to work with. Otherwise, other operators in the same task may starve this + // map (SPARK-9747). + acquireNewPage(); } public BytesToBytesMap( @@ -574,16 +579,9 @@ public boolean putNewKey( final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor; Platform.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); } - final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes); - if (memoryGranted != pageSizeBytes) { - shuffleMemoryManager.release(memoryGranted); - logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes); + if (!acquireNewPage()) { return false; } - MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes); - dataPages.add(newPage); - pageCursor = 0; - currentDataPage = newPage; dataPage = currentDataPage; dataPageBaseObject = currentDataPage.getBaseObject(); dataPageInsertOffset = currentDataPage.getBaseOffset(); @@ -642,6 +640,24 @@ public boolean putNewKey( } } + /** + * Acquire a new page from the {@link ShuffleMemoryManager}. + * @return whether there is enough space to allocate the new page. + */ + private boolean acquireNewPage() { + final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryGranted != pageSizeBytes) { + shuffleMemoryManager.release(memoryGranted); + logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes); + return false; + } + MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes); + dataPages.add(newPage); + pageCursor = 0; + currentDataPage = newPage; + return true; + } + /** * Allocate new data structures for this map. When calling this outside of the constructor, * make sure to keep references to the old data structures so that you can free them. @@ -748,7 +764,7 @@ public long getNumHashCollisions() { } @VisibleForTesting - int getNumDataPages() { + public int getNumDataPages() { return dataPages.size(); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 9601aafe55464..fc364e0a895b1 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -132,16 +132,15 @@ private UnsafeExternalSorter( if (existingInMemorySorter == null) { initializeForWriting(); + // Acquire a new page as soon as we construct the sorter to ensure that we have at + // least one page to work with. Otherwise, other operators in the same task may starve + // this sorter (SPARK-9709). We don't need to do this if we already have an existing sorter. + acquireNewPage(); } else { this.isInMemSorterExternal = true; this.inMemSorter = existingInMemorySorter; } - // Acquire a new page as soon as we construct the sorter to ensure that we have at - // least one page to work with. Otherwise, other operators in the same task may starve - // this sorter (SPARK-9709). - acquireNewPage(); - // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at // the end of the task. This is necessary to avoid memory leaks in when the downstream operator // does not fully consume the sorter's output (e.g. sort followed by limit). diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 1a79c20c35246..ab480b60adaed 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -543,7 +543,7 @@ public void testPeakMemoryUsed() { Platform.LONG_ARRAY_OFFSET, 8); newPeakMemory = map.getPeakMemoryUsedBytes(); - if (i % numRecordsPerPage == 0) { + if (i % numRecordsPerPage == 0 && i > 0) { // We allocated a new page for this record, so peak memory should change assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); } else { @@ -561,4 +561,13 @@ public void testPeakMemoryUsed() { map.free(); } } + + @Test + public void testAcquirePageInConstructor() { + final BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES); + assertEquals(1, map.getNumDataPages()); + map.free(); + } + } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 5cce41d5a7569..09511ff35f785 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -19,6 +19,8 @@ import java.io.IOException; +import com.google.common.annotations.VisibleForTesting; + import org.apache.spark.SparkEnv; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; @@ -220,6 +222,11 @@ public long getPeakMemoryUsedBytes() { return map.getPeakMemoryUsedBytes(); } + @VisibleForTesting + public int getNumDataPages() { + return map.getNumDataPages(); + } + /** * Free the memory associated with this map. This is idempotent and can be called multiple times. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 6b5935a7ce296..c40ca973796a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.execution.aggregate -import org.apache.spark.rdd.RDD +import org.apache.spark.TaskContext +import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -68,35 +69,56 @@ case class TungstenAggregate( protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { val numInputRows = longMetric("numInputRows") val numOutputRows = longMetric("numOutputRows") - child.execute().mapPartitions { iter => - val hasInput = iter.hasNext - if (!hasInput && groupingExpressions.nonEmpty) { - // This is a grouped aggregate and the input iterator is empty, - // so return an empty iterator. - Iterator.empty.asInstanceOf[Iterator[UnsafeRow]] - } else { - val aggregationIterator = - new TungstenAggregationIterator( - groupingExpressions, - nonCompleteAggregateExpressions, - completeAggregateExpressions, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - child.output, - iter, - testFallbackStartsAt, - numInputRows, - numOutputRows) - - if (!hasInput && groupingExpressions.isEmpty) { + + /** + * Set up the underlying unsafe data structures used before computing the parent partition. + * This makes sure our iterator is not starved by other operators in the same task. + */ + def preparePartition(): TungstenAggregationIterator = { + new TungstenAggregationIterator( + groupingExpressions, + nonCompleteAggregateExpressions, + completeAggregateExpressions, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + child.output, + testFallbackStartsAt, + numInputRows, + numOutputRows) + } + + /** Compute a partition using the iterator already set up previously. */ + def executePartition( + context: TaskContext, + partitionIndex: Int, + aggregationIterator: TungstenAggregationIterator, + parentIterator: Iterator[InternalRow]): Iterator[UnsafeRow] = { + val hasInput = parentIterator.hasNext + if (!hasInput) { + // We're not using the underlying map, so we just can free it here + aggregationIterator.free() + if (groupingExpressions.isEmpty) { numOutputRows += 1 Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) } else { - aggregationIterator + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. + Iterator[UnsafeRow]() } + } else { + aggregationIterator.start(parentIterator) + aggregationIterator } } + + // Note: we need to set up the iterator in each partition before computing the + // parent partition, so we cannot simply use `mapPartitions` here (SPARK-9747). + val resultRdd = { + new MapPartitionsWithPreparationRDD[UnsafeRow, InternalRow, TungstenAggregationIterator]( + child.execute(), preparePartition, executePartition, preservesPartitioning = true) + } + resultRdd.asInstanceOf[RDD[InternalRow]] } override def simpleString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 1f383dd04482f..af7e0fcedbe4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -72,8 +72,6 @@ import org.apache.spark.sql.types.StructType * the function used to create mutable projections. * @param originalInputAttributes * attributes of representing input rows from `inputIter`. - * @param inputIter - * the iterator containing input [[UnsafeRow]]s. */ class TungstenAggregationIterator( groupingExpressions: Seq[NamedExpression], @@ -83,12 +81,14 @@ class TungstenAggregationIterator( resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), originalInputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow], testFallbackStartsAt: Option[Int], numInputRows: LongSQLMetric, numOutputRows: LongSQLMetric) extends Iterator[UnsafeRow] with Logging { + // The parent partition iterator, to be initialized later in `start` + private[this] var inputIter: Iterator[InternalRow] = null + /////////////////////////////////////////////////////////////////////////// // Part 1: Initializing aggregate functions. /////////////////////////////////////////////////////////////////////////// @@ -348,11 +348,15 @@ class TungstenAggregationIterator( false // disable tracking of performance metrics ) + // Exposed for testing + private[aggregate] def getHashMap: UnsafeFixedWidthAggregationMap = hashMap + // The function used to read and process input rows. When processing input rows, // it first uses hash-based aggregation by putting groups and their buffers in // hashMap. If we could not allocate more memory for the map, we switch to // sort-based aggregation (by calling switchToSortBasedAggregation). private def processInputs(): Unit = { + assert(inputIter != null, "attempted to process input when iterator was null") while (!sortBased && inputIter.hasNext) { val newInput = inputIter.next() numInputRows += 1 @@ -372,6 +376,7 @@ class TungstenAggregationIterator( // that it switch to sort-based aggregation after `fallbackStartsAt` input rows have // been processed. private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = { + assert(inputIter != null, "attempted to process input when iterator was null") var i = 0 while (!sortBased && inputIter.hasNext) { val newInput = inputIter.next() @@ -412,6 +417,7 @@ class TungstenAggregationIterator( * Switch to sort-based aggregation when the hash-based approach is unable to acquire memory. */ private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = { + assert(inputIter != null, "attempted to process input when iterator was null") logInfo("falling back to sort based aggregation.") // Step 1: Get the ExternalSorter containing sorted entries of the map. externalSorter = hashMap.destructAndCreateExternalSorter() @@ -431,6 +437,11 @@ class TungstenAggregationIterator( case _ => false } + // Note: Since we spill the sorter's contents immediately after creating it, we must insert + // something into the sorter here to ensure that we acquire at least a page of memory. + // This is done through `externalSorter.insertKV`, which will trigger the page allocation. + // Otherwise, children operators may steal the window of opportunity and starve our sorter. + if (needsProcess) { // First, we create a buffer. val buffer = createNewAggregationBuffer() @@ -588,27 +599,33 @@ class TungstenAggregationIterator( // have not switched to sort-based aggregation. /////////////////////////////////////////////////////////////////////////// - // Starts to process input rows. - testFallbackStartsAt match { - case None => - processInputs() - case Some(fallbackStartsAt) => - // This is the testing path. processInputsWithControlledFallback is same as processInputs - // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows - // have been processed. - processInputsWithControlledFallback(fallbackStartsAt) - } + /** + * Start processing input rows. + * Only after this method is called will this iterator be non-empty. + */ + def start(parentIter: Iterator[InternalRow]): Unit = { + inputIter = parentIter + testFallbackStartsAt match { + case None => + processInputs() + case Some(fallbackStartsAt) => + // This is the testing path. processInputsWithControlledFallback is same as processInputs + // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows + // have been processed. + processInputsWithControlledFallback(fallbackStartsAt) + } - // If we did not switch to sort-based aggregation in processInputs, - // we pre-load the first key-value pair from the map (to make hasNext idempotent). - if (!sortBased) { - // First, set aggregationBufferMapIterator. - aggregationBufferMapIterator = hashMap.iterator() - // Pre-load the first key-value pair from the aggregationBufferMapIterator. - mapIteratorHasNext = aggregationBufferMapIterator.next() - // If the map is empty, we just free it. - if (!mapIteratorHasNext) { - hashMap.free() + // If we did not switch to sort-based aggregation in processInputs, + // we pre-load the first key-value pair from the map (to make hasNext idempotent). + if (!sortBased) { + // First, set aggregationBufferMapIterator. + aggregationBufferMapIterator = hashMap.iterator() + // Pre-load the first key-value pair from the aggregationBufferMapIterator. + mapIteratorHasNext = aggregationBufferMapIterator.next() + // If the map is empty, we just free it. + if (!mapIteratorHasNext) { + hashMap.free() + } } } @@ -673,21 +690,20 @@ class TungstenAggregationIterator( } /////////////////////////////////////////////////////////////////////////// - // Part 8: A utility function used to generate a output row when there is no - // input and there is no grouping expression. + // Part 8: Utility functions /////////////////////////////////////////////////////////////////////////// + /** + * Generate a output row when there is no input and there is no grouping expression. + */ def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { - if (groupingExpressions.isEmpty) { - sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer) - // We create a output row and copy it. So, we can free the map. - val resultCopy = - generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy() - hashMap.free() - resultCopy - } else { - throw new IllegalStateException( - "This method should not be called when groupingExpressions is not empty.") - } + assert(groupingExpressions.isEmpty) + assert(inputIter == null) + generateOutput(UnsafeRow.createFromByteArray(0, 0), initialAggregationBuffer) + } + + /** Free memory used in the underlying map. */ + def free(): Unit = { + hashMap.free() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala new file mode 100644 index 0000000000000..ac22c2f3c0a58 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.unsafe.memory.TaskMemoryManager + +class TungstenAggregationIteratorSuite extends SparkFunSuite { + + test("memory acquired on construction") { + // set up environment + val ctx = TestSQLContext + + val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager) + val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty) + TaskContext.setTaskContext(taskContext) + + // Assert that a page is allocated before processing starts + var iter: TungstenAggregationIterator = null + try { + val newMutableProjection = (expr: Seq[Expression], schema: Seq[Attribute]) => { + () => new InterpretedMutableProjection(expr, schema) + } + val dummyAccum = SQLMetrics.createLongMetric(ctx.sparkContext, "dummy") + iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, 0, + Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum) + val numPages = iter.getHashMap.getNumDataPages + assert(numPages === 1) + } finally { + // Clean up + if (iter != null) { + iter.free() + } + TaskContext.unset() + } + } +} From 57ec27dd7784ce15a2ece8a6c8ac7bd5fd25aea2 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 12 Aug 2015 10:38:30 -0700 Subject: [PATCH 295/340] [SPARK-9804] [HIVE] Use correct value for isSrcLocal parameter. If the correct parameter is not provided, Hive will run into an error because it calls methods that are specific to the local filesystem to copy the data. Author: Marcelo Vanzin Closes #8086 from vanzin/SPARK-9804. --- .../org/apache/spark/sql/hive/client/HiveShim.scala | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 6e826ce552204..8fc8935b1dc3c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -25,7 +25,7 @@ import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} @@ -429,7 +429,7 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { isSkewedStoreAsSubdir: Boolean): Unit = { loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, holdDDLTime: JBoolean, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, - JBoolean.TRUE, JBoolean.FALSE) + isSrcLocal(loadPath, hive.getConf()): JBoolean, JBoolean.FALSE) } override def loadTable( @@ -439,7 +439,7 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { replace: Boolean, holdDDLTime: Boolean): Unit = { loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, holdDDLTime: JBoolean, - JBoolean.TRUE, JBoolean.FALSE, JBoolean.FALSE) + isSrcLocal(loadPath, hive.getConf()): JBoolean, JBoolean.FALSE, JBoolean.FALSE) } override def loadDynamicPartitions( @@ -461,6 +461,13 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY, TimeUnit.MILLISECONDS).asInstanceOf[Long] } + + protected def isSrcLocal(path: Path, conf: HiveConf): Boolean = { + val localFs = FileSystem.getLocal(conf) + val pathFs = FileSystem.get(path.toUri(), conf) + localFs.getUri() == pathFs.getUri() + } + } private[client] class Shim_v1_0 extends Shim_v0_14 { From 70fe558867ccb4bcff6ec673438b03608bb02252 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 12 Aug 2015 10:48:52 -0700 Subject: [PATCH 296/340] [SPARK-9847] [ML] Modified copyValues to distinguish between default, explicit param values From JIRA: Currently, Params.copyValues copies default parameter values to the paramMap of the target instance, rather than the defaultParamMap. It should copy to the defaultParamMap because explicitly setting a parameter can change the semantics. This issue arose in SPARK-9789, where 2 params "threshold" and "thresholds" for LogisticRegression can have mutually exclusive values. If thresholds is set, then fit() will copy the default value of threshold as well, easily resulting in inconsistent settings for the 2 params. CC: mengxr Author: Joseph K. Bradley Closes #8115 from jkbradley/copyvalues-fix. --- .../org/apache/spark/ml/param/params.scala | 19 ++++++++++++++++--- .../apache/spark/ml/param/ParamsSuite.scala | 8 ++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index d68f5ff0053c9..91c0a5631319d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -559,13 +559,26 @@ trait Params extends Identifiable with Serializable { /** * Copies param values from this instance to another instance for params shared by them. - * @param to the target instance - * @param extra extra params to be copied + * + * This handles default Params and explicitly set Params separately. + * Default Params are copied from and to [[defaultParamMap]], and explicitly set Params are + * copied from and to [[paramMap]]. + * Warning: This implicitly assumes that this [[Params]] instance and the target instance + * share the same set of default Params. + * + * @param to the target instance, which should work with the same set of default Params as this + * source instance + * @param extra extra params to be copied to the target's [[paramMap]] * @return the target instance with param values copied */ protected def copyValues[T <: Params](to: T, extra: ParamMap = ParamMap.empty): T = { - val map = extractParamMap(extra) + val map = paramMap ++ extra params.foreach { param => + // copy default Params + if (defaultParamMap.contains(param) && to.hasParam(param.name)) { + to.defaultParamMap.put(to.getParam(param.name), defaultParamMap(param)) + } + // copy explicitly set Params if (map.contains(param) && to.hasParam(param.name)) { to.set(param.name, map(param)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 050d4170ea017..be95638d81686 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -200,6 +200,14 @@ class ParamsSuite extends SparkFunSuite { val inArray = ParamValidators.inArray[Int](Array(1, 2)) assert(inArray(1) && inArray(2) && !inArray(0)) } + + test("Params.copyValues") { + val t = new TestParams() + val t2 = t.copy(ParamMap.empty) + assert(!t2.isSet(t2.maxIter)) + val t3 = t.copy(ParamMap(t.maxIter -> 20)) + assert(t3.isSet(t3.maxIter)) + } } object ParamsSuite extends SparkFunSuite { From 60103ecd3d9c92709a5878be7ebd57012813ab48 Mon Sep 17 00:00:00 2001 From: Brennan Ashton Date: Wed, 12 Aug 2015 11:57:30 -0700 Subject: [PATCH 297/340] [SPARK-9726] [PYTHON] PySpark DF join no longer accepts on=None rxin First pull request for Spark so let me know if I am missing anything The contribution is my original work and I license the work to the project under the project's open source license. Author: Brennan Ashton Closes #8016 from btashton/patch-1. --- python/pyspark/sql/dataframe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 47d5a6a43a84d..09647ff6d0749 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -566,8 +566,7 @@ def join(self, other, on=None, how=None): if on is None or len(on) == 0: jdf = self._jdf.join(other._jdf) - - if isinstance(on[0], basestring): + elif isinstance(on[0], basestring): jdf = self._jdf.join(other._jdf, self._jseq(on)) else: assert isinstance(on[0], Column), "on should be Column or list of Column" From 762bacc16ac5e74c8b05a7c1e3e367d1d1633cef Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 12 Aug 2015 13:24:18 -0700 Subject: [PATCH 298/340] [SPARK-9766] [ML] [PySpark] check and add miss docs for PySpark ML Check and add miss docs for PySpark ML (this issue only check miss docs for o.a.s.ml not o.a.s.mllib). Author: Yanbo Liang Closes #8059 from yanboliang/SPARK-9766. --- python/pyspark/ml/classification.py | 12 ++++++++++-- python/pyspark/ml/clustering.py | 4 +++- python/pyspark/ml/evaluation.py | 3 ++- python/pyspark/ml/feature.py | 9 +++++---- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 5978d8f4d3a01..6702dce5545a9 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -34,6 +34,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol): """ Logistic regression. + Currently, this class only supports binary classification. >>> from pyspark.sql import Row >>> from pyspark.mllib.linalg import Vectors @@ -96,8 +97,8 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred # is an L2 penalty. For alpha = 1, it is an L1 penalty. self.elasticNetParam = \ Param(self, "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty " + - "is an L2 penalty. For alpha = 1, it is an L1 penalty.") + "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") #: param for whether to fit an intercept term. self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") #: param for threshold in binary classification prediction, in range [0, 1]. @@ -656,6 +657,13 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H HasRawPredictionCol): """ Naive Bayes Classifiers. + It supports both Multinomial and Bernoulli NB. Multinomial NB + (`http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html`) + can handle finitely supported discrete data. For example, by converting documents into + TF-IDF vectors, it can be used for document classification. By making every vector a + binary (0/1) data, it can also be used as Bernoulli NB + (`http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html`). + The input feature values must be nonnegative. >>> from pyspark.sql import Row >>> from pyspark.mllib.linalg import Vectors diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index b5e9b6549d9f1..48338713a29ea 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -37,7 +37,9 @@ def clusterCenters(self): @inherit_doc class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed): """ - K-means Clustering + K-means clustering with support for multiple parallel runs and a k-means++ like initialization + mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, + they are executed together with joint passes over the data for efficiency. >>> from pyspark.mllib.linalg import Vectors >>> data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 06e809352225b..2734092575ad9 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -23,7 +23,8 @@ from pyspark.ml.util import keyword_only from pyspark.mllib.common import inherit_doc -__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator'] +__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator', + 'MulticlassClassificationEvaluator'] @inherit_doc diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index cb4dfa21298ce..535d55326646c 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -26,10 +26,11 @@ from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector -__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder', - 'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel', - 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', - 'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel'] +__all__ = ['Binarizer', 'Bucketizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', + 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', + 'StandardScalerModel', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', + 'VectorAssembler', 'VectorIndexer', 'Word2Vec', 'Word2VecModel', 'PCA', + 'PCAModel', 'RFormula', 'RFormulaModel'] @inherit_doc From 551def5d6972440365bd7436d484a67138d9a8f3 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 12 Aug 2015 14:27:13 -0700 Subject: [PATCH 299/340] [SPARK-9789] [ML] Added logreg threshold param back Reinstated LogisticRegression.threshold Param for binary compatibility. Param thresholds overrides threshold, if set. CC: mengxr dbtsai feynmanliang Author: Joseph K. Bradley Closes #8079 from jkbradley/logreg-reinstate-threshold. --- .../classification/LogisticRegression.scala | 127 ++++++++++++++---- .../ml/param/shared/SharedParamsCodeGen.scala | 4 +- .../spark/ml/param/shared/sharedParams.scala | 6 +- .../JavaLogisticRegressionSuite.java | 7 +- .../LogisticRegressionSuite.scala | 33 +++-- python/pyspark/ml/classification.py | 98 +++++++++----- 6 files changed, 199 insertions(+), 76 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index f55134d258857..5bcd7117b668c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -34,8 +34,7 @@ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SQLContext} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.storage.StorageLevel /** @@ -43,44 +42,115 @@ import org.apache.spark.storage.StorageLevel */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol - with HasStandardization { + with HasStandardization with HasThreshold { /** - * Version of setThresholds() for binary classification, available for backwards - * compatibility. + * Set threshold in binary classification, in range [0, 1]. * - * Calling this with threshold p will effectively call `setThresholds(Array(1-p, p))`. + * If the estimated probability of class label 1 is > threshold, then predict 1, else 0. + * A high threshold encourages the model to predict 0 more often; + * a low threshold encourages the model to predict 1 more often. + * + * Note: Calling this with threshold p is equivalent to calling `setThresholds(Array(1-p, p))`. + * When [[setThreshold()]] is called, any user-set value for [[thresholds]] will be cleared. + * If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be + * equivalent. + * + * Default is 0.5. + * @group setParam + */ + def setThreshold(value: Double): this.type = { + if (isSet(thresholds)) clear(thresholds) + set(threshold, value) + } + + /** + * Get threshold for binary classification. + * + * If [[threshold]] is set, returns that value. + * Otherwise, if [[thresholds]] is set with length 2 (i.e., binary classification), + * this returns the equivalent threshold: {{{1 / (1 + thresholds(0) / thresholds(1))}}}. + * Otherwise, returns [[threshold]] default value. + * + * @group getParam + * @throws IllegalArgumentException if [[thresholds]] is set to an array of length other than 2. + */ + override def getThreshold: Double = { + checkThresholdConsistency() + if (isSet(thresholds)) { + val ts = $(thresholds) + require(ts.length == 2, "Logistic Regression getThreshold only applies to" + + " binary classification, but thresholds has length != 2. thresholds: " + ts.mkString(",")) + 1.0 / (1.0 + ts(0) / ts(1)) + } else { + $(threshold) + } + } + + /** + * Set thresholds in multiclass (or binary) classification to adjust the probability of + * predicting each class. Array must have length equal to the number of classes, with values >= 0. + * The class with largest value p/t is predicted, where p is the original probability of that + * class and t is the class' threshold. + * + * Note: When [[setThresholds()]] is called, any user-set value for [[threshold]] will be cleared. + * If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be + * equivalent. * - * Default is effectively 0.5. * @group setParam */ - def setThreshold(value: Double): this.type = set(thresholds, Array(1.0 - value, value)) + def setThresholds(value: Array[Double]): this.type = { + if (isSet(threshold)) clear(threshold) + set(thresholds, value) + } /** - * Version of [[getThresholds()]] for binary classification, available for backwards - * compatibility. + * Get thresholds for binary or multiclass classification. + * + * If [[thresholds]] is set, return its value. + * Otherwise, if [[threshold]] is set, return the equivalent thresholds for binary + * classification: (1-threshold, threshold). + * If neither are set, throw an exception. * - * Param thresholds must have length 2 (or not be specified). - * This returns {{{1 / (1 + thresholds(0) / thresholds(1))}}}. * @group getParam */ - def getThreshold: Double = { - if (isDefined(thresholds)) { - val thresholdValues = $(thresholds) - assert(thresholdValues.length == 2, "Logistic Regression getThreshold only applies to" + - " binary classification, but thresholds has length != 2." + - s" thresholds: ${thresholdValues.mkString(",")}") - 1.0 / (1.0 + thresholdValues(0) / thresholdValues(1)) + override def getThresholds: Array[Double] = { + checkThresholdConsistency() + if (!isSet(thresholds) && isSet(threshold)) { + val t = $(threshold) + Array(1-t, t) } else { - 0.5 + $(thresholds) + } + } + + /** + * If [[threshold]] and [[thresholds]] are both set, ensures they are consistent. + * @throws IllegalArgumentException if [[threshold]] and [[thresholds]] are not equivalent + */ + protected def checkThresholdConsistency(): Unit = { + if (isSet(threshold) && isSet(thresholds)) { + val ts = $(thresholds) + require(ts.length == 2, "Logistic Regression found inconsistent values for threshold and" + + s" thresholds. Param threshold is set (${$(threshold)}), indicating binary" + + s" classification, but Param thresholds is set with length ${ts.length}." + + " Clear one Param value to fix this problem.") + val t = 1.0 / (1.0 + ts(0) / ts(1)) + require(math.abs($(threshold) - t) < 1E-5, "Logistic Regression getThreshold found" + + s" inconsistent values for threshold (${$(threshold)}) and thresholds (equivalent to $t)") } } + + override def validateParams(): Unit = { + checkThresholdConsistency() + } } /** * :: Experimental :: * Logistic regression. - * Currently, this class only supports binary classification. + * Currently, this class only supports binary classification. It will support multiclass + * in the future. */ @Experimental class LogisticRegression(override val uid: String) @@ -128,7 +198,7 @@ class LogisticRegression(override val uid: String) * Whether to fit an intercept term. * Default is true. * @group setParam - * */ + */ def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) setDefault(fitIntercept -> true) @@ -140,7 +210,7 @@ class LogisticRegression(override val uid: String) * is applied. In R's GLMNET package, the default behavior is true as well. * Default is true. * @group setParam - * */ + */ def setStandardization(value: Boolean): this.type = set(standardization, value) setDefault(standardization -> true) @@ -148,6 +218,10 @@ class LogisticRegression(override val uid: String) override def getThreshold: Double = super.getThreshold + override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) + + override def getThresholds: Array[Double] = super.getThresholds + override protected def train(dataset: DataFrame): LogisticRegressionModel = { // Extract columns from data. If dataset is persisted, do not persist oldDataset. val instances = extractLabeledPoints(dataset).map { @@ -314,6 +388,10 @@ class LogisticRegressionModel private[ml] ( override def getThreshold: Double = super.getThreshold + override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) + + override def getThresholds: Array[Double] = super.getThresholds + /** Margin (rawPrediction) for class label 1. For binary classification only. */ private val margin: Vector => Double = (features) => { BLAS.dot(features, weights) + intercept @@ -364,6 +442,7 @@ class LogisticRegressionModel private[ml] ( * The behavior of this can be adjusted using [[thresholds]]. */ override protected def predict(features: Vector): Double = { + // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. if (score(features) > getThreshold) 1 else 0 } @@ -393,6 +472,7 @@ class LogisticRegressionModel private[ml] ( } override protected def raw2prediction(rawPrediction: Vector): Double = { + // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. val t = getThreshold val rawThreshold = if (t == 0.0) { Double.NegativeInfinity @@ -405,6 +485,7 @@ class LogisticRegressionModel private[ml] ( } override protected def probability2prediction(probability: Vector): Double = { + // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. if (probability(1) > getThreshold) 1 else 0 } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index da4c076830391..9e12f1856a940 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -45,14 +45,14 @@ private[shared] object SharedParamsCodeGen { " These probabilities should be treated as confidences, not precise probabilities.", Some("\"probability\"")), ParamDesc[Double]("threshold", - "threshold in binary classification prediction, in range [0, 1]", + "threshold in binary classification prediction, in range [0, 1]", Some("0.5"), isValid = "ParamValidators.inRange(0, 1)", finalMethods = false), ParamDesc[Array[Double]]("thresholds", "Thresholds in multi-class classification" + " to adjust the probability of predicting each class." + " Array must have length equal to the number of classes, with values >= 0." + " The class with largest value p/t is predicted, where p is the original probability" + " of that class and t is the class' threshold.", - isValid = "(t: Array[Double]) => t.forall(_ >= 0)"), + isValid = "(t: Array[Double]) => t.forall(_ >= 0)", finalMethods = false), ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 23e2b6cc43996..a17d4ea960a90 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -139,7 +139,7 @@ private[ml] trait HasProbabilityCol extends Params { } /** - * Trait for shared param threshold. + * Trait for shared param threshold (default: 0.5). */ private[ml] trait HasThreshold extends Params { @@ -149,6 +149,8 @@ private[ml] trait HasThreshold extends Params { */ final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]", ParamValidators.inRange(0, 1)) + setDefault(threshold, 0.5) + /** @group getParam */ def getThreshold: Double = $(threshold) } @@ -165,7 +167,7 @@ private[ml] trait HasThresholds extends Params { final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", (t: Array[Double]) => t.forall(_ >= 0)) /** @group getParam */ - final def getThresholds: Array[Double] = $(thresholds) + def getThresholds: Array[Double] = $(thresholds) } /** diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 7e9aa383728f0..618b95b9bd126 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -100,9 +100,7 @@ public void logisticRegressionWithSetters() { assert(r.getDouble(0) == 0.0); } // Call transform with params, and check that the params worked. - double[] thresholds = {1.0, 0.0}; - model.transform( - dataset, model.thresholds().w(thresholds), model.probabilityCol().w("myProb")) + model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb")) .registerTempTable("predNotAllZero"); DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); boolean foundNonZero = false; @@ -112,9 +110,8 @@ public void logisticRegressionWithSetters() { assert(foundNonZero); // Call fit() with new params, and check as many params as we can. - double[] thresholds2 = {0.6, 0.4}; LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), - lr.thresholds().w(thresholds2), lr.probabilityCol().w("theProb")); + lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); LogisticRegression parent2 = (LogisticRegression) model2.parent(); assert(parent2.getMaxIter() == 5); assert(parent2.getRegParam() == 0.1); diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 8c3d4590f5ae9..e354e161c6dee 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -94,12 +94,13 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { test("setThreshold, getThreshold") { val lr = new LogisticRegression // default - withClue("LogisticRegression should not have thresholds set by default") { - intercept[java.util.NoSuchElementException] { + assert(lr.getThreshold === 0.5, "LogisticRegression.threshold should default to 0.5") + withClue("LogisticRegression should not have thresholds set by default.") { + intercept[java.util.NoSuchElementException] { // Note: The exception type may change in future lr.getThresholds } } - // Set via thresholds. + // Set via threshold. // Intuition: Large threshold or large thresholds(1) makes class 0 more likely. lr.setThreshold(1.0) assert(lr.getThresholds === Array(0.0, 1.0)) @@ -107,10 +108,26 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(lr.getThresholds === Array(1.0, 0.0)) lr.setThreshold(0.5) assert(lr.getThresholds === Array(0.5, 0.5)) - // Test getThreshold - lr.setThresholds(Array(0.3, 0.7)) + // Set via thresholds + val lr2 = new LogisticRegression + lr2.setThresholds(Array(0.3, 0.7)) val expectedThreshold = 1.0 / (1.0 + 0.3 / 0.7) - assert(lr.getThreshold ~== expectedThreshold relTol 1E-7) + assert(lr2.getThreshold ~== expectedThreshold relTol 1E-7) + // thresholds and threshold must be consistent + lr2.setThresholds(Array(0.1, 0.2, 0.3)) + withClue("getThreshold should throw error if thresholds has length != 2.") { + intercept[IllegalArgumentException] { + lr2.getThreshold + } + } + // thresholds and threshold must be consistent: values + withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") { + intercept[IllegalArgumentException] { + val lr2model = lr2.fit(dataset, + lr2.thresholds -> Array(0.3, 0.7), lr2.threshold -> (expectedThreshold / 2.0)) + lr2model.getThreshold + } + } } test("logistic regression doesn't fit intercept when fitIntercept is off") { @@ -145,7 +162,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.") // Call transform with params, and check that the params worked. val predNotAllZero = - model.transform(dataset, model.thresholds -> Array(1.0, 0.0), + model.transform(dataset, model.threshold -> 0.0, model.probabilityCol -> "myProb") .select("prediction", "myProb") .collect() @@ -153,8 +170,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(predNotAllZero.exists(_ !== 0.0)) // Call fit() with new params, and check as many params as we can. + lr.setThresholds(Array(0.6, 0.4)) val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, - lr.thresholds -> Array(0.6, 0.4), lr.probabilityCol -> "theProb") val parent2 = model2.parent.asInstanceOf[LogisticRegression] assert(parent2.getMaxIter === 5) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 6702dce5545a9..83f808efc3bf0 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -76,19 +76,21 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti " Array must have length equal to the number of classes, with values >= 0." + " The class with largest value p/t is predicted, where p is the original" + " probability of that class and t is the class' threshold.") + threshold = Param(Params._dummy(), "threshold", + "Threshold in binary classification prediction, in range [0, 1]." + + " If threshold and thresholds are both set, they must match.") @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=None, thresholds=None, + threshold=0.5, thresholds=None, probabilityCol="probability", rawPredictionCol="rawPrediction"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=None, thresholds=None, \ + threshold=0.5, thresholds=None, \ probabilityCol="probability", rawPredictionCol="rawPrediction") - Param thresholds overrides Param threshold; threshold is provided - for backwards compatibility and only applies to binary classification. + If the threshold and thresholds Params are both set, they must be equivalent. """ super(LogisticRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -101,7 +103,11 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") #: param for whether to fit an intercept term. self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") - #: param for threshold in binary classification prediction, in range [0, 1]. + #: param for threshold in binary classification, in range [0, 1]. + self.threshold = Param(self, "threshold", + "Threshold in binary classification prediction, in range [0, 1]." + + " If threshold and thresholds are both set, they must match.") + #: param for thresholds or cutoffs in binary or multiclass classification self.thresholds = \ Param(self, "thresholds", "Thresholds in multi-class classification" + @@ -110,29 +116,28 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred " The class with largest value p/t is predicted, where p is the original" + " probability of that class and t is the class' threshold.") self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6, - fitIntercept=True) + fitIntercept=True, threshold=0.5) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) + self._checkThresholdConsistency() @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=None, thresholds=None, + threshold=0.5, thresholds=None, probabilityCol="probability", rawPredictionCol="rawPrediction"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=None, thresholds=None, \ + threshold=0.5, thresholds=None, \ probabilityCol="probability", rawPredictionCol="rawPrediction") Sets params for logistic regression. - Param thresholds overrides Param threshold; threshold is provided - for backwards compatibility and only applies to binary classification. + If the threshold and thresholds Params are both set, they must be equivalent. """ - # Under the hood we use thresholds so translate threshold to thresholds if applicable - if thresholds is None and threshold is not None: - kwargs[thresholds] = [1-threshold, threshold] kwargs = self.setParams._input_kwargs - return self._set(**kwargs) + self._set(**kwargs) + self._checkThresholdConsistency() + return self def _create_model(self, java_model): return LogisticRegressionModel(java_model) @@ -165,44 +170,65 @@ def getFitIntercept(self): def setThreshold(self, value): """ - Sets the value of :py:attr:`thresholds` using [1-value, value]. + Sets the value of :py:attr:`threshold`. + Clears value of :py:attr:`thresholds` if it has been set. + """ + self._paramMap[self.threshold] = value + if self.isSet(self.thresholds): + del self._paramMap[self.thresholds] + return self - >>> lr = LogisticRegression() - >>> lr.getThreshold() - 0.5 - >>> lr.setThreshold(0.6) - LogisticRegression_... - >>> abs(lr.getThreshold() - 0.6) < 1e-5 - True + def getThreshold(self): + """ + Gets the value of threshold or its default value. """ - return self.setThresholds([1-value, value]) + self._checkThresholdConsistency() + if self.isSet(self.thresholds): + ts = self.getOrDefault(self.thresholds) + if len(ts) != 2: + raise ValueError("Logistic Regression getThreshold only applies to" + + " binary classification, but thresholds has length != 2." + + " thresholds: " + ",".join(ts)) + return 1.0/(1.0 + ts[0]/ts[1]) + else: + return self.getOrDefault(self.threshold) def setThresholds(self, value): """ Sets the value of :py:attr:`thresholds`. + Clears value of :py:attr:`threshold` if it has been set. """ self._paramMap[self.thresholds] = value + if self.isSet(self.threshold): + del self._paramMap[self.threshold] return self def getThresholds(self): """ - Gets the value of thresholds or its default value. + If :py:attr:`thresholds` is set, return its value. + Otherwise, if :py:attr:`threshold` is set, return the equivalent thresholds for binary + classification: (1-threshold, threshold). + If neither are set, throw an error. """ - return self.getOrDefault(self.thresholds) + self._checkThresholdConsistency() + if not self.isSet(self.thresholds) and self.isSet(self.threshold): + t = self.getOrDefault(self.threshold) + return [1.0-t, t] + else: + return self.getOrDefault(self.thresholds) - def getThreshold(self): - """ - Gets the value of threshold or its default value. - """ - if self.isDefined(self.thresholds): - thresholds = self.getOrDefault(self.thresholds) - if len(thresholds) != 2: + def _checkThresholdConsistency(self): + if self.isSet(self.threshold) and self.isSet(self.thresholds): + ts = self.getParam(self.thresholds) + if len(ts) != 2: raise ValueError("Logistic Regression getThreshold only applies to" + " binary classification, but thresholds has length != 2." + - " thresholds: " + ",".join(thresholds)) - return 1.0/(1.0+thresholds[0]/thresholds[1]) - else: - return 0.5 + " thresholds: " + ",".join(ts)) + t = 1.0/(1.0 + ts[0]/ts[1]) + t2 = self.getParam(self.threshold) + if abs(t2 - t) >= 1E-5: + raise ValueError("Logistic Regression getThreshold found inconsistent values for" + + " threshold (%g) and thresholds (equivalent to %g)" % (t2, t)) class LogisticRegressionModel(JavaModel): From 6f60298b1d7aa97268a42eca1e3b4851a7e88cb5 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 14:28:23 -0700 Subject: [PATCH 300/340] [SPARK-8967] [DOC] add Since annotation Add `Since` as a Scala annotation. The benefit is that we can use it without having explicit JavaDoc. This is useful for inherited methods. The limitation is that is doesn't show up in the generated Java API documentation. This might be fixed by modifying genjavadoc. I think we could leave it as a TODO. This is how the generated Scala doc looks: `since` JavaDoc tag: ![screen shot 2015-08-11 at 10 00 37 pm](https://cloud.githubusercontent.com/assets/829644/9230761/fa72865c-40d8-11e5-807e-0f3c815c5acd.png) `Since` annotation: ![screen shot 2015-08-11 at 10 00 28 pm](https://cloud.githubusercontent.com/assets/829644/9230764/0041d7f4-40d9-11e5-8124-c3f3e5d5b31f.png) rxin Author: Xiangrui Meng Closes #8131 from mengxr/SPARK-8967. --- .../org/apache/spark/annotation/Since.scala | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 core/src/main/scala/org/apache/spark/annotation/Since.scala diff --git a/core/src/main/scala/org/apache/spark/annotation/Since.scala b/core/src/main/scala/org/apache/spark/annotation/Since.scala new file mode 100644 index 0000000000000..fa59393c22476 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/annotation/Since.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.annotation + +import scala.annotation.StaticAnnotation + +/** + * A Scala annotation that specifies the Spark version when a definition was added. + * Different from the `@since` tag in JavaDoc, this annotation does not require explicit JavaDoc and + * hence works for overridden methods that inherit API documentation directly from parents. + * The limitation is that it does not show up in the generated Java API documentation. + */ +private[spark] class Since(version: String) extends StaticAnnotation From a17384fa343628cec44437da5b80b9403ecd5838 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 12 Aug 2015 15:27:52 -0700 Subject: [PATCH 301/340] [SPARK-9907] [SQL] Python crc32 is mistakenly calling md5 Author: Reynold Xin Closes #8138 from rxin/SPARK-9907. --- python/pyspark/sql/functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 95f46044d324a..e98979533f901 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -885,10 +885,10 @@ def crc32(col): returns the value as a bigint. >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(crc32('a').alias('crc32')).collect() - [Row(crc32=u'902fbdd2b1df0c4f70b4a5d23525e932')] + [Row(crc32=2743272264)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.md5(_to_java_column(col))) + return Column(sc._jvm.functions.crc32(_to_java_column(col))) @ignore_unicode_prefix From 738f353988dbf02704bd63f5e35d94402c59ed79 Mon Sep 17 00:00:00 2001 From: Niranjan Padmanabhan Date: Wed, 12 Aug 2015 16:10:21 -0700 Subject: [PATCH 302/340] [SPARK-9092] Fixed incompatibility when both num-executors and dynamic... MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … allocation are set. Now, dynamic allocation is set to false when num-executors is explicitly specified as an argument. Consequently, executorAllocationManager in not initialized in the SparkContext. Author: Niranjan Padmanabhan Closes #7657 from neurons/SPARK-9092. --- .../scala/org/apache/spark/SparkConf.scala | 19 +++++++++++++++++++ .../scala/org/apache/spark/SparkContext.scala | 6 +++++- .../org/apache/spark/deploy/SparkSubmit.scala | 4 ++-- .../scala/org/apache/spark/util/Utils.scala | 11 +++++++++++ .../org/apache/spark/SparkContextSuite.scala | 8 ++++++++ .../spark/deploy/SparkSubmitSuite.scala | 1 - docs/running-on-yarn.md | 2 +- .../spark/deploy/yarn/ApplicationMaster.scala | 4 ++-- .../yarn/ApplicationMasterArguments.scala | 5 ----- .../org/apache/spark/deploy/yarn/Client.scala | 5 ++++- .../spark/deploy/yarn/ClientArguments.scala | 8 +------- .../spark/deploy/yarn/YarnAllocator.scala | 9 ++++++++- .../cluster/YarnClientSchedulerBackend.scala | 3 --- .../deploy/yarn/YarnAllocatorSuite.scala | 5 +++-- 14 files changed, 64 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 8ff154fb5e334..b344b5e173d67 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -389,6 +389,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { val driverOptsKey = "spark.driver.extraJavaOptions" val driverClassPathKey = "spark.driver.extraClassPath" val driverLibraryPathKey = "spark.driver.extraLibraryPath" + val sparkExecutorInstances = "spark.executor.instances" // Used by Yarn in 1.1 and before sys.props.get("spark.driver.libraryPath").foreach { value => @@ -476,6 +477,24 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } } } + + if (!contains(sparkExecutorInstances)) { + sys.env.get("SPARK_WORKER_INSTANCES").foreach { value => + val warning = + s""" + |SPARK_WORKER_INSTANCES was detected (set to '$value'). + |This is deprecated in Spark 1.0+. + | + |Please instead use: + | - ./spark-submit with --num-executors to specify the number of executors + | - Or set SPARK_EXECUTOR_INSTANCES + | - spark.executor.instances to configure the number of instances in the spark config. + """.stripMargin + logWarning(warning) + + set("spark.executor.instances", value) + } + } } /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6aafb4c5644d7..207a0c1bffeb3 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -528,7 +528,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } // Optionally scale number of executors dynamically based on workload. Exposed for testing. - val dynamicAllocationEnabled = _conf.getBoolean("spark.dynamicAllocation.enabled", false) + val dynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(_conf) + if (!dynamicAllocationEnabled && _conf.getBoolean("spark.dynamicAllocation.enabled", false)) { + logInfo("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") + } + _executorAllocationManager = if (dynamicAllocationEnabled) { Some(new ExecutorAllocationManager(this, listenerBus, _conf)) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 7ac6cbce4cd1d..02fa3088eded0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -422,7 +422,8 @@ object SparkSubmit { // Yarn client only OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), - OptionAssigner(args.numExecutors, YARN, CLIENT, sysProp = "spark.executor.instances"), + OptionAssigner(args.numExecutors, YARN, ALL_DEPLOY_MODES, + sysProp = "spark.executor.instances"), OptionAssigner(args.files, YARN, CLIENT, sysProp = "spark.yarn.dist.files"), OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"), OptionAssigner(args.principal, YARN, CLIENT, sysProp = "spark.yarn.principal"), @@ -433,7 +434,6 @@ object SparkSubmit { OptionAssigner(args.driverMemory, YARN, CLUSTER, clOption = "--driver-memory"), OptionAssigner(args.driverCores, YARN, CLUSTER, clOption = "--driver-cores"), OptionAssigner(args.queue, YARN, CLUSTER, clOption = "--queue"), - OptionAssigner(args.numExecutors, YARN, CLUSTER, clOption = "--num-executors"), OptionAssigner(args.executorMemory, YARN, CLUSTER, clOption = "--executor-memory"), OptionAssigner(args.executorCores, YARN, CLUSTER, clOption = "--executor-cores"), OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"), diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c4012d0e83f7d..a90d8541366f4 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2286,6 +2286,17 @@ private[spark] object Utils extends Logging { isInDirectory(parent, child.getParentFile) } + /** + * Return whether dynamic allocation is enabled in the given conf + * Dynamic allocation and explicitly setting the number of executors are inherently + * incompatible. In environments where dynamic allocation is turned on by default, + * the latter should override the former (SPARK-9092). + */ + def isDynamicAllocationEnabled(conf: SparkConf): Boolean = { + conf.contains("spark.dynamicAllocation.enabled") && + conf.getInt("spark.executor.instances", 0) == 0 + } + } private [util] class SparkShutdownHookManager { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 5c57940fa5f77..d4f2ea87650a9 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -285,4 +285,12 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { } } + test("No exception when both num-executors and dynamic allocation set.") { + noException should be thrownBy { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local") + .set("spark.dynamicAllocation.enabled", "true").set("spark.executor.instances", "6")) + assert(sc.executorAllocationManager.isEmpty) + assert(sc.getConf.getInt("spark.executor.instances", 0) === 6) + } + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 757e0ce3d278b..2456c5d0d49b0 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -159,7 +159,6 @@ class SparkSubmitSuite childArgsStr should include ("--executor-cores 5") childArgsStr should include ("--arg arg1 --arg arg2") childArgsStr should include ("--queue thequeue") - childArgsStr should include ("--num-executors 6") childArgsStr should include regex ("--jar .*thejar.jar") childArgsStr should include regex ("--addJars .*one.jar,.*two.jar,.*three.jar") childArgsStr should include regex ("--files .*file1.txt,.*file2.txt") diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index cac08a91b97d9..ec32c419b7c51 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -199,7 +199,7 @@ If you need a reference to the proper location to put log files in the YARN so t spark.executor.instances 2 - The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. + The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. If both spark.dynamicAllocation.enabled and spark.executor.instances are specified, dynamic allocation is turned off and the specified number of spark.executor.instances is used. diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 1d67b3ebb51b7..e19940d8d6642 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -64,7 +64,8 @@ private[spark] class ApplicationMaster( // Default to numExecutors * 2, with minimum of 3 private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", - sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3))) + sparkConf.getInt("spark.yarn.max.worker.failures", + math.max(sparkConf.getInt("spark.executor.instances", 0) * 2, 3))) @volatile private var exitCode = 0 @volatile private var unregistered = false @@ -493,7 +494,6 @@ private[spark] class ApplicationMaster( */ private def startUserApplication(): Thread = { logInfo("Starting the user application in a separate Thread") - System.setProperty("spark.executor.instances", args.numExecutors.toString) val classpath = Client.getUserClasspath(sparkConf) val urls = classpath.map { entry => diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index 37f793763367e..b08412414aa1c 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -29,7 +29,6 @@ class ApplicationMasterArguments(val args: Array[String]) { var userArgs: Seq[String] = Nil var executorMemory = 1024 var executorCores = 1 - var numExecutors = DEFAULT_NUMBER_EXECUTORS var propertiesFile: String = null parseArgs(args.toList) @@ -63,10 +62,6 @@ class ApplicationMasterArguments(val args: Array[String]) { userArgsBuffer += value args = tail - case ("--num-workers" | "--num-executors") :: IntParam(value) :: tail => - numExecutors = value - args = tail - case ("--worker-memory" | "--executor-memory") :: MemoryParam(value) :: tail => executorMemory = value args = tail diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index b4ba3f0221600..6d63ddaf15852 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -751,7 +751,6 @@ private[spark] class Client( userArgs ++ Seq( "--executor-memory", args.executorMemory.toString + "m", "--executor-cores", args.executorCores.toString, - "--num-executors ", args.numExecutors.toString, "--properties-file", buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), LOCALIZED_CONF_DIR, SPARK_CONF_FILE)) @@ -960,6 +959,10 @@ object Client extends Logging { val sparkConf = new SparkConf val args = new ClientArguments(argStrings, sparkConf) + // to maintain backwards-compatibility + if (!Utils.isDynamicAllocationEnabled(sparkConf)) { + sparkConf.setIfMissing("spark.executor.instances", args.numExecutors.toString) + } new Client(args, sparkConf).run() } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 20d63d40cf605..4f42ffefa77f9 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -53,8 +53,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) private val amMemOverheadKey = "spark.yarn.am.memoryOverhead" private val driverCoresKey = "spark.driver.cores" private val amCoresKey = "spark.yarn.am.cores" - private val isDynamicAllocationEnabled = - sparkConf.getBoolean("spark.dynamicAllocation.enabled", false) + private val isDynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(sparkConf) parseArgs(args.toList) loadEnvironmentArgs() @@ -196,11 +195,6 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) if (args(0) == "--num-workers") { println("--num-workers is deprecated. Use --num-executors instead.") } - // Dynamic allocation is not compatible with this option - if (isDynamicAllocationEnabled) { - throw new IllegalArgumentException("Explicitly setting the number " + - "of executors is not compatible with spark.dynamicAllocation.enabled!") - } numExecutors = value args = tail diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 59caa787b6e20..ccf753e69f4b6 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -21,6 +21,8 @@ import java.util.Collections import java.util.concurrent._ import java.util.regex.Pattern +import org.apache.spark.util.Utils + import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -86,7 +88,12 @@ private[yarn] class YarnAllocator( private var executorIdCounter = 0 @volatile private var numExecutorsFailed = 0 - @volatile private var targetNumExecutors = args.numExecutors + @volatile private var targetNumExecutors = + if (Utils.isDynamicAllocationEnabled(sparkConf)) { + sparkConf.getInt("spark.dynamicAllocation.initialExecutors", 0) + } else { + sparkConf.getInt("spark.executor.instances", YarnSparkHadoopUtil.DEFAULT_NUMBER_EXECUTORS) + } // Keep track of which container is running which executor to remove the executors later // Visible for testing. diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index d225061fcd1b4..d06d95140438c 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -81,8 +81,6 @@ private[spark] class YarnClientSchedulerBackend( // List of (target Client argument, environment variable, Spark property) val optionTuples = List( - ("--num-executors", "SPARK_WORKER_INSTANCES", "spark.executor.instances"), - ("--num-executors", "SPARK_EXECUTOR_INSTANCES", "spark.executor.instances"), ("--executor-memory", "SPARK_WORKER_MEMORY", "spark.executor.memory"), ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"), ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"), @@ -92,7 +90,6 @@ private[spark] class YarnClientSchedulerBackend( ) // Warn against the following deprecated environment variables: env var -> suggestion val deprecatedEnvVars = Map( - "SPARK_WORKER_INSTANCES" -> "SPARK_WORKER_INSTANCES or --num-executors through spark-submit", "SPARK_WORKER_MEMORY" -> "SPARK_EXECUTOR_MEMORY or --executor-memory through spark-submit", "SPARK_WORKER_CORES" -> "SPARK_EXECUTOR_CORES or --executor-cores through spark-submit") optionTuples.foreach { case (optionName, envVar, sparkProp) => diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 58318bf9bcc08..5d05f514adde3 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -87,16 +87,17 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter def createAllocator(maxExecutors: Int = 5): YarnAllocator = { val args = Array( - "--num-executors", s"$maxExecutors", "--executor-cores", "5", "--executor-memory", "2048", "--jar", "somejar.jar", "--class", "SomeClass") + val sparkConfClone = sparkConf.clone() + sparkConfClone.set("spark.executor.instances", maxExecutors.toString) new YarnAllocator( "not used", mock(classOf[RpcEndpointRef]), conf, - sparkConf, + sparkConfClone, rmClient, appAttemptId, new ApplicationMasterArguments(args), From ab7e721cfec63155641e81e72b4ad43cf6a7d4c7 Mon Sep 17 00:00:00 2001 From: Michel Lemay Date: Wed, 12 Aug 2015 16:17:58 -0700 Subject: [PATCH 303/340] [SPARK-9826] [CORE] Fix cannot use custom classes in log4j.properties Refactor Utils class and create ShutdownHookManager. NOTE: Wasn't able to run /dev/run-tests on windows machine. Manual tests were conducted locally using custom log4j.properties file with Redis appender and logstash formatter (bundled in the fat-jar submitted to spark) ex: log4j.rootCategory=WARN,console,redis log4j.appender.console=org.apache.log4j.ConsoleAppender log4j.appender.console.target=System.err log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n log4j.logger.org.eclipse.jetty=WARN log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO log4j.logger.org.apache.spark.graphx.Pregel=INFO log4j.appender.redis=com.ryantenney.log4j.FailoverRedisAppender log4j.appender.redis.endpoints=hostname:port log4j.appender.redis.key=mykey log4j.appender.redis.alwaysBatch=false log4j.appender.redis.layout=net.logstash.log4j.JSONEventLayoutV1 Author: michellemay Closes #8109 from michellemay/SPARK-9826. --- .../scala/org/apache/spark/SparkContext.scala | 5 +- .../spark/deploy/history/HistoryServer.scala | 4 +- .../spark/deploy/worker/ExecutorRunner.scala | 7 +- .../org/apache/spark/rdd/HadoopRDD.scala | 4 +- .../org/apache/spark/rdd/NewHadoopRDD.scala | 4 +- .../apache/spark/rdd/SqlNewHadoopRDD.scala | 4 +- .../spark/storage/DiskBlockManager.scala | 10 +- .../spark/storage/TachyonBlockManager.scala | 6 +- .../spark/util/ShutdownHookManager.scala | 266 ++++++++++++++++++ .../util/SparkUncaughtExceptionHandler.scala | 2 +- .../scala/org/apache/spark/util/Utils.scala | 222 +-------------- .../hive/thriftserver/HiveThriftServer2.scala | 4 +- .../hive/thriftserver/SparkSQLCLIDriver.scala | 4 +- .../apache/spark/sql/hive/test/TestHive.scala | 4 +- .../spark/streaming/StreamingContext.scala | 8 +- .../spark/deploy/yarn/ApplicationMaster.scala | 5 +- 16 files changed, 307 insertions(+), 252 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 207a0c1bffeb3..2e01a9a18c784 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -563,7 +563,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Make sure the context is stopped if the user forgets about it. This avoids leaving // unfinished event logs around after the JVM exits cleanly. It doesn't help if the JVM // is killed, though. - _shutdownHookRef = Utils.addShutdownHook(Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY) { () => + _shutdownHookRef = ShutdownHookManager.addShutdownHook( + ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY) { () => logInfo("Invoking stop() from shutdown hook") stop() } @@ -1671,7 +1672,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli return } if (_shutdownHookRef != null) { - Utils.removeShutdownHook(_shutdownHookRef) + ShutdownHookManager.removeShutdownHook(_shutdownHookRef) } Utils.tryLogNonFatalError { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index a076a9c3f984d..d4f327cc588fe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -30,7 +30,7 @@ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, Applica UIRoot} import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.{SignalLogger, Utils} +import org.apache.spark.util.{ShutdownHookManager, SignalLogger, Utils} /** * A web server that renders SparkUIs of completed applications. @@ -238,7 +238,7 @@ object HistoryServer extends Logging { val server = new HistoryServer(conf, provider, securityManager, port) server.bind() - Utils.addShutdownHook { () => server.stop() } + ShutdownHookManager.addShutdownHook { () => server.stop() } // Wait until the end of the world... or if the HistoryServer process is manually stopped while(true) { Thread.sleep(Int.MaxValue) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 29a5042285578..ab3fea475c2a5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -28,7 +28,7 @@ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.util.logging.FileAppender /** @@ -70,7 +70,8 @@ private[deploy] class ExecutorRunner( } workerThread.start() // Shutdown hook that kills actors on shutdown. - shutdownHook = Utils.addShutdownHook { () => killProcess(Some("Worker shutting down")) } + shutdownHook = ShutdownHookManager.addShutdownHook { () => + killProcess(Some("Worker shutting down")) } } /** @@ -102,7 +103,7 @@ private[deploy] class ExecutorRunner( workerThread = null state = ExecutorState.KILLED try { - Utils.removeShutdownHook(shutdownHook) + ShutdownHookManager.removeShutdownHook(shutdownHook) } catch { case e: IllegalStateException => None } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index f1c17369cb48c..e1f8719eead02 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -44,7 +44,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{SerializableConfiguration, NextIterator, Utils} +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, NextIterator, Utils} import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation} import org.apache.spark.storage.StorageLevel @@ -274,7 +274,7 @@ class HadoopRDD[K, V]( } } catch { case e: Exception => { - if (!Utils.inShutdown()) { + if (!ShutdownHookManager.inShutdown()) { logWarning("Exception in RecordReader.close()", e) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index f83a051f5da11..6a9c004d65cff 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -33,7 +33,7 @@ import org.apache.spark._ import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.storage.StorageLevel @@ -186,7 +186,7 @@ class NewHadoopRDD[K, V]( } } catch { case e: Exception => { - if (!Utils.inShutdown()) { + if (!ShutdownHookManager.inShutdown()) { logWarning("Exception in RecordReader.close()", e) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index 6a95e44c57fec..fa3fecc80cb63 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -33,7 +33,7 @@ import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.{Partition => SparkPartition, _} import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils} private[spark] class SqlNewHadoopPartition( @@ -212,7 +212,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( } } catch { case e: Exception => - if (!Utils.inShutdown()) { + if (!ShutdownHookManager.inShutdown()) { logWarning("Exception in RecordReader.close()", e) } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 56a33d5ca7d60..3f8d26e1d4cab 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -22,7 +22,7 @@ import java.io.{IOException, File} import org.apache.spark.{SparkConf, Logging} import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** * Creates and maintains the logical mapping between logical blocks and physical on-disk @@ -144,7 +144,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon } private def addShutdownHook(): AnyRef = { - Utils.addShutdownHook(Utils.TEMP_DIR_SHUTDOWN_PRIORITY + 1) { () => + ShutdownHookManager.addShutdownHook(ShutdownHookManager.TEMP_DIR_SHUTDOWN_PRIORITY + 1) { () => logInfo("Shutdown hook called") DiskBlockManager.this.doStop() } @@ -154,7 +154,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon private[spark] def stop() { // Remove the shutdown hook. It causes memory leaks if we leave it around. try { - Utils.removeShutdownHook(shutdownHook) + ShutdownHookManager.removeShutdownHook(shutdownHook) } catch { case e: Exception => logError(s"Exception while removing shutdown hook.", e) @@ -168,7 +168,9 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon localDirs.foreach { localDir => if (localDir.isDirectory() && localDir.exists()) { try { - if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) + if (!ShutdownHookManager.hasRootAsShutdownDeleteDir(localDir)) { + Utils.deleteRecursively(localDir) + } } catch { case e: Exception => logError(s"Exception while deleting local spark dir: $localDir", e) diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index ebad5bc5ab28d..22878783fca67 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -32,7 +32,7 @@ import tachyon.TachyonURI import org.apache.spark.Logging import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** @@ -80,7 +80,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log // in order to avoid having really large inodes at the top level in Tachyon. tachyonDirs = createTachyonDirs() subDirs = Array.fill(tachyonDirs.length)(new Array[TachyonFile](subDirsPerTachyonDir)) - tachyonDirs.foreach(tachyonDir => Utils.registerShutdownDeleteDir(tachyonDir)) + tachyonDirs.foreach(tachyonDir => ShutdownHookManager.registerShutdownDeleteDir(tachyonDir)) } override def toString: String = {"ExternalBlockStore-Tachyon"} @@ -240,7 +240,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log logDebug("Shutdown hook called") tachyonDirs.foreach { tachyonDir => try { - if (!Utils.hasRootAsShutdownDeleteDir(tachyonDir)) { + if (!ShutdownHookManager.hasRootAsShutdownDeleteDir(tachyonDir)) { Utils.deleteRecursively(tachyonDir, client) } } catch { diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala new file mode 100644 index 0000000000000..61ff9b89ec1c1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.File +import java.util.PriorityQueue + +import scala.util.{Failure, Success, Try} +import tachyon.client.TachyonFile + +import org.apache.hadoop.fs.FileSystem +import org.apache.spark.Logging + +/** + * Various utility methods used by Spark. + */ +private[spark] object ShutdownHookManager extends Logging { + val DEFAULT_SHUTDOWN_PRIORITY = 100 + + /** + * The shutdown priority of the SparkContext instance. This is lower than the default + * priority, so that by default hooks are run before the context is shut down. + */ + val SPARK_CONTEXT_SHUTDOWN_PRIORITY = 50 + + /** + * The shutdown priority of temp directory must be lower than the SparkContext shutdown + * priority. Otherwise cleaning the temp directories while Spark jobs are running can + * throw undesirable errors at the time of shutdown. + */ + val TEMP_DIR_SHUTDOWN_PRIORITY = 25 + + private lazy val shutdownHooks = { + val manager = new SparkShutdownHookManager() + manager.install() + manager + } + + private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() + private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() + + // Add a shutdown hook to delete the temp dirs when the JVM exits + addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () => + logInfo("Shutdown hook called") + shutdownDeletePaths.foreach { dirPath => + try { + logInfo("Deleting directory " + dirPath) + Utils.deleteRecursively(new File(dirPath)) + } catch { + case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e) + } + } + } + + // Register the path to be deleted via shutdown hook + def registerShutdownDeleteDir(file: File) { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths += absolutePath + } + } + + // Register the tachyon path to be deleted via shutdown hook + def registerShutdownDeleteDir(tachyonfile: TachyonFile) { + val absolutePath = tachyonfile.getPath() + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths += absolutePath + } + } + + // Remove the path to be deleted via shutdown hook + def removeShutdownDeleteDir(file: File) { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths.remove(absolutePath) + } + } + + // Remove the tachyon path to be deleted via shutdown hook + def removeShutdownDeleteDir(tachyonfile: TachyonFile) { + val absolutePath = tachyonfile.getPath() + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.remove(absolutePath) + } + } + + // Is the path already registered to be deleted via a shutdown hook ? + def hasShutdownDeleteDir(file: File): Boolean = { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths.contains(absolutePath) + } + } + + // Is the path already registered to be deleted via a shutdown hook ? + def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = { + val absolutePath = file.getPath() + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.contains(absolutePath) + } + } + + // Note: if file is child of some registered path, while not equal to it, then return true; + // else false. This is to ensure that two shutdown hooks do not try to delete each others + // paths - resulting in IOException and incomplete cleanup. + def hasRootAsShutdownDeleteDir(file: File): Boolean = { + val absolutePath = file.getAbsolutePath() + val retval = shutdownDeletePaths.synchronized { + shutdownDeletePaths.exists { path => + !absolutePath.equals(path) && absolutePath.startsWith(path) + } + } + if (retval) { + logInfo("path = " + file + ", already present as root for deletion.") + } + retval + } + + // Note: if file is child of some registered path, while not equal to it, then return true; + // else false. This is to ensure that two shutdown hooks do not try to delete each others + // paths - resulting in Exception and incomplete cleanup. + def hasRootAsShutdownDeleteDir(file: TachyonFile): Boolean = { + val absolutePath = file.getPath() + val retval = shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.exists { path => + !absolutePath.equals(path) && absolutePath.startsWith(path) + } + } + if (retval) { + logInfo("path = " + file + ", already present as root for deletion.") + } + retval + } + + /** + * Detect whether this thread might be executing a shutdown hook. Will always return true if + * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g. + * if System.exit was just called by a concurrent thread). + * + * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing + * an IllegalStateException. + */ + def inShutdown(): Boolean = { + try { + val hook = new Thread { + override def run() {} + } + Runtime.getRuntime.addShutdownHook(hook) + Runtime.getRuntime.removeShutdownHook(hook) + } catch { + case ise: IllegalStateException => return true + } + false + } + + /** + * Adds a shutdown hook with default priority. + * + * @param hook The code to run during shutdown. + * @return A handle that can be used to unregister the shutdown hook. + */ + def addShutdownHook(hook: () => Unit): AnyRef = { + addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY)(hook) + } + + /** + * Adds a shutdown hook with the given priority. Hooks with lower priority values run + * first. + * + * @param hook The code to run during shutdown. + * @return A handle that can be used to unregister the shutdown hook. + */ + def addShutdownHook(priority: Int)(hook: () => Unit): AnyRef = { + shutdownHooks.add(priority, hook) + } + + /** + * Remove a previously installed shutdown hook. + * + * @param ref A handle returned by `addShutdownHook`. + * @return Whether the hook was removed. + */ + def removeShutdownHook(ref: AnyRef): Boolean = { + shutdownHooks.remove(ref) + } + +} + +private [util] class SparkShutdownHookManager { + + private val hooks = new PriorityQueue[SparkShutdownHook]() + private var shuttingDown = false + + /** + * Install a hook to run at shutdown and run all registered hooks in order. Hadoop 1.x does not + * have `ShutdownHookManager`, so in that case we just use the JVM's `Runtime` object and hope for + * the best. + */ + def install(): Unit = { + val hookTask = new Runnable() { + override def run(): Unit = runAll() + } + Try(Utils.classForName("org.apache.hadoop.util.ShutdownHookManager")) match { + case Success(shmClass) => + val fsPriority = classOf[FileSystem].getField("SHUTDOWN_HOOK_PRIORITY").get() + .asInstanceOf[Int] + val shm = shmClass.getMethod("get").invoke(null) + shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int]) + .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30)) + + case Failure(_) => + Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook")); + } + } + + def runAll(): Unit = synchronized { + shuttingDown = true + while (!hooks.isEmpty()) { + Try(Utils.logUncaughtExceptions(hooks.poll().run())) + } + } + + def add(priority: Int, hook: () => Unit): AnyRef = synchronized { + checkState() + val hookRef = new SparkShutdownHook(priority, hook) + hooks.add(hookRef) + hookRef + } + + def remove(ref: AnyRef): Boolean = synchronized { + hooks.remove(ref) + } + + private def checkState(): Unit = { + if (shuttingDown) { + throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.") + } + } + +} + +private class SparkShutdownHook(private val priority: Int, hook: () => Unit) + extends Comparable[SparkShutdownHook] { + + override def compareTo(other: SparkShutdownHook): Int = { + other.priority - priority + } + + def run(): Unit = hook() + +} diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala index ad3db1fbb57ed..7248187247330 100644 --- a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala +++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala @@ -33,7 +33,7 @@ private[spark] object SparkUncaughtExceptionHandler // We may have been called from a shutdown hook. If so, we must not call System.exit(). // (If we do, we will deadlock.) - if (!Utils.inShutdown()) { + if (!ShutdownHookManager.inShutdown()) { if (exception.isInstanceOf[OutOfMemoryError]) { System.exit(SparkExitCode.OOM) } else { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index a90d8541366f4..f2abf227dc129 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -21,7 +21,7 @@ import java.io._ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer -import java.util.{PriorityQueue, Properties, Locale, Random, UUID} +import java.util.{Properties, Locale, Random, UUID} import java.util.concurrent._ import javax.net.ssl.HttpsURLConnection @@ -65,21 +65,6 @@ private[spark] object CallSite { private[spark] object Utils extends Logging { val random = new Random() - val DEFAULT_SHUTDOWN_PRIORITY = 100 - - /** - * The shutdown priority of the SparkContext instance. This is lower than the default - * priority, so that by default hooks are run before the context is shut down. - */ - val SPARK_CONTEXT_SHUTDOWN_PRIORITY = 50 - - /** - * The shutdown priority of temp directory must be lower than the SparkContext shutdown - * priority. Otherwise cleaning the temp directories while Spark jobs are running can - * throw undesirable errors at the time of shutdown. - */ - val TEMP_DIR_SHUTDOWN_PRIORITY = 25 - /** * Define a default value for driver memory here since this value is referenced across the code * base and nearly all files already use Utils.scala @@ -90,9 +75,6 @@ private[spark] object Utils extends Logging { @volatile private var localRootDirs: Array[String] = null - private val shutdownHooks = new SparkShutdownHookManager() - shutdownHooks.install() - /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -205,86 +187,6 @@ private[spark] object Utils extends Logging { } } - private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() - private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() - - // Add a shutdown hook to delete the temp dirs when the JVM exits - addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () => - logInfo("Shutdown hook called") - shutdownDeletePaths.foreach { dirPath => - try { - logInfo("Deleting directory " + dirPath) - Utils.deleteRecursively(new File(dirPath)) - } catch { - case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e) - } - } - } - - // Register the path to be deleted via shutdown hook - def registerShutdownDeleteDir(file: File) { - val absolutePath = file.getAbsolutePath() - shutdownDeletePaths.synchronized { - shutdownDeletePaths += absolutePath - } - } - - // Register the tachyon path to be deleted via shutdown hook - def registerShutdownDeleteDir(tachyonfile: TachyonFile) { - val absolutePath = tachyonfile.getPath() - shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths += absolutePath - } - } - - // Is the path already registered to be deleted via a shutdown hook ? - def hasShutdownDeleteDir(file: File): Boolean = { - val absolutePath = file.getAbsolutePath() - shutdownDeletePaths.synchronized { - shutdownDeletePaths.contains(absolutePath) - } - } - - // Is the path already registered to be deleted via a shutdown hook ? - def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = { - val absolutePath = file.getPath() - shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths.contains(absolutePath) - } - } - - // Note: if file is child of some registered path, while not equal to it, then return true; - // else false. This is to ensure that two shutdown hooks do not try to delete each others - // paths - resulting in IOException and incomplete cleanup. - def hasRootAsShutdownDeleteDir(file: File): Boolean = { - val absolutePath = file.getAbsolutePath() - val retval = shutdownDeletePaths.synchronized { - shutdownDeletePaths.exists { path => - !absolutePath.equals(path) && absolutePath.startsWith(path) - } - } - if (retval) { - logInfo("path = " + file + ", already present as root for deletion.") - } - retval - } - - // Note: if file is child of some registered path, while not equal to it, then return true; - // else false. This is to ensure that two shutdown hooks do not try to delete each others - // paths - resulting in Exception and incomplete cleanup. - def hasRootAsShutdownDeleteDir(file: TachyonFile): Boolean = { - val absolutePath = file.getPath() - val retval = shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths.exists { path => - !absolutePath.equals(path) && absolutePath.startsWith(path) - } - } - if (retval) { - logInfo("path = " + file + ", already present as root for deletion.") - } - retval - } - /** * JDK equivalent of `chmod 700 file`. * @@ -333,7 +235,7 @@ private[spark] object Utils extends Logging { root: String = System.getProperty("java.io.tmpdir"), namePrefix: String = "spark"): File = { val dir = createDirectory(root, namePrefix) - registerShutdownDeleteDir(dir) + ShutdownHookManager.registerShutdownDeleteDir(dir) dir } @@ -973,9 +875,7 @@ private[spark] object Utils extends Logging { if (savedIOException != null) { throw savedIOException } - shutdownDeletePaths.synchronized { - shutdownDeletePaths.remove(file.getAbsolutePath) - } + ShutdownHookManager.removeShutdownDeleteDir(file) } } finally { if (!file.delete()) { @@ -1478,27 +1378,6 @@ private[spark] object Utils extends Logging { serializer.deserialize[T](serializer.serialize(value)) } - /** - * Detect whether this thread might be executing a shutdown hook. Will always return true if - * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g. - * if System.exit was just called by a concurrent thread). - * - * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing - * an IllegalStateException. - */ - def inShutdown(): Boolean = { - try { - val hook = new Thread { - override def run() {} - } - Runtime.getRuntime.addShutdownHook(hook) - Runtime.getRuntime.removeShutdownHook(hook) - } catch { - case ise: IllegalStateException => return true - } - false - } - private def isSpace(c: Char): Boolean = { " \t\r\n".indexOf(c) != -1 } @@ -2221,37 +2100,6 @@ private[spark] object Utils extends Logging { msg.startsWith(BACKUP_STANDALONE_MASTER_PREFIX) } - /** - * Adds a shutdown hook with default priority. - * - * @param hook The code to run during shutdown. - * @return A handle that can be used to unregister the shutdown hook. - */ - def addShutdownHook(hook: () => Unit): AnyRef = { - addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY)(hook) - } - - /** - * Adds a shutdown hook with the given priority. Hooks with lower priority values run - * first. - * - * @param hook The code to run during shutdown. - * @return A handle that can be used to unregister the shutdown hook. - */ - def addShutdownHook(priority: Int)(hook: () => Unit): AnyRef = { - shutdownHooks.add(priority, hook) - } - - /** - * Remove a previously installed shutdown hook. - * - * @param ref A handle returned by `addShutdownHook`. - * @return Whether the hook was removed. - */ - def removeShutdownHook(ref: AnyRef): Boolean = { - shutdownHooks.remove(ref) - } - /** * To avoid calling `Utils.getCallSite` for every single RDD we create in the body, * set a dummy call site that RDDs use instead. This is for performance optimization. @@ -2299,70 +2147,6 @@ private[spark] object Utils extends Logging { } -private [util] class SparkShutdownHookManager { - - private val hooks = new PriorityQueue[SparkShutdownHook]() - private var shuttingDown = false - - /** - * Install a hook to run at shutdown and run all registered hooks in order. Hadoop 1.x does not - * have `ShutdownHookManager`, so in that case we just use the JVM's `Runtime` object and hope for - * the best. - */ - def install(): Unit = { - val hookTask = new Runnable() { - override def run(): Unit = runAll() - } - Try(Utils.classForName("org.apache.hadoop.util.ShutdownHookManager")) match { - case Success(shmClass) => - val fsPriority = classOf[FileSystem].getField("SHUTDOWN_HOOK_PRIORITY").get() - .asInstanceOf[Int] - val shm = shmClass.getMethod("get").invoke(null) - shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int]) - .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30)) - - case Failure(_) => - Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook")); - } - } - - def runAll(): Unit = synchronized { - shuttingDown = true - while (!hooks.isEmpty()) { - Try(Utils.logUncaughtExceptions(hooks.poll().run())) - } - } - - def add(priority: Int, hook: () => Unit): AnyRef = synchronized { - checkState() - val hookRef = new SparkShutdownHook(priority, hook) - hooks.add(hookRef) - hookRef - } - - def remove(ref: AnyRef): Boolean = synchronized { - hooks.remove(ref) - } - - private def checkState(): Unit = { - if (shuttingDown) { - throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.") - } - } - -} - -private class SparkShutdownHook(private val priority: Int, hook: () => Unit) - extends Comparable[SparkShutdownHook] { - - override def compareTo(other: SparkShutdownHook): Int = { - other.priority - priority - } - - def run(): Unit = hook() - -} - /** * A utility class to redirect the child process's stdout or stderr. */ diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 9c047347cb58d..2c9fa595b2dad 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.{Logging, SparkContext} @@ -76,7 +76,7 @@ object HiveThriftServer2 extends Logging { logInfo("Starting SparkContext") SparkSQLEnv.init() - Utils.addShutdownHook { () => + ShutdownHookManager.addShutdownHook { () => SparkSQLEnv.stop() uiTab.foreach(_.detach()) } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index d3886142b388d..7799704c819d9 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -39,7 +39,7 @@ import org.apache.thrift.transport.TSocket import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** * This code doesn't support remote connections in Hive 1.2+, as the underlying CliDriver @@ -114,7 +114,7 @@ private[hive] object SparkSQLCLIDriver extends Logging { SessionState.start(sessionState) // Clean up after we exit - Utils.addShutdownHook { () => SparkSQLEnv.stop() } + ShutdownHookManager.addShutdownHook { () => SparkSQLEnv.stop() } val remoteMode = isRemoteMode(sessionState) // "-h" option has been passed, so connect to Hive thrift server. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 296cc5c5e0b04..4eae699ac3b51 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.execution.HiveNativeCommand -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.{SparkConf, SparkContext} /* Implicit conversions */ @@ -154,7 +154,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { val hiveFilesTemp = File.createTempFile("catalystHiveFiles", "") hiveFilesTemp.delete() hiveFilesTemp.mkdir() - Utils.registerShutdownDeleteDir(hiveFilesTemp) + ShutdownHookManager.registerShutdownDeleteDir(hiveFilesTemp) val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 177e710ace54b..b496d1f341a0b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -44,7 +44,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} -import org.apache.spark.util.{CallSite, Utils} +import org.apache.spark.util.{CallSite, ShutdownHookManager, Utils} /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -604,7 +604,7 @@ class StreamingContext private[streaming] ( } StreamingContext.setActiveContext(this) } - shutdownHookRef = Utils.addShutdownHook( + shutdownHookRef = ShutdownHookManager.addShutdownHook( StreamingContext.SHUTDOWN_HOOK_PRIORITY)(stopOnShutdown) // Registering Streaming Metrics at the start of the StreamingContext assert(env.metricsSystem != null) @@ -691,7 +691,7 @@ class StreamingContext private[streaming] ( StreamingContext.setActiveContext(null) waiter.notifyStop() if (shutdownHookRef != null) { - Utils.removeShutdownHook(shutdownHookRef) + ShutdownHookManager.removeShutdownHook(shutdownHookRef) } logInfo("StreamingContext stopped successfully") } @@ -725,7 +725,7 @@ object StreamingContext extends Logging { */ private val ACTIVATION_LOCK = new Object() - private val SHUTDOWN_HOOK_PRIORITY = Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY + 1 + private val SHUTDOWN_HOOK_PRIORITY = ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY + 1 private val activeContext = new AtomicReference[StreamingContext](null) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index e19940d8d6642..6a8ddb37b29e8 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -112,7 +112,8 @@ private[spark] class ApplicationMaster( val fs = FileSystem.get(yarnConf) // This shutdown hook should run *after* the SparkContext is shut down. - Utils.addShutdownHook(Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY - 1) { () => + val priority = ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY - 1 + ShutdownHookManager.addShutdownHook(priority) { () => val maxAppAttempts = client.getMaxRegAttempts(sparkConf, yarnConf) val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts @@ -199,7 +200,7 @@ private[spark] class ApplicationMaster( final def finish(status: FinalApplicationStatus, code: Int, msg: String = null): Unit = { synchronized { if (!finished) { - val inShutdown = Utils.inShutdown() + val inShutdown = ShutdownHookManager.inShutdown() logInfo(s"Final app status: $status, exitCode: $code" + Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) exitCode = code From 7035d880a0cf06910c19b4afd49645124c620f14 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 12 Aug 2015 16:45:15 -0700 Subject: [PATCH 304/340] [SPARK-9894] [SQL] Json writer should handle MapData. https://issues.apache.org/jira/browse/SPARK-9894 Author: Yin Huai Closes #8137 from yhuai/jsonMapData. --- .../datasources/json/JacksonGenerator.scala | 10 +-- .../sources/JsonHadoopFsRelationSuite.scala | 78 +++++++++++++++++++ .../SimpleTextHadoopFsRelationSuite.scala | 30 ------- 3 files changed, 83 insertions(+), 35 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index 37c2b5a296c15..99ac7730bd1c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -107,12 +107,12 @@ private[sql] object JacksonGenerator { v.foreach(ty, (_, value) => valWriter(ty, value)) gen.writeEndArray() - case (MapType(kv, vv, _), v: Map[_, _]) => + case (MapType(kt, vt, _), v: MapData) => gen.writeStartObject() - v.foreach { p => - gen.writeFieldName(p._1.toString) - valWriter(vv, p._2) - } + v.foreach(kt, vt, { (k, v) => + gen.writeFieldName(k.toString) + valWriter(vt, v) + }) gen.writeEndObject() case (StructType(ty), v: InternalRow) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala new file mode 100644 index 0000000000000..ed6d512ab36fe --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.hadoop.fs.Path + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { + override val dataSourceName: String = "json" + + import sqlContext._ + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield s"""{"a":$i,"b":"val_$i"}""") + .saveAsTextFile(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + read.format(dataSourceName) + .option("dataSchema", dataSchemaWithPartition.json) + .load(file.getCanonicalPath)) + } + } + + test("SPARK-9894: save complex types to JSON") { + withTempDir { file => + file.delete() + + val schema = + new StructType() + .add("array", ArrayType(LongType)) + .add("map", MapType(StringType, new StructType().add("innerField", LongType))) + + val data = + Row(Seq(1L, 2L, 3L), Map("m1" -> Row(4L))) :: + Row(Seq(5L, 6L, 7L), Map("m2" -> Row(10L))) :: Nil + val df = createDataFrame(sparkContext.parallelize(data), schema) + + // Write the data out. + df.write.format(dataSourceName).save(file.getCanonicalPath) + + // Read it back and check the result. + checkAnswer( + read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), + df + ) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index 48c37a1fa1022..e8975e5f5cd08 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -50,33 +50,3 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { } } } - -class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { - override val dataSourceName: String = - classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource].getCanonicalName - - import sqlContext._ - - test("save()/load() - partitioned table - simple queries - partition columns in data") { - withTempDir { file => - val basePath = new Path(file.getCanonicalPath) - val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) - val qualifiedBasePath = fs.makeQualified(basePath) - - for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { - val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") - sparkContext - .parallelize(for (i <- 1 to 3) yield s"""{"a":$i,"b":"val_$i"}""") - .saveAsTextFile(partitionDir.toString) - } - - val dataSchemaWithPartition = - StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) - - checkQueries( - read.format(dataSourceName) - .option("dataSchema", dataSchemaWithPartition.json) - .load(file.getCanonicalPath)) - } - } -} From caa14d9dc9e2eb1102052b22445b63b0e004e3c7 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 16:53:47 -0700 Subject: [PATCH 305/340] [SPARK-9913] [MLLIB] LDAUtils should be private feynmanliang Author: Xiangrui Meng Closes #8142 from mengxr/SPARK-9913. --- .../main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala index f7e5ce1665fe6..a9ba7b60bad08 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala @@ -22,7 +22,7 @@ import breeze.numerics._ /** * Utility methods for LDA. */ -object LDAUtils { +private[clustering] object LDAUtils { /** * Log Sum Exp with overflow protection using the identity: * For any a: \log \sum_{n=1}^N \exp\{x_n\} = a + \log \sum_{n=1}^N \exp\{x_n - a\} From 6e409bc1357f49de2efdfc4226d074b943fb1153 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 12 Aug 2015 16:54:45 -0700 Subject: [PATCH 306/340] [SPARK-9909] [ML] [TRIVIAL] move weightCol to shared params As per the TODO move weightCol to Shared Params. Author: Holden Karau Closes #8144 from holdenk/SPARK-9909-move-weightCol-toSharedParams. --- .../ml/param/shared/SharedParamsCodeGen.scala | 4 +++- .../spark/ml/param/shared/sharedParams.scala | 15 +++++++++++++++ .../spark/ml/regression/IsotonicRegression.scala | 16 ++-------------- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 9e12f1856a940..8c16c6149b40d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -70,7 +70,9 @@ private[shared] object SharedParamsCodeGen { " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", isValid = "ParamValidators.inRange(0, 1)"), ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"), - ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization.")) + ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."), + ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " + + "all instance weights as 1.0.")) val code = genSharedParams(params) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index a17d4ea960a90..c26768953e3db 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -342,4 +342,19 @@ private[ml] trait HasStepSize extends Params { /** @group getParam */ final def getStepSize: Double = $(stepSize) } + +/** + * Trait for shared param weightCol. + */ +private[ml] trait HasWeightCol extends Params { + + /** + * Param for weight column name. If this is not set or empty, we treat all instance weights as 1.0.. + * @group param + */ + final val weightCol: Param[String] = new Param[String](this, "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.") + + /** @group getParam */ + final def getWeightCol: String = $(weightCol) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index f570590960a62..0f33bae30e622 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -21,7 +21,7 @@ import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasPredictionCol, HasWeightCol} import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression, IsotonicRegressionModel => MLlibIsotonicRegressionModel} @@ -35,19 +35,7 @@ import org.apache.spark.storage.StorageLevel * Params for isotonic regression. */ private[regression] trait IsotonicRegressionBase extends Params with HasFeaturesCol - with HasLabelCol with HasPredictionCol with Logging { - - /** - * Param for weight column name (default: none). - * @group param - */ - // TODO: Move weightCol to sharedParams. - final val weightCol: Param[String] = - new Param[String](this, "weightCol", - "weight column name. If this is not set or empty, we treat all instance weights as 1.0.") - - /** @group getParam */ - final def getWeightCol: String = $(weightCol) + with HasLabelCol with HasPredictionCol with HasWeightCol with Logging { /** * Param for whether the output sequence should be isotonic/increasing (true) or From e6aef55766d0e2a48e0f9cb6eda0e31a71b962f3 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 17:04:31 -0700 Subject: [PATCH 307/340] [SPARK-9912] [MLLIB] QRDecomposition should use QType and RType for type names instead of UType and VType hhbyyh Author: Xiangrui Meng Closes #8140 from mengxr/SPARK-9912. --- .../apache/spark/mllib/linalg/SingularValueDecomposition.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala index b416d50a5631e..cff5dbeee3e57 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala @@ -31,5 +31,5 @@ case class SingularValueDecomposition[UType, VType](U: UType, s: Vector, V: VTyp * Represents QR factors. */ @Experimental -case class QRDecomposition[UType, VType](Q: UType, R: VType) +case class QRDecomposition[QType, RType](Q: QType, R: RType) From fc1c7fd66e64ccea53b31cd2fbb98bc6d307329c Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 17:06:12 -0700 Subject: [PATCH 308/340] [SPARK-9915] [ML] stopWords should use StringArrayParam hhbyyh Author: Xiangrui Meng Closes #8141 from mengxr/SPARK-9915. --- .../org/apache/spark/ml/feature/StopWordsRemover.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 3cc41424460f2..5d77ea08db657 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -19,12 +19,12 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.param.{ParamMap, BooleanParam, Param} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.types.{StringType, StructField, ArrayType, StructType} import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType} /** * stop words list @@ -100,7 +100,7 @@ class StopWordsRemover(override val uid: String) * the stop words set to be filtered out * @group param */ - val stopWords: Param[Array[String]] = new Param(this, "stopWords", "stop words") + val stopWords: StringArrayParam = new StringArrayParam(this, "stopWords", "stop words") /** @group setParam */ def setStopWords(value: Array[String]): this.type = set(stopWords, value) From 660e6dcff8125b83cc73dbe00c90cbe58744bc66 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 12 Aug 2015 17:07:29 -0700 Subject: [PATCH 309/340] [SPARK-9449] [SQL] Include MetastoreRelation's inputFiles Author: Michael Armbrust Closes #8119 from marmbrus/metastoreInputFiles. --- .../org/apache/spark/sql/DataFrame.scala | 10 ++++--- .../spark/sql/execution/FileRelation.scala | 28 +++++++++++++++++++ .../apache/spark/sql/sources/interfaces.scala | 6 ++-- .../org/apache/spark/sql/DataFrameSuite.scala | 26 +++++++++-------- .../spark/sql/hive/HiveMetastoreCatalog.scala | 16 +++++++++-- 5 files changed, 66 insertions(+), 20 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 27b994f1f0caf..c466d9e6cb349 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -34,10 +34,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD, SQLExecution} +import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation @@ -1560,8 +1560,10 @@ class DataFrame private[sql]( */ def inputFiles: Array[String] = { val files: Seq[String] = logicalPlan.collect { - case LogicalRelation(fsBasedRelation: HadoopFsRelation) => - fsBasedRelation.paths.toSeq + case LogicalRelation(fsBasedRelation: FileRelation) => + fsBasedRelation.inputFiles + case fr: FileRelation => + fr.inputFiles }.flatten files.toSet.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala new file mode 100644 index 0000000000000..7a2a9eed5807d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +/** + * An interface for relations that are backed by files. When a class implements this interface, + * the list of paths that it returns will be returned to a user who calls `inputPaths` on any + * DataFrame that queries this relation. + */ +private[sql] trait FileRelation { + /** Returns the list of files that will be read when scanning this relation. */ + def inputFiles: Array[String] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 2f8417a48d32e..b3b326fe612c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -31,7 +31,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.execution.RDDConversions +import org.apache.spark.sql.execution.{FileRelation, RDDConversions} import org.apache.spark.sql.execution.datasources.{PartitioningUtils, PartitionSpec, Partition} import org.apache.spark.sql.types.StructType import org.apache.spark.sql._ @@ -406,7 +406,7 @@ abstract class OutputWriter { */ @Experimental abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[PartitionSpec]) - extends BaseRelation with Logging { + extends BaseRelation with FileRelation with Logging { override def toString: String = getClass.getSimpleName + paths.mkString("[", ",", "]") @@ -516,6 +516,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio */ def paths: Array[String] + override def inputFiles: Array[String] = cachedLeafStatuses().map(_.getPath.toString).toArray + /** * Partition columns. Can be either defined by [[userDefinedPartitionColumns]] or automatically * discovered. Note that they should always be nullable. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index adbd95197d7ca..2feec29955bc8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -485,21 +485,23 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("inputFiles") { - val fakeRelation1 = new ParquetRelation(Array("/my/path", "/my/other/path"), - Some(testData.schema), None, Map.empty)(sqlContext) - val df1 = DataFrame(sqlContext, LogicalRelation(fakeRelation1)) - assert(df1.inputFiles.toSet == fakeRelation1.paths.toSet) + withTempDir { dir => + val df = Seq((1, 22)).toDF("a", "b") - val fakeRelation2 = new JSONRelation( - None, 1, Some(testData.schema), None, None, Array("/json/path"))(sqlContext) - val df2 = DataFrame(sqlContext, LogicalRelation(fakeRelation2)) - assert(df2.inputFiles.toSet == fakeRelation2.paths.toSet) + val parquetDir = new File(dir, "parquet").getCanonicalPath + df.write.parquet(parquetDir) + val parquetDF = sqlContext.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) - val unionDF = df1.unionAll(df2) - assert(unionDF.inputFiles.toSet == fakeRelation1.paths.toSet ++ fakeRelation2.paths) + val jsonDir = new File(dir, "json").getCanonicalPath + df.write.json(jsonDir) + val jsonDF = sqlContext.read.json(jsonDir) + assert(parquetDF.inputFiles.nonEmpty) - val filtered = df1.filter("false").unionAll(df2.intersect(df2)) - assert(filtered.inputFiles.toSet == fakeRelation1.paths.toSet ++ fakeRelation2.paths) + val unioned = jsonDF.unionAll(parquetDF).inputFiles.sorted + val allFiles = (jsonDF.inputFiles ++ parquetDF.inputFiles).toSet.toArray.sorted + assert(unioned === allFiles) + } } ignore("show") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index ac9aaed19d566..5e5497837a393 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} -import org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.execution.{FileRelation, datasources} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation @@ -739,7 +739,7 @@ private[hive] case class MetastoreRelation (databaseName: String, tableName: String, alias: Option[String]) (val table: HiveTable) (@transient sqlContext: SQLContext) - extends LeafNode with MultiInstanceRelation { + extends LeafNode with MultiInstanceRelation with FileRelation { override def equals(other: Any): Boolean = other match { case relation: MetastoreRelation => @@ -888,6 +888,18 @@ private[hive] case class MetastoreRelation /** An attribute map for determining the ordinal for non-partition columns. */ val columnOrdinals = AttributeMap(attributes.zipWithIndex) + override def inputFiles: Array[String] = { + val partLocations = table.getPartitions(Nil).map(_.storage.location).toArray + if (partLocations.nonEmpty) { + partLocations + } else { + Array( + table.location.getOrElse( + sys.error(s"Could not get the location of ${table.qualifiedName}."))) + } + } + + override def newInstance(): MetastoreRelation = { MetastoreRelation(databaseName, tableName, alias)(table)(sqlContext) } From 8ce60963cb0928058ef7b6e29ff94eb69d1143af Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Wed, 12 Aug 2015 17:44:16 -0700 Subject: [PATCH 310/340] =?UTF-8?q?[SPARK-9780]=20[STREAMING]=20[KAFKA]=20?= =?UTF-8?q?prevent=20NPE=20if=20KafkaRDD=20instantiation=20=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …fails Author: cody koeninger Closes #8133 from koeninger/SPARK-9780 and squashes the following commits: 406259d [cody koeninger] [SPARK-9780][Streaming][Kafka] prevent NPE if KafkaRDD instantiation fails --- .../scala/org/apache/spark/streaming/kafka/KafkaRDD.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index 1a9d78c0d4f59..ea5f842c6cafe 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -197,7 +197,11 @@ class KafkaRDD[ .dropWhile(_.offset < requestOffset) } - override def close(): Unit = consumer.close() + override def close(): Unit = { + if (consumer != null) { + consumer.close() + } + } override def getNext(): R = { if (iter == null || !iter.hasNext) { From 0d1d146c220f0d47d0e62b368d5b94d3bd9dd197 Mon Sep 17 00:00:00 2001 From: Rohit Agarwal Date: Wed, 12 Aug 2015 17:48:43 -0700 Subject: [PATCH 311/340] [SPARK-9724] [WEB UI] Avoid unnecessary redirects in the Spark Web UI. Author: Rohit Agarwal Closes #8014 from mindprince/SPARK-9724 and squashes the following commits: a7af5ff [Rohit Agarwal] [SPARK-9724] [WEB UI] Inline attachPrefix and attachPrefixForRedirect. Fix logic of attachPrefix 8a977cd [Rohit Agarwal] [SPARK-9724] [WEB UI] Address review comments: Remove unneeded code, update scaladoc. b257844 [Rohit Agarwal] [SPARK-9724] [WEB UI] Avoid unnecessary redirects in the Spark Web UI. --- .../main/scala/org/apache/spark/ui/JettyUtils.scala | 13 ++++++------- .../main/scala/org/apache/spark/ui/SparkUI.scala | 4 ++-- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index c8356467fab87..779c0ba083596 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -106,7 +106,11 @@ private[spark] object JettyUtils extends Logging { path: String, servlet: HttpServlet, basePath: String): ServletContextHandler = { - val prefixedPath = attachPrefix(basePath, path) + val prefixedPath = if (basePath == "" && path == "/") { + path + } else { + (basePath + path).stripSuffix("/") + } val contextHandler = new ServletContextHandler val holder = new ServletHolder(servlet) contextHandler.setContextPath(prefixedPath) @@ -121,7 +125,7 @@ private[spark] object JettyUtils extends Logging { beforeRedirect: HttpServletRequest => Unit = x => (), basePath: String = "", httpMethods: Set[String] = Set("GET")): ServletContextHandler = { - val prefixedDestPath = attachPrefix(basePath, destPath) + val prefixedDestPath = basePath + destPath val servlet = new HttpServlet { override def doGet(request: HttpServletRequest, response: HttpServletResponse): Unit = { if (httpMethods.contains("GET")) { @@ -246,11 +250,6 @@ private[spark] object JettyUtils extends Logging { val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName) ServerInfo(server, boundPort, collection) } - - /** Attach a prefix to the given path, but avoid returning an empty path */ - private def attachPrefix(basePath: String, relativePath: String): String = { - if (basePath == "") relativePath else (basePath + relativePath).stripSuffix("/") - } } private[spark] case class ServerInfo( diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 3788916cf39bb..d8b90568b7b9a 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -64,11 +64,11 @@ private[spark] class SparkUI private ( attachTab(new EnvironmentTab(this)) attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) - attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath)) + attachHandler(createRedirectHandler("/", "/jobs/", basePath = basePath)) attachHandler(ApiRootResource.getServletHandler(this)) // This should be POST only, but, the YARN AM proxy won't proxy POSTs attachHandler(createRedirectHandler( - "/stages/stage/kill", "/stages", stagesTab.handleKillRequest, + "/stages/stage/kill", "/stages/", stagesTab.handleKillRequest, httpMethods = Set("GET", "POST"))) } initialize() From f4bc01f1f33a93e6affe5c8a3e33ffbd92d03f38 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Wed, 12 Aug 2015 18:33:27 -0700 Subject: [PATCH 312/340] [SPARK-9855] [SPARKR] Add expression functions into SparkR whose params are simple I added lots of expression functions for SparkR. This PR includes only functions whose params are only `(Column)` or `(Column, Column)`. And I think we need to improve how to test those functions. However, it would be better to work on another issue. ## Diff Summary - Add lots of functions in `functions.R` and their generic in `generic.R` - Add aliases for `ceiling` and `sign` - Move expression functions from `column.R` to `functions.R` - Modify `rdname` from `column` to `functions` I haven't supported `not` function, because the name has a collesion with `testthat` package. I didn't think of the way to define it. ## New Supported Functions ``` approxCountDistinct ascii base64 bin bitwiseNOT ceil (alias: ceiling) crc32 dayofmonth dayofyear explode factorial hex hour initcap isNaN last_day length log2 ltrim md5 minute month negate quarter reverse round rtrim second sha1 signum (alias: sign) size soundex to_date trim unbase64 unhex weekofyear year datediff levenshtein months_between nanvl pmod ``` ## JIRA [[SPARK-9855] Add expression functions into SparkR whose params are simple - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-9855) Author: Yu ISHIKAWA Closes #8123 from yu-iskw/SPARK-9855. --- R/pkg/DESCRIPTION | 1 + R/pkg/R/column.R | 81 -------------- R/pkg/R/functions.R | 123 ++++++++++++++++++++ R/pkg/R/generics.R | 185 +++++++++++++++++++++++++++++-- R/pkg/inst/tests/test_sparkSQL.R | 21 ++-- 5 files changed, 309 insertions(+), 102 deletions(-) create mode 100644 R/pkg/R/functions.R diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 4949d86d20c91..83e64897216b1 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -29,6 +29,7 @@ Collate: 'client.R' 'context.R' 'deserialize.R' + 'functions.R' 'mllib.R' 'serialize.R' 'sparkR.R' diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index eeaf9f193b728..328f595d0805f 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -60,12 +60,6 @@ operators <- list( ) column_functions1 <- c("asc", "desc", "isNull", "isNotNull") column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains") -functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt", - "first", "last", "lower", "upper", "sumDistinct", - "acos", "asin", "atan", "cbrt", "ceiling", "cos", "cosh", "exp", - "expm1", "floor", "log", "log10", "log1p", "rint", "sign", - "sin", "sinh", "tan", "tanh", "toDegrees", "toRadians") -binary_mathfunctions <- c("atan2", "hypot") createOperator <- function(op) { setMethod(op, @@ -111,33 +105,6 @@ createColumnFunction2 <- function(name) { }) } -createStaticFunction <- function(name) { - setMethod(name, - signature(x = "Column"), - function(x) { - if (name == "ceiling") { - name <- "ceil" - } - if (name == "sign") { - name <- "signum" - } - jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc) - column(jc) - }) -} - -createBinaryMathfunctions <- function(name) { - setMethod(name, - signature(y = "Column"), - function(y, x) { - if (class(x) == "Column") { - x <- x@jc - } - jc <- callJStatic("org.apache.spark.sql.functions", name, y@jc, x) - column(jc) - }) -} - createMethods <- function() { for (op in names(operators)) { createOperator(op) @@ -148,12 +115,6 @@ createMethods <- function() { for (name in column_functions2) { createColumnFunction2(name) } - for (x in functions) { - createStaticFunction(x) - } - for (name in binary_mathfunctions) { - createBinaryMathfunctions(name) - } } createMethods() @@ -242,45 +203,3 @@ setMethod("%in%", jc <- callJMethod(x@jc, "in", table) return(column(jc)) }) - -#' Approx Count Distinct -#' -#' @rdname column -#' @return the approximate number of distinct items in a group. -setMethod("approxCountDistinct", - signature(x = "Column"), - function(x, rsd = 0.95) { - jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) - column(jc) - }) - -#' Count Distinct -#' -#' @rdname column -#' @return the number of distinct items in a group. -setMethod("countDistinct", - signature(x = "Column"), - function(x, ...) { - jcol <- lapply(list(...), function (x) { - x@jc - }) - jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, - listToSeq(jcol)) - column(jc) - }) - -#' @rdname column -#' @aliases countDistinct -setMethod("n_distinct", - signature(x = "Column"), - function(x, ...) { - countDistinct(x, ...) - }) - -#' @rdname column -#' @aliases count -setMethod("n", - signature(x = "Column"), - function(x) { - count(x) - }) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R new file mode 100644 index 0000000000000..a15d2d5da534e --- /dev/null +++ b/R/pkg/R/functions.R @@ -0,0 +1,123 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +#' @include generics.R column.R +NULL + +#' @title S4 expression functions for DataFrame column(s) +#' @description These are expression functions on DataFrame columns + +functions1 <- c( + "abs", "acos", "approxCountDistinct", "ascii", "asin", "atan", + "avg", "base64", "bin", "bitwiseNOT", "cbrt", "ceil", "cos", "cosh", "count", + "crc32", "dayofmonth", "dayofyear", "exp", "explode", "expm1", "factorial", + "first", "floor", "hex", "hour", "initcap", "isNaN", "last", "last_day", + "length", "log", "log10", "log1p", "log2", "lower", "ltrim", "max", "md5", + "mean", "min", "minute", "month", "negate", "quarter", "reverse", + "rint", "round", "rtrim", "second", "sha1", "signum", "sin", "sinh", "size", + "soundex", "sqrt", "sum", "sumDistinct", "tan", "tanh", "toDegrees", + "toRadians", "to_date", "trim", "unbase64", "unhex", "upper", "weekofyear", + "year") +functions2 <- c( + "atan2", "datediff", "hypot", "levenshtein", "months_between", "nanvl", "pmod") + +createFunction1 <- function(name) { + setMethod(name, + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc) + column(jc) + }) +} + +createFunction2 <- function(name) { + setMethod(name, + signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", name, y@jc, x) + column(jc) + }) +} + +createFunctions <- function() { + for (name in functions1) { + createFunction1(name) + } + for (name in functions2) { + createFunction2(name) + } +} + +createFunctions() + +#' Approx Count Distinct +#' +#' @rdname functions +#' @return the approximate number of distinct items in a group. +setMethod("approxCountDistinct", + signature(x = "Column"), + function(x, rsd = 0.95) { + jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) + column(jc) + }) + +#' Count Distinct +#' +#' @rdname functions +#' @return the number of distinct items in a group. +setMethod("countDistinct", + signature(x = "Column"), + function(x, ...) { + jcol <- lapply(list(...), function (x) { + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, + listToSeq(jcol)) + column(jc) + }) + +#' @rdname functions +#' @aliases ceil +setMethod("ceiling", + signature(x = "Column"), + function(x) { + ceil(x) + }) + +#' @rdname functions +#' @aliases signum +setMethod("sign", signature(x = "Column"), + function(x) { + signum(x) + }) + +#' @rdname functions +#' @aliases countDistinct +setMethod("n_distinct", signature(x = "Column"), + function(x, ...) { + countDistinct(x, ...) + }) + +#' @rdname functions +#' @aliases count +setMethod("n", signature(x = "Column"), + function(x) { + count(x) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 379a78b1d833e..f11e7fcb6a07c 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -575,10 +575,6 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @export setGeneric("asc", function(x) { standardGeneric("asc") }) -#' @rdname column -#' @export -setGeneric("avg", function(x, ...) { standardGeneric("avg") }) - #' @rdname column #' @export setGeneric("between", function(x, bounds) { standardGeneric("between") }) @@ -587,13 +583,10 @@ setGeneric("between", function(x, bounds) { standardGeneric("between") }) #' @export setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) -#' @rdname column -#' @export -setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) - #' @rdname column #' @export setGeneric("contains", function(x, ...) { standardGeneric("contains") }) + #' @rdname column #' @export setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) @@ -658,22 +651,190 @@ setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) #' @export setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") }) -#' @rdname column + +###################### Expression Function Methods ########################## + +#' @rdname functions +#' @export +setGeneric("ascii", function(x) { standardGeneric("ascii") }) + +#' @rdname functions +#' @export +setGeneric("avg", function(x, ...) { standardGeneric("avg") }) + +#' @rdname functions +#' @export +setGeneric("base64", function(x) { standardGeneric("base64") }) + +#' @rdname functions +#' @export +setGeneric("bin", function(x) { standardGeneric("bin") }) + +#' @rdname functions +#' @export +setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") }) + +#' @rdname functions +#' @export +setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) + +#' @rdname functions +#' @export +setGeneric("ceil", function(x) { standardGeneric("ceil") }) + +#' @rdname functions +#' @export +setGeneric("crc32", function(x) { standardGeneric("crc32") }) + +#' @rdname functions +#' @export +setGeneric("datediff", function(y, x) { standardGeneric("datediff") }) + +#' @rdname functions +#' @export +setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) + +#' @rdname functions +#' @export +setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) + +#' @rdname functions +#' @export +setGeneric("explode", function(x) { standardGeneric("explode") }) + +#' @rdname functions +#' @export +setGeneric("hex", function(x) { standardGeneric("hex") }) + +#' @rdname functions +#' @export +setGeneric("hour", function(x) { standardGeneric("hour") }) + +#' @rdname functions +#' @export +setGeneric("initcap", function(x) { standardGeneric("initcap") }) + +#' @rdname functions +#' @export +setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) + +#' @rdname functions +#' @export +setGeneric("last_day", function(x) { standardGeneric("last_day") }) + +#' @rdname functions +#' @export +setGeneric("levenshtein", function(y, x) { standardGeneric("levenshtein") }) + +#' @rdname functions +#' @export +setGeneric("lower", function(x) { standardGeneric("lower") }) + +#' @rdname functions +#' @export +setGeneric("ltrim", function(x) { standardGeneric("ltrim") }) + +#' @rdname functions +#' @export +setGeneric("md5", function(x) { standardGeneric("md5") }) + +#' @rdname functions +#' @export +setGeneric("minute", function(x) { standardGeneric("minute") }) + +#' @rdname functions +#' @export +setGeneric("month", function(x) { standardGeneric("month") }) + +#' @rdname functions +#' @export +setGeneric("months_between", function(y, x) { standardGeneric("months_between") }) + +#' @rdname functions +#' @export +setGeneric("nanvl", function(y, x) { standardGeneric("nanvl") }) + +#' @rdname functions +#' @export +setGeneric("negate", function(x) { standardGeneric("negate") }) + +#' @rdname functions +#' @export +setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) + +#' @rdname functions +#' @export +setGeneric("quarter", function(x) { standardGeneric("quarter") }) + +#' @rdname functions +#' @export +setGeneric("reverse", function(x) { standardGeneric("reverse") }) + +#' @rdname functions +#' @export +setGeneric("rtrim", function(x) { standardGeneric("rtrim") }) + +#' @rdname functions +#' @export +setGeneric("second", function(x) { standardGeneric("second") }) + +#' @rdname functions +#' @export +setGeneric("sha1", function(x) { standardGeneric("sha1") }) + +#' @rdname functions +#' @export +setGeneric("signum", function(x) { standardGeneric("signum") }) + +#' @rdname functions +#' @export +setGeneric("size", function(x) { standardGeneric("size") }) + +#' @rdname functions +#' @export +setGeneric("soundex", function(x) { standardGeneric("soundex") }) + +#' @rdname functions #' @export setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) -#' @rdname column +#' @rdname functions #' @export setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") }) -#' @rdname column +#' @rdname functions #' @export setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) -#' @rdname column +#' @rdname functions +#' @export +setGeneric("to_date", function(x) { standardGeneric("to_date") }) + +#' @rdname functions +#' @export +setGeneric("trim", function(x) { standardGeneric("trim") }) + +#' @rdname functions +#' @export +setGeneric("unbase64", function(x) { standardGeneric("unbase64") }) + +#' @rdname functions +#' @export +setGeneric("unhex", function(x) { standardGeneric("unhex") }) + +#' @rdname functions #' @export setGeneric("upper", function(x) { standardGeneric("upper") }) +#' @rdname functions +#' @export +setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) + +#' @rdname functions +#' @export +setGeneric("year", function(x) { standardGeneric("year") }) + + #' @rdname glm #' @export setGeneric("glm") diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 7377fc8f1ca9c..e6d3b21ff825b 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -640,15 +640,18 @@ test_that("column operators", { test_that("column functions", { c <- SparkR:::col("a") - c2 <- min(c) + max(c) + sum(c) + avg(c) + count(c) + abs(c) + sqrt(c) - c3 <- lower(c) + upper(c) + first(c) + last(c) - c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string") - c5 <- n(c) + n_distinct(c) - c5 <- acos(c) + asin(c) + atan(c) + cbrt(c) - c6 <- ceiling(c) + cos(c) + cosh(c) + exp(c) + expm1(c) - c7 <- floor(c) + log(c) + log10(c) + log1p(c) + rint(c) - c8 <- sign(c) + sin(c) + sinh(c) + tan(c) + tanh(c) - c9 <- toDegrees(c) + toRadians(c) + c1 <- abs(c) + acos(c) + approxCountDistinct(c) + ascii(c) + asin(c) + atan(c) + c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c) + c3 <- cosh(c) + count(c) + crc32(c) + dayofmonth(c) + dayofyear(c) + exp(c) + c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c) + c5 <- hour(c) + initcap(c) + isNaN(c) + last(c) + last_day(c) + length(c) + c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c) + c7 <- mean(c) + min(c) + minute(c) + month(c) + negate(c) + quarter(c) + c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + second(c) + sha1(c) + c9 <- signum(c) + sin(c) + sinh(c) + size(c) + soundex(c) + sqrt(c) + sum(c) + c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c) + c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) + weekofyear(c) + c12 <- year(c) df <- jsonFile(sqlContext, jsonPath) df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20))) From 7b13ed27c1296cf76d0946e400f3449c335c8471 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 12 Aug 2015 18:52:11 -0700 Subject: [PATCH 313/340] [SPARK-9870] Disable driver UI and Master REST server in SparkSubmitSuite I think that we should pass additional configuration flags to disable the driver UI and Master REST server in SparkSubmitSuite and HiveSparkSubmitSuite. This might cut down on port-contention-related flakiness in Jenkins. Author: Josh Rosen Closes #8124 from JoshRosen/disable-ui-in-sparksubmitsuite. --- .../org/apache/spark/deploy/SparkSubmitSuite.scala | 7 +++++++ .../apache/spark/sql/hive/HiveSparkSubmitSuite.scala | 10 +++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 2456c5d0d49b0..1110ca6051a40 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -324,6 +324,8 @@ class SparkSubmitSuite "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", "--master", "local", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", unusedJar.toString) runSparkSubmit(args) } @@ -337,6 +339,8 @@ class SparkSubmitSuite "--class", JarCreationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", "--jars", jarsString, unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") runSparkSubmit(args) @@ -355,6 +359,7 @@ class SparkSubmitSuite "--packages", Seq(main, dep).mkString(","), "--repositories", repo, "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", unusedJar.toString, "my.great.lib.MyLib", "my.great.dep.MyLib") runSparkSubmit(args) @@ -500,6 +505,8 @@ class SparkSubmitSuite "--master", "local", "--conf", "spark.driver.extraClassPath=" + systemJar, "--conf", "spark.driver.userClassPathFirst=true", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", userJar.toString) runSparkSubmit(args) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index b8d41065d3f02..1e1972d1ac353 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -57,6 +57,8 @@ class HiveSparkSubmitSuite "--class", SparkSubmitClassLoaderTest.getClass.getName.stripSuffix("$"), "--name", "SparkSubmitClassLoaderTest", "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", "--jars", jarsString, unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") runSparkSubmit(args) @@ -68,6 +70,8 @@ class HiveSparkSubmitSuite "--class", SparkSQLConfTest.getClass.getName.stripSuffix("$"), "--name", "SparkSQLConfTest", "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", unusedJar.toString) runSparkSubmit(args) } @@ -79,7 +83,11 @@ class HiveSparkSubmitSuite // the HiveContext code mistakenly overrides the class loader that contains user classes. // For more detail, see sql/hive/src/test/resources/regression-test-SPARK-8489/*scala. val testJar = "sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar" - val args = Seq("--class", "Main", testJar) + val args = Seq( + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--class", "Main", + testJar) runSparkSubmit(args) } From 7c35746c916cf0019367850e75a080d7e739dba0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 12 Aug 2015 20:02:55 -0700 Subject: [PATCH 314/340] [SPARK-9827] [SQL] fix fd leak in UnsafeRowSerializer Currently, UnsafeRowSerializer does not close the InputStream, will cause fd leak if the InputStream has an open fd in it. TODO: the fd could still be leaked, if any items in the stream is not consumed. Currently it replies on GC to close the fd in this case. cc JoshRosen Author: Davies Liu Closes #8116 from davies/fd_leak. --- .../sql/execution/UnsafeRowSerializer.scala | 2 ++ .../execution/UnsafeRowSerializerSuite.scala | 31 +++++++++++++++++-- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 3860c4bba9a99..5c18558f9bde7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -108,6 +108,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def asKeyValueIterator: Iterator[(Int, UnsafeRow)] = { new Iterator[(Int, UnsafeRow)] { private[this] var rowSize: Int = dIn.readInt() + if (rowSize == EOF) dIn.close() override def hasNext: Boolean = rowSize != EOF @@ -119,6 +120,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize) rowSize = dIn.readInt() // read the next row's size if (rowSize == EOF) { // We are returning the last row in this stream + dIn.close() val _rowTuple = rowTuple // Null these out so that the byte array can be garbage collected once the entire // iterator has been consumed diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 40b47ae18d648..bd02c73a26ace 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row @@ -25,6 +25,18 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ + +/** + * used to test close InputStream in UnsafeRowSerializer + */ +class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStream(buf) { + var closed: Boolean = false + override def close(): Unit = { + closed = true + super.close() + } +} + class UnsafeRowSerializerSuite extends SparkFunSuite { private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { @@ -52,8 +64,8 @@ class UnsafeRowSerializerSuite extends SparkFunSuite { serializerStream.writeValue(unsafeRow) } serializerStream.close() - val deserializerIter = serializer.deserializeStream( - new ByteArrayInputStream(baos.toByteArray)).asKeyValueIterator + val input = new ClosableByteArrayInputStream(baos.toByteArray) + val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator for (expectedRow <- unsafeRows) { val actualRow = deserializerIter.next().asInstanceOf[(Integer, UnsafeRow)]._2 assert(expectedRow.getSizeInBytes === actualRow.getSizeInBytes) @@ -61,5 +73,18 @@ class UnsafeRowSerializerSuite extends SparkFunSuite { assert(expectedRow.getInt(1) === actualRow.getInt(1)) } assert(!deserializerIter.hasNext) + assert(input.closed) + } + + test("close empty input stream") { + val baos = new ByteArrayOutputStream() + val dout = new DataOutputStream(baos) + dout.writeInt(-1) // EOF + dout.flush() + val input = new ClosableByteArrayInputStream(baos.toByteArray) + val serializer = new UnsafeRowSerializer(numFields = 2).newInstance() + val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator + assert(!deserializerIter.hasNext) + assert(input.closed) } } From 4413d0855aaba5cb00f737dc6934a0b92d9bc05d Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 12 Aug 2015 20:03:55 -0700 Subject: [PATCH 315/340] [SPARK-9908] [SQL] When spark.sql.tungsten.enabled is false, broadcast join does not work https://issues.apache.org/jira/browse/SPARK-9908 Author: Yin Huai Closes #8149 from yhuai/SPARK-9908. --- .../apache/spark/sql/execution/joins/HashedRelation.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 076afe6e4e960..bb333b4d5ed18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -66,7 +66,8 @@ private[joins] final class GeneralHashedRelation( private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]]) extends HashedRelation with Externalizable { - private def this() = this(null) // Needed for serialization + // Needed for serialization (it is public to make Java serialization work) + def this() = this(null) override def get(key: InternalRow): Seq[InternalRow] = hashTable.get(key) @@ -88,7 +89,8 @@ private[joins] final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow]) extends HashedRelation with Externalizable { - private def this() = this(null) // Needed for serialization + // Needed for serialization (it is public to make Java serialization work) + def this() = this(null) override def get(key: InternalRow): Seq[InternalRow] = { val v = hashTable.get(key) From d2d5e7fe2df582e1c866334b3014d7cb351f5b70 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 12 Aug 2015 20:43:36 -0700 Subject: [PATCH 316/340] [SPARK-9704] [ML] Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs Made ProbabilisticClassifier, Identifiable, VectorUDT public. All are annotated as DeveloperApi. CC: mengxr EronWright Author: Joseph K. Bradley Closes #8004 from jkbradley/ml-api-public-items and squashes the following commits: 7ebefda [Joseph K. Bradley] update per code review 7ff0768 [Joseph K. Bradley] attepting to add mima fix 756d84c [Joseph K. Bradley] VectorUDT annotated as AlphaComponent ae7767d [Joseph K. Bradley] added another warning 94fd553 [Joseph K. Bradley] Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs --- .../classification/ProbabilisticClassifier.scala | 4 ++-- .../org/apache/spark/ml/util/Identifiable.scala | 16 ++++++++++++++-- .../org/apache/spark/mllib/linalg/Vectors.scala | 10 ++++------ project/MimaExcludes.scala | 4 ++++ 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 1e50a895a9a05..fdd1851ae5508 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -50,7 +50,7 @@ private[classification] trait ProbabilisticClassifierParams * @tparam M Concrete Model type */ @DeveloperApi -private[spark] abstract class ProbabilisticClassifier[ +abstract class ProbabilisticClassifier[ FeaturesType, E <: ProbabilisticClassifier[FeaturesType, E, M], M <: ProbabilisticClassificationModel[FeaturesType, M]] @@ -74,7 +74,7 @@ private[spark] abstract class ProbabilisticClassifier[ * @tparam M Concrete Model type */ @DeveloperApi -private[spark] abstract class ProbabilisticClassificationModel[ +abstract class ProbabilisticClassificationModel[ FeaturesType, M <: ProbabilisticClassificationModel[FeaturesType, M]] extends ClassificationModel[FeaturesType, M] with ProbabilisticClassifierParams { diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala index ddd34a54503a6..bd213e7362e94 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala @@ -19,11 +19,19 @@ package org.apache.spark.ml.util import java.util.UUID +import org.apache.spark.annotation.DeveloperApi + /** + * :: DeveloperApi :: + * * Trait for an object with an immutable unique ID that identifies itself and its derivatives. + * + * WARNING: There have not yet been final discussions on this API, so it may be broken in future + * releases. */ -private[spark] trait Identifiable { +@DeveloperApi +trait Identifiable { /** * An immutable unique ID for the object and its derivatives. @@ -33,7 +41,11 @@ private[spark] trait Identifiable { override def toString: String = uid } -private[spark] object Identifiable { +/** + * :: DeveloperApi :: + */ +@DeveloperApi +object Identifiable { /** * Returns a random UID that concatenates the given prefix, "_", and 12 random hex chars. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 86c461fa91633..df15d985c814c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -26,7 +26,7 @@ import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.apache.spark.SparkException -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.AlphaComponent import org.apache.spark.mllib.util.NumericParser import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow @@ -159,15 +159,13 @@ sealed trait Vector extends Serializable { } /** - * :: DeveloperApi :: + * :: AlphaComponent :: * * User-defined type for [[Vector]] which allows easy interaction with SQL * via [[org.apache.spark.sql.DataFrame]]. - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. */ -@DeveloperApi -private[spark] class VectorUDT extends UserDefinedType[Vector] { +@AlphaComponent +class VectorUDT extends UserDefinedType[Vector] { override def sqlType: StructType = { // type: 0 = sparse, 1 = dense diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 90261ca3d61aa..784f83c10e023 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -178,6 +178,10 @@ object MimaExcludes { // SPARK-4751 Dynamic allocation for standalone mode ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.SparkContext.supportDynamicAllocation") + ) ++ Seq( + // SPARK-9704 Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.mllib.linalg.VectorUDT.serialize") ) case v if v.startsWith("1.4") => From d7053bea985679c514b3add029631ea23e1730ce Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 20:44:40 -0700 Subject: [PATCH 317/340] [SPARK-9903] [MLLIB] skip local processing in PrefixSpan if there are no small prefixes There exists a chance that the prefixes keep growing to the maximum pattern length. Then the final local processing step becomes unnecessary. feynmanliang Author: Xiangrui Meng Closes #8136 from mengxr/SPARK-9903. --- .../apache/spark/mllib/fpm/PrefixSpan.scala | 37 +++++++++++-------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index ad6715b52f337..dc4ae1d0b69ed 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -282,25 +282,30 @@ object PrefixSpan extends Logging { largePrefixes = newLargePrefixes } - // Switch to local processing. - val bcSmallPrefixes = sc.broadcast(smallPrefixes) - val distributedFreqPattern = postfixes.flatMap { postfix => - bcSmallPrefixes.value.values.map { prefix => - (prefix.id, postfix.project(prefix).compressed) - }.filter(_._2.nonEmpty) - }.groupByKey().flatMap { case (id, projPostfixes) => - val prefix = bcSmallPrefixes.value(id) - val localPrefixSpan = new LocalPrefixSpan(minCount, maxPatternLength - prefix.length) - // TODO: We collect projected postfixes into memory. We should also compare the performance - // TODO: of keeping them on shuffle files. - localPrefixSpan.run(projPostfixes.toArray).map { case (pattern, count) => - (prefix.items ++ pattern, count) + var freqPatterns = sc.parallelize(localFreqPatterns, 1) + + val numSmallPrefixes = smallPrefixes.size + logInfo(s"number of small prefixes for local processing: $numSmallPrefixes") + if (numSmallPrefixes > 0) { + // Switch to local processing. + val bcSmallPrefixes = sc.broadcast(smallPrefixes) + val distributedFreqPattern = postfixes.flatMap { postfix => + bcSmallPrefixes.value.values.map { prefix => + (prefix.id, postfix.project(prefix).compressed) + }.filter(_._2.nonEmpty) + }.groupByKey().flatMap { case (id, projPostfixes) => + val prefix = bcSmallPrefixes.value(id) + val localPrefixSpan = new LocalPrefixSpan(minCount, maxPatternLength - prefix.length) + // TODO: We collect projected postfixes into memory. We should also compare the performance + // TODO: of keeping them on shuffle files. + localPrefixSpan.run(projPostfixes.toArray).map { case (pattern, count) => + (prefix.items ++ pattern, count) + } } + // Union local frequent patterns and distributed ones. + freqPatterns = freqPatterns ++ distributedFreqPattern } - // Union local frequent patterns and distributed ones. - val freqPatterns = (sc.parallelize(localFreqPatterns, 1) ++ distributedFreqPattern) - .persist(StorageLevel.MEMORY_AND_DISK) freqPatterns } From 2fb4901b71cee65d40a43e61e3f4411c30cdefc3 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 12 Aug 2015 20:59:38 -0700 Subject: [PATCH 318/340] [SPARK-9916] [BUILD] [SPARKR] removed left-over sparkr.zip copy/create commands from codebase sparkr.zip is now built by SparkSubmit on a need-to-build basis. cc shivaram Author: Burak Yavuz Closes #8147 from brkyvz/make-dist-fix. --- R/install-dev.bat | 5 ----- make-distribution.sh | 1 - 2 files changed, 6 deletions(-) diff --git a/R/install-dev.bat b/R/install-dev.bat index f32670b67de96..008a5c668bc45 100644 --- a/R/install-dev.bat +++ b/R/install-dev.bat @@ -25,8 +25,3 @@ set SPARK_HOME=%~dp0.. MKDIR %SPARK_HOME%\R\lib R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\ - -rem Zip the SparkR package so that it can be distributed to worker nodes on YARN -pushd %SPARK_HOME%\R\lib -%JAVA_HOME%\bin\jar.exe cfM "%SPARK_HOME%\R\lib\sparkr.zip" SparkR -popd diff --git a/make-distribution.sh b/make-distribution.sh index 4789b0e09cc8a..247a81341e4a4 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -219,7 +219,6 @@ cp -r "$SPARK_HOME/ec2" "$DISTDIR" if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then mkdir -p "$DISTDIR"/R/lib cp -r "$SPARK_HOME/R/lib/SparkR" "$DISTDIR"/R/lib - cp "$SPARK_HOME/R/lib/sparkr.zip" "$DISTDIR"/R/lib fi # Download and copy in tachyon, if requested From 2278219054314f1d31ffc358a59aa5067f9f5de9 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 12 Aug 2015 21:24:15 -0700 Subject: [PATCH 319/340] [SPARK-9920] [SQL] The simpleString of TungstenAggregate does not show its output https://issues.apache.org/jira/browse/SPARK-9920 Taking `sqlContext.sql("select i, sum(j1) as sum from testAgg group by i").explain()` as an example, the output of our current master is ``` == Physical Plan == TungstenAggregate(key=[i#0], value=[(sum(cast(j1#1 as bigint)),mode=Final,isDistinct=false)] TungstenExchange hashpartitioning(i#0) TungstenAggregate(key=[i#0], value=[(sum(cast(j1#1 as bigint)),mode=Partial,isDistinct=false)] Scan ParquetRelation[file:/user/hive/warehouse/testagg][i#0,j1#1] ``` With this PR, the output will be ``` == Physical Plan == TungstenAggregate(key=[i#0], functions=[(sum(cast(j1#1 as bigint)),mode=Final,isDistinct=false)], output=[i#0,sum#18L]) TungstenExchange hashpartitioning(i#0) TungstenAggregate(key=[i#0], functions=[(sum(cast(j1#1 as bigint)),mode=Partial,isDistinct=false)], output=[i#0,currentSum#22L]) Scan ParquetRelation[file:/user/hive/warehouse/testagg][i#0,j1#1] ``` Author: Yin Huai Closes #8150 from yhuai/SPARK-9920. --- .../spark/sql/execution/aggregate/SortBasedAggregate.scala | 6 +++++- .../spark/sql/execution/aggregate/TungstenAggregate.scala | 7 ++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index ab26f9c58aa2e..f4c14a9b3556f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -108,6 +108,10 @@ case class SortBasedAggregate( override def simpleString: String = { val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions - s"""SortBasedAggregate ${groupingExpressions} ${allAggregateExpressions}""" + + val keyString = groupingExpressions.mkString("[", ",", "]") + val functionString = allAggregateExpressions.mkString("[", ",", "]") + val outputString = output.mkString("[", ",", "]") + s"SortBasedAggregate(key=$keyString, functions=$functionString, output=$outputString)" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index c40ca973796a6..99f51ba5b6935 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -127,11 +127,12 @@ case class TungstenAggregate( testFallbackStartsAt match { case None => val keyString = groupingExpressions.mkString("[", ",", "]") - val valueString = allAggregateExpressions.mkString("[", ",", "]") - s"TungstenAggregate(key=$keyString, value=$valueString" + val functionString = allAggregateExpressions.mkString("[", ",", "]") + val outputString = output.mkString("[", ",", "]") + s"TungstenAggregate(key=$keyString, functions=$functionString, output=$outputString)" case Some(fallbackStartsAt) => s"TungstenAggregateWithControlledFallback $groupingExpressions " + - s"$allAggregateExpressions fallbackStartsAt=$fallbackStartsAt" + s"$allAggregateExpressions $resultExpressions fallbackStartsAt=$fallbackStartsAt" } } } From a8ab2634c1eee143a4deaf309204df8add727f9e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 12 Aug 2015 21:26:00 -0700 Subject: [PATCH 320/340] [SPARK-9832] [SQL] add a thread-safe lookup for BytesToBytseMap This patch add a thread-safe lookup for BytesToBytseMap, and use that in broadcasted HashedRelation. Author: Davies Liu Closes #8151 from davies/safeLookup. --- .../spark/unsafe/map/BytesToBytesMap.java | 30 ++++++++++++++----- .../sql/execution/joins/HashedRelation.scala | 6 ++-- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 87ed47e88c4ef..5f3a4fcf4d585 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -17,25 +17,24 @@ package org.apache.spark.unsafe.map; -import java.lang.Override; -import java.lang.UnsupportedOperationException; +import javax.annotation.Nullable; import java.util.Iterator; import java.util.LinkedList; import java.util.List; -import javax.annotation.Nullable; - import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.shuffle.ShuffleMemoryManager; -import org.apache.spark.unsafe.*; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.bitset.BitSet; import org.apache.spark.unsafe.hash.Murmur3_x86_32; -import org.apache.spark.unsafe.memory.*; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.MemoryLocation; +import org.apache.spark.unsafe.memory.TaskMemoryManager; /** * An append-only hash map where keys and values are contiguous regions of bytes. @@ -328,6 +327,20 @@ public Location lookup( Object keyBaseObject, long keyBaseOffset, int keyRowLengthBytes) { + safeLookup(keyBaseObject, keyBaseOffset, keyRowLengthBytes, loc); + return loc; + } + + /** + * Looks up a key, and saves the result in provided `loc`. + * + * This is a thread-safe version of `lookup`, could be used by multiple threads. + */ + public void safeLookup( + Object keyBaseObject, + long keyBaseOffset, + int keyRowLengthBytes, + Location loc) { assert(bitset != null); assert(longArray != null); @@ -343,7 +356,8 @@ public Location lookup( } if (!bitset.isSet(pos)) { // This is a new key. - return loc.with(pos, hashcode, false); + loc.with(pos, hashcode, false); + return; } else { long stored = longArray.get(pos * 2 + 1); if ((int) (stored) == hashcode) { @@ -361,7 +375,7 @@ public Location lookup( keyRowLengthBytes ); if (areEqual) { - return loc; + return; } else { if (enablePerfMetrics) { numHashCollisions++; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index bb333b4d5ed18..ea02076b41a6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -215,8 +215,10 @@ private[joins] final class UnsafeHashedRelation( if (binaryMap != null) { // Used in Broadcast join - val loc = binaryMap.lookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, - unsafeKey.getSizeInBytes) + val map = binaryMap // avoid the compiler error + val loc = new map.Location // this could be allocated in stack + binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, + unsafeKey.getSizeInBytes, loc) if (loc.isDefined) { val buffer = CompactBuffer[UnsafeRow]() From 5fc058a1fc5d83ad53feec936475484aef3800b3 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 21:33:38 -0700 Subject: [PATCH 321/340] [SPARK-9917] [ML] add getMin/getMax and doc for originalMin/origianlMax in MinMaxScaler hhbyyh Author: Xiangrui Meng Closes #8145 from mengxr/SPARK-9917. --- .../org/apache/spark/ml/feature/MinMaxScaler.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index b30adf3df48d2..9a473dd23772d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -41,6 +41,9 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H val min: DoubleParam = new DoubleParam(this, "min", "lower bound of the output feature range") + /** @group getParam */ + def getMin: Double = $(min) + /** * upper bound after transformation, shared by all features * Default: 1.0 @@ -49,6 +52,9 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H val max: DoubleParam = new DoubleParam(this, "max", "upper bound of the output feature range") + /** @group getParam */ + def getMax: Double = $(max) + /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType @@ -115,6 +121,9 @@ class MinMaxScaler(override val uid: String) * :: Experimental :: * Model fitted by [[MinMaxScaler]]. * + * @param originalMin min value for each original column during fitting + * @param originalMax max value for each original column during fitting + * * TODO: The transformer does not yet set the metadata in the output column (SPARK-8529). */ @Experimental @@ -136,7 +145,6 @@ class MinMaxScalerModel private[ml] ( /** @group setParam */ def setMax(value: Double): this.type = set(max, value) - override def transform(dataset: DataFrame): DataFrame = { val originalRange = (originalMax.toBreeze - originalMin.toBreeze).toArray val minArray = originalMin.toArray From df543892122342b97e5137b266959ba97589b3ef Mon Sep 17 00:00:00 2001 From: "shikai.tang" Date: Wed, 12 Aug 2015 21:53:15 -0700 Subject: [PATCH 322/340] [SPARK-8922] [DOCUMENTATION, MLLIB] Add @since tags to mllib.evaluation Author: shikai.tang Closes #7429 from mosessky/master. --- .../BinaryClassificationMetrics.scala | 32 ++++++++++++++++--- .../mllib/evaluation/MulticlassMetrics.scala | 9 ++++++ .../mllib/evaluation/MultilabelMetrics.scala | 4 +++ .../mllib/evaluation/RankingMetrics.scala | 4 +++ .../mllib/evaluation/RegressionMetrics.scala | 6 ++++ 5 files changed, 50 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index c1d1a224817e8..486741edd6f5a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.DataFrame * of bins may not exactly equal numBins. The last bin in each partition may * be smaller as a result, meaning there may be an extra sample at * partition boundaries. + * @since 1.3.0 */ @Experimental class BinaryClassificationMetrics( @@ -51,6 +52,7 @@ class BinaryClassificationMetrics( /** * Defaults `numBins` to 0. + * @since 1.0.0 */ def this(scoreAndLabels: RDD[(Double, Double)]) = this(scoreAndLabels, 0) @@ -61,12 +63,18 @@ class BinaryClassificationMetrics( private[mllib] def this(scoreAndLabels: DataFrame) = this(scoreAndLabels.map(r => (r.getDouble(0), r.getDouble(1)))) - /** Unpersist intermediate RDDs used in the computation. */ + /** + * Unpersist intermediate RDDs used in the computation. + * @since 1.0.0 + */ def unpersist() { cumulativeCounts.unpersist() } - /** Returns thresholds in descending order. */ + /** + * Returns thresholds in descending order. + * @since 1.0.0 + */ def thresholds(): RDD[Double] = cumulativeCounts.map(_._1) /** @@ -74,6 +82,7 @@ class BinaryClassificationMetrics( * which is an RDD of (false positive rate, true positive rate) * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic + * @since 1.0.0 */ def roc(): RDD[(Double, Double)] = { val rocCurve = createCurve(FalsePositiveRate, Recall) @@ -85,6 +94,7 @@ class BinaryClassificationMetrics( /** * Computes the area under the receiver operating characteristic (ROC) curve. + * @since 1.0.0 */ def areaUnderROC(): Double = AreaUnderCurve.of(roc()) @@ -92,6 +102,7 @@ class BinaryClassificationMetrics( * Returns the precision-recall curve, which is an RDD of (recall, precision), * NOT (precision, recall), with (0.0, 1.0) prepended to it. * @see http://en.wikipedia.org/wiki/Precision_and_recall + * @since 1.0.0 */ def pr(): RDD[(Double, Double)] = { val prCurve = createCurve(Recall, Precision) @@ -102,6 +113,7 @@ class BinaryClassificationMetrics( /** * Computes the area under the precision-recall curve. + * @since 1.0.0 */ def areaUnderPR(): Double = AreaUnderCurve.of(pr()) @@ -110,16 +122,26 @@ class BinaryClassificationMetrics( * @param beta the beta factor in F-Measure computation. * @return an RDD of (threshold, F-Measure) pairs. * @see http://en.wikipedia.org/wiki/F1_score + * @since 1.0.0 */ def fMeasureByThreshold(beta: Double): RDD[(Double, Double)] = createCurve(FMeasure(beta)) - /** Returns the (threshold, F-Measure) curve with beta = 1.0. */ + /** + * Returns the (threshold, F-Measure) curve with beta = 1.0. + * @since 1.0.0 + */ def fMeasureByThreshold(): RDD[(Double, Double)] = fMeasureByThreshold(1.0) - /** Returns the (threshold, precision) curve. */ + /** + * Returns the (threshold, precision) curve. + * @since 1.0.0 + */ def precisionByThreshold(): RDD[(Double, Double)] = createCurve(Precision) - /** Returns the (threshold, recall) curve. */ + /** + * Returns the (threshold, recall) curve. + * @since 1.0.0 + */ def recallByThreshold(): RDD[(Double, Double)] = createCurve(Recall) private lazy val ( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 4628dc5690913..dddfa3ea5b800 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.DataFrame * Evaluator for multiclass classification. * * @param predictionAndLabels an RDD of (prediction, label) pairs. + * @since 1.1.0 */ @Experimental class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { @@ -64,6 +65,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * predicted classes are in columns, * they are ordered by class label ascending, * as in "labels" + * @since 1.1.0 */ def confusionMatrix: Matrix = { val n = labels.size @@ -83,12 +85,14 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns true positive rate for a given label (category) * @param label the label. + * @since 1.1.0 */ def truePositiveRate(label: Double): Double = recall(label) /** * Returns false positive rate for a given label (category) * @param label the label. + * @since 1.1.0 */ def falsePositiveRate(label: Double): Double = { val fp = fpByClass.getOrElse(label, 0) @@ -98,6 +102,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns precision for a given label (category) * @param label the label. + * @since 1.1.0 */ def precision(label: Double): Double = { val tp = tpByClass(label) @@ -108,6 +113,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns recall for a given label (category) * @param label the label. + * @since 1.1.0 */ def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label) @@ -115,6 +121,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * Returns f-measure for a given label (category) * @param label the label. * @param beta the beta parameter. + * @since 1.1.0 */ def fMeasure(label: Double, beta: Double): Double = { val p = precision(label) @@ -126,6 +133,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns f1-measure for a given label (category) * @param label the label. + * @since 1.1.0 */ def fMeasure(label: Double): Double = fMeasure(label, 1.0) @@ -179,6 +187,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns weighted averaged f-measure * @param beta the beta parameter. + * @since 1.1.0 */ def weightedFMeasure(beta: Double): Double = labelCountByClass.map { case (category, count) => fMeasure(category, beta) * count.toDouble / labelCount diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala index bf6eb1d5bd2ab..77cb1e09bdbb5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.DataFrame * Evaluator for multilabel classification. * @param predictionAndLabels an RDD of (predictions, labels) pairs, * both are non-null Arrays, each with unique elements. + * @since 1.2.0 */ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) { @@ -103,6 +104,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns precision for a given label (category) * @param label the label. + * @since 1.2.0 */ def precision(label: Double): Double = { val tp = tpPerClass(label) @@ -113,6 +115,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns recall for a given label (category) * @param label the label. + * @since 1.2.0 */ def recall(label: Double): Double = { val tp = tpPerClass(label) @@ -123,6 +126,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns f1-measure for a given label (category) * @param label the label. + * @since 1.2.0 */ def f1Measure(label: Double): Double = { val p = precision(label) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index 5b5a2a1450f7f..063fbed8cdeea 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD * Java users should use [[RankingMetrics$.of]] to create a [[RankingMetrics]] instance. * * @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs. + * @since 1.2.0 */ @Experimental class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]) @@ -55,6 +56,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] * * @param k the position to compute the truncated precision, must be positive * @return the average precision at the first k ranking positions + * @since 1.2.0 */ def precisionAt(k: Int): Double = { require(k > 0, "ranking position k should be positive") @@ -124,6 +126,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] * * @param k the position to compute the truncated ndcg, must be positive * @return the average ndcg at the first k ranking positions + * @since 1.2.0 */ def ndcgAt(k: Int): Double = { require(k > 0, "ranking position k should be positive") @@ -162,6 +165,7 @@ object RankingMetrics { /** * Creates a [[RankingMetrics]] instance (for Java users). * @param predictionAndLabels a JavaRDD of (predicted ranking, ground truth set) pairs + * @since 1.4.0 */ def of[E, T <: jl.Iterable[E]](predictionAndLabels: JavaRDD[(T, T)]): RankingMetrics[E] = { implicit val tag = JavaSparkContext.fakeClassTag[E] diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index 408847afa800d..54dfd8c099494 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.DataFrame * Evaluator for regression. * * @param predictionAndObservations an RDD of (prediction, observation) pairs. + * @since 1.2.0 */ @Experimental class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extends Logging { @@ -66,6 +67,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend * Returns the variance explained by regression. * explainedVariance = \sum_i (\hat{y_i} - \bar{y})^2 / n * @see [[https://en.wikipedia.org/wiki/Fraction_of_variance_unexplained]] + * @since 1.2.0 */ def explainedVariance: Double = { SSreg / summary.count @@ -74,6 +76,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend /** * Returns the mean absolute error, which is a risk function corresponding to the * expected value of the absolute error loss or l1-norm loss. + * @since 1.2.0 */ def meanAbsoluteError: Double = { summary.normL1(1) / summary.count @@ -82,6 +85,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend /** * Returns the mean squared error, which is a risk function corresponding to the * expected value of the squared error loss or quadratic loss. + * @since 1.2.0 */ def meanSquaredError: Double = { SSerr / summary.count @@ -90,6 +94,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend /** * Returns the root mean squared error, which is defined as the square root of * the mean squared error. + * @since 1.2.0 */ def rootMeanSquaredError: Double = { math.sqrt(this.meanSquaredError) @@ -98,6 +103,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend /** * Returns R^2^, the unadjusted coefficient of determination. * @see [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] + * @since 1.2.0 */ def r2: Double = { 1 - SSerr / SStot From d7eb371eb6369a34e58a09179efe058c4101de9e Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 22:30:33 -0700 Subject: [PATCH 323/340] [SPARK-9914] [ML] define setters explicitly for Java and use setParam group in RFormula The problem with defining setters in the base class is that it doesn't return the correct type in Java. ericl Author: Xiangrui Meng Closes #8143 from mengxr/SPARK-9914 and squashes the following commits: d36c887 [Xiangrui Meng] remove setters from model a49021b [Xiangrui Meng] define setters explicitly for Java and use setParam group --- .../scala/org/apache/spark/ml/feature/RFormula.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index d5360c9217ea9..a752dacd72d95 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -33,11 +33,6 @@ import org.apache.spark.sql.types._ * Base trait for [[RFormula]] and [[RFormulaModel]]. */ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { - /** @group getParam */ - def setFeaturesCol(value: String): this.type = set(featuresCol, value) - - /** @group getParam */ - def setLabelCol(value: String): this.type = set(labelCol, value) protected def hasLabelCol(schema: StructType): Boolean = { schema.map(_.name).contains($(labelCol)) @@ -71,6 +66,12 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R /** @group getParam */ def getFormula: String = $(formula) + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + /** Whether the formula specifies fitting an intercept. */ private[ml] def hasIntercept: Boolean = { require(isDefined(formula), "Formula must be defined first.") From d0b18919d16e6a2f19159516bd2767b60b595279 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Thu, 13 Aug 2015 13:33:39 +0800 Subject: [PATCH 324/340] [SPARK-9927] [SQL] Revert 8049 since it's pushing wrong filter down I made a mistake in #8049 by casting literal value to attribute's data type, which would cause simply truncate the literal value and push a wrong filter down. JIRA: https://issues.apache.org/jira/browse/SPARK-9927 Author: Yijie Shen Closes #8157 from yjshen/rever8049. --- .../datasources/DataSourceStrategy.scala | 30 ++-------------- .../execution/datasources/jdbc/JDBCRDD.scala | 2 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 35 ------------------- 3 files changed, 3 insertions(+), 64 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 9eea2b0382535..2a4c40db8bb66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.{Logging, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions} +import org.apache.spark.sql.catalyst.{InternalRow, expressions} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{TimestampType, DateType, StringType, StructType} +import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.sql.{SaveMode, Strategy, execution, sources, _} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -343,17 +343,11 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { * and convert them. */ protected[sql] def selectFilters(filters: Seq[Expression]) = { - import CatalystTypeConverters._ - def translate(predicate: Expression): Option[Filter] = predicate match { case expressions.EqualTo(a: Attribute, Literal(v, _)) => Some(sources.EqualTo(a.name, v)) case expressions.EqualTo(Literal(v, _), a: Attribute) => Some(sources.EqualTo(a.name, v)) - case expressions.EqualTo(Cast(a: Attribute, _), l: Literal) => - Some(sources.EqualTo(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) - case expressions.EqualTo(l: Literal, Cast(a: Attribute, _)) => - Some(sources.EqualTo(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) => Some(sources.EqualNullSafe(a.name, v)) @@ -364,41 +358,21 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { Some(sources.GreaterThan(a.name, v)) case expressions.GreaterThan(Literal(v, _), a: Attribute) => Some(sources.LessThan(a.name, v)) - case expressions.GreaterThan(Cast(a: Attribute, _), l: Literal) => - Some(sources.GreaterThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) - case expressions.GreaterThan(l: Literal, Cast(a: Attribute, _)) => - Some(sources.LessThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.LessThan(a: Attribute, Literal(v, _)) => Some(sources.LessThan(a.name, v)) case expressions.LessThan(Literal(v, _), a: Attribute) => Some(sources.GreaterThan(a.name, v)) - case expressions.LessThan(Cast(a: Attribute, _), l: Literal) => - Some(sources.LessThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) - case expressions.LessThan(l: Literal, Cast(a: Attribute, _)) => - Some(sources.GreaterThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => Some(sources.GreaterThanOrEqual(a.name, v)) case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => Some(sources.LessThanOrEqual(a.name, v)) - case expressions.GreaterThanOrEqual(Cast(a: Attribute, _), l: Literal) => - Some(sources.GreaterThanOrEqual(a.name, - convertToScala(Cast(l, a.dataType).eval(), a.dataType))) - case expressions.GreaterThanOrEqual(l: Literal, Cast(a: Attribute, _)) => - Some(sources.LessThanOrEqual(a.name, - convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => Some(sources.LessThanOrEqual(a.name, v)) case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => Some(sources.GreaterThanOrEqual(a.name, v)) - case expressions.LessThanOrEqual(Cast(a: Attribute, _), l: Literal) => - Some(sources.LessThanOrEqual(a.name, - convertToScala(Cast(l, a.dataType).eval(), a.dataType))) - case expressions.LessThanOrEqual(l: Literal, Cast(a: Attribute, _)) => - Some(sources.GreaterThanOrEqual(a.name, - convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.InSet(a: Attribute, set) => Some(sources.In(a.name, set.toArray)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 281943e23fcff..8eab6a0adccc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -284,7 +284,7 @@ private[sql] class JDBCRDD( /** * `filters`, but as a WHERE clause suitable for injection into a SQL query. */ - val filterWhereClause: String = { + private val filterWhereClause: String = { val filterStrings = filters map compileFilter filter (_ != null) if (filterStrings.size > 0) { val sb = new StringBuilder("WHERE ") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index b9cfae51e809c..e4dcf4c75d208 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -25,8 +25,6 @@ import org.h2.jdbc.JdbcSQLException import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD -import org.apache.spark.sql.execution.PhysicalRDD import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -150,18 +148,6 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { |OPTIONS (url '$url', dbtable 'TEST.FLTTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - conn.prepareStatement("create table test.decimals (a DECIMAL(7, 2), b DECIMAL(4, 0))"). - executeUpdate() - conn.prepareStatement("insert into test.decimals values (12345.67, 1234)").executeUpdate() - conn.prepareStatement("insert into test.decimals values (34567.89, 1428)").executeUpdate() - conn.commit() - sql( - s""" - |CREATE TEMPORARY TABLE decimals - |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.DECIMALS', user 'testUser', password 'testPass') - """.stripMargin.replaceAll("\n", " ")) - conn.prepareStatement( s""" |create table test.nulltypes (a INT, b BOOLEAN, c TINYINT, d BINARY(20), e VARCHAR(20), @@ -458,25 +444,4 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) } - - test("SPARK-9182: filters are not passed through to jdbc source") { - def checkPushedFilter(query: String, filterStr: String): Unit = { - val rddOpt = sql(query).queryExecution.executedPlan.collectFirst { - case PhysicalRDD(_, rdd: JDBCRDD, _) => rdd - } - assert(rddOpt.isDefined) - val pushedFilterStr = rddOpt.get.filterWhereClause - assert(pushedFilterStr.contains(filterStr), - s"Expected to push [$filterStr], actually we pushed [$pushedFilterStr]") - } - - checkPushedFilter("select * from foobar where NAME = 'fred'", "NAME = 'fred'") - checkPushedFilter("select * from inttypes where A > '15'", "A > 15") - checkPushedFilter("select * from inttypes where C <= 20", "C <= 20") - checkPushedFilter("select * from decimals where A > 1000", "A > 1000.00") - checkPushedFilter("select * from decimals where A > 1000 AND A < 2000", - "A > 1000.00 AND A < 2000.00") - checkPushedFilter("select * from decimals where A = 2000 AND B > 20", "A = 2000.00 AND B > 20") - checkPushedFilter("select * from timetypes where B > '1998-09-10'", "B > 1998-09-10") - } } From 68f99571492f67596b3656e9f076deeb96616f4a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 23:04:59 -0700 Subject: [PATCH 325/340] [SPARK-9918] [MLLIB] remove runs from k-means and rename epsilon to tol This requires some discussion. I'm not sure whether `runs` is a useful parameter. It certainly complicates the implementation. We might want to optimize the k-means implementation with block matrix operations. In this case, having `runs` may not be worth the trade-off. Also it increases the communication cost in a single job, which might cause other issues. This PR also renames `epsilon` to `tol` to have consistent naming among algorithms. The Python constructor is updated to include all parameters. jkbradley yu-iskw Author: Xiangrui Meng Closes #8148 from mengxr/SPARK-9918 and squashes the following commits: 149b9e5 [Xiangrui Meng] fix constructor in Python and rename epsilon to tol 3cc15b3 [Xiangrui Meng] fix test and change initStep to initSteps in python a0a0274 [Xiangrui Meng] remove runs from k-means in the pipeline API --- .../apache/spark/ml/clustering/KMeans.scala | 51 +++------------ .../spark/ml/clustering/KMeansSuite.scala | 12 +--- python/pyspark/ml/clustering.py | 63 ++++--------------- 3 files changed, 26 insertions(+), 100 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index dc192add6ca13..47a18cdb31b53 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -18,8 +18,8 @@ package org.apache.spark.ml.clustering import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.param.{Param, Params, IntParam, DoubleParam, ParamMap} -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasMaxIter, HasPredictionCol, HasSeed} +import org.apache.spark.ml.param.{Param, Params, IntParam, ParamMap} +import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} @@ -27,14 +27,13 @@ import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.util.Utils /** * Common params for KMeans and KMeansModel */ -private[clustering] trait KMeansParams - extends Params with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol { +private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol + with HasSeed with HasPredictionCol with HasTol { /** * Set the number of clusters to create (k). Must be > 1. Default: 2. @@ -45,31 +44,6 @@ private[clustering] trait KMeansParams /** @group getParam */ def getK: Int = $(k) - /** - * Param the number of runs of the algorithm to execute in parallel. We initialize the algorithm - * this many times with random starting conditions (configured by the initialization mode), then - * return the best clustering found over any run. Must be >= 1. Default: 1. - * @group param - */ - final val runs = new IntParam(this, "runs", - "number of runs of the algorithm to execute in parallel", (value: Int) => value >= 1) - - /** @group getParam */ - def getRuns: Int = $(runs) - - /** - * Param the distance threshold within which we've consider centers to have converged. - * If all centers move less than this Euclidean distance, we stop iterating one run. - * Must be >= 0.0. Default: 1e-4 - * @group param - */ - final val epsilon = new DoubleParam(this, "epsilon", - "distance threshold within which we've consider centers to have converge", - (value: Double) => value >= 0.0) - - /** @group getParam */ - def getEpsilon: Double = $(epsilon) - /** * Param for the initialization algorithm. This can be either "random" to choose random points as * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ @@ -136,9 +110,9 @@ class KMeansModel private[ml] ( /** * :: Experimental :: - * K-means clustering with support for multiple parallel runs and a k-means++ like initialization - * mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, - * they are executed together with joint passes over the data for efficiency. + * K-means clustering with support for k-means|| initialization proposed by Bahmani et al. + * + * @see [[http://dx.doi.org/10.14778/2180912.2180915 Bahmani et al., Scalable k-means++.]] */ @Experimental class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMeansParams { @@ -146,10 +120,9 @@ class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMean setDefault( k -> 2, maxIter -> 20, - runs -> 1, initMode -> MLlibKMeans.K_MEANS_PARALLEL, initSteps -> 5, - epsilon -> 1e-4) + tol -> 1e-4) override def copy(extra: ParamMap): KMeans = defaultCopy(extra) @@ -174,10 +147,7 @@ class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMean def setMaxIter(value: Int): this.type = set(maxIter, value) /** @group setParam */ - def setRuns(value: Int): this.type = set(runs, value) - - /** @group setParam */ - def setEpsilon(value: Double): this.type = set(epsilon, value) + def setTol(value: Double): this.type = set(tol, value) /** @group setParam */ def setSeed(value: Long): this.type = set(seed, value) @@ -191,8 +161,7 @@ class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMean .setInitializationSteps($(initSteps)) .setMaxIterations($(maxIter)) .setSeed($(seed)) - .setEpsilon($(epsilon)) - .setRuns($(runs)) + .setEpsilon($(tol)) val parentModel = algo.run(rdd) val model = new KMeansModel(uid, parentModel) copyValues(model) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 1f15ac02f4008..688b0e31f91dc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -52,10 +52,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { assert(kmeans.getFeaturesCol === "features") assert(kmeans.getPredictionCol === "prediction") assert(kmeans.getMaxIter === 20) - assert(kmeans.getRuns === 1) assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL) assert(kmeans.getInitSteps === 5) - assert(kmeans.getEpsilon === 1e-4) + assert(kmeans.getTol === 1e-4) } test("set parameters") { @@ -64,21 +63,19 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { .setFeaturesCol("test_feature") .setPredictionCol("test_prediction") .setMaxIter(33) - .setRuns(7) .setInitMode(MLlibKMeans.RANDOM) .setInitSteps(3) .setSeed(123) - .setEpsilon(1e-3) + .setTol(1e-3) assert(kmeans.getK === 9) assert(kmeans.getFeaturesCol === "test_feature") assert(kmeans.getPredictionCol === "test_prediction") assert(kmeans.getMaxIter === 33) - assert(kmeans.getRuns === 7) assert(kmeans.getInitMode === MLlibKMeans.RANDOM) assert(kmeans.getInitSteps === 3) assert(kmeans.getSeed === 123) - assert(kmeans.getEpsilon === 1e-3) + assert(kmeans.getTol === 1e-3) } test("parameters validation") { @@ -91,9 +88,6 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { intercept[IllegalArgumentException] { new KMeans().setInitSteps(0) } - intercept[IllegalArgumentException] { - new KMeans().setRuns(0) - } } test("fit & transform") { diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 48338713a29ea..cb4c16e25a7a3 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -19,7 +19,6 @@ from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * from pyspark.mllib.common import inherit_doc -from pyspark.mllib.linalg import _convert_to_vector __all__ = ['KMeans', 'KMeansModel'] @@ -35,7 +34,7 @@ def clusterCenters(self): @inherit_doc -class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed): +class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed): """ K-means clustering with support for multiple parallel runs and a k-means++ like initialization mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, @@ -45,7 +44,7 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed): >>> data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), ... (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)] >>> df = sqlContext.createDataFrame(data, ["features"]) - >>> kmeans = KMeans().setK(2).setSeed(1).setFeaturesCol("features") + >>> kmeans = KMeans(k=2, seed=1) >>> model = kmeans.fit(df) >>> centers = model.clusterCenters() >>> len(centers) @@ -60,10 +59,6 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed): # a placeholder to make it appear in the generated doc k = Param(Params._dummy(), "k", "number of clusters to create") - epsilon = Param(Params._dummy(), "epsilon", - "distance threshold within which " + - "we've consider centers to have converged") - runs = Param(Params._dummy(), "runs", "number of runs of the algorithm to execute in parallel") initMode = Param(Params._dummy(), "initMode", "the initialization algorithm. This can be either \"random\" to " + "choose random points as initial cluster centers, or \"k-means||\" " + @@ -71,21 +66,21 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed): initSteps = Param(Params._dummy(), "initSteps", "steps for k-means initialization mode") @keyword_only - def __init__(self, k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initStep=5): + def __init__(self, featuresCol="features", predictionCol="prediction", k=2, + initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None): + """ + __init__(self, featuresCol="features", predictionCol="prediction", k=2, \ + initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None) + """ super(KMeans, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid) self.k = Param(self, "k", "number of clusters to create") - self.epsilon = Param(self, "epsilon", - "distance threshold within which " + - "we've consider centers to have converged") - self.runs = Param(self, "runs", "number of runs of the algorithm to execute in parallel") - self.seed = Param(self, "seed", "random seed") self.initMode = Param(self, "initMode", "the initialization algorithm. This can be either \"random\" to " + "choose random points as initial cluster centers, or \"k-means||\" " + "to use a parallel variant of k-means++") self.initSteps = Param(self, "initSteps", "steps for k-means initialization mode") - self._setDefault(k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initSteps=5) + self._setDefault(k=2, initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -93,9 +88,11 @@ def _create_model(self, java_model): return KMeansModel(java_model) @keyword_only - def setParams(self, k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initSteps=5): + def setParams(self, featuresCol="features", predictionCol="prediction", k=2, + initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None): """ - setParams(self, k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initSteps=5): + setParams(self, featuresCol="features", predictionCol="prediction", k=2, \ + initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None) Sets params for KMeans. """ @@ -119,40 +116,6 @@ def getK(self): """ return self.getOrDefault(self.k) - def setEpsilon(self, value): - """ - Sets the value of :py:attr:`epsilon`. - - >>> algo = KMeans().setEpsilon(1e-5) - >>> abs(algo.getEpsilon() - 1e-5) < 1e-5 - True - """ - self._paramMap[self.epsilon] = value - return self - - def getEpsilon(self): - """ - Gets the value of `epsilon` - """ - return self.getOrDefault(self.epsilon) - - def setRuns(self, value): - """ - Sets the value of :py:attr:`runs`. - - >>> algo = KMeans().setRuns(10) - >>> algo.getRuns() - 10 - """ - self._paramMap[self.runs] = value - return self - - def getRuns(self): - """ - Gets the value of `runs` - """ - return self.getOrDefault(self.runs) - def setInitMode(self, value): """ Sets the value of :py:attr:`initMode`. From 84a27916a62980c8fcb0977c3a7fdb73c0bd5812 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 13 Aug 2015 15:08:57 +0800 Subject: [PATCH 326/340] [SPARK-9885] [SQL] Also pass barrierPrefixes and sharedPrefixes to IsolatedClientLoader when hiveMetastoreJars is set to maven. https://issues.apache.org/jira/browse/SPARK-9885 cc marmbrus liancheng Author: Yin Huai Closes #8158 from yhuai/classloaderMaven. --- .../scala/org/apache/spark/sql/hive/HiveContext.scala | 6 +++++- .../spark/sql/hive/client/IsolatedClientLoader.scala | 11 +++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index f17177a771c3b..17762649fd70d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -231,7 +231,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { // TODO: Support for loading the jars from an already downloaded location. logInfo( s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using maven.") - IsolatedClientLoader.forVersion(hiveMetastoreVersion, allConfig) + IsolatedClientLoader.forVersion( + version = hiveMetastoreVersion, + config = allConfig, + barrierPrefixes = hiveMetastoreBarrierPrefixes, + sharedPrefixes = hiveMetastoreSharedPrefixes) } else { // Convert to files and expand any directories. val jars = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index a7d5a991948d9..7856037508412 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -42,11 +42,18 @@ private[hive] object IsolatedClientLoader { def forVersion( version: String, config: Map[String, String] = Map.empty, - ivyPath: Option[String] = None): IsolatedClientLoader = synchronized { + ivyPath: Option[String] = None, + sharedPrefixes: Seq[String] = Seq.empty, + barrierPrefixes: Seq[String] = Seq.empty): IsolatedClientLoader = synchronized { val resolvedVersion = hiveVersion(version) val files = resolvedVersions.getOrElseUpdate(resolvedVersion, downloadVersion(resolvedVersion, ivyPath)) - new IsolatedClientLoader(hiveVersion(version), files, config) + new IsolatedClientLoader( + version = hiveVersion(version), + execJars = files, + config = config, + sharedPrefixes = sharedPrefixes, + barrierPrefixes = barrierPrefixes) } def hiveVersion(version: String): HiveVersion = version match { From 69930310115501f0de094fe6f5c6c60dade342bd Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 13 Aug 2015 16:16:50 +0800 Subject: [PATCH 327/340] [SPARK-9757] [SQL] Fixes persistence of Parquet relation with decimal column PR #7967 enables us to save data source relations to metastore in Hive compatible format when possible. But it fails to persist Parquet relations with decimal column(s) to Hive metastore of versions lower than 1.2.0. This is because `ParquetHiveSerDe` in Hive versions prior to 1.2.0 doesn't support decimal. This PR checks for this case and falls back to Spark SQL specific metastore table format. Author: Yin Huai Author: Cheng Lian Closes #8130 from liancheng/spark-9757/old-hive-parquet-decimal. --- .../apache/spark/sql/types/ArrayType.scala | 6 +- .../org/apache/spark/sql/types/DataType.scala | 5 ++ .../org/apache/spark/sql/types/MapType.scala | 6 +- .../apache/spark/sql/types/StructType.scala | 8 ++- .../spark/sql/types/DataTypeSuite.scala | 24 +++++++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 39 ++++++++--- .../sql/hive/client/ClientInterface.scala | 3 + .../spark/sql/hive/client/ClientWrapper.scala | 2 +- .../spark/sql/hive/client/package.scala | 2 +- .../sql/hive/HiveMetastoreCatalogSuite.scala | 17 +++-- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 68 +++++++++++++++++-- 11 files changed, 150 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 5094058164b2f..5770f59b53077 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -75,6 +75,10 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT override def simpleString: String = s"array<${elementType.simpleString}>" - private[spark] override def asNullable: ArrayType = + override private[spark] def asNullable: ArrayType = ArrayType(elementType.asNullable, containsNull = true) + + override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { + f(this) || elementType.existsRecursively(f) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index f4428c2e8b202..7bcd623b3f33e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -77,6 +77,11 @@ abstract class DataType extends AbstractDataType { */ private[spark] def asNullable: DataType + /** + * Returns true if any `DataType` of this DataType tree satisfies the given function `f`. + */ + private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = f(this) + override private[sql] def defaultConcreteType: DataType = this override private[sql] def acceptsType(other: DataType): Boolean = sameType(other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index ac34b642827ca..00461e529ca0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -62,8 +62,12 @@ case class MapType( override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>" - private[spark] override def asNullable: MapType = + override private[spark] def asNullable: MapType = MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true) + + override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { + f(this) || keyType.existsRecursively(f) || valueType.existsRecursively(f) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 9cbc207538d4f..d8968ef806390 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -24,7 +24,7 @@ import org.json4s.JsonDSL._ import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, AttributeReference, Attribute, InterpretedOrdering$} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} /** @@ -292,7 +292,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru private[sql] def merge(that: StructType): StructType = StructType.merge(this, that).asInstanceOf[StructType] - private[spark] override def asNullable: StructType = { + override private[spark] def asNullable: StructType = { val newFields = fields.map { case StructField(name, dataType, nullable, metadata) => StructField(name, dataType.asNullable, nullable = true, metadata) @@ -301,6 +301,10 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru StructType(newFields) } + override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { + f(this) || fields.exists(field => field.dataType.existsRecursively(f)) + } + private[sql] val interpretedOrdering = InterpretedOrdering.forSchema(this.fields.map(_.dataType)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 88b221cd81d74..706ecd29d1355 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -170,6 +170,30 @@ class DataTypeSuite extends SparkFunSuite { } } + test("existsRecursively") { + val struct = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + assert(struct.existsRecursively(_.isInstanceOf[LongType])) + assert(struct.existsRecursively(_.isInstanceOf[StructType])) + assert(!struct.existsRecursively(_.isInstanceOf[IntegerType])) + + val mapType = MapType(struct, StringType) + assert(mapType.existsRecursively(_.isInstanceOf[LongType])) + assert(mapType.existsRecursively(_.isInstanceOf[StructType])) + assert(mapType.existsRecursively(_.isInstanceOf[StringType])) + assert(mapType.existsRecursively(_.isInstanceOf[MapType])) + assert(!mapType.existsRecursively(_.isInstanceOf[IntegerType])) + + val arrayType = ArrayType(mapType) + assert(arrayType.existsRecursively(_.isInstanceOf[LongType])) + assert(arrayType.existsRecursively(_.isInstanceOf[StructType])) + assert(arrayType.existsRecursively(_.isInstanceOf[StringType])) + assert(arrayType.existsRecursively(_.isInstanceOf[MapType])) + assert(arrayType.existsRecursively(_.isInstanceOf[ArrayType])) + assert(!arrayType.existsRecursively(_.isInstanceOf[IntegerType])) + } + def checkDataTypeJsonRepr(dataType: DataType): Unit = { test(s"JSON - $dataType") { assert(DataType.fromJson(dataType.json) === dataType) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 5e5497837a393..6770462bb0ad3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -33,15 +33,14 @@ import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} -import org.apache.spark.sql.execution.{FileRelation, datasources} +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} +import org.apache.spark.sql.execution.{FileRelation, datasources} import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} @@ -86,9 +85,9 @@ private[hive] object HiveSerDe { serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"))) val key = source.toLowerCase match { - case _ if source.startsWith("org.apache.spark.sql.parquet") => "parquet" - case _ if source.startsWith("org.apache.spark.sql.orc") => "orc" - case _ => source.toLowerCase + case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet" + case s if s.startsWith("org.apache.spark.sql.orc") => "orc" + case s => s } serdeMap.get(key) @@ -309,11 +308,31 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val hiveTable = (maybeSerDe, dataSource.relation) match { case (Some(serde), relation: HadoopFsRelation) if relation.paths.length == 1 && relation.partitionColumns.isEmpty => - logInfo { - "Persisting data source relation with a single input path into Hive metastore in Hive " + - s"compatible format. Input path: ${relation.paths.head}" + // Hive ParquetSerDe doesn't support decimal type until 1.2.0. + val isParquetSerDe = serde.inputFormat.exists(_.toLowerCase.contains("parquet")) + val hasDecimalFields = relation.schema.existsRecursively(_.isInstanceOf[DecimalType]) + + val hiveParquetSupportsDecimal = client.version match { + case org.apache.spark.sql.hive.client.hive.v1_2 => true + case _ => false + } + + if (isParquetSerDe && !hiveParquetSupportsDecimal && hasDecimalFields) { + // If Hive version is below 1.2.0, we cannot save Hive compatible schema to + // metastore when the file format is Parquet and the schema has DecimalType. + logWarning { + "Persisting Parquet relation with decimal field(s) into Hive metastore in Spark SQL " + + "specific format, which is NOT compatible with Hive. Because ParquetHiveSerDe in " + + s"Hive ${client.version.fullVersion} doesn't support decimal type. See HIVE-6384." + } + newSparkSQLSpecificMetastoreTable() + } else { + logInfo { + "Persisting data source relation with a single input path into Hive metastore in " + + s"Hive compatible format. Input path: ${relation.paths.head}" + } + newHiveCompatibleMetastoreTable(relation, serde) } - newHiveCompatibleMetastoreTable(relation, serde) case (Some(serde), relation: HadoopFsRelation) if relation.partitionColumns.nonEmpty => logWarning { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index a82e152dcda2c..3811c152a7ae6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -88,6 +88,9 @@ private[hive] case class HiveTable( */ private[hive] trait ClientInterface { + /** Returns the Hive Version of this client. */ + def version: HiveVersion + /** Returns the configuration for the given key in the current session. */ def getConf(key: String, defaultValue: String): String diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 3d05b583cf9e0..f49c97de8ff4e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -58,7 +58,7 @@ import org.apache.spark.util.{CircularBuffer, Utils} * this ClientWrapper. */ private[hive] class ClientWrapper( - version: HiveVersion, + override val version: HiveVersion, config: Map[String, String], initClassLoader: ClassLoader) extends ClientInterface diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 0503691a44249..b1b8439efa011 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -25,7 +25,7 @@ package object client { val exclusions: Seq[String] = Nil) // scalastyle:off - private[client] object hive { + private[hive] object hive { case object v12 extends HiveVersion("0.12.0") case object v13 extends HiveVersion("0.13.1") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 332c3ec0c28b8..59e65ff97b8e0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, ManagedTable} +import org.apache.spark.sql.hive.client.{ExternalTable, ManagedTable} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.sources.DataSourceTest import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DecimalType, StringType, StructType} import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.{Logging, SparkFunSuite} @@ -55,7 +55,10 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging { class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTestUtils { override val sqlContext = TestHive - private val testDF = (1 to 2).map(i => (i, s"val_$i")).toDF("d1", "d2").coalesce(1) + private val testDF = range(1, 3).select( + ('id + 0.1) cast DecimalType(10, 3) as 'd1, + 'id cast StringType as 'd2 + ).coalesce(1) Seq( "parquet" -> ( @@ -88,10 +91,10 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTes val columns = hiveTable.schema assert(columns.map(_.name) === Seq("d1", "d2")) - assert(columns.map(_.hiveType) === Seq("int", "string")) + assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string")) checkAnswer(table("t"), testDF) - assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1", "2\tval_2")) + assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } @@ -117,10 +120,10 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTes val columns = hiveTable.schema assert(columns.map(_.name) === Seq("d1", "d2")) - assert(columns.map(_.hiveType) === Seq("int", "string")) + assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string")) checkAnswer(table("t"), testDF) - assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1", "2\tval_2")) + assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 1e1972d1ac353..0c29646114465 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -20,16 +20,18 @@ package org.apache.spark.sql.hive import java.io.File import scala.collection.mutable.ArrayBuffer -import scala.sys.process.{ProcessLogger, Process} +import scala.sys.process.{Process, ProcessLogger} +import org.scalatest.Matchers +import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException +import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.sql.QueryTest import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} +import org.apache.spark.sql.types.DecimalType import org.apache.spark.util.{ResetSystemProperties, Utils} -import org.scalatest.Matchers -import org.scalatest.concurrent.Timeouts -import org.scalatest.time.SpanSugar._ /** * This suite tests spark-submit with applications using HiveContext. @@ -50,8 +52,8 @@ class HiveSparkSubmitSuite val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) - val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath() - val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath() + val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath + val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath val jarsString = Seq(jar1, jar2, jar3, jar4).map(j => j.toString).mkString(",") val args = Seq( "--class", SparkSubmitClassLoaderTest.getClass.getName.stripSuffix("$"), @@ -91,6 +93,16 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } + test("SPARK-9757 Persist Parquet relation with decimal column") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SPARK_9757.getClass.getName.stripSuffix("$"), + "--name", "SparkSQLConfTest", + "--master", "local-cluster[2,1,1024]", + unusedJar.toString) + runSparkSubmit(args) + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. // This is copied from org.apache.spark.deploy.SparkSubmitSuite private def runSparkSubmit(args: Seq[String]): Unit = { @@ -213,7 +225,7 @@ object SparkSQLConfTest extends Logging { // before spark.sql.hive.metastore.jars get set, we will see the following exception: // Exception in thread "main" java.lang.IllegalArgumentException: Builtin jars can only // be used when hive execution version == hive metastore version. - // Execution: 0.13.1 != Metastore: 0.12. Specify a vaild path to the correct hive jars + // Execution: 0.13.1 != Metastore: 0.12. Specify a valid path to the correct hive jars // using $HIVE_METASTORE_JARS or change spark.sql.hive.metastore.version to 0.13.1. val conf = new SparkConf() { override def getAll: Array[(String, String)] = { @@ -239,3 +251,45 @@ object SparkSQLConfTest extends Logging { sc.stop() } } + +object SPARK_9757 extends QueryTest with Logging { + def main(args: Array[String]): Unit = { + Utils.configTestLog4j("INFO") + + val sparkContext = new SparkContext( + new SparkConf() + .set("spark.sql.hive.metastore.version", "0.13.1") + .set("spark.sql.hive.metastore.jars", "maven")) + + val hiveContext = new TestHiveContext(sparkContext) + import hiveContext.implicits._ + import org.apache.spark.sql.functions._ + + val dir = Utils.createTempDir() + dir.delete() + + try { + { + val df = + hiveContext + .range(10) + .select(('id + 0.1) cast DecimalType(10, 3) as 'dec) + df.write.option("path", dir.getCanonicalPath).mode("overwrite").saveAsTable("t") + checkAnswer(hiveContext.table("t"), df) + } + + { + val df = + hiveContext + .range(10) + .select(callUDF("struct", ('id + 0.2) cast DecimalType(10, 3)) as 'dec_struct) + df.write.option("path", dir.getCanonicalPath).mode("overwrite").saveAsTable("t") + checkAnswer(hiveContext.table("t"), df) + } + } finally { + dir.delete() + hiveContext.sql("DROP TABLE t") + sparkContext.stop() + } + } +} From 2932e25da4532de9e86b01d08bce0cb680874e70 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Thu, 13 Aug 2015 09:17:19 -0700 Subject: [PATCH 328/340] [SPARK-9073] [ML] spark.ml Models copy() should call setParent when there is a parent Copied ML models must have the same parent of original ones Author: lewuathe Author: Lewuathe Closes #7447 from Lewuathe/SPARK-9073. --- .../examples/ml/JavaDeveloperApiExample.java | 3 +- .../examples/ml/DeveloperApiExample.scala | 2 +- .../scala/org/apache/spark/ml/Pipeline.scala | 2 +- .../DecisionTreeClassifier.scala | 1 + .../ml/classification/GBTClassifier.scala | 2 +- .../classification/LogisticRegression.scala | 2 +- .../spark/ml/classification/OneVsRest.scala | 2 +- .../RandomForestClassifier.scala | 1 + .../apache/spark/ml/feature/Bucketizer.scala | 4 ++- .../org/apache/spark/ml/feature/IDF.scala | 2 +- .../spark/ml/feature/MinMaxScaler.scala | 2 +- .../org/apache/spark/ml/feature/PCA.scala | 2 +- .../spark/ml/feature/StandardScaler.scala | 2 +- .../spark/ml/feature/StringIndexer.scala | 2 +- .../spark/ml/feature/VectorIndexer.scala | 2 +- .../apache/spark/ml/feature/Word2Vec.scala | 2 +- .../apache/spark/ml/recommendation/ALS.scala | 2 +- .../ml/regression/DecisionTreeRegressor.scala | 2 +- .../spark/ml/regression/GBTRegressor.scala | 2 +- .../ml/regression/LinearRegression.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 2 +- .../spark/ml/tuning/CrossValidator.scala | 2 +- .../org/apache/spark/ml/PipelineSuite.scala | 3 ++ .../DecisionTreeClassifierSuite.scala | 4 +++ .../classification/GBTClassifierSuite.scala | 4 +++ .../LogisticRegressionSuite.scala | 4 +++ .../ml/classification/OneVsRestSuite.scala | 6 +++- .../RandomForestClassifierSuite.scala | 4 +++ .../spark/ml/feature/BucketizerSuite.scala | 1 + .../spark/ml/feature/MinMaxScalerSuite.scala | 4 +++ .../apache/spark/ml/feature/PCASuite.scala | 4 +++ .../spark/ml/feature/StringIndexerSuite.scala | 5 ++++ .../spark/ml/feature/VectorIndexerSuite.scala | 5 ++++ .../spark/ml/feature/Word2VecSuite.scala | 4 +++ .../spark/ml/recommendation/ALSSuite.scala | 4 +++ .../DecisionTreeRegressorSuite.scala | 11 +++++++ .../ml/regression/GBTRegressorSuite.scala | 5 ++++ .../ml/regression/LinearRegressionSuite.scala | 5 ++++ .../RandomForestRegressorSuite.scala | 7 ++++- .../spark/ml/tuning/CrossValidatorSuite.scala | 5 ++++ .../apache/spark/ml/util/MLTestingUtils.scala | 30 +++++++++++++++++++ 41 files changed, 138 insertions(+), 22 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index 9df26ffca5775..3f1fe900b0008 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -230,6 +230,7 @@ public Vector predictRaw(Vector features) { */ @Override public MyJavaLogisticRegressionModel copy(ParamMap extra) { - return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra); + return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra) + .setParent(parent()); } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 78f31b4ffe56a..340c3559b15ef 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -179,7 +179,7 @@ private class MyLogisticRegressionModel( * This is used for the default implementation of [[transform()]]. */ override def copy(extra: ParamMap): MyLogisticRegressionModel = { - copyValues(new MyLogisticRegressionModel(uid, weights), extra) + copyValues(new MyLogisticRegressionModel(uid, weights), extra).setParent(parent) } } // scalastyle:on println diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index aef2c019d2871..a3e59401c5cfb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -198,6 +198,6 @@ class PipelineModel private[ml] ( } override def copy(extra: ParamMap): PipelineModel = { - new PipelineModel(uid, stages.map(_.copy(extra))) + new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 29598f3f05c2d..6f70b96b17ec6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -141,6 +141,7 @@ final class DecisionTreeClassificationModel private[ml] ( override def copy(extra: ParamMap): DecisionTreeClassificationModel = { copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra) + .setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index c3891a9599262..3073a2a61ce83 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -196,7 +196,7 @@ final class GBTClassificationModel( } override def copy(extra: ParamMap): GBTClassificationModel = { - copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra) + copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra).setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 5bcd7117b668c..21fbe38ca8233 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -468,7 +468,7 @@ class LogisticRegressionModel private[ml] ( } override def copy(extra: ParamMap): LogisticRegressionModel = { - copyValues(new LogisticRegressionModel(uid, weights, intercept), extra) + copyValues(new LogisticRegressionModel(uid, weights, intercept), extra).setParent(parent) } override protected def raw2prediction(rawPrediction: Vector): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 1741f19dc911c..1132d8046df67 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -138,7 +138,7 @@ final class OneVsRestModel private[ml] ( override def copy(extra: ParamMap): OneVsRestModel = { val copied = new OneVsRestModel( uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]])) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 156050aaf7a45..11a6d72468333 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -189,6 +189,7 @@ final class RandomForestClassificationModel private[ml] ( override def copy(extra: ParamMap): RandomForestClassificationModel = { copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra) + .setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 67e4785bc3553..cfca494dcf468 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -90,7 +90,9 @@ final class Bucketizer(override val uid: String) SchemaUtils.appendColumn(schema, prepOutputField(schema)) } - override def copy(extra: ParamMap): Bucketizer = defaultCopy(extra) + override def copy(extra: ParamMap): Bucketizer = { + defaultCopy[Bucketizer](extra).setParent(parent) + } } private[feature] object Bucketizer { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index ecde80810580c..938447447a0a2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -114,6 +114,6 @@ class IDFModel private[ml] ( override def copy(extra: ParamMap): IDFModel = { val copied = new IDFModel(uid, idfModel) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 9a473dd23772d..1b494ec8b1727 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -173,6 +173,6 @@ class MinMaxScalerModel private[ml] ( override def copy(extra: ParamMap): MinMaxScalerModel = { val copied = new MinMaxScalerModel(uid, originalMin, originalMax) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 2d3bb680cf309..539084704b653 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -125,6 +125,6 @@ class PCAModel private[ml] ( override def copy(extra: ParamMap): PCAModel = { val copied = new PCAModel(uid, pcaModel) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 72b545e5db3e4..f6d0b0c0e9e75 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -136,6 +136,6 @@ class StandardScalerModel private[ml] ( override def copy(extra: ParamMap): StandardScalerModel = { val copied = new StandardScalerModel(uid, scaler) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index e4485eb038409..9e4b0f0add612 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -168,7 +168,7 @@ class StringIndexerModel private[ml] ( override def copy(extra: ParamMap): StringIndexerModel = { val copied = new StringIndexerModel(uid, labels) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index c73bdccdef5fa..6875aefe065bb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -405,6 +405,6 @@ class VectorIndexerModel private[ml] ( override def copy(extra: ParamMap): VectorIndexerModel = { val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 29acc3eb5865f..5af775a4159ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -221,6 +221,6 @@ class Word2VecModel private[ml] ( override def copy(extra: ParamMap): Word2VecModel = { val copied = new Word2VecModel(uid, wordVectors) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 2e44cd4cc6a22..7db8ad8d27918 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -219,7 +219,7 @@ class ALSModel private[ml] ( override def copy(extra: ParamMap): ALSModel = { val copied = new ALSModel(uid, rank, userFactors, itemFactors) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index dc94a14014542..a2bcd67401d08 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -114,7 +114,7 @@ final class DecisionTreeRegressionModel private[ml] ( } override def copy(extra: ParamMap): DecisionTreeRegressionModel = { - copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra) + copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra).setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 5633bc320273a..b66e61f37dd5e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -185,7 +185,7 @@ final class GBTRegressionModel( } override def copy(extra: ParamMap): GBTRegressionModel = { - copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra) + copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra).setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 92d819bad8654..884003eb38524 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -312,7 +312,7 @@ class LinearRegressionModel private[ml] ( override def copy(extra: ParamMap): LinearRegressionModel = { val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept)) if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) - newModel + newModel.setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index db75c0d26392f..2f36da371f577 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -151,7 +151,7 @@ final class RandomForestRegressionModel private[ml] ( } override def copy(extra: ParamMap): RandomForestRegressionModel = { - copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra) + copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index f979319cc4b58..4792eb0f0a288 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -160,6 +160,6 @@ class CrossValidatorModel private[ml] ( uid, bestModel.copy(extra).asInstanceOf[Model[_]], avgMetrics.clone()) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 63d2fa31c7499..1f2c9b75b617b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.HashingTF import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.sql.DataFrame class PipelineSuite extends SparkFunSuite { @@ -65,6 +66,8 @@ class PipelineSuite extends SparkFunSuite { .setStages(Array(estimator0, transformer1, estimator2, transformer3)) val pipelineModel = pipeline.fit(dataset0) + MLTestingUtils.checkCopy(pipelineModel) + assert(pipelineModel.stages.length === 4) assert(pipelineModel.stages(0).eq(model0)) assert(pipelineModel.stages(1).eq(transformer1)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index c7bbf1ce07a23..4b7c5d3f23d2c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} @@ -244,6 +245,9 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) val newTree = dt.fit(newData) + // copied model must have the same parent. + MLTestingUtils.checkCopy(newTree) + val predictions = newTree.transform(newData) .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol) .collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index d4b5896c12c06..e3909bccaa5ca 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -92,6 +93,9 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { .setCheckpointInterval(2) val model = gbt.fit(df) + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + sc.checkpointDir = None Utils.deleteRecursively(tempDir) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index e354e161c6dee..cce39f382f738 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -135,6 +136,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { lr.setFitIntercept(false) val model = lr.fit(dataset) assert(model.intercept === 0.0) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) } test("logistic regression with setters") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index bd8e819f6926c..977f0e0b70c1a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.feature.StringIndexer import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.ml.util.{MLTestingUtils, MetadataUtils} import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.evaluation.MulticlassMetrics @@ -70,6 +70,10 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(ova.getLabelCol === "label") assert(ova.getPredictionCol === "prediction") val ovaModel = ova.fit(dataset) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(ovaModel) + assert(ovaModel.models.size === numClasses) val transformedDataset = ovaModel.transform(dataset) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 6ca4b5aa5fde8..b4403ec30049a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} @@ -135,6 +136,9 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) val model = rf.fit(df) + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + val predictions = model.transform(df) .select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol) .collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index ec85e0d151e07..0eba34fda6228 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index c452054bec92f..c04dda41eea34 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Row, SQLContext} @@ -51,6 +52,9 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext { .foreach { case Row(vector1: Vector, vector2: Vector) => assert(vector1.equals(vector2), "Transformed vector is different with expected.") } + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) } test("MinMaxScaler arguments max must be larger than min") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index d0ae36b28c7a9..30c500f87a769 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, Matrices} import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -56,6 +57,9 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext { .setK(3) .fit(df) + // copied model must have the same parent. + MLTestingUtils.checkCopy(pca) + pca.transform(df).select("pca_features", "expected").collect().foreach { case Row(x: Vector, y: Vector) => assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index b111036087e6a..2d24914cb91f6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkException import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.util.MLlibTestSparkContext class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -38,6 +39,10 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { .setInputCol("label") .setOutputCol("labelIndex") .fit(df) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(indexer) + val transformed = indexer.transform(df) val attr = Attribute.fromStructField(transformed.schema("labelIndex")) .asInstanceOf[NominalAttribute] diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 03120c828ca96..8cb0a2cf14d37 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -22,6 +22,7 @@ import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD @@ -109,6 +110,10 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with L test("Throws error when given RDDs with different size vectors") { val vectorIndexer = getIndexer val model = vectorIndexer.fit(densePoints1) // vectors of length 3 + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + model.transform(densePoints1) // should work model.transform(sparsePoints1) // should work intercept[SparkException] { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index adcda0e623b25..a2e46f2029956 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -62,6 +63,9 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { .setSeed(42L) .fit(docDF) + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + model.transform(docDF).select("result", "expected").collect().foreach { case Row(vector1: Vector, vector2: Vector) => assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.") diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 2e5cfe7027eb6..eadc80e0e62b1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -28,6 +28,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.recommendation.ALS._ +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -374,6 +375,9 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { } logInfo(s"Test RMSE is $rmse.") assert(rmse < targetRMSE) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) } test("exact rank-1 matrix") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 33aa9d0d62343..b092bcd6a7e86 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} @@ -61,6 +62,16 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) } + test("copied model must have the same parent") { + val categoricalFeatures = Map(0 -> 2, 1-> 2) + val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) + val model = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(8).fit(df) + MLTestingUtils.checkCopy(model) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index dbdce0c9dea54..a68197b59193d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} @@ -82,6 +83,9 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { .setMaxDepth(2) .setMaxIter(2) val model = gbt.fit(df) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) val preds = model.transform(df) val predictions = preds.select("prediction").map(_.getDouble(0)) // Checks based on SPARK-8736 (to ensure it is not doing classification) @@ -104,6 +108,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { sc.checkpointDir = None Utils.deleteRecursively(tempDir) + } // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 21ad8225bd9f7..2aaee71ecc734 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{DenseVector, Vectors} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ @@ -72,6 +73,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(lir.getFitIntercept) assert(lir.getStandardization) val model = lir.fit(dataset) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + model.transform(dataset) .select("label", "prediction") .collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 992ce9562434e..7b1b3f11481de 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} @@ -91,7 +92,11 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex val categoricalFeatures = Map.empty[Int, Int] val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0) - val importances = rf.fit(df).featureImportances + val model = rf.fit(df) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + val importances = model.featureImportances val mostImportantFeature = importances.argmax assert(mostImportantFeature === 1) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index db64511a76055..aaca08bb61a45 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} @@ -53,6 +54,10 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { .setEvaluator(eval) .setNumFolds(3) val cvModel = cv.fit(dataset) + + // copied model must have the same paren. + MLTestingUtils.checkCopy(cvModel) + val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala new file mode 100644 index 0000000000000..d290cc9b06e73 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.ml.Model +import org.apache.spark.ml.param.ParamMap + +object MLTestingUtils { + def checkCopy(model: Model[_]): Unit = { + val copied = model.copy(ParamMap.empty) + .asInstanceOf[Model[_]] + assert(copied.parent.uid == model.parent.uid) + assert(copied.parent == model.parent) + } +} From 7a539ef3b1792764f866fa88c84c78ad59903f21 Mon Sep 17 00:00:00 2001 From: Rosstin Date: Thu, 13 Aug 2015 09:18:39 -0700 Subject: [PATCH 329/340] [SPARK-8965] [DOCS] Add ml-guide Python Example: Estimator, Transformer, and Param Added ml-guide Python Example: Estimator, Transformer, and Param /docs/_site/ml-guide.html Author: Rosstin Closes #8081 from Rosstin/SPARK-8965. --- docs/ml-guide.md | 68 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index b6ca50e98db02..a03ab4356a413 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -355,6 +355,74 @@ jsc.stop(); {% endhighlight %} +
    +{% highlight python %} +from pyspark import SparkContext +from pyspark.mllib.regression import LabeledPoint +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.param import Param, Params +from pyspark.sql import Row, SQLContext + +sc = SparkContext(appName="SimpleParamsExample") +sqlContext = SQLContext(sc) + +# Prepare training data. +# We use LabeledPoint. +# Spark SQL can convert RDDs of LabeledPoints into DataFrames. +training = sc.parallelize([LabeledPoint(1.0, [0.0, 1.1, 0.1]), + LabeledPoint(0.0, [2.0, 1.0, -1.0]), + LabeledPoint(0.0, [2.0, 1.3, 1.0]), + LabeledPoint(1.0, [0.0, 1.2, -0.5])]) + +# Create a LogisticRegression instance. This instance is an Estimator. +lr = LogisticRegression(maxIter=10, regParam=0.01) +# Print out the parameters, documentation, and any default values. +print "LogisticRegression parameters:\n" + lr.explainParams() + "\n" + +# Learn a LogisticRegression model. This uses the parameters stored in lr. +model1 = lr.fit(training.toDF()) + +# Since model1 is a Model (i.e., a transformer produced by an Estimator), +# we can view the parameters it used during fit(). +# This prints the parameter (name: value) pairs, where names are unique IDs for this +# LogisticRegression instance. +print "Model 1 was fit using parameters: " +print model1.extractParamMap() + +# We may alternatively specify parameters using a Python dictionary as a paramMap +paramMap = {lr.maxIter: 20} +paramMap[lr.maxIter] = 30 # Specify 1 Param, overwriting the original maxIter. +paramMap.update({lr.regParam: 0.1, lr.threshold: 0.55}) # Specify multiple Params. + +# You can combine paramMaps, which are python dictionaries. +paramMap2 = {lr.probabilityCol: "myProbability"} # Change output column name +paramMapCombined = paramMap.copy() +paramMapCombined.update(paramMap2) + +# Now learn a new model using the paramMapCombined parameters. +# paramMapCombined overrides all parameters set earlier via lr.set* methods. +model2 = lr.fit(training.toDF(), paramMapCombined) +print "Model 2 was fit using parameters: " +print model2.extractParamMap() + +# Prepare test data +test = sc.parallelize([LabeledPoint(1.0, [-1.0, 1.5, 1.3]), + LabeledPoint(0.0, [ 3.0, 2.0, -0.1]), + LabeledPoint(1.0, [ 0.0, 2.2, -1.5])]) + +# Make predictions on test data using the Transformer.transform() method. +# LogisticRegression.transform will only use the 'features' column. +# Note that model2.transform() outputs a "myProbability" column instead of the usual +# 'probability' column since we renamed the lr.probabilityCol parameter previously. +prediction = model2.transform(test.toDF()) +selected = prediction.select("features", "label", "myProbability", "prediction") +for row in selected.collect(): + print row + +sc.stop() +{% endhighlight %} +
    + ## Example: Pipeline From 4b70798c96b0a784b85fda461426ec60f609be12 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 13 Aug 2015 09:31:14 -0700 Subject: [PATCH 330/340] [MINOR] [ML] change MultilayerPerceptronClassifierModel to MultilayerPerceptronClassificationModel To follow the naming rule of ML, change `MultilayerPerceptronClassifierModel` to `MultilayerPerceptronClassificationModel` like `DecisionTreeClassificationModel`, `GBTClassificationModel` and so on. Author: Yanbo Liang Closes #8164 from yanboliang/mlp-name. --- .../MultilayerPerceptronClassifier.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 8cd2103d7d5e6..c154561886585 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -131,7 +131,7 @@ private object LabelConverter { */ @Experimental class MultilayerPerceptronClassifier(override val uid: String) - extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassifierModel] + extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel] with MultilayerPerceptronParams { def this() = this(Identifiable.randomUID("mlpc")) @@ -146,7 +146,7 @@ class MultilayerPerceptronClassifier(override val uid: String) * @param dataset Training dataset * @return Fitted model */ - override protected def train(dataset: DataFrame): MultilayerPerceptronClassifierModel = { + override protected def train(dataset: DataFrame): MultilayerPerceptronClassificationModel = { val myLayers = $(layers) val labels = myLayers.last val lpData = extractLabeledPoints(dataset) @@ -156,13 +156,13 @@ class MultilayerPerceptronClassifier(override val uid: String) FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol($(tol)).setNumIterations($(maxIter)) FeedForwardTrainer.setStackSize($(blockSize)) val mlpModel = FeedForwardTrainer.train(data) - new MultilayerPerceptronClassifierModel(uid, myLayers, mlpModel.weights()) + new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights()) } } /** * :: Experimental :: - * Classifier model based on the Multilayer Perceptron. + * Classification model based on the Multilayer Perceptron. * Each layer has sigmoid activation function, output layer has softmax. * @param uid uid * @param layers array of layer sizes including input and output layers @@ -170,11 +170,11 @@ class MultilayerPerceptronClassifier(override val uid: String) * @return prediction model */ @Experimental -class MultilayerPerceptronClassifierModel private[ml] ( +class MultilayerPerceptronClassificationModel private[ml] ( override val uid: String, layers: Array[Int], weights: Vector) - extends PredictionModel[Vector, MultilayerPerceptronClassifierModel] + extends PredictionModel[Vector, MultilayerPerceptronClassificationModel] with Serializable { private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) @@ -187,7 +187,7 @@ class MultilayerPerceptronClassifierModel private[ml] ( LabelConverter.decodeLabel(mlpModel.predict(features)) } - override def copy(extra: ParamMap): MultilayerPerceptronClassifierModel = { - copyValues(new MultilayerPerceptronClassifierModel(uid, layers, weights), extra) + override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = { + copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra) } } From 65fec798ce52ca6b8b0fe14b78a16712778ad04c Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 13 Aug 2015 10:16:40 -0700 Subject: [PATCH 331/340] [MINOR] [DOC] fix mllib pydoc warnings Switch to correct Sphinx syntax. MechCoder Author: Xiangrui Meng Closes #8169 from mengxr/mllib-pydoc-fix. --- python/pyspark/mllib/regression.py | 14 ++++++++++---- python/pyspark/mllib/util.py | 1 + 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 5b7afc15ddfba..41946e3674fbe 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -207,8 +207,10 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, Train a linear regression model using Stochastic Gradient Descent (SGD). This solves the least squares regression formulation - f(weights) = 1/n ||A weights-y||^2^ - (which is the mean squared error). + + f(weights) = 1/(2n) ||A weights - y||^2, + + which is the mean squared error. Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with its corresponding right hand side label y. See also the documentation for the precise formulation. @@ -334,7 +336,9 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, Stochastic Gradient Descent. This solves the l1-regularized least squares regression formulation - f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1 + + f(weights) = 1/(2n) ||A weights - y||^2 + regParam ||weights||_1. + Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with its corresponding right hand side label y. See also the documentation for the precise formulation. @@ -451,7 +455,9 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, Stochastic Gradient Descent. This solves the l2-regularized least squares regression formulation - f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^ + + f(weights) = 1/(2n) ||A weights - y||^2 + regParam/2 ||weights||^2. + Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with its corresponding right hand side label y. See also the documentation for the precise formulation. diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 916de2d6fcdbd..10a1e4b3eb0fc 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -300,6 +300,7 @@ def generateLinearInput(intercept, weights, xMean, xVariance, :param: seed Random Seed :param: eps Used to scale the noise. If eps is set high, the amount of gaussian noise added is more. + Returns a list of LabeledPoints of length nPoints """ weights = [float(weight) for weight in weights] From 8815ba2f674dbb18eb499216df9942b411e10daa Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 13 Aug 2015 11:31:10 -0700 Subject: [PATCH 332/340] [SPARK-9649] Fix MasterSuite, third time's a charm This particular test did not load the default configurations so it continued to start the REST server, which causes port bind exceptions. --- .../test/scala/org/apache/spark/deploy/master/MasterSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 20d0201a364ab..242bf4b5566eb 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -40,6 +40,7 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva conf.set("spark.deploy.recoveryMode", "CUSTOM") conf.set("spark.deploy.recoveryMode.factory", classOf[CustomRecoveryModeFactory].getCanonicalName) + conf.set("spark.master.rest.enabled", "false") val instantiationAttempts = CustomRecoveryModeFactory.instantiationAttempts From 864de8eaf4b6ad5c9099f6f29e251c56b029f631 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Thu, 13 Aug 2015 13:42:35 -0700 Subject: [PATCH 333/340] [SPARK-9661] [MLLIB] [ML] Java compatibility I skimmed through the docs for various instance of Object and replaced them with Java compaible versions of the same. 1. Some methods in LDAModel. 2. runMiniBatchSGD 3. kolmogorovSmirnovTest Author: MechCoder Closes #8126 from MechCoder/java_incop. --- .../spark/mllib/clustering/LDAModel.scala | 27 +++++++++++++++++-- .../apache/spark/mllib/stat/Statistics.scala | 16 ++++++++++- .../spark/mllib/clustering/JavaLDASuite.java | 24 +++++++++++++++++ .../spark/mllib/stat/JavaStatisticsSuite.java | 22 +++++++++++++++ .../spark/mllib/clustering/LDASuite.scala | 13 +++++++++ 5 files changed, 99 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 5dc637ebdc133..f31949f13a4cf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -26,7 +26,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.JavaPairRDD +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId} import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} @@ -228,6 +228,11 @@ class LocalLDAModel private[clustering] ( docConcentration, topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k, vocabSize) + /** Java-friendly version of [[logLikelihood]] */ + def logLikelihood(documents: JavaPairRDD[java.lang.Long, Vector]): Double = { + logLikelihood(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) + } + /** * Calculate an upper bound bound on perplexity. (Lower is better.) * See Equation (16) in original Online LDA paper. @@ -242,6 +247,11 @@ class LocalLDAModel private[clustering] ( -logLikelihood(documents) / corpusTokenCount } + /** Java-friendly version of [[logPerplexity]] */ + def logPerplexity(documents: JavaPairRDD[java.lang.Long, Vector]): Double = { + logPerplexity(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) + } + /** * Estimate the variational likelihood bound of from `documents`: * log p(documents) >= E_q[log p(documents)] - E_q[log q(documents)] @@ -341,8 +351,14 @@ class LocalLDAModel private[clustering] ( } } -} + /** Java-friendly version of [[topicDistributions]] */ + def topicDistributions( + documents: JavaPairRDD[java.lang.Long, Vector]): JavaPairRDD[java.lang.Long, Vector] = { + val distributions = topicDistributions(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) + JavaPairRDD.fromRDD(distributions.asInstanceOf[RDD[(java.lang.Long, Vector)]]) + } +} @Experimental object LocalLDAModel extends Loader[LocalLDAModel] { @@ -657,6 +673,13 @@ class DistributedLDAModel private[clustering] ( } } + /** Java-friendly version of [[topTopicsPerDocument]] */ + def javaTopTopicsPerDocument( + k: Int): JavaRDD[(java.lang.Long, Array[Int], Array[java.lang.Double])] = { + val topics = topTopicsPerDocument(k) + topics.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[java.lang.Double])]].toJavaRDD() + } + // TODO: // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index f84502919e381..24fe48cb8f71f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.stat import scala.annotation.varargs import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.java.{JavaRDD, JavaDoubleRDD} import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.linalg.{Matrix, Vector} import org.apache.spark.mllib.regression.LabeledPoint @@ -178,6 +178,9 @@ object Statistics { ChiSqTest.chiSquaredFeatures(data) } + /** Java-friendly version of [[chiSqTest()]] */ + def chiSqTest(data: JavaRDD[LabeledPoint]): Array[ChiSqTestResult] = chiSqTest(data.rdd) + /** * Conduct the two-sided Kolmogorov-Smirnov (KS) test for data sampled from a * continuous distribution. By comparing the largest difference between the empirical cumulative @@ -212,4 +215,15 @@ object Statistics { : KolmogorovSmirnovTestResult = { KolmogorovSmirnovTest.testOneSample(data, distName, params: _*) } + + /** Java-friendly version of [[kolmogorovSmirnovTest()]] */ + @varargs + def kolmogorovSmirnovTest( + data: JavaDoubleRDD, + distName: String, + params: java.lang.Double*): KolmogorovSmirnovTestResult = { + val javaParams = params.asInstanceOf[Seq[Double]] + KolmogorovSmirnovTest.testOneSample(data.rdd.asInstanceOf[RDD[Double]], + distName, javaParams: _*) + } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index d272a42c8576f..427be9430d820 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -124,6 +124,10 @@ public Boolean call(Tuple2 tuple2) { } }); assertEquals(topicDistributions.count(), nonEmptyCorpus.count()); + + // Check: javaTopTopicsPerDocuments + JavaRDD> topTopics = + model.javaTopTopicsPerDocument(3); } @Test @@ -160,11 +164,31 @@ public void OnlineOptimizerCompatibility() { assertEquals(roundedLocalTopicSummary.length, k); } + @Test + public void localLdaMethods() { + JavaRDD> docs = sc.parallelize(toyData, 2); + JavaPairRDD pairedDocs = JavaPairRDD.fromJavaRDD(docs); + + // check: topicDistributions + assertEquals(toyModel.topicDistributions(pairedDocs).count(), pairedDocs.count()); + + // check: logPerplexity + double logPerplexity = toyModel.logPerplexity(pairedDocs); + + // check: logLikelihood. + ArrayList> docsSingleWord = new ArrayList>(); + docsSingleWord.add(new Tuple2(Long.valueOf(0), Vectors.dense(1.0, 0.0, 0.0))); + JavaPairRDD single = JavaPairRDD.fromJavaRDD(sc.parallelize(docsSingleWord)); + double logLikelihood = toyModel.logLikelihood(single); + } + private static int tinyK = LDASuite$.MODULE$.tinyK(); private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize(); private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics(); private static Tuple2[] tinyTopicDescription = LDASuite$.MODULE$.tinyTopicDescription(); private JavaPairRDD corpus; + private LocalLDAModel toyModel = LDASuite$.MODULE$.toyModel(); + private ArrayList> toyData = LDASuite$.MODULE$.javaToyData(); } diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java index 62f7f26b7c98f..eb4e3698624bc 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java @@ -27,7 +27,12 @@ import static org.junit.Assert.assertEquals; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaDoubleRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.stat.test.ChiSqTestResult; +import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult; public class JavaStatisticsSuite implements Serializable { private transient JavaSparkContext sc; @@ -53,4 +58,21 @@ public void testCorr() { // Check default method assertEquals(corr1, corr2); } + + @Test + public void kolmogorovSmirnovTest() { + JavaDoubleRDD data = sc.parallelizeDoubles(Lists.newArrayList(0.2, 1.0, -1.0, 2.0)); + KolmogorovSmirnovTestResult testResult1 = Statistics.kolmogorovSmirnovTest(data, "norm"); + KolmogorovSmirnovTestResult testResult2 = Statistics.kolmogorovSmirnovTest( + data, "norm", 0.0, 1.0); + } + + @Test + public void chiSqTest() { + JavaRDD data = sc.parallelize(Lists.newArrayList( + new LabeledPoint(0.0, Vectors.dense(0.1, 2.3)), + new LabeledPoint(1.0, Vectors.dense(1.5, 5.1)), + new LabeledPoint(0.0, Vectors.dense(2.4, 8.1)))); + ChiSqTestResult[] testResults = Statistics.chiSqTest(data); + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index ce6a8eb8e8c46..926185e90bcf9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.clustering +import java.util.{ArrayList => JArrayList} + import breeze.linalg.{DenseMatrix => BDM, argtopk, max, argmax} import org.apache.spark.SparkFunSuite @@ -575,6 +577,17 @@ private[clustering] object LDASuite { Vectors.sparse(6, Array(4, 5), Array(1, 1)) ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + /** Used in the Java Test Suite */ + def javaToyData: JArrayList[(java.lang.Long, Vector)] = { + val javaData = new JArrayList[(java.lang.Long, Vector)] + var i = 0 + while (i < toyData.size) { + javaData.add((toyData(i)._1, toyData(i)._2)) + i += 1 + } + javaData + } + def toyModel: LocalLDAModel = { val k = 2 val vocabSize = 6 From a8d2f4c5f92210a09c846711bd7cc89a43e2edd2 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 13 Aug 2015 14:03:55 -0700 Subject: [PATCH 334/340] [SPARK-9942] [PYSPARK] [SQL] ignore exceptions while try to import pandas If pandas is broken (can't be imported, raise other exceptions other than ImportError), pyspark can't be imported, we should ignore all the exceptions. Author: Davies Liu Closes #8173 from davies/fix_pandas. --- python/pyspark/sql/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 917de24f3536b..0ef46c44644ab 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -39,7 +39,7 @@ try: import pandas has_pandas = True -except ImportError: +except Exception: has_pandas = False __all__ = ["SQLContext", "HiveContext", "UDFRegistration"] From c2520f501a200cf794bbe5dc9c385100f518d020 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 13 Aug 2015 16:07:03 -0700 Subject: [PATCH 335/340] [SPARK-9935] [SQL] EqualNotNull not processed in ORC https://issues.apache.org/jira/browse/SPARK-9935 Author: hyukjinkwon Closes #8163 from HyukjinKwon/master. --- .../scala/org/apache/spark/sql/hive/orc/OrcFilters.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index 86142e5d66f37..b3d9f7f71a27d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -107,6 +107,11 @@ private[orc] object OrcFilters extends Logging { .filter(isSearchableLiteral) .map(builder.equals(attribute, _)) + case EqualNullSafe(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.nullSafeEquals(attribute, _)) + case LessThan(attribute, value) => Option(value) .filter(isSearchableLiteral) From 6c5858bc65c8a8602422b46bfa9cf0a1fb296b88 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 13 Aug 2015 16:52:17 -0700 Subject: [PATCH 336/340] [SPARK-9922] [ML] rename StringIndexerReverse to IndexToString What `StringIndexerInverse` does is not strictly associated with `StringIndexer`, and the name is not clearly describing the transformation. Renaming to `IndexToString` might be better. ~~I also changed `invert` to `inverse` without arguments. `inputCol` and `outputCol` could be set after.~~ I also removed `invert`. jkbradley holdenk Author: Xiangrui Meng Closes #8152 from mengxr/SPARK-9922. --- .../spark/ml/feature/StringIndexer.scala | 34 +++++-------- .../spark/ml/feature/StringIndexerSuite.scala | 50 +++++++++++++------ 2 files changed, 48 insertions(+), 36 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 9e4b0f0add612..9f6e7b6b6b274 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.Transformer -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType} @@ -59,6 +59,8 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha * If the input column is numeric, we cast it to string and index the string values. * The indices are in [0, numLabels), ordered by label frequencies. * So the most frequent label gets index 0. + * + * @see [[IndexToString]] for the inverse transformation */ @Experimental class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] @@ -170,34 +172,24 @@ class StringIndexerModel private[ml] ( val copied = new StringIndexerModel(uid, labels) copyValues(copied, extra).setParent(parent) } - - /** - * Return a model to perform the inverse transformation. - * Note: By default we keep the original columns during this transformation, so the inverse - * should only be used on new columns such as predicted labels. - */ - def invert(inputCol: String, outputCol: String): StringIndexerInverse = { - new StringIndexerInverse() - .setInputCol(inputCol) - .setOutputCol(outputCol) - .setLabels(labels) - } } /** * :: Experimental :: - * Transform a provided column back to the original input types using either the metadata - * on the input column, or if provided using the labels supplied by the user. - * Note: By default we keep the original columns during this transformation, - * so the inverse should only be used on new columns such as predicted labels. + * A [[Transformer]] that maps a column of string indices back to a new column of corresponding + * string values using either the ML attributes of the input column, or if provided using the labels + * supplied by the user. + * All original columns are kept during transformation. + * + * @see [[StringIndexer]] for converting strings into indices */ @Experimental -class StringIndexerInverse private[ml] ( +class IndexToString private[ml] ( override val uid: String) extends Transformer with HasInputCol with HasOutputCol { def this() = - this(Identifiable.randomUID("strIdxInv")) + this(Identifiable.randomUID("idxToStr")) /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -257,7 +249,7 @@ class StringIndexerInverse private[ml] ( } val indexer = udf { index: Double => val idx = index.toInt - if (0 <= idx && idx < values.size) { + if (0 <= idx && idx < values.length) { values(idx) } else { throw new SparkException(s"Unseen index: $index ??") @@ -268,7 +260,7 @@ class StringIndexerInverse private[ml] ( indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName)) } - override def copy(extra: ParamMap): StringIndexerInverse = { + override def copy(extra: ParamMap): IndexToString = { defaultCopy(extra) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 2d24914cb91f6..fa918ce64877c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkException -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.col class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -53,19 +54,6 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { // a -> 0, b -> 2, c -> 1 val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) assert(output === expected) - // convert reverse our transform - val reversed = indexer.invert("labelIndex", "label2") - .transform(transformed) - .select("id", "label2") - assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === - reversed.collect().map(r => (r.getInt(0), r.getString(1))).toSet) - // Check invert using only metadata - val inverse2 = new StringIndexerInverse() - .setInputCol("labelIndex") - .setOutputCol("label2") - val reversed2 = inverse2.transform(transformed).select("id", "label2") - assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === - reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet) } test("StringIndexerUnseen") { @@ -125,4 +113,36 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { val df = sqlContext.range(0L, 10L) assert(indexerModel.transform(df).eq(df)) } + + test("IndexToString params") { + val idxToStr = new IndexToString() + ParamsSuite.checkParams(idxToStr) + } + + test("IndexToString.transform") { + val labels = Array("a", "b", "c") + val df0 = sqlContext.createDataFrame(Seq( + (0, "a"), (1, "b"), (2, "c"), (0, "a") + )).toDF("index", "expected") + + val idxToStr0 = new IndexToString() + .setInputCol("index") + .setOutputCol("actual") + .setLabels(labels) + idxToStr0.transform(df0).select("actual", "expected").collect().foreach { + case Row(actual, expected) => + assert(actual === expected) + } + + val attr = NominalAttribute.defaultAttr.withValues(labels) + val df1 = df0.select(col("index").as("indexWithAttr", attr.toMetadata()), col("expected")) + + val idxToStr1 = new IndexToString() + .setInputCol("indexWithAttr") + .setOutputCol("actual") + idxToStr1.transform(df1).select("actual", "expected").collect().foreach { + case Row(actual, expected) => + assert(actual === expected) + } + } } From 17c3f3db1d6761c49a5f2b6eb6163aa35abdc119 Mon Sep 17 00:00:00 2001 From: mcheah Date: Mon, 17 Aug 2015 15:16:23 -0700 Subject: [PATCH 337/340] Using ExternalList[_] in KryoSerializer. Clean up SpillableCollection.next --- .../spark/serializer/KryoSerializer.scala | 2 +- .../util/collection/SpillableCollection.scala | 17 +++++------------ 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 1acd994cd9d86..2eab6aff045eb 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -103,7 +103,7 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer()) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) - kryo.register(classOf[ExternalList[Any]], new ExternalList.ExternalListSerializer[Any]()) + kryo.register(classOf[ExternalList[_]], new ExternalListSerializer[Any]()) kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas)) kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas)) diff --git a/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala index 19ef8cc888666..3d2b1487cdba9 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala @@ -199,7 +199,7 @@ private[spark] trait SpillableCollection[C, T <: Iterable[C]] extends Spillable[ } } - override def hasNext: Boolean = { + override def hasNext(): Boolean = { if (!nextItem.isDefined) { if (deserializeStream == null) { return false @@ -210,19 +210,12 @@ private[spark] trait SpillableCollection[C, T <: Iterable[C]] extends Spillable[ } override def next(): C = { - val item = nextItem match { - case None => readNextItem() - case Some(theItem) => nextItem - } - if (!item.isDefined) { - throw new NoSuchElementException + if (!hasNext()) { + throw new NoSuchElementException() } + val nextValue = nextItem.get nextItem = None - item match { - case Some(value) => value - // Should never get here because of the throwing above - case None => null.asInstanceOf[C] - } + nextValue } protected def readNextItemFromStream(deserializeStream: DeserializationStream): C From 3d066fc9dd36893b09cd9733ef29aa3076ab6626 Mon Sep 17 00:00:00 2001 From: mcheah Date: Mon, 17 Aug 2015 17:44:39 -0700 Subject: [PATCH 338/340] Fixing unit test --- .../util/collection/ExternalListSuite.scala | 86 ++++++++++--------- 1 file changed, 46 insertions(+), 40 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala index d0334a30945d9..86dfdb1c6c798 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala @@ -22,51 +22,36 @@ import java.lang.ref.WeakReference import scala.language.existentials import scala.reflect.ClassTag -import org.apache.spark.{SparkEnv, SparkContext, SparkConf, SparkFunSuite} -import org.apache.spark.util.collection.ExternalListSuite._ +import org.apache.spark.{SharedSparkContext, SparkEnv, SparkFunSuite, TaskContextImpl, TaskContext} import org.apache.spark.serializer.{KryoSerializer, JavaSerializer, SerializerInstance} +import org.apache.spark.util.collection.ExternalListSuite._ +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.junit.Assert.{assertEquals, assertTrue, assertFalse} +import org.mockito.Mockito.mock import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -class ExternalListSuite extends SparkFunSuite with Serializable { - - val conf = new SparkConf(false) - conf.set("spark.kryoserializer.buffer.max", "2046m") - conf.set("spark.shuffle.spill.initialMemoryThreshold", "1") - conf.set("spark.shuffle.spill.batchSize", "10") - conf.set("spark.shuffle.memoryFraction", "0.035") - conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") - conf.set("spark.task.maxFailures", "1") - conf.setMaster("local[8]") - conf.setAppName("test") +class ExternalListSuite extends SparkFunSuite with SharedSparkContext { - val sparkContext = new SparkContext(conf) + override def beforeAll() { + conf.set("spark.kryoserializer.buffer.max", "2046m") + conf.set("spark.shuffle.spill.initialMemoryThreshold", "1") + conf.set("spark.shuffle.spill.batchSize", "10") + conf.set("spark.shuffle.memoryFraction", "0.035") + conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") + conf.set("spark.task.maxFailures", "1") + conf.setAppName("test") + super.beforeAll() + } test("Serializing and deserializing a spilled list should produce the same values") { - var serializer = new KryoSerializer(conf).newInstance() - var list = new ExternalList[Int] - // Test big list for Kryo because it's fast enough to handle it - // and we want to test the case where the list would spill to disk - for (i <- 0 to 5000000) { - list += i - } - testSerialization(serializer, list) - serializer = new JavaSerializer(conf).newInstance() - list = new ExternalList[Int] - // Test smaller list for Java serialization since serializing with Java is - // really slow, and we already test serialization causing spilling in the Kryo case - for (i <- 0 to 1000000) { - list += i - } - testSerialization(serializer, list) + testSerialization(new KryoSerializer(conf).newInstance(), 4500000) + testSerialization(new JavaSerializer(conf).newInstance(), 3000) } - val totalRddSize = 7200000 - val numBuckets = 5 - val rawLargeRdd = sparkContext.parallelize(1 to totalRddSize) test("Lists that are cached should be accessible twice, but when unpersisted are cleaned up.") { + val rawLargeRdd = sc.parallelize(1 to totalRddSize) val groupedRdd = rawLargeRdd.map(x => (x % numBuckets, x)).groupByKey val cachedRdd = groupedRdd.cache() cachedRdd.foreach(validateList(totalRddSize, numBuckets, _)) @@ -84,6 +69,7 @@ class ExternalListSuite extends SparkFunSuite with Serializable { } test("List that is created in a task and released immediately should eventually clean up") { + val rawLargeRdd = sc.parallelize(1 to totalRddSize) val filePaths = rawLargeRdd .map(x => (x % numBuckets, x)) .groupByKey @@ -93,7 +79,7 @@ class ExternalListSuite extends SparkFunSuite with Serializable { } private def checkFilesEventuallyRemoved(filePaths: Array[Iterable[String]]) { - eventually(timeout(15000 millis), interval(100 millis)) { + eventually(timeout(30000 millis), interval(100 millis)) { filePaths.foreach(paths => { paths.foreach(f => assertFalse(new File(f).exists())) }) @@ -114,28 +100,45 @@ class ExternalListSuite extends SparkFunSuite with Serializable { private def testSerialization[T: ClassTag]( serializer: SerializerInstance, - list: ExternalList[T]): Unit = { + numItems: Int): Unit = { + val list = new ExternalList[Int] + // Test big list for Kryo because it's fast enough to handle it + // and we want to test the case where the list would spill to disk + for (i <- 0 to numItems) { + list += i + } + createAndSetFakeTaskContext() val bytes = serializer.serialize(list) var readList = serializer.deserialize(bytes).asInstanceOf[ExternalList[Int]] val originalIt = list.iterator var readIt = readList.iterator while (originalIt.hasNext) { - assert (originalIt.next == readIt.next) + assertTrue(originalIt.next == readIt.next) } - assert (!readIt.hasNext) + assertFalse (readIt.hasNext) val filePaths = readList.getBackingFileLocations() readList = null readIt = null + taskContext.markTaskCompleted() runGC() - eventually(timeout(15000 millis), interval(100 millis)) { + eventually(timeout(30000 millis), interval(100 millis)) { filePaths.foreach(path => assertFalse(new File(path).exists())) } + TaskContext.unset() } - } object ExternalListSuite { - def validateList(totalRddSize: Int, numBuckets: Int, kv: (Int, Iterable[Int])): Unit = { + var taskContext: TaskContextImpl = null + val totalRddSize = 2000000 + val numBuckets = 5 + + private def createAndSetFakeTaskContext(): Unit = { + taskContext = new TaskContextImpl(0, 0, 0L, 0, mock(classOf[TaskMemoryManager]), SparkEnv.get.metricsSystem) + TaskContext.setTaskContext(taskContext) + } + + private def validateList(totalRddSize: Int, numBuckets: Int, kv: (Int, Iterable[Int])): Unit = { var numItems = 0 for (valsInBucket <- kv._2) { numItems += 1 @@ -145,3 +148,6 @@ object ExternalListSuite { assertEquals(s"Number of items in bucket ${kv._1} is incorrect.", totalRddSize / numBuckets, numItems) } } + + + From 4c051103b42fa6bb6904f5940aa4679c763cfad8 Mon Sep 17 00:00:00 2001 From: mcheah Date: Tue, 25 Aug 2015 15:04:02 -0700 Subject: [PATCH 339/340] Fix a whole ton of Scalastyle errors --- .../org/apache/spark/ExecutorCleaner.scala | 5 +++-- .../apache/spark/rdd/PairRDDFunctions.scala | 3 ++- .../collection/ExternalAppendOnlyMap.scala | 16 +++++++++++---- .../spark/util/collection/ExternalList.scala | 14 +++++++++---- .../SizeTrackingCompactBuffer.scala | 20 +++++++++++++++++-- .../util/collection/SpillableCollection.scala | 4 +++- .../util/collection/ExternalListSuite.scala | 19 ++++++++++++------ 7 files changed, 61 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorCleaner.scala b/core/src/main/scala/org/apache/spark/ExecutorCleaner.scala index 29641572b31dc..6fe705582c4d6 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorCleaner.scala @@ -41,9 +41,10 @@ private[spark] class ExecutorCleaner extends WeakReferenceCleaner { override protected def handleCleanupForSpecificTask(task: CleanupTask): Unit = { task match { case CleanExternalList(paths) => doCleanExternalList(paths) - case unknown => logWarning(s"Got cleanup task that cannot be handled by ExecutorCleaner: $unknown") + case unknown => logWarning(s"Got cleanup task that cannot be" + + s" handled by ExecutorCleaner: $unknown") } } override protected def cleanupThreadName(): String = "Executor Cleaner" -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 6927f2aecec34..5e89cbd1eaefb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -469,7 +469,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) c2.foreach(c => c1 += c) c1 } - val aggregator = new Aggregator[K, V, ExternalList[V]](createCombiner, mergeValue, mergeCombiners) + val aggregator = new Aggregator[K, V, ExternalList[V]](createCombiner, + mergeValue, mergeCombiners) val shuffledRdd = if (self.partitioner != partitioner) { self.partitionBy(partitioner) } else { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 003040640ca59..3284113809dca 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -295,7 +295,8 @@ class ExternalAppendOnlyMap[K, V, C]( private class DiskMapIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]) extends DiskIterator(file, blockId, batchSizes) { - override protected def readNextItemFromStream(deserializeStream: DeserializationStream): (K, C) = { + override protected def readNextItemFromStream( + deserializeStream: DeserializationStream): (K, C) = { val k = deserializeStream.readKey().asInstanceOf[K] val v = deserializeStream.readValue().asInstanceOf[C] (k, v) @@ -308,13 +309,20 @@ class ExternalAppendOnlyMap[K, V, C]( /** Convenience function to hash the given (K, C) pair by the key. */ private def hashKey(kc: (K, C)): Int = ExternalAppendOnlyMap.hash(kc._1) - override protected def getIteratorForCurrentSpillable(): Iterator[(K, C)] = currentMap.destructiveSortedIterator(keyComparator) + override protected def getIteratorForCurrentSpillable(): Iterator[(K, C)] = { + currentMap.destructiveSortedIterator(keyComparator) + } - override protected def writeNextObject(c: (K, C), writer: DiskBlockObjectWriter): Unit = { + override protected def writeNextObject( + c: (K, C), + writer: DiskBlockObjectWriter): Unit = { writer.write(c._1, c._2) } - override protected def recordNextSpilledPart(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]): Unit = { + override protected def recordNextSpilledPart( + file: File, + blockId: BlockId, + batchSizes: ArrayBuffer[Long]): Unit = { spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala index 28467b3d87d62..4bdbb454bd5fd 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala @@ -57,9 +57,9 @@ private[spark] class ExternalList[T](implicit var tag: ClassTag[T]) context.addTaskCompletionListener(new ScheduleCleanExternalList(this)) } - override def size() = numItems + override def size(): Int = numItems - override def +=(value: T) = { + override def +=(value: T): ExternalList[T] = { currentInMemoryList += value if (maybeSpill(currentInMemoryList, currentInMemoryList.estimateSize())) { currentInMemoryList = new SizeTrackingCompactBuffer @@ -148,8 +148,14 @@ private[spark] class ExternalList[T](implicit var tag: ClassTag[T]) } } - override protected def getIteratorForCurrentSpillable(): Iterator[T] = currentInMemoryList.iterator - override protected def recordNextSpilledPart(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]): Unit = { + override protected def getIteratorForCurrentSpillable(): Iterator[T] = { + currentInMemoryList.iterator + } + + override protected def recordNextSpilledPart( + file: File, + blockId: BlockId, + batchSizes: ArrayBuffer[Long]): Unit = { spilledLists += new DiskListIterable(file, blockId, batchSizes) } override protected def writeNextObject(c: T, writer: DiskBlockObjectWriter): Unit = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingCompactBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingCompactBuffer.scala index 00de7913f1491..d923e9a9e0bd1 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingCompactBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingCompactBuffer.scala @@ -1,3 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.util.collection import scala.reflect.ClassTag @@ -8,13 +24,13 @@ import scala.reflect.ClassTag private[spark] class SizeTrackingCompactBuffer[T: ClassTag] extends CompactBuffer[T] with SizeTracker { - override def +=(t: T) = { + override def +=(t: T): SizeTrackingCompactBuffer[T] = { super.+=(t) super.afterUpdate() this } - override def ++=(t: TraversableOnce[T]) = { + override def ++=(t: TraversableOnce[T]): SizeTrackingCompactBuffer[T] = { super.++=(t) super.afterUpdate() this diff --git a/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala index 9a7c34af283de..c4d0f46bb0bd7 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala @@ -240,6 +240,8 @@ private object SpillableCollection { * NOTE: Setting this too low can cause excessive copying when serializing, since some serializers * grow internal data structures by growing + copying every time the number of objects doubles. */ - private def serializerBatchSize(): Long = sparkConf.getLong("spark.shuffle.spill.batchSize", 10000) + private def serializerBatchSize(): Long = + sparkConf.getLong("spark.shuffle.spill.batchSize", 10000) + private def serializer(): Serializer = SparkEnv.get.serializer } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala index 86dfdb1c6c798..612421564ee71 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala @@ -22,7 +22,7 @@ import java.lang.ref.WeakReference import scala.language.existentials import scala.reflect.ClassTag -import org.apache.spark.{SharedSparkContext, SparkEnv, SparkFunSuite, TaskContextImpl, TaskContext} +import org.apache.spark._ import org.apache.spark.serializer.{KryoSerializer, JavaSerializer, SerializerInstance} import org.apache.spark.util.collection.ExternalListSuite._ import org.apache.spark.unsafe.memory.TaskMemoryManager @@ -58,7 +58,10 @@ class ExternalListSuite extends SparkFunSuite with SharedSparkContext { runGC() // GC on the Cached RDD shouldn't trigger the cleanup cachedRdd.foreach(validateList(totalRddSize, numBuckets, _)) - val filePaths = cachedRdd.map(_._2.asInstanceOf[ExternalList[Int]].getBackingFileLocations()).collect + def fileLocationsFromIterable(pair: (_, Iterable[Int])): Iterable[String] = { + pair._2.asInstanceOf[ExternalList[Int]].getBackingFileLocations() + } + val filePaths = cachedRdd.map(fileLocationsFromIterable).collect filePaths.foreach(paths => { paths.foreach(f => assertTrue(new File(f).exists())) }) @@ -134,7 +137,8 @@ object ExternalListSuite { val numBuckets = 5 private def createAndSetFakeTaskContext(): Unit = { - taskContext = new TaskContextImpl(0, 0, 0L, 0, mock(classOf[TaskMemoryManager]), SparkEnv.get.metricsSystem) + taskContext = new TaskContextImpl(0, 0, 0L, 0, mock(classOf[TaskMemoryManager]), + SparkEnv.get.metricsSystem, Seq.empty[Accumulator[Long]]) TaskContext.setTaskContext(taskContext) } @@ -142,10 +146,13 @@ object ExternalListSuite { var numItems = 0 for (valsInBucket <- kv._2) { numItems += 1 - // Can't use scala assertions because including assert statements makes closures not serializable. - assertEquals(s"Value $valsInBucket should not be in bucket ${kv._1}", kv._1, valsInBucket % numBuckets) + // Can't use scala assertions because including assert statements makes closures + // not serializable. + assertEquals(s"Value $valsInBucket should not be" + + s" in bucket ${kv._1}", kv._1, valsInBucket % numBuckets) } - assertEquals(s"Number of items in bucket ${kv._1} is incorrect.", totalRddSize / numBuckets, numItems) + assertEquals(s"Number of items in bucket ${kv._1} is incorrect.", + totalRddSize / numBuckets, numItems) } } From 8f5d5e38325bf3cec03e8d109af492a65e327a90 Mon Sep 17 00:00:00 2001 From: mcheah Date: Tue, 25 Aug 2015 16:03:50 -0700 Subject: [PATCH 340/340] Continuing to sanitize unit tests --- .../org/apache/spark/ContextCleaner.scala | 8 ++++---- .../org/apache/spark/ExecutorCleaner.scala | 7 ++++++- .../spark/util/collection/ExternalList.scala | 2 +- .../util/collection/ExternalListSuite.scala | 20 +++++-------------- 4 files changed, 16 insertions(+), 21 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index b4c658b29f2af..a14a55ec352d3 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -17,13 +17,13 @@ package org.apache.spark -import org.apache.spark.util.Utils - -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} - import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData} import org.apache.spark.util.Utils +import org.apache.spark.util.cleanup.{ CleanAccum, CleanBroadcast, CleanCheckpoint } +import org.apache.spark.util.cleanup.{ CleanRDD, CleanShuffle, CleanupTask } + +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} /** * An asynchronous cleaner for RDD, shuffle, and broadcast state. diff --git a/core/src/main/scala/org/apache/spark/ExecutorCleaner.scala b/core/src/main/scala/org/apache/spark/ExecutorCleaner.scala index 6fe705582c4d6..716f0906e9fc3 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorCleaner.scala @@ -34,7 +34,12 @@ private[spark] class ExecutorCleaner extends WeakReferenceCleaner { def doCleanExternalList(paths: Iterable[String]): Unit = { paths.map(path => new File(path)).foreach(f => { - if (f.exists()) f.delete() + if (f.exists()) { + val isDeleted = f.delete() + if (!isDeleted) { + logWarning(s"Failed to delete ${f.getAbsolutePath} backing ExternalList") + } + } }) } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala index 4bdbb454bd5fd..f0e4fcff81420 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala @@ -59,7 +59,7 @@ private[spark] class ExternalList[T](implicit var tag: ClassTag[T]) override def size(): Int = numItems - override def +=(value: T): ExternalList[T] = { + override def +=(value: T): this.type = { currentInMemoryList += value if (maybeSpill(currentInMemoryList, currentInMemoryList.estimateSize())) { currentInMemoryList = new SizeTrackingCompactBuffer diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala index 612421564ee71..9a7cccacb234b 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala @@ -37,8 +37,8 @@ class ExternalListSuite extends SparkFunSuite with SharedSparkContext { override def beforeAll() { conf.set("spark.kryoserializer.buffer.max", "2046m") conf.set("spark.shuffle.spill.initialMemoryThreshold", "1") - conf.set("spark.shuffle.spill.batchSize", "10") - conf.set("spark.shuffle.memoryFraction", "0.035") + conf.set("spark.shuffle.spill.batchSize", "500") + conf.set("spark.shuffle.memoryFraction", "0.04") conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") conf.set("spark.task.maxFailures", "1") conf.setAppName("test") @@ -71,18 +71,8 @@ class ExternalListSuite extends SparkFunSuite with SharedSparkContext { cachedRdd.foreach(validateList(totalRddSize, numBuckets, _)) } - test("List that is created in a task and released immediately should eventually clean up") { - val rawLargeRdd = sc.parallelize(1 to totalRddSize) - val filePaths = rawLargeRdd - .map(x => (x % numBuckets, x)) - .groupByKey - .map(x => x._2.asInstanceOf[ExternalList[Int]].getBackingFileLocations()).collect - runGC() - checkFilesEventuallyRemoved(filePaths) - } - private def checkFilesEventuallyRemoved(filePaths: Array[Iterable[String]]) { - eventually(timeout(30000 millis), interval(100 millis)) { + eventually(timeout(40000 millis), interval(100 millis)) { filePaths.foreach(paths => { paths.foreach(f => assertFalse(new File(f).exists())) }) @@ -124,7 +114,7 @@ class ExternalListSuite extends SparkFunSuite with SharedSparkContext { readIt = null taskContext.markTaskCompleted() runGC() - eventually(timeout(30000 millis), interval(100 millis)) { + eventually(timeout(40000 millis), interval(100 millis)) { filePaths.foreach(path => assertFalse(new File(path).exists())) } TaskContext.unset() @@ -149,7 +139,7 @@ object ExternalListSuite { // Can't use scala assertions because including assert statements makes closures // not serializable. assertEquals(s"Value $valsInBucket should not be" + - s" in bucket ${kv._1}", kv._1, valsInBucket % numBuckets) + s" in bucket ${kv._1}", valsInBucket % numBuckets, kv._1) } assertEquals(s"Number of items in bucket ${kv._1} is incorrect.", totalRddSize / numBuckets, numItems)