diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 4fadfe36cd716..7fdcf22c45f73 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -879,6 +879,7 @@ public LongArray getArray() { * Reset this map to initialized state. */ public void reset() { + updatePeakMemoryUsed(); numKeys = 0; numValues = 0; freeArray(longArray); diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 67e993c7f02e2..7aecd3c9668ea 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -99,7 +99,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) private def calcChecksum(block: ByteBuffer): Int = { val adler = new Adler32() if (block.hasArray) { - adler.update(block.array, block.arrayOffset + block.position, block.limit - block.position) + adler.update(block.array, block.arrayOffset + block.position(), block.limit() + - block.position()) } else { val bytes = new Array[Byte](block.remaining()) block.duplicate.get(bytes) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 4c1f92a1bcbf2..9b62e4b1b7150 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -165,11 +165,7 @@ private[spark] class CoarseGrainedExecutorBackend( } if (notifyDriver && driver.nonEmpty) { - driver.get.ask[Boolean]( - RemoveExecutor(executorId, new ExecutorLossReason(reason)) - ).failed.foreach(e => - logWarning(s"Unable to notify the driver due to " + e.getMessage, e) - )(ThreadUtils.sameThread) + driver.get.send(RemoveExecutor(executorId, new ExecutorLossReason(reason))) } System.exit(code) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index e3e555eaa0277..af0a0ab656564 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -452,7 +452,7 @@ private[spark] class Executor( // TODO: do not serialize value twice val directResult = new DirectTaskResult(valueBytes, accumUpdates) val serializedDirectResult = ser.serialize(directResult) - val resultSize = serializedDirectResult.limit + val resultSize = serializedDirectResult.limit() // directSend = sending directly back to the driver val serializedResult: ByteBuffer = { diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index 8f4c1b60920db..b0cd7110a3b47 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -235,7 +235,9 @@ private[spark] case class ConfigBuilder(key: String) { } def fallbackConf[T](fallback: ConfigEntry[T]): ConfigEntry[T] = { - new FallbackConfigEntry(key, _alternatives, _doc, _public, fallback) + val entry = new FallbackConfigEntry(key, _alternatives, _doc, _public, fallback) + _onCreate.foreach(_(entry)) + entry } def regexConf: TypedConfigBuilder[Regex] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 7bfb4d53c1834..4d75063fbf1c5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -95,6 +95,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // The num of current max ExecutorId used to re-register appMaster @volatile protected var currentExecutorIdCounter = 0 + private val reviveThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-revive-thread") + class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends ThreadSafeRpcEndpoint with Logging { @@ -103,9 +106,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp protected val addressToExecutorId = new HashMap[RpcAddress, String] - private val reviveThread = - ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-revive-thread") - override def onStart() { // Periodically revive offers to allow delay scheduling to work val reviveIntervalMs = conf.getTimeAsMs("spark.scheduler.revive.interval", "1s") @@ -154,6 +154,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp executorDataMap.values.foreach { ed => ed.executorEndpoint.send(UpdateDelegationTokens(newDelegationTokens)) } + + case RemoveExecutor(executorId, reason) => + // We will remove the executor's state and cannot restore it. However, the connection + // between the driver and the executor may be still alive so that the executor won't exit + // automatically, so try to tell the executor to stop itself. See SPARK-13519. + executorDataMap.get(executorId).foreach(_.executorEndpoint.send(StopExecutor)) + removeExecutor(executorId, reason) } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -215,14 +222,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } context.reply(true) - case RemoveExecutor(executorId, reason) => - // We will remove the executor's state and cannot restore it. However, the connection - // between the driver and the executor may be still alive so that the executor won't exit - // automatically, so try to tell the executor to stop itself. See SPARK-13519. - executorDataMap.get(executorId).foreach(_.executorEndpoint.send(StopExecutor)) - removeExecutor(executorId, reason) - context.reply(true) - case RemoveWorker(workerId, host, message) => removeWorker(workerId, host, message) context.reply(true) @@ -288,13 +287,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp private def launchTasks(tasks: Seq[Seq[TaskDescription]]) { for (task <- tasks.flatten) { val serializedTask = TaskDescription.encode(task) - if (serializedTask.limit >= maxRpcMessageSize) { + if (serializedTask.limit() >= maxRpcMessageSize) { scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr => try { var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " + "spark.rpc.message.maxSize (%d bytes). Consider increasing " + "spark.rpc.message.maxSize or using broadcast variables for large values." - msg = msg.format(task.taskId, task.index, serializedTask.limit, maxRpcMessageSize) + msg = msg.format(task.taskId, task.index, serializedTask.limit(), maxRpcMessageSize) taskSetMgr.abort(msg) } catch { case e: Exception => logError("Exception in error callback", e) @@ -373,10 +372,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp shouldDisable } - - override def onStop() { - reviveThread.shutdownNow() - } } var driverEndpoint: RpcEndpointRef = null @@ -417,6 +412,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } override def stop() { + reviveThread.shutdownNow() stopExecutors() try { if (driverEndpoint != null) { @@ -465,9 +461,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * at once. */ protected def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = { - // Only log the failure since we don't care about the result. - driverEndpoint.ask[Boolean](RemoveExecutor(executorId, reason)).failed.foreach(t => - logError(t.getMessage, t))(ThreadUtils.sameThread) + driverEndpoint.send(RemoveExecutor(executorId, reason)) } protected def removeWorker(workerId: String, host: String, message: String): Unit = { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 56d0266b8edad..89a6a71a589a1 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -17,6 +17,7 @@ package org.apache.spark.storage +import java.io.IOException import java.util.{HashMap => JHashMap} import scala.collection.JavaConverters._ @@ -159,11 +160,16 @@ class BlockManagerMasterEndpoint( // Ask the slaves to remove the RDD, and put the result in a sequence of Futures. // The dispatcher is used as an implicit argument into the Future sequence construction. val removeMsg = RemoveRdd(rddId) - Future.sequence( - blockManagerInfo.values.map { bm => - bm.slaveEndpoint.ask[Int](removeMsg) - }.toSeq - ) + + val futures = blockManagerInfo.values.map { bm => + bm.slaveEndpoint.ask[Int](removeMsg).recover { + case e: IOException => + logWarning(s"Error trying to remove RDD $rddId", e) + 0 // zero blocks were removed + } + }.toSeq + + Future.sequence(futures) } private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala index a938cb07724c7..a5ee0ff16b5df 100644 --- a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala @@ -54,7 +54,7 @@ class ByteBufferInputStream(private var buffer: ByteBuffer) override def skip(bytes: Long): Long = { if (buffer != null) { val amountToSkip = math.min(bytes, buffer.remaining).toInt - buffer.position(buffer.position + amountToSkip) + buffer.position(buffer.position() + amountToSkip) if (buffer.remaining() == 0) { cleanUp() } diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index c28570fb24560..7367af7888bd8 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -65,7 +65,7 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { for (bytes <- getChunks()) { while (bytes.remaining() > 0) { val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize) - bytes.limit(bytes.position + ioSize) + bytes.limit(bytes.position() + ioSize) channel.write(bytes) } } @@ -206,7 +206,7 @@ private[spark] class ChunkedByteBufferInputStream( override def skip(bytes: Long): Long = { if (currentChunk != null) { val amountToSkip = math.min(bytes, currentChunk.remaining).toInt - currentChunk.position(currentChunk.position + amountToSkip) + currentChunk.position(currentChunk.position() + amountToSkip) if (currentChunk.remaining() == 0) { if (chunks.hasNext) { currentChunk = chunks.next() diff --git a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala index bf08276dbf971..02514dc7daef4 100644 --- a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala @@ -288,4 +288,24 @@ class ConfigEntrySuite extends SparkFunSuite { conf.remove(testKey("b")) assert(conf.get(iConf) === 3) } + + test("onCreate") { + var onCreateCalled = false + ConfigBuilder(testKey("oc1")).onCreate(_ => onCreateCalled = true).intConf.createWithDefault(1) + assert(onCreateCalled) + + onCreateCalled = false + ConfigBuilder(testKey("oc2")).onCreate(_ => onCreateCalled = true).intConf.createOptional + assert(onCreateCalled) + + onCreateCalled = false + ConfigBuilder(testKey("oc3")).onCreate(_ => onCreateCalled = true).intConf + .createWithDefaultString("1.0") + assert(onCreateCalled) + + val fallback = ConfigBuilder(testKey("oc4")).intConf.createWithDefault(1) + onCreateCalled = false + ConfigBuilder(testKey("oc5")).onCreate(_ => onCreateCalled = true).fallbackConf(fallback) + assert(onCreateCalled) + } } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index eaec098b8d785..fc78655bf52ec 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -199,7 +199,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) // Check that very long ranges don't get written one element at a time - assert(ser.serialize(t).limit < 100) + assert(ser.serialize(t).limit() < 100) } check(1 to 1000000) check(1 to 1000000 by 2) diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala index 7258fdf5efc0d..efdd02fff7871 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala @@ -118,7 +118,7 @@ class DiskStoreSuite extends SparkFunSuite { val chunks = chunkedByteBuffer.chunks assert(chunks.size === 2) for (chunk <- chunks) { - assert(chunk.limit === 10 * 1024) + assert(chunk.limit() === 10 * 1024) } val e = intercept[IllegalArgumentException]{ diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 2c68b73095c4d..1831f3378e852 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -35,7 +35,7 @@ commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar commons-codec-1.10.jar commons-collections-3.2.2.jar -commons-compiler-3.0.7.jar +commons-compiler-3.0.8.jar commons-compress-1.4.1.jar commons-configuration-1.6.jar commons-crypto-1.0.0.jar @@ -96,7 +96,7 @@ jackson-mapper-asl-1.9.13.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar -janino-3.0.7.jar +janino-3.0.8.jar java-xmlbuilder-1.1.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar @@ -180,7 +180,7 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.5.4.jar +univocity-parsers-2.5.9.jar validation-api-1.1.0.Final.jar xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 2aaac600b3ec3..fe14c05987327 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -35,7 +35,7 @@ commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar commons-codec-1.10.jar commons-collections-3.2.2.jar -commons-compiler-3.0.7.jar +commons-compiler-3.0.8.jar commons-compress-1.4.1.jar commons-configuration-1.6.jar commons-crypto-1.0.0.jar @@ -96,7 +96,7 @@ jackson-mapper-asl-1.9.13.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar -janino-3.0.7.jar +janino-3.0.8.jar java-xmlbuilder-1.1.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar @@ -181,7 +181,7 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.5.4.jar +univocity-parsers-2.5.9.jar validation-api-1.1.0.Final.jar xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index 2df8352f48660..08db4d827e400 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -296,7 +296,9 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L props.put("replica.socket.timeout.ms", "1500") props.put("delete.topic.enable", "true") props.put("offsets.topic.num.partitions", "1") - props.putAll(withBrokerProps.asJava) + // Can not use properties.putAll(propsMap.asJava) in scala-2.12 + // See https://github.com/scala/bug/issues/10418 + withBrokerProps.foreach { case (k, v) => props.put(k, v) } props } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 4663f16b5f5dc..730ee9fc08db8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.HasInputCols +import org.apache.spark.ml.param.shared.{HasInputCols, HasOutputCols} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -32,7 +32,7 @@ import org.apache.spark.sql.types._ /** * Params for [[Imputer]] and [[ImputerModel]]. */ -private[feature] trait ImputerParams extends Params with HasInputCols { +private[feature] trait ImputerParams extends Params with HasInputCols with HasOutputCols { /** * The imputation strategy. Currently only "mean" and "median" are supported. @@ -63,16 +63,6 @@ private[feature] trait ImputerParams extends Params with HasInputCols { /** @group getParam */ def getMissingValue: Double = $(missingValue) - /** - * Param for output column names. - * @group param - */ - final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols", - "output column names") - - /** @group getParam */ - final def getOutputCols: Array[String] = $(outputCols) - /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { require($(inputCols).length == $(inputCols).distinct.length, s"inputCols contains" + diff --git a/pom.xml b/pom.xml index 07bca9d267da0..52db79eaf036b 100644 --- a/pom.xml +++ b/pom.xml @@ -170,7 +170,7 @@ 3.5 3.2.10 - 3.0.7 + 3.0.8 2.22.2 2.9.3 3.5.2 diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala index 3febb2f47cfd4..13c09033a50ee 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -282,7 +282,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn // No more deletion attempts of the executors. // This is graceful termination and should not be detected as a failure. verify(podOperations, times(1)).delete(resolvedPod) - verify(driverEndpointRef, times(1)).ask[Boolean]( + verify(driverEndpointRef, times(1)).send( RemoveExecutor("1", ExecutorExited( 0, exitCausedByApp = false, @@ -318,7 +318,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn requestExecutorRunnable.getValue.run() allocatorRunnable.getAllValues.asScala.last.run() verify(podOperations, never()).delete(firstResolvedPod) - verify(driverEndpointRef).ask[Boolean]( + verify(driverEndpointRef).send( RemoveExecutor("1", ExecutorExited( 1, exitCausedByApp = true, @@ -356,7 +356,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn val recreatedResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) allocatorRunnable.getValue.run() verify(podOperations).delete(firstResolvedPod) - verify(driverEndpointRef).ask[Boolean]( + verify(driverEndpointRef).send( RemoveExecutor("1", SlaveLost("Executor lost for unknown reasons."))) } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 415a29fd887e8..bb615c36cd97f 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.atomic.{AtomicBoolean} import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.ExecutionContext.Implicits.global import scala.util.{Failure, Success} import scala.util.control.NonFatal @@ -245,14 +246,7 @@ private[spark] abstract class YarnSchedulerBackend( Future.successful(RemoveExecutor(executorId, SlaveLost("AM is not yet registered."))) } - removeExecutorMessage - .flatMap { message => - driverEndpoint.ask[Boolean](message) - }(ThreadUtils.sameThread) - .onFailure { - case NonFatal(e) => logError( - s"Error requesting driver to remove executor $executorId after disconnection.", e) - }(ThreadUtils.sameThread) + removeExecutorMessage.foreach { message => driverEndpoint.send(message) } } override def receive: PartialFunction[Any, Unit] = { @@ -265,12 +259,10 @@ private[spark] abstract class YarnSchedulerBackend( addWebUIFilter(filterName, filterParams, proxyBase) case r @ RemoveExecutor(executorId, reason) => - logWarning(reason.toString) - driverEndpoint.ask[Boolean](r).onFailure { - case e => - logError("Error requesting driver to remove executor" + - s" $executorId for reason $reason", e) - }(ThreadUtils.sameThread) + if (!stopped.get) { + logWarning(s"Requesting driver to remove executor $executorId for reason $reason") + driverEndpoint.send(r) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 179853032035e..4d26d9819321b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -982,35 +982,28 @@ case class ScalaUDF( // scalastyle:on line.size.limit - // Generate codes used to convert the arguments to Scala type for user-defined functions - private[this] def genCodeForConverter(ctx: CodegenContext, index: Int): String = { - val converterClassName = classOf[Any => Any].getName - val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" - val expressionClassName = classOf[Expression].getName - val scalaUDFClassName = classOf[ScalaUDF].getName + private val converterClassName = classOf[Any => Any].getName + private val scalaUDFClassName = classOf[ScalaUDF].getName + private val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" + // Generate codes used to convert the arguments to Scala type for user-defined functions + private[this] def genCodeForConverter(ctx: CodegenContext, index: Int): (String, String) = { val converterTerm = ctx.freshName("converter") val expressionIdx = ctx.references.size - 1 - ctx.addMutableState(converterClassName, converterTerm, - s"$converterTerm = ($converterClassName)$typeConvertersClassName" + - s".createToScalaConverter(((${expressionClassName})((($scalaUDFClassName)" + - s"references[$expressionIdx]).getChildren().apply($index))).dataType());") - converterTerm + (converterTerm, + s"$converterClassName $converterTerm = ($converterClassName)$typeConvertersClassName" + + s".createToScalaConverter(((Expression)((($scalaUDFClassName)" + + s"references[$expressionIdx]).getChildren().apply($index))).dataType());") } override def doGenCode( ctx: CodegenContext, ev: ExprCode): ExprCode = { + val scalaUDF = ctx.freshName("scalaUDF") + val scalaUDFRef = ctx.addReferenceMinorObj(this, scalaUDFClassName) - val scalaUDF = ctx.addReferenceObj("scalaUDF", this) - val converterClassName = classOf[Any => Any].getName - val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" - - // Generate codes used to convert the returned value of user-defined functions to Catalyst type + // Object to convert the returned value of user-defined functions to Catalyst type val catalystConverterTerm = ctx.freshName("catalystConverter") - ctx.addMutableState(converterClassName, catalystConverterTerm, - s"$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" + - s".createToCatalystConverter($scalaUDF.dataType());") val resultTerm = ctx.freshName("result") @@ -1022,8 +1015,6 @@ case class ScalaUDF( val funcClassName = s"scala.Function${children.size}" val funcTerm = ctx.freshName("udf") - ctx.addMutableState(funcClassName, funcTerm, - s"$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();") // codegen for children expressions val evals = children.map(_.genCode(ctx)) @@ -1033,34 +1024,45 @@ case class ScalaUDF( // such as IntegerType, its javaType is `int` and the returned type of user-defined // function is Object. Trying to convert an Object to `int` will cause casting exception. val evalCode = evals.map(_.code).mkString - val (converters, funcArguments) = converterTerms.zipWithIndex.map { case (converter, i) => - val eval = evals(i) - val argTerm = ctx.freshName("arg") - val convert = s"Object $argTerm = ${eval.isNull} ? null : $converter.apply(${eval.value});" - (convert, argTerm) + val (converters, funcArguments) = converterTerms.zipWithIndex.map { + case ((convName, convInit), i) => + val eval = evals(i) + val argTerm = ctx.freshName("arg") + val convert = + s""" + |$convInit + |Object $argTerm = ${eval.isNull} ? null : $convName.apply(${eval.value}); + """.stripMargin + (convert, argTerm) }.unzip val getFuncResult = s"$funcTerm.apply(${funcArguments.mkString(", ")})" val callFunc = s""" - ${ctx.boxedType(dataType)} $resultTerm = null; - try { - $resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult); - } catch (Exception e) { - throw new org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e); - } - """ + |${ctx.boxedType(dataType)} $resultTerm = null; + |$scalaUDFClassName $scalaUDF = $scalaUDFRef; + |try { + | $funcClassName $funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc(); + | $converterClassName $catalystConverterTerm = ($converterClassName) + | $typeConvertersClassName.createToCatalystConverter($scalaUDF.dataType()); + | $resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult); + |} catch (Exception e) { + | throw new org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e); + |} + """.stripMargin - ev.copy(code = s""" - $evalCode - ${converters.mkString("\n")} - $callFunc - - boolean ${ev.isNull} = $resultTerm == null; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.value} = $resultTerm; - }""") + ev.copy(code = + s""" + |$evalCode + |${converters.mkString("\n")} + |$callFunc + | + |boolean ${ev.isNull} = $resultTerm == null; + |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |if (!${ev.isNull}) { + | ${ev.value} = $resultTerm; + |} + """.stripMargin) } private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 739bd13c5078d..1893eec22b65d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -602,23 +602,38 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) - ctx.addMutableState(ctx.javaType(dataType), ev.value) - def updateEval(eval: ExprCode): String = { + val tmpIsNull = ctx.freshName("leastTmpIsNull") + ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull) + val evals = evalChildren.map(eval => s""" - ${eval.code} - if (!${eval.isNull} && (${ev.isNull} || - ${ctx.genGreater(dataType, ev.value, eval.value)})) { - ${ev.isNull} = false; - ${ev.value} = ${eval.value}; - } - """ - } - val codes = ctx.splitExpressionsWithCurrentInputs(evalChildren.map(updateEval)) - ev.copy(code = s""" - ${ev.isNull} = true; - ${ev.value} = ${ctx.defaultValue(dataType)}; - $codes""") + |${eval.code} + |if (!${eval.isNull} && ($tmpIsNull || + | ${ctx.genGreater(dataType, ev.value, eval.value)})) { + | $tmpIsNull = false; + | ${ev.value} = ${eval.value}; + |} + """.stripMargin + ) + + val resultType = ctx.javaType(dataType) + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = evals, + funcName = "least", + extraArguments = Seq(resultType -> ev.value), + returnType = resultType, + makeSplitFunction = body => + s""" + |$body + |return ${ev.value}; + """.stripMargin, + foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) + ev.copy(code = + s""" + |$tmpIsNull = true; + |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |$codes + |final boolean ${ev.isNull} = $tmpIsNull; + """.stripMargin) } } @@ -668,22 +683,37 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) - ctx.addMutableState(ctx.javaType(dataType), ev.value) - def updateEval(eval: ExprCode): String = { + val tmpIsNull = ctx.freshName("greatestTmpIsNull") + ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull) + val evals = evalChildren.map(eval => s""" - ${eval.code} - if (!${eval.isNull} && (${ev.isNull} || - ${ctx.genGreater(dataType, eval.value, ev.value)})) { - ${ev.isNull} = false; - ${ev.value} = ${eval.value}; - } - """ - } - val codes = ctx.splitExpressionsWithCurrentInputs(evalChildren.map(updateEval)) - ev.copy(code = s""" - ${ev.isNull} = true; - ${ev.value} = ${ctx.defaultValue(dataType)}; - $codes""") + |${eval.code} + |if (!${eval.isNull} && ($tmpIsNull || + | ${ctx.genGreater(dataType, eval.value, ev.value)})) { + | $tmpIsNull = false; + | ${ev.value} = ${eval.value}; + |} + """.stripMargin + ) + + val resultType = ctx.javaType(dataType) + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = evals, + funcName = "greatest", + extraArguments = Seq(resultType -> ev.value), + returnType = resultType, + makeSplitFunction = body => + s""" + |$body + |return ${ev.value}; + """.stripMargin, + foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) + ev.copy(code = + s""" + |$tmpIsNull = true; + |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |$codes + |final boolean ${ev.isNull} = $tmpIsNull; + """.stripMargin) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 670c82eff9286..5c9e604a8d293 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -29,7 +29,7 @@ import scala.util.control.NonFatal import com.google.common.cache.{CacheBuilder, CacheLoader} import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException} import org.codehaus.commons.compiler.CompileException -import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, JaninoRuntimeException, SimpleCompiler} +import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, InternalCompilerException, SimpleCompiler} import org.codehaus.janino.util.ClassFile import org.apache.spark.{SparkEnv, TaskContext, TaskKilledException} @@ -1240,12 +1240,12 @@ object CodeGenerator extends Logging { evaluator.cook("generated.java", code.body) updateAndGetCompilationStats(evaluator) } catch { - case e: JaninoRuntimeException => + case e: InternalCompilerException => val msg = s"failed to compile: $e" logError(msg, e) val maxLines = SQLConf.get.loggingMaxLinesForCodegen logInfo(s"\n${CodeFormatter.format(code, maxLines)}") - throw new JaninoRuntimeException(msg, e) + throw new InternalCompilerException(msg, e) case e: CompileException => val msg = s"failed to compile: $e" logError(msg, e) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 44e7148e5d98f..3dcbb518ba42a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -49,8 +49,6 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val tmpInput = ctx.freshName("tmpInput") val output = ctx.freshName("safeRow") val values = ctx.freshName("values") - // These expressions could be split into multiple functions - ctx.addMutableState("Object[]", values, s"$values = null;") val rowClass = classOf[GenericInternalRow].getName @@ -66,15 +64,15 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val allFields = ctx.splitExpressions( expressions = fieldWriters, funcName = "writeFields", - arguments = Seq("InternalRow" -> tmpInput) + arguments = Seq("InternalRow" -> tmpInput, "Object[]" -> values) ) - val code = s""" - final InternalRow $tmpInput = $input; - $values = new Object[${schema.length}]; - $allFields - final InternalRow $output = new $rowClass($values); - $values = null; - """ + val code = + s""" + |final InternalRow $tmpInput = $input; + |final Object[] $values = new Object[${schema.length}]; + |$allFields + |final InternalRow $output = new $rowClass($values); + """.stripMargin ExprCode(code, "false", output) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 087b21043b309..3dc2ee03a86e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -356,22 +356,25 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName val values = ctx.freshName("values") - ctx.addMutableState("Object[]", values, s"$values = null;") + val valCodes = valExprs.zipWithIndex.map { case (e, i) => + val eval = e.genCode(ctx) + s""" + |${eval.code} + |if (${eval.isNull}) { + | $values[$i] = null; + |} else { + | $values[$i] = ${eval.value}; + |} + """.stripMargin + } val valuesCode = ctx.splitExpressionsWithCurrentInputs( - valExprs.zipWithIndex.map { case (e, i) => - val eval = e.genCode(ctx) - s""" - ${eval.code} - if (${eval.isNull}) { - $values[$i] = null; - } else { - $values[$i] = ${eval.value}; - }""" - }) + expressions = valCodes, + funcName = "createNamedStruct", + extraArguments = "Object[]" -> values :: Nil) ev.copy(code = s""" - |$values = new Object[${valExprs.size}]; + |Object[] $values = new Object[${valExprs.size}]; |$valuesCode |final InternalRow ${ev.value} = new $rowClass($values); |$values = null; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index ae5f7140847db..53c3b226895ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -180,13 +180,18 @@ case class CaseWhen( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - // This variable represents whether the first successful condition is met or not. - // It is initialized to `false` and it is set to `true` when the first condition which - // evaluates to `true` is met and therefore is not needed to go on anymore on the computation - // of the following conditions. - val conditionMet = ctx.freshName("caseWhenConditionMet") - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) - ctx.addMutableState(ctx.javaType(dataType), ev.value) + // This variable holds the state of the result: + // -1 means the condition is not met yet and the result is unknown. + val NOT_MATCHED = -1 + // 0 means the condition is met and result is not null. + val HAS_NONNULL = 0 + // 1 means the condition is met and result is null. + val HAS_NULL = 1 + // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, + // We won't go on anymore on the computation. + val resultState = ctx.freshName("caseWhenResultState") + val tmpResult = ctx.freshName("caseWhenTmpResult") + ctx.addMutableState(ctx.javaType(dataType), tmpResult) // these blocks are meant to be inside a // do { @@ -200,9 +205,8 @@ case class CaseWhen( |${cond.code} |if (!${cond.isNull} && ${cond.value}) { | ${res.code} - | ${ev.isNull} = ${res.isNull}; - | ${ev.value} = ${res.value}; - | $conditionMet = true; + | $resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL); + | $tmpResult = ${res.value}; | continue; |} """.stripMargin @@ -212,59 +216,63 @@ case class CaseWhen( val res = elseExpr.genCode(ctx) s""" |${res.code} - |${ev.isNull} = ${res.isNull}; - |${ev.value} = ${res.value}; + |$resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL); + |$tmpResult = ${res.value}; """.stripMargin } val allConditions = cases ++ elseCode // This generates code like: - // conditionMet = caseWhen_1(i); - // if(conditionMet) { + // caseWhenResultState = caseWhen_1(i); + // if(caseWhenResultState != -1) { // continue; // } - // conditionMet = caseWhen_2(i); - // if(conditionMet) { + // caseWhenResultState = caseWhen_2(i); + // if(caseWhenResultState != -1) { // continue; // } // ... // and the declared methods are: - // private boolean caseWhen_1234() { - // boolean conditionMet = false; + // private byte caseWhen_1234() { + // byte caseWhenResultState = -1; // do { // // here the evaluation of the conditions // } while (false); - // return conditionMet; + // return caseWhenResultState; // } val codes = ctx.splitExpressionsWithCurrentInputs( expressions = allConditions, funcName = "caseWhen", - returnType = ctx.JAVA_BOOLEAN, + returnType = ctx.JAVA_BYTE, makeSplitFunction = func => s""" - |${ctx.JAVA_BOOLEAN} $conditionMet = false; + |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED; |do { | $func |} while (false); - |return $conditionMet; + |return $resultState; """.stripMargin, foldFunctions = _.map { funcCall => s""" - |$conditionMet = $funcCall; - |if ($conditionMet) { + |$resultState = $funcCall; + |if ($resultState != $NOT_MATCHED) { | continue; |} """.stripMargin }.mkString) - ev.copy(code = s""" - ${ev.isNull} = true; - ${ev.value} = ${ctx.defaultValue(dataType)}; - ${ctx.JAVA_BOOLEAN} $conditionMet = false; - do { - $codes - } while (false);""") + ev.copy(code = + s""" + |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED; + |$tmpResult = ${ctx.defaultValue(dataType)}; + |do { + | $codes + |} while (false); + |// TRUE if any condition is met and the result is null, or no any condition is met. + |final boolean ${ev.isNull} = ($resultState != $HAS_NONNULL); + |final ${ctx.javaType(dataType)} ${ev.value} = $tmpResult; + """.stripMargin) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 26c9a41efc9f9..294cdcb2e9546 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -72,8 +72,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) - ctx.addMutableState(ctx.javaType(dataType), ev.value) + val tmpIsNull = ctx.freshName("coalesceTmpIsNull") + ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -81,26 +81,30 @@ case class Coalesce(children: Seq[Expression]) extends Expression { s""" |${eval.code} |if (!${eval.isNull}) { - | ${ev.isNull} = false; + | $tmpIsNull = false; | ${ev.value} = ${eval.value}; | continue; |} """.stripMargin } + val resultType = ctx.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "coalesce", + returnType = resultType, makeSplitFunction = func => s""" + |$resultType ${ev.value} = ${ctx.defaultValue(dataType)}; |do { | $func |} while (false); + |return ${ev.value}; """.stripMargin, foldFunctions = _.map { funcCall => s""" - |$funcCall; - |if (!${ev.isNull}) { + |${ev.value} = $funcCall; + |if (!$tmpIsNull) { | continue; |} """.stripMargin @@ -109,11 +113,12 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ev.copy(code = s""" - |${ev.isNull} = true; - |${ev.value} = ${ctx.defaultValue(dataType)}; + |$tmpIsNull = true; + |$resultType ${ev.value} = ${ctx.defaultValue(dataType)}; |do { | $codes |} while (false); + |final boolean ${ev.isNull} = $tmpIsNull; """.stripMargin) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 730b2ff96da6c..349afece84d5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1106,27 +1106,31 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericRowWithSchema].getName val values = ctx.freshName("values") - ctx.addMutableState("Object[]", values) val childrenCodes = children.zipWithIndex.map { case (e, i) => val eval = e.genCode(ctx) - eval.code + s""" - if (${eval.isNull}) { - $values[$i] = null; - } else { - $values[$i] = ${eval.value}; - } - """ + s""" + |${eval.code} + |if (${eval.isNull}) { + | $values[$i] = null; + |} else { + | $values[$i] = ${eval.value}; + |} + """.stripMargin } - val childrenCode = ctx.splitExpressionsWithCurrentInputs(childrenCodes) - val schemaField = ctx.addReferenceObj("schema", schema) + val childrenCode = ctx.splitExpressionsWithCurrentInputs( + expressions = childrenCodes, + funcName = "createExternalRow", + extraArguments = "Object[]" -> values :: Nil) + val schemaField = ctx.addReferenceMinorObj(schema) - val code = s""" - $values = new Object[${children.size}]; - $childrenCode - final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); - """ + val code = + s""" + |Object[] $values = new Object[${children.size}]; + |$childrenCode + |final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); + """.stripMargin ev.copy(code = code, isNull = "false") } } @@ -1244,25 +1248,28 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val javaBeanInstance = ctx.freshName("javaBean") val beanInstanceJavaType = ctx.javaType(beanInstance.dataType) - ctx.addMutableState(beanInstanceJavaType, javaBeanInstance) val initialize = setters.map { case (setterMethod, fieldValue) => val fieldGen = fieldValue.genCode(ctx) s""" - ${fieldGen.code} - ${javaBeanInstance}.$setterMethod(${fieldGen.value}); - """ + |${fieldGen.code} + |$javaBeanInstance.$setterMethod(${fieldGen.value}); + """.stripMargin } - val initializeCode = ctx.splitExpressionsWithCurrentInputs(initialize.toSeq) + val initializeCode = ctx.splitExpressionsWithCurrentInputs( + expressions = initialize.toSeq, + funcName = "initializeJavaBean", + extraArguments = beanInstanceJavaType -> javaBeanInstance :: Nil) - val code = s""" - ${instanceGen.code} - ${javaBeanInstance} = ${instanceGen.value}; - if (!${instanceGen.isNull}) { - $initializeCode - } - """ + val code = + s""" + |${instanceGen.code} + |$beanInstanceJavaType $javaBeanInstance = ${instanceGen.value}; + |if (!${instanceGen.isNull}) { + | $initializeCode + |} + """.stripMargin ev.copy(code = code, isNull = instanceGen.isNull, value = instanceGen.value) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 04e669492ec6d..8eb41addaf689 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -237,8 +237,14 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val javaDataType = ctx.javaType(value.dataType) val valueGen = value.genCode(ctx) val listGen = list.map(_.genCode(ctx)) - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.value) - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + // inTmpResult has 3 possible values: + // -1 means no matches found and there is at least one value in the list evaluated to null + val HAS_NULL = -1 + // 0 means no matches found and all values in the list are not null + val NOT_MATCHED = 0 + // 1 means one value in the list is matched + val MATCHED = 1 + val tmpResult = ctx.freshName("inTmpResult") val valueArg = ctx.freshName("valueArg") // All the blocks are meant to be inside a do { ... } while (false); loop. // The evaluation of variables can be stopped when we find a matching value. @@ -246,10 +252,9 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { s""" |${x.code} |if (${x.isNull}) { - | ${ev.isNull} = true; + | $tmpResult = $HAS_NULL; // ${ev.isNull} = true; |} else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) { - | ${ev.isNull} = false; - | ${ev.value} = true; + | $tmpResult = $MATCHED; // ${ev.isNull} = false; ${ev.value} = true; | continue; |} """.stripMargin) @@ -257,17 +262,19 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = listCode, funcName = "valueIn", - extraArguments = (javaDataType, valueArg) :: Nil, + extraArguments = (javaDataType, valueArg) :: (ctx.JAVA_BYTE, tmpResult) :: Nil, + returnType = ctx.JAVA_BYTE, makeSplitFunction = body => s""" |do { | $body |} while (false); + |return $tmpResult; """.stripMargin, foldFunctions = _.map { funcCall => s""" - |$funcCall; - |if (${ev.value}) { + |$tmpResult = $funcCall; + |if ($tmpResult == $MATCHED) { | continue; |} """.stripMargin @@ -276,14 +283,16 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { ev.copy(code = s""" |${valueGen.code} - |${ev.value} = false; - |${ev.isNull} = ${valueGen.isNull}; - |if (!${ev.isNull}) { + |byte $tmpResult = $HAS_NULL; + |if (!${valueGen.isNull}) { + | $tmpResult = 0; | $javaDataType $valueArg = ${valueGen.value}; | do { | $codes | } while (false); |} + |final boolean ${ev.isNull} = ($tmpResult == $HAS_NULL); + |final boolean ${ev.value} = ($tmpResult == $MATCHED); """.stripMargin) } @@ -344,17 +353,17 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } else { "" } - ctx.addMutableState(setName, setTerm, - s"$setTerm = (($InSetName)references[${ctx.references.size - 1}]).getSet();") - ev.copy(code = s""" - ${childGen.code} - boolean ${ev.isNull} = ${childGen.isNull}; - boolean ${ev.value} = false; - if (!${ev.isNull}) { - ${ev.value} = $setTerm.contains(${childGen.value}); - $setNull - } - """) + ev.copy(code = + s""" + |${childGen.code} + |${ctx.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; + |${ctx.JAVA_BOOLEAN} ${ev.value} = false; + |if (!${ev.isNull}) { + | $setName $setTerm = (($InSetName)references[${ctx.references.size - 1}]).getSet(); + | ${ev.value} = $setTerm.contains(${childGen.value}); + | $setNull + |} + """.stripMargin) } override def sql: String = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 785e815b41185..6305b6c84bae3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -64,49 +64,89 @@ object ConstantFolding extends Rule[LogicalPlan] { * }}} * * Approach used: - * - Start from AND operator as the root - * - Get all the children conjunctive predicates which are EqualTo / EqualNullSafe such that they - * don't have a `NOT` or `OR` operator in them * - Populate a mapping of attribute => constant value by looking at all the equals predicates * - Using this mapping, replace occurrence of the attributes with the corresponding constant values * in the AND node. */ object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper { - private def containsNonConjunctionPredicates(expression: Expression): Boolean = expression.find { - case _: Not | _: Or => true - case _ => false - }.isDefined - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case f: Filter => f transformExpressionsUp { - case and: And => - val conjunctivePredicates = - splitConjunctivePredicates(and) - .filter(expr => expr.isInstanceOf[EqualTo] || expr.isInstanceOf[EqualNullSafe]) - .filterNot(expr => containsNonConjunctionPredicates(expr)) - - val equalityPredicates = conjunctivePredicates.collect { - case e @ EqualTo(left: AttributeReference, right: Literal) => ((left, right), e) - case e @ EqualTo(left: Literal, right: AttributeReference) => ((right, left), e) - case e @ EqualNullSafe(left: AttributeReference, right: Literal) => ((left, right), e) - case e @ EqualNullSafe(left: Literal, right: AttributeReference) => ((right, left), e) - } + case f: Filter => + val (newCondition, _) = traverse(f.condition, replaceChildren = true) + if (newCondition.isDefined) { + f.copy(condition = newCondition.get) + } else { + f + } + } - val constantsMap = AttributeMap(equalityPredicates.map(_._1)) - val predicates = equalityPredicates.map(_._2).toSet + type EqualityPredicates = Seq[((AttributeReference, Literal), BinaryComparison)] - def replaceConstants(expression: Expression) = expression transform { - case a: AttributeReference => - constantsMap.get(a) match { - case Some(literal) => literal - case None => a - } + /** + * Traverse a condition as a tree and replace attributes with constant values. + * - On matching [[And]], recursively traverse each children and get propagated mappings. + * If the current node is not child of another [[And]], replace all occurrences of the + * attributes with the corresponding constant values. + * - If a child of [[And]] is [[EqualTo]] or [[EqualNullSafe]], propagate the mapping + * of attribute => constant. + * - On matching [[Or]] or [[Not]], recursively traverse each children, propagate empty mapping. + * - Otherwise, stop traversal and propagate empty mapping. + * @param condition condition to be traversed + * @param replaceChildren whether to replace attributes with constant values in children + * @return A tuple including: + * 1. Option[Expression]: optional changed condition after traversal + * 2. EqualityPredicates: propagated mapping of attribute => constant + */ + private def traverse(condition: Expression, replaceChildren: Boolean) + : (Option[Expression], EqualityPredicates) = + condition match { + case e @ EqualTo(left: AttributeReference, right: Literal) => (None, Seq(((left, right), e))) + case e @ EqualTo(left: Literal, right: AttributeReference) => (None, Seq(((right, left), e))) + case e @ EqualNullSafe(left: AttributeReference, right: Literal) => + (None, Seq(((left, right), e))) + case e @ EqualNullSafe(left: Literal, right: AttributeReference) => + (None, Seq(((right, left), e))) + case a: And => + val (newLeft, equalityPredicatesLeft) = traverse(a.left, replaceChildren = false) + val (newRight, equalityPredicatesRight) = traverse(a.right, replaceChildren = false) + val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight + val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) { + Some(And(replaceConstants(newLeft.getOrElse(a.left), equalityPredicates), + replaceConstants(newRight.getOrElse(a.right), equalityPredicates))) + } else { + if (newLeft.isDefined || newRight.isDefined) { + Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right))) + } else { + None + } } - - and transform { - case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants(e) - case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants(e) + (newSelf, equalityPredicates) + case o: Or => + // Ignore the EqualityPredicates from children since they are only propagated through And. + val (newLeft, _) = traverse(o.left, replaceChildren = true) + val (newRight, _) = traverse(o.right, replaceChildren = true) + val newSelf = if (newLeft.isDefined || newRight.isDefined) { + Some(Or(left = newLeft.getOrElse(o.left), right = newRight.getOrElse((o.right)))) + } else { + None } + (newSelf, Seq.empty) + case n: Not => + // Ignore the EqualityPredicates from children since they are only propagated through And. + val (newChild, _) = traverse(n.child, replaceChildren = true) + (newChild.map(Not), Seq.empty) + case _ => (None, Seq.empty) + } + + private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates) + : Expression = { + val constantsMap = AttributeMap(equalityPredicates.map(_._1)) + val predicates = equalityPredicates.map(_._2).toSet + def replaceConstants0(expression: Expression) = expression transform { + case a: AttributeReference => constantsMap.getOrElse(a, a) + } + condition transform { + case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e) + case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala index 06196b5afb031..7a927e1e083b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala @@ -38,7 +38,7 @@ object EventTimeWatermark { case class EventTimeWatermark( eventTime: Attribute, delay: CalendarInterval, - child: LogicalPlan) extends LogicalPlan { + child: LogicalPlan) extends UnaryNode { // Update the metadata on the eventTime column to include the desired delay. override val output: Seq[Attribute] = child.output.map { a => @@ -60,6 +60,4 @@ case class EventTimeWatermark( a } } - - override val children: Seq[LogicalPlan] = child :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index fb759eba6a9e2..be638d80e45d8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types._ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -343,4 +344,14 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Least(inputsExpr), "s" * 1, EmptyRow) checkEvaluation(Greatest(inputsExpr), "s" * N, EmptyRow) } + + test("SPARK-22704: Least and greatest use less global variables") { + val ctx1 = new CodegenContext() + Least(Seq(Literal(1), Literal(1))).genCode(ctx1) + assert(ctx1.mutableStates.size == 1) + + val ctx2 = new CodegenContext() + Greatest(Seq(Literal(1), Literal(1))).genCode(ctx2) + assert(ctx2.mutableStates.size == 1) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index a4198f826cedb..40bf29bb3b573 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, GetExternalRowField, ValidateExternalType} +import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -380,4 +380,18 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { s"Incorrect Evaluation: expressions: $exprAnd, actual: $actualAnd, expected: $expectedAnd") } } + + test("SPARK-22696: CreateExternalRow should not use global variables") { + val ctx = new CodegenContext + val schema = new StructType().add("a", IntegerType).add("b", StringType) + CreateExternalRow(Seq(Literal(1), Literal("x")), schema).genCode(ctx) + assert(ctx.mutableStates.isEmpty) + } + + test("SPARK-22696: InitializeJavaBean should not use global variables") { + val ctx = new CodegenContext + InitializeJavaBean(Literal.fromObject(new java.util.LinkedList[Int]), + Map("add" -> Literal(1))).genCode(ctx) + assert(ctx.mutableStates.isEmpty) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index b0eaad1c80f89..6dfca7d73a3df 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -299,4 +300,10 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { new StringToMap(Literal("a=1_b=2_c=3"), Literal("_"), NonFoldableLiteral("=")) .checkInputDataTypes().isFailure) } + + test("SPARK-22693: CreateNamedStruct should not use global variables") { + val ctx = new CodegenContext + CreateNamedStruct(Seq("a", "x", "b", 2.0)).genCode(ctx) + assert(ctx.mutableStates.isEmpty) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index 3e11c3d2d4fe3..60d84aae1fa3f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types._ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -145,4 +146,10 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper IndexedSeq((Literal(12) === Literal(1), Literal(42)), (Literal(12) === Literal(42), Literal(1)))) } + + test("SPARK-22705: case when should use less global variables") { + val ctx = new CodegenContext() + CaseWhen(Seq((Literal.create(false, BooleanType), Literal(1))), Literal(-1)).genCode(ctx) + assert(ctx.mutableStates.size == 1) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index 40ef7770da33f..a23cd95632770 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.types._ @@ -155,6 +156,12 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Coalesce(inputs), "x_1") } + test("SPARK-22705: Coalesce should use less global variables") { + val ctx = new CodegenContext() + Coalesce(Seq(Literal("a"), Literal("b"))).genCode(ctx) + assert(ctx.mutableStates.size == 1) + } + test("AtLeastNNonNulls should not throw 64kb exception") { val inputs = (1 to 4000).map(x => Literal(s"x_$x")) checkEvaluation(AtLeastNNonNulls(1, inputs), true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 0079e4e8d6f74..15cb0bea08f17 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ @@ -245,6 +246,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(In(Literal(1.0D), sets), true) } + test("SPARK-22705: In should use less global variables") { + val ctx = new CodegenContext() + In(Literal(1.0D), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx) + assert(ctx.mutableStates.isEmpty) + } + test("INSET") { val hS = HashSet[Any]() + 1 + 2 val nS = HashSet[Any]() + 1 + 2 + null @@ -429,4 +436,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { val infinity = Literal(Double.PositiveInfinity) checkEvaluation(EqualTo(infinity, infinity), true) } + + test("SPARK-22693: InSet should not use global variables") { + val ctx = new CodegenContext + InSet(Literal(1), Set(1, 2, 3, 4)).genCode(ctx) + assert(ctx.mutableStates.isEmpty) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala index 13bd363c8b692..70dea4b39d55d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types.{IntegerType, StringType} class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -47,4 +48,9 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { assert(e2.getMessage.contains("Failed to execute user defined function")) } + test("SPARK-22695: ScalaUDF should not use global variables") { + val ctx = new CodegenContext + ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil).genCode(ctx) + assert(ctx.mutableStates.isEmpty) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index 0cd0d8859145f..6031bdf19e957 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -208,4 +208,15 @@ class GeneratedProjectionSuite extends SparkFunSuite { unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("b")))) assert(row.getStruct(0, 1).getString(0).toString == "a") } + + test("SPARK-22699: GenerateSafeProjection should not use global variables for struct") { + val safeProj = GenerateSafeProjection.generate( + Seq(BoundReference(0, new StructType().add("i", IntegerType), true))) + val globalVariables = safeProj.getClass.getDeclaredFields + // We need always 3 variables: + // - one is a reference to this + // - one is the references object + // - one is the mutableRow + assert(globalVariables.length == 3) + } } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 4db3fea008ee9..93010c606cf45 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -38,7 +38,7 @@ com.univocity univocity-parsers - 2.5.4 + 2.5.9 jar diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java index 9467435435d1f..24260b05194a7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java @@ -41,7 +41,7 @@ public class AggregateHashMap { private OnHeapColumnVector[] columnVectors; - private ColumnarBatch batch; + private MutableColumnarRow aggBufferRow; private int[] buckets; private int numBuckets; private int numRows = 0; @@ -63,7 +63,7 @@ public AggregateHashMap(StructType schema, int capacity, double loadFactor, int this.maxSteps = maxSteps; numBuckets = (int) (capacity / loadFactor); columnVectors = OnHeapColumnVector.allocateColumns(capacity, schema); - batch = new ColumnarBatch(schema, columnVectors, capacity); + aggBufferRow = new MutableColumnarRow(columnVectors); buckets = new int[numBuckets]; Arrays.fill(buckets, -1); } @@ -72,14 +72,15 @@ public AggregateHashMap(StructType schema) { this(schema, DEFAULT_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_MAX_STEPS); } - public ColumnarRow findOrInsert(long key) { + public MutableColumnarRow findOrInsert(long key) { int idx = find(key); if (idx != -1 && buckets[idx] == -1) { columnVectors[0].putLong(numRows, key); columnVectors[1].putLong(numRows, 0); buckets[idx] = numRows++; } - return batch.getRow(buckets[idx]); + aggBufferRow.rowId = buckets[idx]; + return aggBufferRow; } @VisibleForTesting diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 0071bd66760be..1f1347ccd315e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -323,7 +323,6 @@ public ArrowColumnVector(ValueVector vector) { for (int i = 0; i < childColumns.length; ++i) { childColumns[i] = new ArrowColumnVector(mapVector.getVectorById(i)); } - resultStruct = new ColumnarRow(childColumns); } else { throw new UnsupportedOperationException(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index cca14911fbb28..e6b87519239dd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -157,18 +157,16 @@ public abstract class ColumnVector implements AutoCloseable { /** * Returns a utility object to get structs. */ - public ColumnarRow getStruct(int rowId) { - resultStruct.rowId = rowId; - return resultStruct; + public final ColumnarRow getStruct(int rowId) { + return new ColumnarRow(this, rowId); } /** * Returns a utility object to get structs. * provided to keep API compatibility with InternalRow for code generation */ - public ColumnarRow getStruct(int rowId, int size) { - resultStruct.rowId = rowId; - return resultStruct; + public final ColumnarRow getStruct(int rowId, int size) { + return getStruct(rowId); } /** @@ -216,11 +214,6 @@ public MapData getMap(int ordinal) { */ protected DataType type; - /** - * Reusable Struct holder for getStruct(). - */ - protected ColumnarRow resultStruct; - /** * The Dictionary for this column. * diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 2f5fb360b226f..a9d09aa679726 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -18,6 +18,7 @@ import java.util.*; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.StructType; /** @@ -40,10 +41,10 @@ public final class ColumnarBatch { private final StructType schema; private final int capacity; private int numRows; - final ColumnVector[] columns; + private final ColumnVector[] columns; - // Staging row returned from getRow. - final ColumnarRow row; + // Staging row returned from `getRow`. + private final MutableColumnarRow row; /** * Called to close all the columns in this batch. It is not valid to access the data after @@ -58,10 +59,10 @@ public void close() { /** * Returns an iterator over the rows in this batch. This skips rows that are filtered out. */ - public Iterator rowIterator() { + public Iterator rowIterator() { final int maxRows = numRows; - final ColumnarRow row = new ColumnarRow(columns); - return new Iterator() { + final MutableColumnarRow row = new MutableColumnarRow(columns); + return new Iterator() { int rowId = 0; @Override @@ -70,7 +71,7 @@ public boolean hasNext() { } @Override - public ColumnarRow next() { + public InternalRow next() { if (rowId >= maxRows) { throw new NoSuchElementException(); } @@ -133,9 +134,8 @@ public void setNumRows(int numRows) { /** * Returns the row in this batch at `rowId`. Returned row is reused across calls. */ - public ColumnarRow getRow(int rowId) { - assert(rowId >= 0); - assert(rowId < numRows); + public InternalRow getRow(int rowId) { + assert(rowId >= 0 && rowId < numRows); row.rowId = rowId; return row; } @@ -144,6 +144,6 @@ public ColumnarBatch(StructType schema, ColumnVector[] columns, int capacity) { this.schema = schema; this.columns = columns; this.capacity = capacity; - this.row = new ColumnarRow(columns); + this.row = new MutableColumnarRow(columns); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java index cabb7479525d9..95c0d09873d67 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java @@ -28,30 +28,32 @@ * to be reused, callers should copy the data out if it needs to be stored. */ public final class ColumnarRow extends InternalRow { - protected int rowId; - private final ColumnVector[] columns; - - // Ctor used if this is a struct. - ColumnarRow(ColumnVector[] columns) { - this.columns = columns; + // The data for this row. E.g. the value of 3rd int field is `data.getChildColumn(3).getInt(rowId)`. + private final ColumnVector data; + private final int rowId; + private final int numFields; + + ColumnarRow(ColumnVector data, int rowId) { + assert (data.dataType() instanceof StructType); + this.data = data; + this.rowId = rowId; + this.numFields = ((StructType) data.dataType()).size(); } - public ColumnVector[] columns() { return columns; } - @Override - public int numFields() { return columns.length; } + public int numFields() { return numFields; } /** * Revisit this. This is expensive. This is currently only used in test paths. */ @Override public InternalRow copy() { - GenericInternalRow row = new GenericInternalRow(columns.length); + GenericInternalRow row = new GenericInternalRow(numFields); for (int i = 0; i < numFields(); i++) { if (isNullAt(i)) { row.setNullAt(i); } else { - DataType dt = columns[i].dataType(); + DataType dt = data.getChildColumn(i).dataType(); if (dt instanceof BooleanType) { row.setBoolean(i, getBoolean(i)); } else if (dt instanceof ByteType) { @@ -91,65 +93,65 @@ public boolean anyNull() { } @Override - public boolean isNullAt(int ordinal) { return columns[ordinal].isNullAt(rowId); } + public boolean isNullAt(int ordinal) { return data.getChildColumn(ordinal).isNullAt(rowId); } @Override - public boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); } + public boolean getBoolean(int ordinal) { return data.getChildColumn(ordinal).getBoolean(rowId); } @Override - public byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); } + public byte getByte(int ordinal) { return data.getChildColumn(ordinal).getByte(rowId); } @Override - public short getShort(int ordinal) { return columns[ordinal].getShort(rowId); } + public short getShort(int ordinal) { return data.getChildColumn(ordinal).getShort(rowId); } @Override - public int getInt(int ordinal) { return columns[ordinal].getInt(rowId); } + public int getInt(int ordinal) { return data.getChildColumn(ordinal).getInt(rowId); } @Override - public long getLong(int ordinal) { return columns[ordinal].getLong(rowId); } + public long getLong(int ordinal) { return data.getChildColumn(ordinal).getLong(rowId); } @Override - public float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); } + public float getFloat(int ordinal) { return data.getChildColumn(ordinal).getFloat(rowId); } @Override - public double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); } + public double getDouble(int ordinal) { return data.getChildColumn(ordinal).getDouble(rowId); } @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getDecimal(rowId, precision, scale); + if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; + return data.getChildColumn(ordinal).getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getUTF8String(rowId); + if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; + return data.getChildColumn(ordinal).getUTF8String(rowId); } @Override public byte[] getBinary(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getBinary(rowId); + if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; + return data.getChildColumn(ordinal).getBinary(rowId); } @Override public CalendarInterval getInterval(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - final int months = columns[ordinal].getChildColumn(0).getInt(rowId); - final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId); + if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; + final int months = data.getChildColumn(ordinal).getChildColumn(0).getInt(rowId); + final long microseconds = data.getChildColumn(ordinal).getChildColumn(1).getLong(rowId); return new CalendarInterval(months, microseconds); } @Override public ColumnarRow getStruct(int ordinal, int numFields) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getStruct(rowId); + if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; + return data.getChildColumn(ordinal).getStruct(rowId); } @Override public ColumnarArray getArray(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getArray(rowId); + if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; + return data.getChildColumn(ordinal).getArray(rowId); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index f272cc163611b..06602c147dfe9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -28,17 +28,24 @@ /** * A mutable version of {@link ColumnarRow}, which is used in the vectorized hash map for hash - * aggregate. + * aggregate, and {@link ColumnarBatch} to save object creation. * * Note that this class intentionally has a lot of duplicated code with {@link ColumnarRow}, to * avoid java polymorphism overhead by keeping {@link ColumnarRow} and this class final classes. */ public final class MutableColumnarRow extends InternalRow { public int rowId; - private final WritableColumnVector[] columns; + private final ColumnVector[] columns; + private final WritableColumnVector[] writableColumns; - public MutableColumnarRow(WritableColumnVector[] columns) { + public MutableColumnarRow(ColumnVector[] columns) { this.columns = columns; + this.writableColumns = null; + } + + public MutableColumnarRow(WritableColumnVector[] writableColumns) { + this.columns = writableColumns; + this.writableColumns = writableColumns; } @Override @@ -225,54 +232,54 @@ public void update(int ordinal, Object value) { @Override public void setNullAt(int ordinal) { - columns[ordinal].putNull(rowId); + writableColumns[ordinal].putNull(rowId); } @Override public void setBoolean(int ordinal, boolean value) { - columns[ordinal].putNotNull(rowId); - columns[ordinal].putBoolean(rowId, value); + writableColumns[ordinal].putNotNull(rowId); + writableColumns[ordinal].putBoolean(rowId, value); } @Override public void setByte(int ordinal, byte value) { - columns[ordinal].putNotNull(rowId); - columns[ordinal].putByte(rowId, value); + writableColumns[ordinal].putNotNull(rowId); + writableColumns[ordinal].putByte(rowId, value); } @Override public void setShort(int ordinal, short value) { - columns[ordinal].putNotNull(rowId); - columns[ordinal].putShort(rowId, value); + writableColumns[ordinal].putNotNull(rowId); + writableColumns[ordinal].putShort(rowId, value); } @Override public void setInt(int ordinal, int value) { - columns[ordinal].putNotNull(rowId); - columns[ordinal].putInt(rowId, value); + writableColumns[ordinal].putNotNull(rowId); + writableColumns[ordinal].putInt(rowId, value); } @Override public void setLong(int ordinal, long value) { - columns[ordinal].putNotNull(rowId); - columns[ordinal].putLong(rowId, value); + writableColumns[ordinal].putNotNull(rowId); + writableColumns[ordinal].putLong(rowId, value); } @Override public void setFloat(int ordinal, float value) { - columns[ordinal].putNotNull(rowId); - columns[ordinal].putFloat(rowId, value); + writableColumns[ordinal].putNotNull(rowId); + writableColumns[ordinal].putFloat(rowId, value); } @Override public void setDouble(int ordinal, double value) { - columns[ordinal].putNotNull(rowId); - columns[ordinal].putDouble(rowId, value); + writableColumns[ordinal].putNotNull(rowId); + writableColumns[ordinal].putDouble(rowId, value); } @Override public void setDecimal(int ordinal, Decimal value, int precision) { - columns[ordinal].putNotNull(rowId); - columns[ordinal].putDecimal(rowId, value, precision); + writableColumns[ordinal].putNotNull(rowId); + writableColumns[ordinal].putDecimal(rowId, value, precision); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 806d0291a6c49..5f1b9885334b7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -547,7 +547,7 @@ protected void reserveInternal(int newCapacity) { } else if (type instanceof LongType || type instanceof DoubleType || DecimalType.is64BitDecimalType(type) || type instanceof TimestampType) { this.data = Platform.reallocateMemory(data, oldCapacity * 8, newCapacity * 8); - } else if (resultStruct != null) { + } else if (childColumns != null) { // Nothing to store. } else { throw new RuntimeException("Unhandled " + type); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 6e7f74ce12f16..f12772ede575d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -558,7 +558,7 @@ protected void reserveInternal(int newCapacity) { if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, capacity); doubleData = newData; } - } else if (resultStruct != null) { + } else if (childColumns != null) { // Nothing to store. } else { throw new RuntimeException("Unhandled " + type); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 0bea4cc97142d..7c053b579442c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -74,7 +74,6 @@ public void close() { dictionaryIds = null; } dictionary = null; - resultStruct = null; } public void reserve(int requiredCapacity) { @@ -673,23 +672,19 @@ protected WritableColumnVector(int capacity, DataType type) { } this.childColumns = new WritableColumnVector[1]; this.childColumns[0] = reserveNewColumn(childCapacity, childType); - this.resultStruct = null; } else if (type instanceof StructType) { StructType st = (StructType)type; this.childColumns = new WritableColumnVector[st.fields().length]; for (int i = 0; i < childColumns.length; ++i) { this.childColumns[i] = reserveNewColumn(capacity, st.fields()[i].dataType()); } - this.resultStruct = new ColumnarRow(this.childColumns); } else if (type instanceof CalendarIntervalType) { // Two columns. Months as int. Microseconds as Long. this.childColumns = new WritableColumnVector[2]; this.childColumns[0] = reserveNewColumn(capacity, DataTypes.IntegerType); this.childColumns[1] = reserveNewColumn(capacity, DataTypes.LongType); - this.resultStruct = new ColumnarRow(this.childColumns); } else { this.childColumns = null; - this.resultStruct = null; } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java index 9a89c8193dd6e..b2c908dc73a61 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java @@ -49,4 +49,35 @@ public DataSourceV2Options(Map originalMap) { public Optional get(String key) { return Optional.ofNullable(keyLowerCasedMap.get(toLowerCase(key))); } + + /** + * Returns the boolean value to which the specified key is mapped, + * or defaultValue if there is no mapping for the key. The key match is case-insensitive + */ + public boolean getBoolean(String key, boolean defaultValue) { + String lcaseKey = toLowerCase(key); + return keyLowerCasedMap.containsKey(lcaseKey) ? + Boolean.parseBoolean(keyLowerCasedMap.get(lcaseKey)) : defaultValue; + } + + /** + * Returns the integer value to which the specified key is mapped, + * or defaultValue if there is no mapping for the key. The key match is case-insensitive + */ + public int getInt(String key, int defaultValue) { + String lcaseKey = toLowerCase(key); + return keyLowerCasedMap.containsKey(lcaseKey) ? + Integer.parseInt(keyLowerCasedMap.get(lcaseKey)) : defaultValue; + } + + /** + * Returns the long value to which the specified key is mapped, + * or defaultValue if there is no mapping for the key. The key match is case-insensitive + */ + public long getLong(String key, long defaultValue) { + String lcaseKey = toLowerCase(key); + return keyLowerCasedMap.containsKey(lcaseKey) ? + Long.parseLong(keyLowerCasedMap.get(lcaseKey)) : defaultValue; + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 657b265260135..787c1cfbfb3d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionContext import org.codehaus.commons.compiler.CompileException -import org.codehaus.janino.JaninoRuntimeException +import org.codehaus.janino.InternalCompilerException import org.apache.spark.{broadcast, SparkEnv} import org.apache.spark.internal.Logging @@ -385,7 +385,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ try { GeneratePredicate.generate(expression, inputSchema) } catch { - case _ @ (_: JaninoRuntimeException | _: CompileException) if codeGenFallBack => + case _ @ (_: InternalCompilerException | _: CompileException) if codeGenFallBack => genInterpretedPredicate(expression, inputSchema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 26d8cd7278353..9cadd13999e72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -595,9 +595,7 @@ case class HashAggregateExec( ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, s"$fastHashMapTerm = new $fastHashMapClassName();") - ctx.addMutableState( - s"java.util.Iterator<${classOf[ColumnarRow].getName}>", - iterTermForFastHashMap) + ctx.addMutableState(s"java.util.Iterator", iterTermForFastHashMap) } else { val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions, fastHashMapClassName, groupingKeySchema, bufferSchema).generate() @@ -674,7 +672,7 @@ case class HashAggregateExec( """.stripMargin } - // Iterate over the aggregate rows and convert them from ColumnarRow to UnsafeRow + // Iterate over the aggregate rows and convert them from InternalRow to UnsafeRow def outputFromVectorizedMap: String = { val row = ctx.freshName("fastHashMapRow") ctx.currentVars = null @@ -687,10 +685,9 @@ case class HashAggregateExec( bufferSchema.toAttributes.zipWithIndex.map { case (attr, i) => BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable) }) - val columnarRowCls = classOf[ColumnarRow].getName s""" |while ($iterTermForFastHashMap.hasNext()) { - | $columnarRowCls $row = ($columnarRowCls) $iterTermForFastHashMap.next(); + | InternalRow $row = (InternalRow) $iterTermForFastHashMap.next(); | ${generateKeyRow.code} | ${generateBufferRow.code} | $outputFunc(${generateKeyRow.value}, ${generateBufferRow.value}); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 44ba539ebf7c2..f04cd48072f17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext -import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnarRow, MutableColumnarRow, OnHeapColumnVector} +import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, MutableColumnarRow, OnHeapColumnVector} import org.apache.spark.sql.types._ /** @@ -231,7 +232,7 @@ class VectorizedHashMapGenerator( protected def generateRowIterator(): String = { s""" - |public java.util.Iterator<${classOf[ColumnarRow].getName}> rowIterator() { + |public java.util.Iterator<${classOf[InternalRow].getName}> rowIterator() { | batch.setNumRows(numRows); | return batch.rowIterator(); |} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala index 2f09757aa341c..341ade1a5c613 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala @@ -35,7 +35,7 @@ private[columnar] trait NullableColumnAccessor extends ColumnAccessor { nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 pos = 0 - underlyingBuffer.position(underlyingBuffer.position + 4 + nullCount * 4) + underlyingBuffer.position(underlyingBuffer.position() + 4 + nullCount * 4) super.initialize() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala index bf00ad997c76e..79dcf3a6105ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala @@ -112,7 +112,7 @@ private[columnar] case object PassThrough extends CompressionScheme { var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else capacity var pos = 0 var seenNulls = 0 - var bufferPos = buffer.position + var bufferPos = buffer.position() while (pos < capacity) { if (pos != nextNullIndex) { val len = nextNullIndex - pos diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 75c42213db3c8..f7471cd7debce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -124,7 +125,7 @@ class OrcFileFormat true } - override def buildReader( + override def buildReaderWithPartitionValues( sparkSession: SparkSession, dataSchema: StructType, partitionSchema: StructType, @@ -167,9 +168,17 @@ class OrcFileFormat val iter = new RecordReaderIterator[OrcStruct](orcRecordReader) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) - val unsafeProjection = UnsafeProjection.create(requiredSchema) + val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes + val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) val deserializer = new OrcDeserializer(dataSchema, requiredSchema, requestedColIds) - iter.map(value => unsafeProjection(deserializer.deserialize(value))) + + if (partitionSchema.length == 0) { + iter.map(value => unsafeProjection(deserializer.deserialize(value))) + } else { + val joinedRow = new JoinedRow() + iter.map(value => + unsafeProjection(joinedRow(deserializer.deserialize(value), file.partitionValues))) + } } } } diff --git a/sql/core/src/test/resources/test-data/comments.csv b/sql/core/src/test/resources/test-data/comments.csv index 6275be7285b36..c0ace46db8c00 100644 --- a/sql/core/src/test/resources/test-data/comments.csv +++ b/sql/core/src/test/resources/test-data/comments.csv @@ -4,3 +4,4 @@ 6,7,8,9,0,2015-08-21 16:58:01 ~0,9,8,7,6,2015-08-22 17:59:02 1,2,3,4,5,2015-08-23 18:00:42 +~ comment in last line to test SPARK-22516 - do not add empty line at the end of this file! \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 8ddddbeee598f..5e077285ade55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2775,32 +2775,4 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } - - test("SPARK-21791 ORC should support column names with dot") { - val orc = classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName - withTempDir { dir => - val path = new File(dir, "orc").getCanonicalPath - Seq(Some(1), None).toDF("col.dots").write.format(orc).save(path) - assert(spark.read.format(orc).load(path).collect().length == 2) - } - } - - test("SPARK-20728 Make ORCFileFormat configurable between sql/hive and sql/core") { - withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> "hive") { - val e = intercept[AnalysisException] { - sql("CREATE TABLE spark_20728(a INT) USING ORC") - } - assert(e.message.contains("Hive built-in ORC data source must be used with Hive support")) - } - - withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> "native") { - withTable("spark_20728") { - sql("CREATE TABLE spark_20728(a INT) USING ORC") - val fileFormat = sql("SELECT * FROM spark_20728").queryExecution.analyzed.collectFirst { - case l: LogicalRelation => l.relation.asInstanceOf[HadoopFsRelation].fileFormat.getClass - } - assert(fileFormat == Some(classOf[OrcFileFormat])) - } - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index e439699605abb..4fe45420b4e77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -483,18 +483,21 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("commented lines in CSV data") { - val results = spark.read - .format("csv") - .options(Map("comment" -> "~", "header" -> "false")) - .load(testFile(commentsFile)) - .collect() + Seq("false", "true").foreach { multiLine => - val expected = - Seq(Seq("1", "2", "3", "4", "5.01", "2015-08-20 15:57:00"), - Seq("6", "7", "8", "9", "0", "2015-08-21 16:58:01"), - Seq("1", "2", "3", "4", "5", "2015-08-23 18:00:42")) + val results = spark.read + .format("csv") + .options(Map("comment" -> "~", "header" -> "false", "multiLine" -> multiLine)) + .load(testFile(commentsFile)) + .collect() - assert(results.toSeq.map(_.toSeq) === expected) + val expected = + Seq(Seq("1", "2", "3", "4", "5.01", "2015-08-20 15:57:00"), + Seq("6", "7", "8", "9", "0", "2015-08-21 16:58:01"), + Seq("1", "2", "3", "4", "5", "2015-08-23 18:00:42")) + + assert(results.toSeq.map(_.toSeq) === expected) + } } test("inferring schema with commented lines in CSV data") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala similarity index 87% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index de6f0d67f1734..a5f6b68ee862e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -15,25 +15,32 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.orc +package org.apache.spark.sql.execution.datasources.orc import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ -import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument} +import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument} -import org.apache.spark.sql.{Column, DataFrame, QueryTest} +import org.apache.spark.sql.{Column, DataFrame} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ /** - * A test suite that tests ORC filter API based filter pushdown optimization. + * A test suite that tests Apache ORC filter API based filter pushdown optimization. + * OrcFilterSuite and HiveOrcFilterSuite is logically duplicated to provide the same test coverage. + * The difference are the packages containing 'Predicate' and 'SearchArgument' classes. + * - OrcFilterSuite uses 'org.apache.orc.storage.ql.io.sarg' package. + * - HiveOrcFilterSuite uses 'org.apache.hadoop.hive.ql.io.sarg' package. */ -class OrcFilterSuite extends QueryTest with OrcTest { +class OrcFilterSuite extends OrcTest with SharedSQLContext { + private def checkFilterPredicate( df: DataFrame, predicate: Predicate, @@ -55,7 +62,7 @@ class OrcFilterSuite extends QueryTest with OrcTest { DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) assert(selectedFilters.nonEmpty, "No filter is pushed down") - val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters.toArray) + val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $selectedFilters") checker(maybeFilter.get) } @@ -99,7 +106,7 @@ class OrcFilterSuite extends QueryTest with OrcTest { DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) assert(selectedFilters.nonEmpty, "No filter is pushed down") - val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters.toArray) + val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters) assert(maybeFilter.isEmpty, s"Could generate filter predicate for $selectedFilters") } @@ -284,40 +291,27 @@ class OrcFilterSuite extends QueryTest with OrcTest { test("filter pushdown - combinations with logical operators") { withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => - // Because `ExpressionTree` is not accessible at Hive 1.2.x, this should be checked - // in string form in order to check filter creation including logical operators - // such as `and`, `or` or `not`. So, this function uses `SearchArgument.toString()` - // to produce string expression and then compare it to given string expression below. - // This might have to be changed after Hive version is upgraded. checkFilterPredicate( '_1.isNotNull, - """leaf-0 = (IS_NULL _1) - |expr = (not leaf-0)""".stripMargin.trim + "leaf-0 = (IS_NULL _1), expr = (not leaf-0)" ) checkFilterPredicate( '_1 =!= 1, - """leaf-0 = (IS_NULL _1) - |leaf-1 = (EQUALS _1 1) - |expr = (and (not leaf-0) (not leaf-1))""".stripMargin.trim + "leaf-0 = (IS_NULL _1), leaf-1 = (EQUALS _1 1), expr = (and (not leaf-0) (not leaf-1))" ) checkFilterPredicate( !('_1 < 4), - """leaf-0 = (IS_NULL _1) - |leaf-1 = (LESS_THAN _1 4) - |expr = (and (not leaf-0) (not leaf-1))""".stripMargin.trim + "leaf-0 = (IS_NULL _1), leaf-1 = (LESS_THAN _1 4), expr = (and (not leaf-0) (not leaf-1))" ) checkFilterPredicate( '_1 < 2 || '_1 > 3, - """leaf-0 = (LESS_THAN _1 2) - |leaf-1 = (LESS_THAN_EQUALS _1 3) - |expr = (or leaf-0 (not leaf-1))""".stripMargin.trim + "leaf-0 = (LESS_THAN _1 2), leaf-1 = (LESS_THAN_EQUALS _1 3), " + + "expr = (or leaf-0 (not leaf-1))" ) checkFilterPredicate( '_1 < 2 && '_1 > 3, - """leaf-0 = (IS_NULL _1) - |leaf-1 = (LESS_THAN _1 2) - |leaf-2 = (LESS_THAN_EQUALS _1 3) - |expr = (and (not leaf-0) leaf-1 (not leaf-2))""".stripMargin.trim + "leaf-0 = (IS_NULL _1), leaf-1 = (LESS_THAN _1 2), leaf-2 = (LESS_THAN_EQUALS _1 3), " + + "expr = (and (not leaf-0) leaf-1 (not leaf-2))" ) } } @@ -344,4 +338,30 @@ class OrcFilterSuite extends QueryTest with OrcTest { checkNoFilterPredicate('_1.isNotNull) } } + + test("SPARK-12218 Converting conjunctions into ORC SearchArguments") { + import org.apache.spark.sql.sources._ + // The `LessThan` should be converted while the `StringContains` shouldn't + val schema = new StructType( + Array( + StructField("a", IntegerType, nullable = true), + StructField("b", StringType, nullable = true))) + assertResult("leaf-0 = (LESS_THAN a 10), expr = leaf-0") { + OrcFilters.createFilter(schema, Array( + LessThan("a", 10), + StringContains("b", "prefix") + )).get.toString + } + + // The `LessThan` should be converted while the whole inner `And` shouldn't + assertResult("leaf-0 = (LESS_THAN a 10), expr = leaf-0") { + OrcFilters.createFilter(schema, Array( + LessThan("a", 10), + Not(And( + GreaterThan("a", 1), + StringContains("b", "prefix") + )) + )).get.toString + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala similarity index 82% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala index d1ce3f1e2f058..d1911ea7f32a9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala @@ -15,19 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.orc +package org.apache.spark.sql.execution.datasources.orc import java.io.File -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.TypeTag - -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.sql._ -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.util.Utils +import org.apache.spark.sql.test.SharedSQLContext // The data where the partitioning key exists only in the directory structure. case class OrcParData(intField: Int, stringField: String) @@ -35,28 +28,8 @@ case class OrcParData(intField: Int, stringField: String) // The data that also includes the partitioning key case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) -// TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot -class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { - import spark._ - import spark.implicits._ - - val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultStrVal - - def withTempDir(f: File => Unit): Unit = { - val dir = Utils.createTempDir().getCanonicalFile - try f(dir) finally Utils.deleteRecursively(dir) - } - - def makeOrcFile[T <: Product: ClassTag: TypeTag]( - data: Seq[T], path: File): Unit = { - data.toDF().write.mode("overwrite").orc(path.getCanonicalPath) - } - - - def makeOrcFile[T <: Product: ClassTag: TypeTag]( - df: DataFrame, path: File): Unit = { - df.write.mode("overwrite").orc(path.getCanonicalPath) - } +abstract class OrcPartitionDiscoveryTest extends OrcTest { + val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__" protected def withTempTable(tableName: String)(f: => Unit): Unit = { try f finally spark.catalog.dropTempView(tableName) @@ -90,7 +63,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with B makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - read.orc(base.getCanonicalPath).createOrReplaceTempView("t") + spark.read.orc(base.getCanonicalPath).createOrReplaceTempView("t") withTempTable("t") { checkAnswer( @@ -137,7 +110,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with B makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - read.orc(base.getCanonicalPath).createOrReplaceTempView("t") + spark.read.orc(base.getCanonicalPath).createOrReplaceTempView("t") withTempTable("t") { checkAnswer( @@ -186,8 +159,8 @@ class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with B makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - read - .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) + spark.read + .option("hive.exec.default.partition.name", defaultPartitionName) .orc(base.getCanonicalPath) .createOrReplaceTempView("t") @@ -228,8 +201,8 @@ class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with B makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - read - .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) + spark.read + .option("hive.exec.default.partition.name", defaultPartitionName) .orc(base.getCanonicalPath) .createOrReplaceTempView("t") @@ -253,3 +226,4 @@ class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with B } } +class OrcPartitionDiscoverySuite extends OrcPartitionDiscoveryTest with SharedSQLContext diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala similarity index 68% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index 1ffaf30311037..e00e057a18cc6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -15,24 +15,27 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.orc +package org.apache.spark.sql.execution.datasources.orc +import java.io.File import java.nio.charset.StandardCharsets import java.sql.Timestamp import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.ql.io.orc.{OrcStruct, SparkOrcNewRecordReader} +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{JobID, TaskAttemptID, TaskID, TaskType} +import org.apache.hadoop.mapreduce.lib.input.FileSplit +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.orc.{OrcConf, OrcFile} import org.apache.orc.OrcConf.COMPRESS -import org.scalatest.BeforeAndAfterAll +import org.apache.orc.mapred.OrcStruct +import org.apache.orc.mapreduce.OrcInputFormat import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, RecordReaderIterator} -import org.apache.spark.sql.hive.HiveUtils -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.util.Utils @@ -57,7 +60,8 @@ case class Contact(name: String, phone: String) case class Person(name: String, age: Int, contacts: Seq[Contact]) -class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { +abstract class OrcQueryTest extends OrcTest { + import testImplicits._ test("Read/write All Types") { val data = (0 to 255).map { i => @@ -73,7 +77,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { test("Read/write binary data") { withOrcFile(BinaryData("test".getBytes(StandardCharsets.UTF_8)) :: Nil) { file => - val bytes = read.orc(file).head().getAs[Array[Byte]](0) + val bytes = spark.read.orc(file).head().getAs[Array[Byte]](0) assert(new String(bytes, StandardCharsets.UTF_8) === "test") } } @@ -91,7 +95,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withOrcFile(data) { file => checkAnswer( - read.orc(file), + spark.read.orc(file), data.toDF().collect()) } } @@ -172,7 +176,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withOrcFile(data) { file => checkAnswer( - read.orc(file), + spark.read.orc(file), Row(Seq.fill(5)(null): _*)) } } @@ -183,9 +187,13 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { spark.range(0, 10).write .option(COMPRESS.getAttribute, "ZLIB") .orc(file.getCanonicalPath) - val expectedCompressionKind = - OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression - assert("ZLIB" === expectedCompressionKind.name()) + + val maybeOrcFile = file.listFiles().find(_.getName.endsWith(".zlib.orc")) + assert(maybeOrcFile.isDefined) + + val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath) + val conf = OrcFile.readerOptions(new Configuration()) + assert("ZLIB" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name) } // `compression` overrides `orc.compress`. @@ -194,9 +202,13 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { .option("compression", "ZLIB") .option(COMPRESS.getAttribute, "SNAPPY") .orc(file.getCanonicalPath) - val expectedCompressionKind = - OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression - assert("ZLIB" === expectedCompressionKind.name()) + + val maybeOrcFile = file.listFiles().find(_.getName.endsWith(".zlib.orc")) + assert(maybeOrcFile.isDefined) + + val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath) + val conf = OrcFile.readerOptions(new Configuration()) + assert("ZLIB" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name) } } @@ -206,39 +218,39 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { spark.range(0, 10).write .option("compression", "ZLIB") .orc(file.getCanonicalPath) - val expectedCompressionKind = - OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression - assert("ZLIB" === expectedCompressionKind.name()) + + val maybeOrcFile = file.listFiles().find(_.getName.endsWith(".zlib.orc")) + assert(maybeOrcFile.isDefined) + + val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath) + val conf = OrcFile.readerOptions(new Configuration()) + assert("ZLIB" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name) } withTempPath { file => spark.range(0, 10).write .option("compression", "SNAPPY") .orc(file.getCanonicalPath) - val expectedCompressionKind = - OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression - assert("SNAPPY" === expectedCompressionKind.name()) + + val maybeOrcFile = file.listFiles().find(_.getName.endsWith(".snappy.orc")) + assert(maybeOrcFile.isDefined) + + val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath) + val conf = OrcFile.readerOptions(new Configuration()) + assert("SNAPPY" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name) } withTempPath { file => spark.range(0, 10).write .option("compression", "NONE") .orc(file.getCanonicalPath) - val expectedCompressionKind = - OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression - assert("NONE" === expectedCompressionKind.name()) - } - } - // Following codec is not supported in Hive 1.2.1, ignore it now - ignore("LZO compression options for writing to an ORC file not supported in Hive 1.2.1") { - withTempPath { file => - spark.range(0, 10).write - .option("compression", "LZO") - .orc(file.getCanonicalPath) - val expectedCompressionKind = - OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression - assert("LZO" === expectedCompressionKind.name()) + val maybeOrcFile = file.listFiles().find(_.getName.endsWith(".orc")) + assert(maybeOrcFile.isDefined) + + val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath) + val conf = OrcFile.readerOptions(new Configuration()) + assert("NONE" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name) } } @@ -256,22 +268,28 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { test("appending") { val data = (0 until 10).map(i => (i, i.toString)) - createDataFrame(data).toDF("c1", "c2").createOrReplaceTempView("tmp") + spark.createDataFrame(data).toDF("c1", "c2").createOrReplaceTempView("tmp") withOrcTable(data, "t") { sql("INSERT INTO TABLE t SELECT * FROM tmp") - checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) + checkAnswer(spark.table("t"), (data ++ data).map(Row.fromTuple)) } - sessionState.catalog.dropTable(TableIdentifier("tmp"), ignoreIfNotExists = true, purge = false) + spark.sessionState.catalog.dropTable( + TableIdentifier("tmp"), + ignoreIfNotExists = true, + purge = false) } test("overwriting") { val data = (0 until 10).map(i => (i, i.toString)) - createDataFrame(data).toDF("c1", "c2").createOrReplaceTempView("tmp") + spark.createDataFrame(data).toDF("c1", "c2").createOrReplaceTempView("tmp") withOrcTable(data, "t") { sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") - checkAnswer(table("t"), data.map(Row.fromTuple)) + checkAnswer(spark.table("t"), data.map(Row.fromTuple)) } - sessionState.catalog.dropTable(TableIdentifier("tmp"), ignoreIfNotExists = true, purge = false) + spark.sessionState.catalog.dropTable( + TableIdentifier("tmp"), + ignoreIfNotExists = true, + purge = false) } test("self-join") { @@ -334,60 +352,16 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withTempPath { dir => val path = dir.getCanonicalPath - spark.range(0, 10).select('id as "Acol").write.format("orc").save(path) - spark.read.format("orc").load(path).schema("Acol") + spark.range(0, 10).select('id as "Acol").write.orc(path) + spark.read.orc(path).schema("Acol") intercept[IllegalArgumentException] { - spark.read.format("orc").load(path).schema("acol") + spark.read.orc(path).schema("acol") } - checkAnswer(spark.read.format("orc").load(path).select("acol").sort("acol"), + checkAnswer(spark.read.orc(path).select("acol").sort("acol"), (0 until 10).map(Row(_))) } } - test("SPARK-8501: Avoids discovery schema from empty ORC files") { - withTempPath { dir => - val path = dir.getCanonicalPath - - withTable("empty_orc") { - withTempView("empty", "single") { - spark.sql( - s"""CREATE TABLE empty_orc(key INT, value STRING) - |STORED AS ORC - |LOCATION '${dir.toURI}' - """.stripMargin) - - val emptyDF = Seq.empty[(Int, String)].toDF("key", "value").coalesce(1) - emptyDF.createOrReplaceTempView("empty") - - // This creates 1 empty ORC file with Hive ORC SerDe. We are using this trick because - // Spark SQL ORC data source always avoids write empty ORC files. - spark.sql( - s"""INSERT INTO TABLE empty_orc - |SELECT key, value FROM empty - """.stripMargin) - - val errorMessage = intercept[AnalysisException] { - spark.read.orc(path) - }.getMessage - - assert(errorMessage.contains("Unable to infer schema for ORC")) - - val singleRowDF = Seq((0, "foo")).toDF("key", "value").coalesce(1) - singleRowDF.createOrReplaceTempView("single") - - spark.sql( - s"""INSERT INTO TABLE empty_orc - |SELECT key, value FROM single - """.stripMargin) - - val df = spark.read.orc(path) - assert(df.schema === singleRowDF.schema.asNullable) - checkAnswer(df, singleRowDF) - } - } - } - } - test("SPARK-10623 Enable ORC PPD") { withTempPath { dir => withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { @@ -405,7 +379,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } // It needs to repartition data so that we can have several ORC files // in order to skip stripes in ORC. - createDataFrame(data).toDF("a", "b").repartition(10).write.orc(path) + spark.createDataFrame(data).toDF("a", "b").repartition(10).write.orc(path) val df = spark.read.orc(path) def checkPredicate(pred: Column, answer: Seq[Row]): Unit = { @@ -440,77 +414,6 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } - test("Verify the ORC conversion parameter: CONVERT_METASTORE_ORC") { - withTempView("single") { - val singleRowDF = Seq((0, "foo")).toDF("key", "value") - singleRowDF.createOrReplaceTempView("single") - - Seq("true", "false").foreach { orcConversion => - withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> orcConversion) { - withTable("dummy_orc") { - withTempPath { dir => - val path = dir.getCanonicalPath - spark.sql( - s""" - |CREATE TABLE dummy_orc(key INT, value STRING) - |STORED AS ORC - |LOCATION '${dir.toURI}' - """.stripMargin) - - spark.sql( - s""" - |INSERT INTO TABLE dummy_orc - |SELECT key, value FROM single - """.stripMargin) - - val df = spark.sql("SELECT * FROM dummy_orc WHERE key=0") - checkAnswer(df, singleRowDF) - - val queryExecution = df.queryExecution - if (orcConversion == "true") { - queryExecution.analyzed.collectFirst { - case _: LogicalRelation => () - }.getOrElse { - fail(s"Expecting the query plan to convert orc to data sources, " + - s"but got:\n$queryExecution") - } - } else { - queryExecution.analyzed.collectFirst { - case _: HiveTableRelation => () - }.getOrElse { - fail(s"Expecting no conversion from orc to data sources, " + - s"but got:\n$queryExecution") - } - } - } - } - } - } - } - } - - test("converted ORC table supports resolving mixed case field") { - withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "true") { - withTable("dummy_orc") { - withTempPath { dir => - val df = spark.range(5).selectExpr("id", "id as valueField", "id as partitionValue") - df.write - .partitionBy("partitionValue") - .mode("overwrite") - .orc(dir.getAbsolutePath) - - spark.sql(s""" - |create external table dummy_orc (id long, valueField long) - |partitioned by (partitionValue int) - |stored as orc - |location "${dir.toURI}"""".stripMargin) - spark.sql(s"msck repair table dummy_orc") - checkAnswer(spark.sql("select * from dummy_orc"), df) - } - } - } - } - test("SPARK-14962 Produce correct results on array type with isnotnull") { withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { val data = (0 until 10).map(i => Tuple1(Array(i))) @@ -544,7 +447,8 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withTempPath { file => // It needs to repartition data so that we can have several ORC files // in order to skip stripes in ORC. - createDataFrame(data).toDF("a").repartition(10).write.orc(file.getCanonicalPath) + spark.createDataFrame(data).toDF("a").repartition(10) + .write.orc(file.getCanonicalPath) val df = spark.read.orc(file.getCanonicalPath).where("a == 2") val actual = stripSparkFilter(df).count() @@ -563,7 +467,8 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withTempPath { file => // It needs to repartition data so that we can have several ORC files // in order to skip stripes in ORC. - createDataFrame(data).toDF("a").repartition(10).write.orc(file.getCanonicalPath) + spark.createDataFrame(data).toDF("a").repartition(10) + .write.orc(file.getCanonicalPath) val df = spark.read.orc(file.getCanonicalPath).where(s"a == '$timeString'") val actual = stripSparkFilter(df).count() @@ -596,14 +501,18 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { test("Empty schema does not read data from ORC file") { val data = Seq((1, 1), (2, 2)) withOrcFile(data) { path => - val requestedSchema = StructType(Nil) val conf = new Configuration() - val physicalSchema = OrcFileOperator.readSchema(Seq(path), Some(conf)).get - OrcFileFormat.setRequiredColumns(conf, physicalSchema, requestedSchema) - val maybeOrcReader = OrcFileOperator.getFileReader(path, Some(conf)) - assert(maybeOrcReader.isDefined) - val orcRecordReader = new SparkOrcNewRecordReader( - maybeOrcReader.get, conf, 0, maybeOrcReader.get.getContentLength) + conf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, "") + conf.setBoolean("hive.io.file.read.all.columns", false) + + val orcRecordReader = { + val file = new File(path).listFiles().find(_.getName.endsWith(".snappy.orc")).head + val split = new FileSplit(new Path(file.toURI), 0, file.length, Array.empty[String]) + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + val oif = new OrcInputFormat[OrcStruct] + oif.createRecordReader(split, hadoopAttemptContext) + } val recordsIterator = new RecordReaderIterator[OrcStruct](orcRecordReader) try { @@ -614,27 +523,88 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } - test("read from multiple orc input paths") { - val path1 = Utils.createTempDir() - val path2 = Utils.createTempDir() - makeOrcFile((1 to 10).map(Tuple1.apply), path1) - makeOrcFile((1 to 10).map(Tuple1.apply), path2) - assertResult(20)(read.orc(path1.getCanonicalPath, path2.getCanonicalPath).count()) - } + test("read from multiple orc input paths") { + val path1 = Utils.createTempDir() + val path2 = Utils.createTempDir() + makeOrcFile((1 to 10).map(Tuple1.apply), path1) + makeOrcFile((1 to 10).map(Tuple1.apply), path2) + val df = spark.read.orc(path1.getCanonicalPath, path2.getCanonicalPath) + assert(df.count() == 20) + } +} + +class OrcQuerySuite extends OrcQueryTest with SharedSQLContext { + import testImplicits._ + + test("LZO compression options for writing to an ORC file") { + withTempPath { file => + spark.range(0, 10).write + .option("compression", "LZO") + .orc(file.getCanonicalPath) + + val maybeOrcFile = file.listFiles().find(_.getName.endsWith(".lzo.orc")) + assert(maybeOrcFile.isDefined) + + val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath) + val conf = OrcFile.readerOptions(new Configuration()) + assert("LZO" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name) + } + } + + test("Schema discovery on empty ORC files") { + // SPARK-8501 is fixed. + withTempPath { dir => + val path = dir.getCanonicalPath + + withTable("empty_orc") { + withTempView("empty", "single") { + spark.sql( + s"""CREATE TABLE empty_orc(key INT, value STRING) + |USING ORC + |LOCATION '${dir.toURI}' + """.stripMargin) + + val emptyDF = Seq.empty[(Int, String)].toDF("key", "value").coalesce(1) + emptyDF.createOrReplaceTempView("empty") + + // This creates 1 empty ORC file with ORC SerDe. We are using this trick because + // Spark SQL ORC data source always avoids write empty ORC files. + spark.sql( + s"""INSERT INTO TABLE empty_orc + |SELECT key, value FROM empty + """.stripMargin) + + val df = spark.read.orc(path) + assert(df.schema === emptyDF.schema.asNullable) + checkAnswer(df, emptyDF) + } + } + } + } + + test("SPARK-21791 ORC should support column names with dot") { + withTempDir { dir => + val path = new File(dir, "orc").getCanonicalPath + Seq(Some(1), None).toDF("col.dots").write.orc(path) + assert(spark.read.orc(path).collect().length == 2) + } + } test("SPARK-20728 Make ORCFileFormat configurable between sql/hive and sql/core") { - Seq( - ("native", classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat]), - ("hive", classOf[org.apache.spark.sql.hive.orc.OrcFileFormat])).foreach { case (i, format) => - - withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> i) { - withTable("spark_20728") { - sql("CREATE TABLE spark_20728(a INT) USING ORC") - val fileFormat = sql("SELECT * FROM spark_20728").queryExecution.analyzed.collectFirst { - case l: LogicalRelation => l.relation.asInstanceOf[HadoopFsRelation].fileFormat.getClass - } - assert(fileFormat == Some(format)) + withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> "hive") { + val e = intercept[AnalysisException] { + sql("CREATE TABLE spark_20728(a INT) USING ORC") + } + assert(e.message.contains("Hive built-in ORC data source must be used with Hive support")) + } + + withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> "native") { + withTable("spark_20728") { + sql("CREATE TABLE spark_20728(a INT) USING ORC") + val fileFormat = sql("SELECT * FROM spark_20728").queryExecution.analyzed.collectFirst { + case l: LogicalRelation => l.relation.asInstanceOf[HadoopFsRelation].fileFormat.getClass } + assert(fileFormat == Some(classOf[OrcFileFormat])) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala similarity index 63% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 2a086be57f517..6f5f2fd795f74 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.orc +package org.apache.spark.sql.execution.datasources.orc import java.io.File import java.util.Locale @@ -23,50 +23,30 @@ import java.util.Locale import org.apache.orc.OrcConf.COMPRESS import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.execution.datasources.orc.OrcOptions -import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.Row import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils case class OrcData(intField: Int, stringField: String) -abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { - import spark._ +abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { + import testImplicits._ var orcTableDir: File = null var orcTableAsDir: File = null - override def beforeAll(): Unit = { + protected override def beforeAll(): Unit = { super.beforeAll() orcTableAsDir = Utils.createTempDir("orctests", "sparksql") - - // Hack: to prepare orc data files using hive external tables orcTableDir = Utils.createTempDir("orctests", "sparksql") - import org.apache.spark.sql.hive.test.TestHive.implicits._ sparkContext .makeRDD(1 to 10) .map(i => OrcData(i, s"part-$i")) .toDF() - .createOrReplaceTempView(s"orc_temp_table") - - sql( - s"""CREATE EXTERNAL TABLE normal_orc( - | intField INT, - | stringField STRING - |) - |STORED AS ORC - |LOCATION '${orcTableAsDir.toURI}' - """.stripMargin) - - sql( - s"""INSERT INTO TABLE normal_orc - |SELECT intField, stringField FROM orc_temp_table - """.stripMargin) + .createOrReplaceTempView("orc_temp_table") } test("create temporary orc table") { @@ -152,56 +132,13 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA } test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { - val conf = sqlContext.sessionState.conf + val conf = spark.sessionState.conf val option = new OrcOptions(Map(COMPRESS.getAttribute.toUpperCase(Locale.ROOT) -> "NONE"), conf) assert(option.compressionCodec == "NONE") } - test("SPARK-19459/SPARK-18220: read char/varchar column written by Hive") { - val location = Utils.createTempDir() - val uri = location.toURI - try { - hiveClient.runSqlHive("USE default") - hiveClient.runSqlHive( - """ - |CREATE EXTERNAL TABLE hive_orc( - | a STRING, - | b CHAR(10), - | c VARCHAR(10), - | d ARRAY) - |STORED AS orc""".stripMargin) - // Hive throws an exception if I assign the location in the create table statement. - hiveClient.runSqlHive( - s"ALTER TABLE hive_orc SET LOCATION '$uri'") - hiveClient.runSqlHive( - """ - |INSERT INTO TABLE hive_orc - |SELECT 'a', 'b', 'c', ARRAY(CAST('d' AS CHAR(3))) - |FROM (SELECT 1) t""".stripMargin) - - // We create a different table in Spark using the same schema which points to - // the same location. - spark.sql( - s""" - |CREATE EXTERNAL TABLE spark_orc( - | a STRING, - | b CHAR(10), - | c VARCHAR(10), - | d ARRAY) - |STORED AS orc - |LOCATION '$uri'""".stripMargin) - val result = Row("a", "b ", "c", Seq("d ")) - checkAnswer(spark.table("hive_orc"), result) - checkAnswer(spark.table("spark_orc"), result) - } finally { - hiveClient.runSqlHive("DROP TABLE IF EXISTS hive_orc") - hiveClient.runSqlHive("DROP TABLE IF EXISTS spark_orc") - Utils.deleteRecursively(location) - } - } - test("SPARK-21839: Add SQL config for ORC compression") { - val conf = sqlContext.sessionState.conf + val conf = spark.sessionState.conf // Test if the default of spark.sql.orc.compression.codec is snappy assert(new OrcOptions(Map.empty[String, String], conf).compressionCodec == "SNAPPY") @@ -225,13 +162,28 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA } } -class OrcSourceSuite extends OrcSuite { - override def beforeAll(): Unit = { +class OrcSourceSuite extends OrcSuite with SharedSQLContext { + + protected override def beforeAll(): Unit = { super.beforeAll() + sql( + s"""CREATE TABLE normal_orc( + | intField INT, + | stringField STRING + |) + |USING ORC + |LOCATION '${orcTableAsDir.toURI}' + """.stripMargin) + + sql( + s"""INSERT INTO TABLE normal_orc + |SELECT intField, stringField FROM orc_temp_table + """.stripMargin) + spark.sql( s"""CREATE TEMPORARY VIEW normal_orc_source - |USING org.apache.spark.sql.hive.orc + |USING ORC |OPTIONS ( | PATH '${new File(orcTableAsDir.getAbsolutePath).toURI}' |) @@ -239,43 +191,10 @@ class OrcSourceSuite extends OrcSuite { spark.sql( s"""CREATE TEMPORARY VIEW normal_orc_as_source - |USING org.apache.spark.sql.hive.orc + |USING ORC |OPTIONS ( | PATH '${new File(orcTableAsDir.getAbsolutePath).toURI}' |) """.stripMargin) } - - test("SPARK-12218 Converting conjunctions into ORC SearchArguments") { - // The `LessThan` should be converted while the `StringContains` shouldn't - val schema = new StructType( - Array( - StructField("a", IntegerType, nullable = true), - StructField("b", StringType, nullable = true))) - assertResult( - """leaf-0 = (LESS_THAN a 10) - |expr = leaf-0 - """.stripMargin.trim - ) { - OrcFilters.createFilter(schema, Array( - LessThan("a", 10), - StringContains("b", "prefix") - )).get.toString - } - - // The `LessThan` should be converted while the whole inner `And` shouldn't - assertResult( - """leaf-0 = (LESS_THAN a 10) - |expr = leaf-0 - """.stripMargin.trim - ) { - OrcFilters.createFilter(schema, Array( - LessThan("a", 10), - Not(And( - GreaterThan("a", 1), - StringContains("b", "prefix") - )) - )).get.toString - } - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala similarity index 72% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala index a2f08c5ba72c6..d94cb850ed2a2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala @@ -15,20 +15,51 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.orc +package org.apache.spark.sql.execution.datasources.orc import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import org.scalatest.BeforeAndAfterAll + import org.apache.spark.sql._ -import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf.ORC_IMPLEMENTATION import org.apache.spark.sql.test.SQLTestUtils -private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton { +/** + * OrcTest + * -> OrcSuite + * -> OrcSourceSuite + * -> HiveOrcSourceSuite + * -> OrcQueryTests + * -> OrcQuerySuite + * -> HiveOrcQuerySuite + * -> OrcPartitionDiscoveryTest + * -> OrcPartitionDiscoverySuite + * -> HiveOrcPartitionDiscoverySuite + * -> OrcFilterSuite + * -> HiveOrcFilterSuite + */ +abstract class OrcTest extends QueryTest with SQLTestUtils with BeforeAndAfterAll { import testImplicits._ + val orcImp: String = "native" + + private var originalConfORCImplementation = "native" + + protected override def beforeAll(): Unit = { + super.beforeAll() + originalConfORCImplementation = conf.getConf(ORC_IMPLEMENTATION) + conf.setConf(ORC_IMPLEMENTATION, orcImp) + } + + protected override def afterAll(): Unit = { + conf.setConf(ORC_IMPLEMENTATION, originalConfORCImplementation) + super.afterAll() + } + /** * Writes `data` to a Orc file, which is then passed to `f` and will be deleted after `f` * returns. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 0ae4f2d117609..c9c6bee513b53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -751,11 +751,6 @@ class ColumnarBatchSuite extends SparkFunSuite { c2.putDouble(1, 5.67) val s = column.getStruct(0) - assert(s.columns()(0).getInt(0) == 123) - assert(s.columns()(0).getInt(1) == 456) - assert(s.columns()(1).getDouble(0) == 3.45) - assert(s.columns()(1).getDouble(1) == 5.67) - assert(s.getInt(0) == 123) assert(s.getDouble(1) == 3.45) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala index 933f4075bcc8a..752d3c193cc74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala @@ -37,4 +37,35 @@ class DataSourceV2OptionsSuite extends SparkFunSuite { val options = new DataSourceV2Options(Map("foo" -> "bAr").asJava) assert(options.get("foo").get == "bAr") } + + test("getInt") { + val options = new DataSourceV2Options(Map("numFOo" -> "1", "foo" -> "bar").asJava) + assert(options.getInt("numFOO", 10) == 1) + assert(options.getInt("numFOO2", 10) == 10) + + intercept[NumberFormatException]{ + options.getInt("foo", 1) + } + } + + test("getBoolean") { + val options = new DataSourceV2Options( + Map("isFoo" -> "true", "isFOO2" -> "false", "foo" -> "bar").asJava) + assert(options.getBoolean("isFoo", false)) + assert(!options.getBoolean("isFoo2", true)) + assert(options.getBoolean("isBar", true)) + assert(!options.getBoolean("isBar", false)) + assert(!options.getBoolean("FOO", true)) + } + + test("getLong") { + val options = new DataSourceV2Options(Map("numFoo" -> "9223372036854775807", + "foo" -> "bar").asJava) + assert(options.getLong("numFOO", 0L) == 9223372036854775807L) + assert(options.getLong("numFoo2", -1L) == -1L) + + intercept[NumberFormatException]{ + options.getLong("foo", 0L) + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala index d786a610f1535..3328400b214fb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala @@ -412,7 +412,9 @@ case class HiveScriptIOSchema ( propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) val properties = new Properties() - properties.putAll(propsMap.asJava) + // Can not use properties.putAll(propsMap.asJava) in scala-2.12 + // See https://github.com/scala/bug/issues/10418 + propsMap.foreach { case (k, v) => properties.put(k, v) } serde.initialize(null, properties) serde @@ -424,7 +426,9 @@ case class HiveScriptIOSchema ( recordReaderClass.map { klass => val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordReader] val props = new Properties() - props.putAll(outputSerdeProps.toMap.asJava) + // Can not use props.putAll(outputSerdeProps.toMap.asJava) in scala-2.12 + // See https://github.com/scala/bug/issues/10418 + outputSerdeProps.toMap.foreach { case (k, v) => props.put(k, v) } instance.initialize(inputStream, conf, props) instance } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcFilterSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcFilterSuite.scala new file mode 100644 index 0000000000000..283037caf4a9b --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcFilterSuite.scala @@ -0,0 +1,387 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument} + +import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.orc.OrcTest +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.types._ + +/** + * A test suite that tests Hive ORC filter API based filter pushdown optimization. + */ +class HiveOrcFilterSuite extends OrcTest with TestHiveSingleton { + + override val orcImp: String = "hive" + + private def checkFilterPredicate( + df: DataFrame, + predicate: Predicate, + checker: (SearchArgument) => Unit): Unit = { + val output = predicate.collect { case a: Attribute => a }.distinct + val query = df + .select(output.map(e => Column(e)): _*) + .where(Column(predicate)) + + var maybeRelation: Option[HadoopFsRelation] = None + val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _, _)) => + maybeRelation = Some(orcRelation) + filters + }.flatten.reduceLeftOption(_ && _) + assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") + + val (_, selectedFilters, _) = + DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) + assert(selectedFilters.nonEmpty, "No filter is pushed down") + + val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters.toArray) + assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $selectedFilters") + checker(maybeFilter.get) + } + + private def checkFilterPredicate + (predicate: Predicate, filterOperator: PredicateLeaf.Operator) + (implicit df: DataFrame): Unit = { + def checkComparisonOperator(filter: SearchArgument) = { + val operator = filter.getLeaves.asScala + assert(operator.map(_.getOperator).contains(filterOperator)) + } + checkFilterPredicate(df, predicate, checkComparisonOperator) + } + + private def checkFilterPredicate + (predicate: Predicate, stringExpr: String) + (implicit df: DataFrame): Unit = { + def checkLogicalOperator(filter: SearchArgument) = { + assert(filter.toString == stringExpr) + } + checkFilterPredicate(df, predicate, checkLogicalOperator) + } + + private def checkNoFilterPredicate + (predicate: Predicate) + (implicit df: DataFrame): Unit = { + val output = predicate.collect { case a: Attribute => a }.distinct + val query = df + .select(output.map(e => Column(e)): _*) + .where(Column(predicate)) + + var maybeRelation: Option[HadoopFsRelation] = None + val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _, _)) => + maybeRelation = Some(orcRelation) + filters + }.flatten.reduceLeftOption(_ && _) + assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") + + val (_, selectedFilters, _) = + DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) + assert(selectedFilters.nonEmpty, "No filter is pushed down") + + val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters.toArray) + assert(maybeFilter.isEmpty, s"Could generate filter predicate for $selectedFilters") + } + + test("filter pushdown - integer") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - long") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - float") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - double") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - string") { + withOrcDataFrame((1 to 4).map(i => Tuple1(i.toString))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === "1", PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> "1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < "2", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > "3", PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= "1", PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= "4", PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal("1") === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal("1") <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal("2") > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal("3") < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal("1") >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal("4") <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - boolean") { + withOrcDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === true, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> true, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < true, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > false, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= false, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= false, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(false) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(false) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(false) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(true) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(true) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(true) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - decimal") { + withOrcDataFrame((1 to 4).map(i => Tuple1.apply(BigDecimal.valueOf(i)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === BigDecimal.valueOf(1), PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> BigDecimal.valueOf(1), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < BigDecimal.valueOf(2), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > BigDecimal.valueOf(3), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= BigDecimal.valueOf(1), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= BigDecimal.valueOf(4), PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate( + Literal(BigDecimal.valueOf(1)) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate( + Literal(BigDecimal.valueOf(1)) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate( + Literal(BigDecimal.valueOf(2)) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate( + Literal(BigDecimal.valueOf(3)) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate( + Literal(BigDecimal.valueOf(1)) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate( + Literal(BigDecimal.valueOf(4)) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - timestamp") { + val timeString = "2015-08-20 14:57:00" + val timestamps = (1 to 4).map { i => + val milliseconds = Timestamp.valueOf(timeString).getTime + i * 3600 + new Timestamp(milliseconds) + } + withOrcDataFrame(timestamps.map(Tuple1(_))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === timestamps(0), PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> timestamps(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < timestamps(1), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > timestamps(2), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= timestamps(0), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= timestamps(3), PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(timestamps(0)) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(timestamps(0)) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(timestamps(1)) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(timestamps(2)) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(timestamps(0)) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(timestamps(3)) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - combinations with logical operators") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => + // Because `ExpressionTree` is not accessible at Hive 1.2.x, this should be checked + // in string form in order to check filter creation including logical operators + // such as `and`, `or` or `not`. So, this function uses `SearchArgument.toString()` + // to produce string expression and then compare it to given string expression below. + // This might have to be changed after Hive version is upgraded. + checkFilterPredicate( + '_1.isNotNull, + """leaf-0 = (IS_NULL _1) + |expr = (not leaf-0)""".stripMargin.trim + ) + checkFilterPredicate( + '_1 =!= 1, + """leaf-0 = (IS_NULL _1) + |leaf-1 = (EQUALS _1 1) + |expr = (and (not leaf-0) (not leaf-1))""".stripMargin.trim + ) + checkFilterPredicate( + !('_1 < 4), + """leaf-0 = (IS_NULL _1) + |leaf-1 = (LESS_THAN _1 4) + |expr = (and (not leaf-0) (not leaf-1))""".stripMargin.trim + ) + checkFilterPredicate( + '_1 < 2 || '_1 > 3, + """leaf-0 = (LESS_THAN _1 2) + |leaf-1 = (LESS_THAN_EQUALS _1 3) + |expr = (or leaf-0 (not leaf-1))""".stripMargin.trim + ) + checkFilterPredicate( + '_1 < 2 && '_1 > 3, + """leaf-0 = (IS_NULL _1) + |leaf-1 = (LESS_THAN _1 2) + |leaf-2 = (LESS_THAN_EQUALS _1 3) + |expr = (and (not leaf-0) leaf-1 (not leaf-2))""".stripMargin.trim + ) + } + } + + test("no filter pushdown - non-supported types") { + implicit class IntToBinary(int: Int) { + def b: Array[Byte] = int.toString.getBytes(StandardCharsets.UTF_8) + } + // ArrayType + withOrcDataFrame((1 to 4).map(i => Tuple1(Array(i)))) { implicit df => + checkNoFilterPredicate('_1.isNull) + } + // BinaryType + withOrcDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df => + checkNoFilterPredicate('_1 <=> 1.b) + } + // DateType + val stringDate = "2015-01-01" + withOrcDataFrame(Seq(Tuple1(Date.valueOf(stringDate)))) { implicit df => + checkNoFilterPredicate('_1 === Date.valueOf(stringDate)) + } + // MapType + withOrcDataFrame((1 to 4).map(i => Tuple1(Map(i -> i)))) { implicit df => + checkNoFilterPredicate('_1.isNotNull) + } + } + + test("SPARK-12218 Converting conjunctions into ORC SearchArguments") { + import org.apache.spark.sql.sources._ + // The `LessThan` should be converted while the `StringContains` shouldn't + val schema = new StructType( + Array( + StructField("a", IntegerType, nullable = true), + StructField("b", StringType, nullable = true))) + assertResult( + """leaf-0 = (LESS_THAN a 10) + |expr = leaf-0 + """.stripMargin.trim + ) { + OrcFilters.createFilter(schema, Array( + LessThan("a", 10), + StringContains("b", "prefix") + )).get.toString + } + + // The `LessThan` should be converted while the whole inner `And` shouldn't + assertResult( + """leaf-0 = (LESS_THAN a 10) + |expr = leaf-0 + """.stripMargin.trim + ) { + OrcFilters.createFilter(schema, Array( + LessThan("a", 10), + Not(And( + GreaterThan("a", 1), + StringContains("b", "prefix") + )) + )).get.toString + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcPartitionDiscoverySuite.scala new file mode 100644 index 0000000000000..ab9b492f347c6 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcPartitionDiscoverySuite.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import org.apache.spark.sql.execution.datasources.orc.OrcPartitionDiscoveryTest +import org.apache.spark.sql.hive.test.TestHiveSingleton + +class HiveOrcPartitionDiscoverySuite extends OrcPartitionDiscoveryTest with TestHiveSingleton { + override val orcImp: String = "hive" +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala new file mode 100644 index 0000000000000..7244c369bd3f4 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.orc.OrcQueryTest +import org.apache.spark.sql.hive.HiveUtils +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf + +class HiveOrcQuerySuite extends OrcQueryTest with TestHiveSingleton { + import testImplicits._ + + override val orcImp: String = "hive" + + test("SPARK-8501: Avoids discovery schema from empty ORC files") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withTable("empty_orc") { + withTempView("empty", "single") { + spark.sql( + s"""CREATE TABLE empty_orc(key INT, value STRING) + |STORED AS ORC + |LOCATION '${dir.toURI}' + """.stripMargin) + + val emptyDF = Seq.empty[(Int, String)].toDF("key", "value").coalesce(1) + emptyDF.createOrReplaceTempView("empty") + + // This creates 1 empty ORC file with Hive ORC SerDe. We are using this trick because + // Spark SQL ORC data source always avoids write empty ORC files. + spark.sql( + s"""INSERT INTO TABLE empty_orc + |SELECT key, value FROM empty + """.stripMargin) + + val errorMessage = intercept[AnalysisException] { + spark.read.orc(path) + }.getMessage + + assert(errorMessage.contains("Unable to infer schema for ORC")) + + val singleRowDF = Seq((0, "foo")).toDF("key", "value").coalesce(1) + singleRowDF.createOrReplaceTempView("single") + + spark.sql( + s"""INSERT INTO TABLE empty_orc + |SELECT key, value FROM single + """.stripMargin) + + val df = spark.read.orc(path) + assert(df.schema === singleRowDF.schema.asNullable) + checkAnswer(df, singleRowDF) + } + } + } + } + + test("Verify the ORC conversion parameter: CONVERT_METASTORE_ORC") { + withTempView("single") { + val singleRowDF = Seq((0, "foo")).toDF("key", "value") + singleRowDF.createOrReplaceTempView("single") + + Seq("true", "false").foreach { orcConversion => + withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> orcConversion) { + withTable("dummy_orc") { + withTempPath { dir => + val path = dir.getCanonicalPath + spark.sql( + s""" + |CREATE TABLE dummy_orc(key INT, value STRING) + |STORED AS ORC + |LOCATION '${dir.toURI}' + """.stripMargin) + + spark.sql( + s""" + |INSERT INTO TABLE dummy_orc + |SELECT key, value FROM single + """.stripMargin) + + val df = spark.sql("SELECT * FROM dummy_orc WHERE key=0") + checkAnswer(df, singleRowDF) + + val queryExecution = df.queryExecution + if (orcConversion == "true") { + queryExecution.analyzed.collectFirst { + case _: LogicalRelation => () + }.getOrElse { + fail(s"Expecting the query plan to convert orc to data sources, " + + s"but got:\n$queryExecution") + } + } else { + queryExecution.analyzed.collectFirst { + case _: HiveTableRelation => () + }.getOrElse { + fail(s"Expecting no conversion from orc to data sources, " + + s"but got:\n$queryExecution") + } + } + } + } + } + } + } + } + + test("converted ORC table supports resolving mixed case field") { + withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "true") { + withTable("dummy_orc") { + withTempPath { dir => + val df = spark.range(5).selectExpr("id", "id as valueField", "id as partitionValue") + df.write + .partitionBy("partitionValue") + .mode("overwrite") + .orc(dir.getAbsolutePath) + + spark.sql(s""" + |create external table dummy_orc (id long, valueField long) + |partitioned by (partitionValue int) + |stored as orc + |location "${dir.toURI}"""".stripMargin) + spark.sql(s"msck repair table dummy_orc") + checkAnswer(spark.sql("select * from dummy_orc"), df) + } + } + } + } + + test("SPARK-20728 Make ORCFileFormat configurable between sql/hive and sql/core") { + Seq( + ("native", classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat]), + ("hive", classOf[org.apache.spark.sql.hive.orc.OrcFileFormat])).foreach { + case (orcImpl, format) => + withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> orcImpl) { + withTable("spark_20728") { + sql("CREATE TABLE spark_20728(a INT) USING ORC") + val fileFormat = sql("SELECT * FROM spark_20728").queryExecution.analyzed.collectFirst { + case l: LogicalRelation => + l.relation.asInstanceOf[HadoopFsRelation].fileFormat.getClass + } + assert(fileFormat == Some(format)) + } + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala new file mode 100644 index 0000000000000..17b7d8cfe127e --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File + +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.datasources.orc.OrcSuite +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.util.Utils + +class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { + + override val orcImp: String = "hive" + + override def beforeAll(): Unit = { + super.beforeAll() + + sql( + s"""CREATE EXTERNAL TABLE normal_orc( + | intField INT, + | stringField STRING + |) + |STORED AS ORC + |LOCATION '${orcTableAsDir.toURI}' + """.stripMargin) + + sql( + s"""INSERT INTO TABLE normal_orc + |SELECT intField, stringField FROM orc_temp_table + """.stripMargin) + + spark.sql( + s"""CREATE TEMPORARY VIEW normal_orc_source + |USING org.apache.spark.sql.hive.orc + |OPTIONS ( + | PATH '${new File(orcTableAsDir.getAbsolutePath).toURI}' + |) + """.stripMargin) + + spark.sql( + s"""CREATE TEMPORARY VIEW normal_orc_as_source + |USING org.apache.spark.sql.hive.orc + |OPTIONS ( + | PATH '${new File(orcTableAsDir.getAbsolutePath).toURI}' + |) + """.stripMargin) + } + + test("SPARK-19459/SPARK-18220: read char/varchar column written by Hive") { + val location = Utils.createTempDir() + val uri = location.toURI + try { + hiveClient.runSqlHive("USE default") + hiveClient.runSqlHive( + """ + |CREATE EXTERNAL TABLE hive_orc( + | a STRING, + | b CHAR(10), + | c VARCHAR(10), + | d ARRAY) + |STORED AS orc""".stripMargin) + // Hive throws an exception if I assign the location in the create table statement. + hiveClient.runSqlHive( + s"ALTER TABLE hive_orc SET LOCATION '$uri'") + hiveClient.runSqlHive( + """ + |INSERT INTO TABLE hive_orc + |SELECT 'a', 'b', 'c', ARRAY(CAST('d' AS CHAR(3))) + |FROM (SELECT 1) t""".stripMargin) + + // We create a different table in Spark using the same schema which points to + // the same location. + spark.sql( + s""" + |CREATE EXTERNAL TABLE spark_orc( + | a STRING, + | b CHAR(10), + | c VARCHAR(10), + | d ARRAY) + |STORED AS orc + |LOCATION '$uri'""".stripMargin) + val result = Row("a", "b ", "c", Seq("d ")) + checkAnswer(spark.table("hive_orc"), result) + checkAnswer(spark.table("spark_orc"), result) + } finally { + hiveClient.runSqlHive("DROP TABLE IF EXISTS hive_orc") + hiveClient.runSqlHive("DROP TABLE IF EXISTS spark_orc") + Utils.deleteRecursively(location) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index ba0a7605da71c..f87162f94c01a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -30,7 +30,8 @@ import org.apache.spark.sql.types._ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { import testImplicits._ - override val dataSourceName: String = classOf[OrcFileFormat].getCanonicalName + override val dataSourceName: String = + classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName // ORC does not play well with NullType and UDT. override protected def supportsDataType(dataType: DataType): Boolean = dataType match { @@ -116,3 +117,8 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { } } } + +class HiveOrcHadoopFsRelationSuite extends OrcHadoopFsRelationSuite { + override val dataSourceName: String = + classOf[org.apache.spark.sql.hive.orc.OrcFileFormat].getCanonicalName +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala index b2ec33e82ddaa..b22bbb79a5cc9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala @@ -98,7 +98,7 @@ class RawNetworkReceiver(host: String, port: Int, storageLevel: StorageLevel) /** Read a buffer fully from a given Channel */ private def readFully(channel: ReadableByteChannel, dest: ByteBuffer) { - while (dest.position < dest.limit) { + while (dest.position() < dest.limit()) { if (channel.read(dest) == -1) { throw new EOFException("End of channel") }