diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 4fadfe36cd716..7fdcf22c45f73 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -879,6 +879,7 @@ public LongArray getArray() {
* Reset this map to initialized state.
*/
public void reset() {
+ updatePeakMemoryUsed();
numKeys = 0;
numValues = 0;
freeArray(longArray);
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 67e993c7f02e2..7aecd3c9668ea 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -99,7 +99,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
private def calcChecksum(block: ByteBuffer): Int = {
val adler = new Adler32()
if (block.hasArray) {
- adler.update(block.array, block.arrayOffset + block.position, block.limit - block.position)
+ adler.update(block.array, block.arrayOffset + block.position(), block.limit()
+ - block.position())
} else {
val bytes = new Array[Byte](block.remaining())
block.duplicate.get(bytes)
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 4c1f92a1bcbf2..9b62e4b1b7150 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -165,11 +165,7 @@ private[spark] class CoarseGrainedExecutorBackend(
}
if (notifyDriver && driver.nonEmpty) {
- driver.get.ask[Boolean](
- RemoveExecutor(executorId, new ExecutorLossReason(reason))
- ).failed.foreach(e =>
- logWarning(s"Unable to notify the driver due to " + e.getMessage, e)
- )(ThreadUtils.sameThread)
+ driver.get.send(RemoveExecutor(executorId, new ExecutorLossReason(reason)))
}
System.exit(code)
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index e3e555eaa0277..af0a0ab656564 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -452,7 +452,7 @@ private[spark] class Executor(
// TODO: do not serialize value twice
val directResult = new DirectTaskResult(valueBytes, accumUpdates)
val serializedDirectResult = ser.serialize(directResult)
- val resultSize = serializedDirectResult.limit
+ val resultSize = serializedDirectResult.limit()
// directSend = sending directly back to the driver
val serializedResult: ByteBuffer = {
diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala
index 8f4c1b60920db..b0cd7110a3b47 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala
@@ -235,7 +235,9 @@ private[spark] case class ConfigBuilder(key: String) {
}
def fallbackConf[T](fallback: ConfigEntry[T]): ConfigEntry[T] = {
- new FallbackConfigEntry(key, _alternatives, _doc, _public, fallback)
+ val entry = new FallbackConfigEntry(key, _alternatives, _doc, _public, fallback)
+ _onCreate.foreach(_(entry))
+ entry
}
def regexConf: TypedConfigBuilder[Regex] = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 7bfb4d53c1834..4d75063fbf1c5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -95,6 +95,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
// The num of current max ExecutorId used to re-register appMaster
@volatile protected var currentExecutorIdCounter = 0
+ private val reviveThread =
+ ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-revive-thread")
+
class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)])
extends ThreadSafeRpcEndpoint with Logging {
@@ -103,9 +106,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
protected val addressToExecutorId = new HashMap[RpcAddress, String]
- private val reviveThread =
- ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-revive-thread")
-
override def onStart() {
// Periodically revive offers to allow delay scheduling to work
val reviveIntervalMs = conf.getTimeAsMs("spark.scheduler.revive.interval", "1s")
@@ -154,6 +154,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
executorDataMap.values.foreach { ed =>
ed.executorEndpoint.send(UpdateDelegationTokens(newDelegationTokens))
}
+
+ case RemoveExecutor(executorId, reason) =>
+ // We will remove the executor's state and cannot restore it. However, the connection
+ // between the driver and the executor may be still alive so that the executor won't exit
+ // automatically, so try to tell the executor to stop itself. See SPARK-13519.
+ executorDataMap.get(executorId).foreach(_.executorEndpoint.send(StopExecutor))
+ removeExecutor(executorId, reason)
}
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
@@ -215,14 +222,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
context.reply(true)
- case RemoveExecutor(executorId, reason) =>
- // We will remove the executor's state and cannot restore it. However, the connection
- // between the driver and the executor may be still alive so that the executor won't exit
- // automatically, so try to tell the executor to stop itself. See SPARK-13519.
- executorDataMap.get(executorId).foreach(_.executorEndpoint.send(StopExecutor))
- removeExecutor(executorId, reason)
- context.reply(true)
-
case RemoveWorker(workerId, host, message) =>
removeWorker(workerId, host, message)
context.reply(true)
@@ -288,13 +287,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
private def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
for (task <- tasks.flatten) {
val serializedTask = TaskDescription.encode(task)
- if (serializedTask.limit >= maxRpcMessageSize) {
+ if (serializedTask.limit() >= maxRpcMessageSize) {
scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr =>
try {
var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
"spark.rpc.message.maxSize (%d bytes). Consider increasing " +
"spark.rpc.message.maxSize or using broadcast variables for large values."
- msg = msg.format(task.taskId, task.index, serializedTask.limit, maxRpcMessageSize)
+ msg = msg.format(task.taskId, task.index, serializedTask.limit(), maxRpcMessageSize)
taskSetMgr.abort(msg)
} catch {
case e: Exception => logError("Exception in error callback", e)
@@ -373,10 +372,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
shouldDisable
}
-
- override def onStop() {
- reviveThread.shutdownNow()
- }
}
var driverEndpoint: RpcEndpointRef = null
@@ -417,6 +412,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
override def stop() {
+ reviveThread.shutdownNow()
stopExecutors()
try {
if (driverEndpoint != null) {
@@ -465,9 +461,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
* at once.
*/
protected def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = {
- // Only log the failure since we don't care about the result.
- driverEndpoint.ask[Boolean](RemoveExecutor(executorId, reason)).failed.foreach(t =>
- logError(t.getMessage, t))(ThreadUtils.sameThread)
+ driverEndpoint.send(RemoveExecutor(executorId, reason))
}
protected def removeWorker(workerId: String, host: String, message: String): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
index 56d0266b8edad..89a6a71a589a1 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
@@ -17,6 +17,7 @@
package org.apache.spark.storage
+import java.io.IOException
import java.util.{HashMap => JHashMap}
import scala.collection.JavaConverters._
@@ -159,11 +160,16 @@ class BlockManagerMasterEndpoint(
// Ask the slaves to remove the RDD, and put the result in a sequence of Futures.
// The dispatcher is used as an implicit argument into the Future sequence construction.
val removeMsg = RemoveRdd(rddId)
- Future.sequence(
- blockManagerInfo.values.map { bm =>
- bm.slaveEndpoint.ask[Int](removeMsg)
- }.toSeq
- )
+
+ val futures = blockManagerInfo.values.map { bm =>
+ bm.slaveEndpoint.ask[Int](removeMsg).recover {
+ case e: IOException =>
+ logWarning(s"Error trying to remove RDD $rddId", e)
+ 0 // zero blocks were removed
+ }
+ }.toSeq
+
+ Future.sequence(futures)
}
private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = {
diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala
index a938cb07724c7..a5ee0ff16b5df 100644
--- a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala
+++ b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala
@@ -54,7 +54,7 @@ class ByteBufferInputStream(private var buffer: ByteBuffer)
override def skip(bytes: Long): Long = {
if (buffer != null) {
val amountToSkip = math.min(bytes, buffer.remaining).toInt
- buffer.position(buffer.position + amountToSkip)
+ buffer.position(buffer.position() + amountToSkip)
if (buffer.remaining() == 0) {
cleanUp()
}
diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
index c28570fb24560..7367af7888bd8 100644
--- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
@@ -65,7 +65,7 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
for (bytes <- getChunks()) {
while (bytes.remaining() > 0) {
val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize)
- bytes.limit(bytes.position + ioSize)
+ bytes.limit(bytes.position() + ioSize)
channel.write(bytes)
}
}
@@ -206,7 +206,7 @@ private[spark] class ChunkedByteBufferInputStream(
override def skip(bytes: Long): Long = {
if (currentChunk != null) {
val amountToSkip = math.min(bytes, currentChunk.remaining).toInt
- currentChunk.position(currentChunk.position + amountToSkip)
+ currentChunk.position(currentChunk.position() + amountToSkip)
if (currentChunk.remaining() == 0) {
if (chunks.hasNext) {
currentChunk = chunks.next()
diff --git a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala
index bf08276dbf971..02514dc7daef4 100644
--- a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala
+++ b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala
@@ -288,4 +288,24 @@ class ConfigEntrySuite extends SparkFunSuite {
conf.remove(testKey("b"))
assert(conf.get(iConf) === 3)
}
+
+ test("onCreate") {
+ var onCreateCalled = false
+ ConfigBuilder(testKey("oc1")).onCreate(_ => onCreateCalled = true).intConf.createWithDefault(1)
+ assert(onCreateCalled)
+
+ onCreateCalled = false
+ ConfigBuilder(testKey("oc2")).onCreate(_ => onCreateCalled = true).intConf.createOptional
+ assert(onCreateCalled)
+
+ onCreateCalled = false
+ ConfigBuilder(testKey("oc3")).onCreate(_ => onCreateCalled = true).intConf
+ .createWithDefaultString("1.0")
+ assert(onCreateCalled)
+
+ val fallback = ConfigBuilder(testKey("oc4")).intConf.createWithDefault(1)
+ onCreateCalled = false
+ ConfigBuilder(testKey("oc5")).onCreate(_ => onCreateCalled = true).fallbackConf(fallback)
+ assert(onCreateCalled)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
index eaec098b8d785..fc78655bf52ec 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
@@ -199,7 +199,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext {
def check[T: ClassTag](t: T) {
assert(ser.deserialize[T](ser.serialize(t)) === t)
// Check that very long ranges don't get written one element at a time
- assert(ser.serialize(t).limit < 100)
+ assert(ser.serialize(t).limit() < 100)
}
check(1 to 1000000)
check(1 to 1000000 by 2)
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala
index 7258fdf5efc0d..efdd02fff7871 100644
--- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala
@@ -118,7 +118,7 @@ class DiskStoreSuite extends SparkFunSuite {
val chunks = chunkedByteBuffer.chunks
assert(chunks.size === 2)
for (chunk <- chunks) {
- assert(chunk.limit === 10 * 1024)
+ assert(chunk.limit() === 10 * 1024)
}
val e = intercept[IllegalArgumentException]{
diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6
index 2c68b73095c4d..1831f3378e852 100644
--- a/dev/deps/spark-deps-hadoop-2.6
+++ b/dev/deps/spark-deps-hadoop-2.6
@@ -35,7 +35,7 @@ commons-beanutils-core-1.8.0.jar
commons-cli-1.2.jar
commons-codec-1.10.jar
commons-collections-3.2.2.jar
-commons-compiler-3.0.7.jar
+commons-compiler-3.0.8.jar
commons-compress-1.4.1.jar
commons-configuration-1.6.jar
commons-crypto-1.0.0.jar
@@ -96,7 +96,7 @@ jackson-mapper-asl-1.9.13.jar
jackson-module-paranamer-2.7.9.jar
jackson-module-scala_2.11-2.6.7.1.jar
jackson-xc-1.9.13.jar
-janino-3.0.7.jar
+janino-3.0.8.jar
java-xmlbuilder-1.1.jar
javassist-3.18.1-GA.jar
javax.annotation-api-1.2.jar
@@ -180,7 +180,7 @@ stax-api-1.0.1.jar
stream-2.7.0.jar
stringtemplate-3.2.1.jar
super-csv-2.2.0.jar
-univocity-parsers-2.5.4.jar
+univocity-parsers-2.5.9.jar
validation-api-1.1.0.Final.jar
xbean-asm5-shaded-4.4.jar
xercesImpl-2.9.1.jar
diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7
index 2aaac600b3ec3..fe14c05987327 100644
--- a/dev/deps/spark-deps-hadoop-2.7
+++ b/dev/deps/spark-deps-hadoop-2.7
@@ -35,7 +35,7 @@ commons-beanutils-core-1.8.0.jar
commons-cli-1.2.jar
commons-codec-1.10.jar
commons-collections-3.2.2.jar
-commons-compiler-3.0.7.jar
+commons-compiler-3.0.8.jar
commons-compress-1.4.1.jar
commons-configuration-1.6.jar
commons-crypto-1.0.0.jar
@@ -96,7 +96,7 @@ jackson-mapper-asl-1.9.13.jar
jackson-module-paranamer-2.7.9.jar
jackson-module-scala_2.11-2.6.7.1.jar
jackson-xc-1.9.13.jar
-janino-3.0.7.jar
+janino-3.0.8.jar
java-xmlbuilder-1.1.jar
javassist-3.18.1-GA.jar
javax.annotation-api-1.2.jar
@@ -181,7 +181,7 @@ stax-api-1.0.1.jar
stream-2.7.0.jar
stringtemplate-3.2.1.jar
super-csv-2.2.0.jar
-univocity-parsers-2.5.4.jar
+univocity-parsers-2.5.9.jar
validation-api-1.1.0.Final.jar
xbean-asm5-shaded-4.4.jar
xercesImpl-2.9.1.jar
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala
index 2df8352f48660..08db4d827e400 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala
@@ -296,7 +296,9 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L
props.put("replica.socket.timeout.ms", "1500")
props.put("delete.topic.enable", "true")
props.put("offsets.topic.num.partitions", "1")
- props.putAll(withBrokerProps.asJava)
+ // Can not use properties.putAll(propsMap.asJava) in scala-2.12
+ // See https://github.com/scala/bug/issues/10418
+ withBrokerProps.foreach { case (k, v) => props.put(k, v) }
props
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
index 4663f16b5f5dc..730ee9fc08db8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
@@ -23,7 +23,7 @@ import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.HasInputCols
+import org.apache.spark.ml.param.shared.{HasInputCols, HasOutputCols}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
@@ -32,7 +32,7 @@ import org.apache.spark.sql.types._
/**
* Params for [[Imputer]] and [[ImputerModel]].
*/
-private[feature] trait ImputerParams extends Params with HasInputCols {
+private[feature] trait ImputerParams extends Params with HasInputCols with HasOutputCols {
/**
* The imputation strategy. Currently only "mean" and "median" are supported.
@@ -63,16 +63,6 @@ private[feature] trait ImputerParams extends Params with HasInputCols {
/** @group getParam */
def getMissingValue: Double = $(missingValue)
- /**
- * Param for output column names.
- * @group param
- */
- final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols",
- "output column names")
-
- /** @group getParam */
- final def getOutputCols: Array[String] = $(outputCols)
-
/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
require($(inputCols).length == $(inputCols).distinct.length, s"inputCols contains" +
diff --git a/pom.xml b/pom.xml
index 07bca9d267da0..52db79eaf036b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -170,7 +170,7 @@
3.5
3.2.10
- 3.0.7
+ 3.0.8
2.22.2
2.9.3
3.5.2
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala
index 3febb2f47cfd4..13c09033a50ee 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala
@@ -282,7 +282,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn
// No more deletion attempts of the executors.
// This is graceful termination and should not be detected as a failure.
verify(podOperations, times(1)).delete(resolvedPod)
- verify(driverEndpointRef, times(1)).ask[Boolean](
+ verify(driverEndpointRef, times(1)).send(
RemoveExecutor("1", ExecutorExited(
0,
exitCausedByApp = false,
@@ -318,7 +318,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn
requestExecutorRunnable.getValue.run()
allocatorRunnable.getAllValues.asScala.last.run()
verify(podOperations, never()).delete(firstResolvedPod)
- verify(driverEndpointRef).ask[Boolean](
+ verify(driverEndpointRef).send(
RemoveExecutor("1", ExecutorExited(
1,
exitCausedByApp = true,
@@ -356,7 +356,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn
val recreatedResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD)
allocatorRunnable.getValue.run()
verify(podOperations).delete(firstResolvedPod)
- verify(driverEndpointRef).ask[Boolean](
+ verify(driverEndpointRef).send(
RemoveExecutor("1", SlaveLost("Executor lost for unknown reasons.")))
}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index 415a29fd887e8..bb615c36cd97f 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -20,6 +20,7 @@ package org.apache.spark.scheduler.cluster
import java.util.concurrent.atomic.{AtomicBoolean}
import scala.concurrent.{ExecutionContext, Future}
+import scala.concurrent.ExecutionContext.Implicits.global
import scala.util.{Failure, Success}
import scala.util.control.NonFatal
@@ -245,14 +246,7 @@ private[spark] abstract class YarnSchedulerBackend(
Future.successful(RemoveExecutor(executorId, SlaveLost("AM is not yet registered.")))
}
- removeExecutorMessage
- .flatMap { message =>
- driverEndpoint.ask[Boolean](message)
- }(ThreadUtils.sameThread)
- .onFailure {
- case NonFatal(e) => logError(
- s"Error requesting driver to remove executor $executorId after disconnection.", e)
- }(ThreadUtils.sameThread)
+ removeExecutorMessage.foreach { message => driverEndpoint.send(message) }
}
override def receive: PartialFunction[Any, Unit] = {
@@ -265,12 +259,10 @@ private[spark] abstract class YarnSchedulerBackend(
addWebUIFilter(filterName, filterParams, proxyBase)
case r @ RemoveExecutor(executorId, reason) =>
- logWarning(reason.toString)
- driverEndpoint.ask[Boolean](r).onFailure {
- case e =>
- logError("Error requesting driver to remove executor" +
- s" $executorId for reason $reason", e)
- }(ThreadUtils.sameThread)
+ if (!stopped.get) {
+ logWarning(s"Requesting driver to remove executor $executorId for reason $reason")
+ driverEndpoint.send(r)
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 179853032035e..4d26d9819321b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -982,35 +982,28 @@ case class ScalaUDF(
// scalastyle:on line.size.limit
- // Generate codes used to convert the arguments to Scala type for user-defined functions
- private[this] def genCodeForConverter(ctx: CodegenContext, index: Int): String = {
- val converterClassName = classOf[Any => Any].getName
- val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$"
- val expressionClassName = classOf[Expression].getName
- val scalaUDFClassName = classOf[ScalaUDF].getName
+ private val converterClassName = classOf[Any => Any].getName
+ private val scalaUDFClassName = classOf[ScalaUDF].getName
+ private val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$"
+ // Generate codes used to convert the arguments to Scala type for user-defined functions
+ private[this] def genCodeForConverter(ctx: CodegenContext, index: Int): (String, String) = {
val converterTerm = ctx.freshName("converter")
val expressionIdx = ctx.references.size - 1
- ctx.addMutableState(converterClassName, converterTerm,
- s"$converterTerm = ($converterClassName)$typeConvertersClassName" +
- s".createToScalaConverter(((${expressionClassName})((($scalaUDFClassName)" +
- s"references[$expressionIdx]).getChildren().apply($index))).dataType());")
- converterTerm
+ (converterTerm,
+ s"$converterClassName $converterTerm = ($converterClassName)$typeConvertersClassName" +
+ s".createToScalaConverter(((Expression)((($scalaUDFClassName)" +
+ s"references[$expressionIdx]).getChildren().apply($index))).dataType());")
}
override def doGenCode(
ctx: CodegenContext,
ev: ExprCode): ExprCode = {
+ val scalaUDF = ctx.freshName("scalaUDF")
+ val scalaUDFRef = ctx.addReferenceMinorObj(this, scalaUDFClassName)
- val scalaUDF = ctx.addReferenceObj("scalaUDF", this)
- val converterClassName = classOf[Any => Any].getName
- val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$"
-
- // Generate codes used to convert the returned value of user-defined functions to Catalyst type
+ // Object to convert the returned value of user-defined functions to Catalyst type
val catalystConverterTerm = ctx.freshName("catalystConverter")
- ctx.addMutableState(converterClassName, catalystConverterTerm,
- s"$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" +
- s".createToCatalystConverter($scalaUDF.dataType());")
val resultTerm = ctx.freshName("result")
@@ -1022,8 +1015,6 @@ case class ScalaUDF(
val funcClassName = s"scala.Function${children.size}"
val funcTerm = ctx.freshName("udf")
- ctx.addMutableState(funcClassName, funcTerm,
- s"$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();")
// codegen for children expressions
val evals = children.map(_.genCode(ctx))
@@ -1033,34 +1024,45 @@ case class ScalaUDF(
// such as IntegerType, its javaType is `int` and the returned type of user-defined
// function is Object. Trying to convert an Object to `int` will cause casting exception.
val evalCode = evals.map(_.code).mkString
- val (converters, funcArguments) = converterTerms.zipWithIndex.map { case (converter, i) =>
- val eval = evals(i)
- val argTerm = ctx.freshName("arg")
- val convert = s"Object $argTerm = ${eval.isNull} ? null : $converter.apply(${eval.value});"
- (convert, argTerm)
+ val (converters, funcArguments) = converterTerms.zipWithIndex.map {
+ case ((convName, convInit), i) =>
+ val eval = evals(i)
+ val argTerm = ctx.freshName("arg")
+ val convert =
+ s"""
+ |$convInit
+ |Object $argTerm = ${eval.isNull} ? null : $convName.apply(${eval.value});
+ """.stripMargin
+ (convert, argTerm)
}.unzip
val getFuncResult = s"$funcTerm.apply(${funcArguments.mkString(", ")})"
val callFunc =
s"""
- ${ctx.boxedType(dataType)} $resultTerm = null;
- try {
- $resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult);
- } catch (Exception e) {
- throw new org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e);
- }
- """
+ |${ctx.boxedType(dataType)} $resultTerm = null;
+ |$scalaUDFClassName $scalaUDF = $scalaUDFRef;
+ |try {
+ | $funcClassName $funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();
+ | $converterClassName $catalystConverterTerm = ($converterClassName)
+ | $typeConvertersClassName.createToCatalystConverter($scalaUDF.dataType());
+ | $resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult);
+ |} catch (Exception e) {
+ | throw new org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e);
+ |}
+ """.stripMargin
- ev.copy(code = s"""
- $evalCode
- ${converters.mkString("\n")}
- $callFunc
-
- boolean ${ev.isNull} = $resultTerm == null;
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- if (!${ev.isNull}) {
- ${ev.value} = $resultTerm;
- }""")
+ ev.copy(code =
+ s"""
+ |$evalCode
+ |${converters.mkString("\n")}
+ |$callFunc
+ |
+ |boolean ${ev.isNull} = $resultTerm == null;
+ |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ |if (!${ev.isNull}) {
+ | ${ev.value} = $resultTerm;
+ |}
+ """.stripMargin)
}
private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 739bd13c5078d..1893eec22b65d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -602,23 +602,38 @@ case class Least(children: Seq[Expression]) extends Expression {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evalChildren = children.map(_.genCode(ctx))
- ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
- ctx.addMutableState(ctx.javaType(dataType), ev.value)
- def updateEval(eval: ExprCode): String = {
+ val tmpIsNull = ctx.freshName("leastTmpIsNull")
+ ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull)
+ val evals = evalChildren.map(eval =>
s"""
- ${eval.code}
- if (!${eval.isNull} && (${ev.isNull} ||
- ${ctx.genGreater(dataType, ev.value, eval.value)})) {
- ${ev.isNull} = false;
- ${ev.value} = ${eval.value};
- }
- """
- }
- val codes = ctx.splitExpressionsWithCurrentInputs(evalChildren.map(updateEval))
- ev.copy(code = s"""
- ${ev.isNull} = true;
- ${ev.value} = ${ctx.defaultValue(dataType)};
- $codes""")
+ |${eval.code}
+ |if (!${eval.isNull} && ($tmpIsNull ||
+ | ${ctx.genGreater(dataType, ev.value, eval.value)})) {
+ | $tmpIsNull = false;
+ | ${ev.value} = ${eval.value};
+ |}
+ """.stripMargin
+ )
+
+ val resultType = ctx.javaType(dataType)
+ val codes = ctx.splitExpressionsWithCurrentInputs(
+ expressions = evals,
+ funcName = "least",
+ extraArguments = Seq(resultType -> ev.value),
+ returnType = resultType,
+ makeSplitFunction = body =>
+ s"""
+ |$body
+ |return ${ev.value};
+ """.stripMargin,
+ foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
+ ev.copy(code =
+ s"""
+ |$tmpIsNull = true;
+ |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ |$codes
+ |final boolean ${ev.isNull} = $tmpIsNull;
+ """.stripMargin)
}
}
@@ -668,22 +683,37 @@ case class Greatest(children: Seq[Expression]) extends Expression {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evalChildren = children.map(_.genCode(ctx))
- ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
- ctx.addMutableState(ctx.javaType(dataType), ev.value)
- def updateEval(eval: ExprCode): String = {
+ val tmpIsNull = ctx.freshName("greatestTmpIsNull")
+ ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull)
+ val evals = evalChildren.map(eval =>
s"""
- ${eval.code}
- if (!${eval.isNull} && (${ev.isNull} ||
- ${ctx.genGreater(dataType, eval.value, ev.value)})) {
- ${ev.isNull} = false;
- ${ev.value} = ${eval.value};
- }
- """
- }
- val codes = ctx.splitExpressionsWithCurrentInputs(evalChildren.map(updateEval))
- ev.copy(code = s"""
- ${ev.isNull} = true;
- ${ev.value} = ${ctx.defaultValue(dataType)};
- $codes""")
+ |${eval.code}
+ |if (!${eval.isNull} && ($tmpIsNull ||
+ | ${ctx.genGreater(dataType, eval.value, ev.value)})) {
+ | $tmpIsNull = false;
+ | ${ev.value} = ${eval.value};
+ |}
+ """.stripMargin
+ )
+
+ val resultType = ctx.javaType(dataType)
+ val codes = ctx.splitExpressionsWithCurrentInputs(
+ expressions = evals,
+ funcName = "greatest",
+ extraArguments = Seq(resultType -> ev.value),
+ returnType = resultType,
+ makeSplitFunction = body =>
+ s"""
+ |$body
+ |return ${ev.value};
+ """.stripMargin,
+ foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
+ ev.copy(code =
+ s"""
+ |$tmpIsNull = true;
+ |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ |$codes
+ |final boolean ${ev.isNull} = $tmpIsNull;
+ """.stripMargin)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 670c82eff9286..5c9e604a8d293 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -29,7 +29,7 @@ import scala.util.control.NonFatal
import com.google.common.cache.{CacheBuilder, CacheLoader}
import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException}
import org.codehaus.commons.compiler.CompileException
-import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, JaninoRuntimeException, SimpleCompiler}
+import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, InternalCompilerException, SimpleCompiler}
import org.codehaus.janino.util.ClassFile
import org.apache.spark.{SparkEnv, TaskContext, TaskKilledException}
@@ -1240,12 +1240,12 @@ object CodeGenerator extends Logging {
evaluator.cook("generated.java", code.body)
updateAndGetCompilationStats(evaluator)
} catch {
- case e: JaninoRuntimeException =>
+ case e: InternalCompilerException =>
val msg = s"failed to compile: $e"
logError(msg, e)
val maxLines = SQLConf.get.loggingMaxLinesForCodegen
logInfo(s"\n${CodeFormatter.format(code, maxLines)}")
- throw new JaninoRuntimeException(msg, e)
+ throw new InternalCompilerException(msg, e)
case e: CompileException =>
val msg = s"failed to compile: $e"
logError(msg, e)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index 44e7148e5d98f..3dcbb518ba42a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -49,8 +49,6 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val tmpInput = ctx.freshName("tmpInput")
val output = ctx.freshName("safeRow")
val values = ctx.freshName("values")
- // These expressions could be split into multiple functions
- ctx.addMutableState("Object[]", values, s"$values = null;")
val rowClass = classOf[GenericInternalRow].getName
@@ -66,15 +64,15 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val allFields = ctx.splitExpressions(
expressions = fieldWriters,
funcName = "writeFields",
- arguments = Seq("InternalRow" -> tmpInput)
+ arguments = Seq("InternalRow" -> tmpInput, "Object[]" -> values)
)
- val code = s"""
- final InternalRow $tmpInput = $input;
- $values = new Object[${schema.length}];
- $allFields
- final InternalRow $output = new $rowClass($values);
- $values = null;
- """
+ val code =
+ s"""
+ |final InternalRow $tmpInput = $input;
+ |final Object[] $values = new Object[${schema.length}];
+ |$allFields
+ |final InternalRow $output = new $rowClass($values);
+ """.stripMargin
ExprCode(code, "false", output)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 087b21043b309..3dc2ee03a86e3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -356,22 +356,25 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val rowClass = classOf[GenericInternalRow].getName
val values = ctx.freshName("values")
- ctx.addMutableState("Object[]", values, s"$values = null;")
+ val valCodes = valExprs.zipWithIndex.map { case (e, i) =>
+ val eval = e.genCode(ctx)
+ s"""
+ |${eval.code}
+ |if (${eval.isNull}) {
+ | $values[$i] = null;
+ |} else {
+ | $values[$i] = ${eval.value};
+ |}
+ """.stripMargin
+ }
val valuesCode = ctx.splitExpressionsWithCurrentInputs(
- valExprs.zipWithIndex.map { case (e, i) =>
- val eval = e.genCode(ctx)
- s"""
- ${eval.code}
- if (${eval.isNull}) {
- $values[$i] = null;
- } else {
- $values[$i] = ${eval.value};
- }"""
- })
+ expressions = valCodes,
+ funcName = "createNamedStruct",
+ extraArguments = "Object[]" -> values :: Nil)
ev.copy(code =
s"""
- |$values = new Object[${valExprs.size}];
+ |Object[] $values = new Object[${valExprs.size}];
|$valuesCode
|final InternalRow ${ev.value} = new $rowClass($values);
|$values = null;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index ae5f7140847db..53c3b226895ec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -180,13 +180,18 @@ case class CaseWhen(
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- // This variable represents whether the first successful condition is met or not.
- // It is initialized to `false` and it is set to `true` when the first condition which
- // evaluates to `true` is met and therefore is not needed to go on anymore on the computation
- // of the following conditions.
- val conditionMet = ctx.freshName("caseWhenConditionMet")
- ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
- ctx.addMutableState(ctx.javaType(dataType), ev.value)
+ // This variable holds the state of the result:
+ // -1 means the condition is not met yet and the result is unknown.
+ val NOT_MATCHED = -1
+ // 0 means the condition is met and result is not null.
+ val HAS_NONNULL = 0
+ // 1 means the condition is met and result is null.
+ val HAS_NULL = 1
+ // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`,
+ // We won't go on anymore on the computation.
+ val resultState = ctx.freshName("caseWhenResultState")
+ val tmpResult = ctx.freshName("caseWhenTmpResult")
+ ctx.addMutableState(ctx.javaType(dataType), tmpResult)
// these blocks are meant to be inside a
// do {
@@ -200,9 +205,8 @@ case class CaseWhen(
|${cond.code}
|if (!${cond.isNull} && ${cond.value}) {
| ${res.code}
- | ${ev.isNull} = ${res.isNull};
- | ${ev.value} = ${res.value};
- | $conditionMet = true;
+ | $resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL);
+ | $tmpResult = ${res.value};
| continue;
|}
""".stripMargin
@@ -212,59 +216,63 @@ case class CaseWhen(
val res = elseExpr.genCode(ctx)
s"""
|${res.code}
- |${ev.isNull} = ${res.isNull};
- |${ev.value} = ${res.value};
+ |$resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL);
+ |$tmpResult = ${res.value};
""".stripMargin
}
val allConditions = cases ++ elseCode
// This generates code like:
- // conditionMet = caseWhen_1(i);
- // if(conditionMet) {
+ // caseWhenResultState = caseWhen_1(i);
+ // if(caseWhenResultState != -1) {
// continue;
// }
- // conditionMet = caseWhen_2(i);
- // if(conditionMet) {
+ // caseWhenResultState = caseWhen_2(i);
+ // if(caseWhenResultState != -1) {
// continue;
// }
// ...
// and the declared methods are:
- // private boolean caseWhen_1234() {
- // boolean conditionMet = false;
+ // private byte caseWhen_1234() {
+ // byte caseWhenResultState = -1;
// do {
// // here the evaluation of the conditions
// } while (false);
- // return conditionMet;
+ // return caseWhenResultState;
// }
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = allConditions,
funcName = "caseWhen",
- returnType = ctx.JAVA_BOOLEAN,
+ returnType = ctx.JAVA_BYTE,
makeSplitFunction = func =>
s"""
- |${ctx.JAVA_BOOLEAN} $conditionMet = false;
+ |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
|do {
| $func
|} while (false);
- |return $conditionMet;
+ |return $resultState;
""".stripMargin,
foldFunctions = _.map { funcCall =>
s"""
- |$conditionMet = $funcCall;
- |if ($conditionMet) {
+ |$resultState = $funcCall;
+ |if ($resultState != $NOT_MATCHED) {
| continue;
|}
""".stripMargin
}.mkString)
- ev.copy(code = s"""
- ${ev.isNull} = true;
- ${ev.value} = ${ctx.defaultValue(dataType)};
- ${ctx.JAVA_BOOLEAN} $conditionMet = false;
- do {
- $codes
- } while (false);""")
+ ev.copy(code =
+ s"""
+ |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
+ |$tmpResult = ${ctx.defaultValue(dataType)};
+ |do {
+ | $codes
+ |} while (false);
+ |// TRUE if any condition is met and the result is null, or no any condition is met.
+ |final boolean ${ev.isNull} = ($resultState != $HAS_NONNULL);
+ |final ${ctx.javaType(dataType)} ${ev.value} = $tmpResult;
+ """.stripMargin)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index 26c9a41efc9f9..294cdcb2e9546 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -72,8 +72,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
- ctx.addMutableState(ctx.javaType(dataType), ev.value)
+ val tmpIsNull = ctx.freshName("coalesceTmpIsNull")
+ ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull)
// all the evals are meant to be in a do { ... } while (false); loop
val evals = children.map { e =>
@@ -81,26 +81,30 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
s"""
|${eval.code}
|if (!${eval.isNull}) {
- | ${ev.isNull} = false;
+ | $tmpIsNull = false;
| ${ev.value} = ${eval.value};
| continue;
|}
""".stripMargin
}
+ val resultType = ctx.javaType(dataType)
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = evals,
funcName = "coalesce",
+ returnType = resultType,
makeSplitFunction = func =>
s"""
+ |$resultType ${ev.value} = ${ctx.defaultValue(dataType)};
|do {
| $func
|} while (false);
+ |return ${ev.value};
""".stripMargin,
foldFunctions = _.map { funcCall =>
s"""
- |$funcCall;
- |if (!${ev.isNull}) {
+ |${ev.value} = $funcCall;
+ |if (!$tmpIsNull) {
| continue;
|}
""".stripMargin
@@ -109,11 +113,12 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
ev.copy(code =
s"""
- |${ev.isNull} = true;
- |${ev.value} = ${ctx.defaultValue(dataType)};
+ |$tmpIsNull = true;
+ |$resultType ${ev.value} = ${ctx.defaultValue(dataType)};
|do {
| $codes
|} while (false);
+ |final boolean ${ev.isNull} = $tmpIsNull;
""".stripMargin)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 730b2ff96da6c..349afece84d5c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -1106,27 +1106,31 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val rowClass = classOf[GenericRowWithSchema].getName
val values = ctx.freshName("values")
- ctx.addMutableState("Object[]", values)
val childrenCodes = children.zipWithIndex.map { case (e, i) =>
val eval = e.genCode(ctx)
- eval.code + s"""
- if (${eval.isNull}) {
- $values[$i] = null;
- } else {
- $values[$i] = ${eval.value};
- }
- """
+ s"""
+ |${eval.code}
+ |if (${eval.isNull}) {
+ | $values[$i] = null;
+ |} else {
+ | $values[$i] = ${eval.value};
+ |}
+ """.stripMargin
}
- val childrenCode = ctx.splitExpressionsWithCurrentInputs(childrenCodes)
- val schemaField = ctx.addReferenceObj("schema", schema)
+ val childrenCode = ctx.splitExpressionsWithCurrentInputs(
+ expressions = childrenCodes,
+ funcName = "createExternalRow",
+ extraArguments = "Object[]" -> values :: Nil)
+ val schemaField = ctx.addReferenceMinorObj(schema)
- val code = s"""
- $values = new Object[${children.size}];
- $childrenCode
- final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField);
- """
+ val code =
+ s"""
+ |Object[] $values = new Object[${children.size}];
+ |$childrenCode
+ |final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField);
+ """.stripMargin
ev.copy(code = code, isNull = "false")
}
}
@@ -1244,25 +1248,28 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
val javaBeanInstance = ctx.freshName("javaBean")
val beanInstanceJavaType = ctx.javaType(beanInstance.dataType)
- ctx.addMutableState(beanInstanceJavaType, javaBeanInstance)
val initialize = setters.map {
case (setterMethod, fieldValue) =>
val fieldGen = fieldValue.genCode(ctx)
s"""
- ${fieldGen.code}
- ${javaBeanInstance}.$setterMethod(${fieldGen.value});
- """
+ |${fieldGen.code}
+ |$javaBeanInstance.$setterMethod(${fieldGen.value});
+ """.stripMargin
}
- val initializeCode = ctx.splitExpressionsWithCurrentInputs(initialize.toSeq)
+ val initializeCode = ctx.splitExpressionsWithCurrentInputs(
+ expressions = initialize.toSeq,
+ funcName = "initializeJavaBean",
+ extraArguments = beanInstanceJavaType -> javaBeanInstance :: Nil)
- val code = s"""
- ${instanceGen.code}
- ${javaBeanInstance} = ${instanceGen.value};
- if (!${instanceGen.isNull}) {
- $initializeCode
- }
- """
+ val code =
+ s"""
+ |${instanceGen.code}
+ |$beanInstanceJavaType $javaBeanInstance = ${instanceGen.value};
+ |if (!${instanceGen.isNull}) {
+ | $initializeCode
+ |}
+ """.stripMargin
ev.copy(code = code, isNull = instanceGen.isNull, value = instanceGen.value)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 04e669492ec6d..8eb41addaf689 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -237,8 +237,14 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
val javaDataType = ctx.javaType(value.dataType)
val valueGen = value.genCode(ctx)
val listGen = list.map(_.genCode(ctx))
- ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.value)
- ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
+ // inTmpResult has 3 possible values:
+ // -1 means no matches found and there is at least one value in the list evaluated to null
+ val HAS_NULL = -1
+ // 0 means no matches found and all values in the list are not null
+ val NOT_MATCHED = 0
+ // 1 means one value in the list is matched
+ val MATCHED = 1
+ val tmpResult = ctx.freshName("inTmpResult")
val valueArg = ctx.freshName("valueArg")
// All the blocks are meant to be inside a do { ... } while (false); loop.
// The evaluation of variables can be stopped when we find a matching value.
@@ -246,10 +252,9 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
s"""
|${x.code}
|if (${x.isNull}) {
- | ${ev.isNull} = true;
+ | $tmpResult = $HAS_NULL; // ${ev.isNull} = true;
|} else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) {
- | ${ev.isNull} = false;
- | ${ev.value} = true;
+ | $tmpResult = $MATCHED; // ${ev.isNull} = false; ${ev.value} = true;
| continue;
|}
""".stripMargin)
@@ -257,17 +262,19 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = listCode,
funcName = "valueIn",
- extraArguments = (javaDataType, valueArg) :: Nil,
+ extraArguments = (javaDataType, valueArg) :: (ctx.JAVA_BYTE, tmpResult) :: Nil,
+ returnType = ctx.JAVA_BYTE,
makeSplitFunction = body =>
s"""
|do {
| $body
|} while (false);
+ |return $tmpResult;
""".stripMargin,
foldFunctions = _.map { funcCall =>
s"""
- |$funcCall;
- |if (${ev.value}) {
+ |$tmpResult = $funcCall;
+ |if ($tmpResult == $MATCHED) {
| continue;
|}
""".stripMargin
@@ -276,14 +283,16 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
ev.copy(code =
s"""
|${valueGen.code}
- |${ev.value} = false;
- |${ev.isNull} = ${valueGen.isNull};
- |if (!${ev.isNull}) {
+ |byte $tmpResult = $HAS_NULL;
+ |if (!${valueGen.isNull}) {
+ | $tmpResult = 0;
| $javaDataType $valueArg = ${valueGen.value};
| do {
| $codes
| } while (false);
|}
+ |final boolean ${ev.isNull} = ($tmpResult == $HAS_NULL);
+ |final boolean ${ev.value} = ($tmpResult == $MATCHED);
""".stripMargin)
}
@@ -344,17 +353,17 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
} else {
""
}
- ctx.addMutableState(setName, setTerm,
- s"$setTerm = (($InSetName)references[${ctx.references.size - 1}]).getSet();")
- ev.copy(code = s"""
- ${childGen.code}
- boolean ${ev.isNull} = ${childGen.isNull};
- boolean ${ev.value} = false;
- if (!${ev.isNull}) {
- ${ev.value} = $setTerm.contains(${childGen.value});
- $setNull
- }
- """)
+ ev.copy(code =
+ s"""
+ |${childGen.code}
+ |${ctx.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull};
+ |${ctx.JAVA_BOOLEAN} ${ev.value} = false;
+ |if (!${ev.isNull}) {
+ | $setName $setTerm = (($InSetName)references[${ctx.references.size - 1}]).getSet();
+ | ${ev.value} = $setTerm.contains(${childGen.value});
+ | $setNull
+ |}
+ """.stripMargin)
}
override def sql: String = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 785e815b41185..6305b6c84bae3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -64,49 +64,89 @@ object ConstantFolding extends Rule[LogicalPlan] {
* }}}
*
* Approach used:
- * - Start from AND operator as the root
- * - Get all the children conjunctive predicates which are EqualTo / EqualNullSafe such that they
- * don't have a `NOT` or `OR` operator in them
* - Populate a mapping of attribute => constant value by looking at all the equals predicates
* - Using this mapping, replace occurrence of the attributes with the corresponding constant values
* in the AND node.
*/
object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper {
- private def containsNonConjunctionPredicates(expression: Expression): Boolean = expression.find {
- case _: Not | _: Or => true
- case _ => false
- }.isDefined
-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case f: Filter => f transformExpressionsUp {
- case and: And =>
- val conjunctivePredicates =
- splitConjunctivePredicates(and)
- .filter(expr => expr.isInstanceOf[EqualTo] || expr.isInstanceOf[EqualNullSafe])
- .filterNot(expr => containsNonConjunctionPredicates(expr))
-
- val equalityPredicates = conjunctivePredicates.collect {
- case e @ EqualTo(left: AttributeReference, right: Literal) => ((left, right), e)
- case e @ EqualTo(left: Literal, right: AttributeReference) => ((right, left), e)
- case e @ EqualNullSafe(left: AttributeReference, right: Literal) => ((left, right), e)
- case e @ EqualNullSafe(left: Literal, right: AttributeReference) => ((right, left), e)
- }
+ case f: Filter =>
+ val (newCondition, _) = traverse(f.condition, replaceChildren = true)
+ if (newCondition.isDefined) {
+ f.copy(condition = newCondition.get)
+ } else {
+ f
+ }
+ }
- val constantsMap = AttributeMap(equalityPredicates.map(_._1))
- val predicates = equalityPredicates.map(_._2).toSet
+ type EqualityPredicates = Seq[((AttributeReference, Literal), BinaryComparison)]
- def replaceConstants(expression: Expression) = expression transform {
- case a: AttributeReference =>
- constantsMap.get(a) match {
- case Some(literal) => literal
- case None => a
- }
+ /**
+ * Traverse a condition as a tree and replace attributes with constant values.
+ * - On matching [[And]], recursively traverse each children and get propagated mappings.
+ * If the current node is not child of another [[And]], replace all occurrences of the
+ * attributes with the corresponding constant values.
+ * - If a child of [[And]] is [[EqualTo]] or [[EqualNullSafe]], propagate the mapping
+ * of attribute => constant.
+ * - On matching [[Or]] or [[Not]], recursively traverse each children, propagate empty mapping.
+ * - Otherwise, stop traversal and propagate empty mapping.
+ * @param condition condition to be traversed
+ * @param replaceChildren whether to replace attributes with constant values in children
+ * @return A tuple including:
+ * 1. Option[Expression]: optional changed condition after traversal
+ * 2. EqualityPredicates: propagated mapping of attribute => constant
+ */
+ private def traverse(condition: Expression, replaceChildren: Boolean)
+ : (Option[Expression], EqualityPredicates) =
+ condition match {
+ case e @ EqualTo(left: AttributeReference, right: Literal) => (None, Seq(((left, right), e)))
+ case e @ EqualTo(left: Literal, right: AttributeReference) => (None, Seq(((right, left), e)))
+ case e @ EqualNullSafe(left: AttributeReference, right: Literal) =>
+ (None, Seq(((left, right), e)))
+ case e @ EqualNullSafe(left: Literal, right: AttributeReference) =>
+ (None, Seq(((right, left), e)))
+ case a: And =>
+ val (newLeft, equalityPredicatesLeft) = traverse(a.left, replaceChildren = false)
+ val (newRight, equalityPredicatesRight) = traverse(a.right, replaceChildren = false)
+ val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight
+ val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) {
+ Some(And(replaceConstants(newLeft.getOrElse(a.left), equalityPredicates),
+ replaceConstants(newRight.getOrElse(a.right), equalityPredicates)))
+ } else {
+ if (newLeft.isDefined || newRight.isDefined) {
+ Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right)))
+ } else {
+ None
+ }
}
-
- and transform {
- case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants(e)
- case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants(e)
+ (newSelf, equalityPredicates)
+ case o: Or =>
+ // Ignore the EqualityPredicates from children since they are only propagated through And.
+ val (newLeft, _) = traverse(o.left, replaceChildren = true)
+ val (newRight, _) = traverse(o.right, replaceChildren = true)
+ val newSelf = if (newLeft.isDefined || newRight.isDefined) {
+ Some(Or(left = newLeft.getOrElse(o.left), right = newRight.getOrElse((o.right))))
+ } else {
+ None
}
+ (newSelf, Seq.empty)
+ case n: Not =>
+ // Ignore the EqualityPredicates from children since they are only propagated through And.
+ val (newChild, _) = traverse(n.child, replaceChildren = true)
+ (newChild.map(Not), Seq.empty)
+ case _ => (None, Seq.empty)
+ }
+
+ private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates)
+ : Expression = {
+ val constantsMap = AttributeMap(equalityPredicates.map(_._1))
+ val predicates = equalityPredicates.map(_._2).toSet
+ def replaceConstants0(expression: Expression) = expression transform {
+ case a: AttributeReference => constantsMap.getOrElse(a, a)
+ }
+ condition transform {
+ case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e)
+ case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala
index 06196b5afb031..7a927e1e083b5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala
@@ -38,7 +38,7 @@ object EventTimeWatermark {
case class EventTimeWatermark(
eventTime: Attribute,
delay: CalendarInterval,
- child: LogicalPlan) extends LogicalPlan {
+ child: LogicalPlan) extends UnaryNode {
// Update the metadata on the eventTime column to include the desired delay.
override val output: Seq[Attribute] = child.output.map { a =>
@@ -60,6 +60,4 @@ case class EventTimeWatermark(
a
}
}
-
- override val children: Seq[LogicalPlan] = child :: Nil
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index fb759eba6a9e2..be638d80e45d8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types._
class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -343,4 +344,14 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Least(inputsExpr), "s" * 1, EmptyRow)
checkEvaluation(Greatest(inputsExpr), "s" * N, EmptyRow)
}
+
+ test("SPARK-22704: Least and greatest use less global variables") {
+ val ctx1 = new CodegenContext()
+ Least(Seq(Literal(1), Literal(1))).genCode(ctx1)
+ assert(ctx1.mutableStates.size == 1)
+
+ val ctx2 = new CodegenContext()
+ Greatest(Seq(Literal(1), Literal(1))).genCode(ctx2)
+ assert(ctx2.mutableStates.size == 1)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index a4198f826cedb..40bf29bb3b573 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, GetExternalRowField, ValidateExternalType}
+import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -380,4 +380,18 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
s"Incorrect Evaluation: expressions: $exprAnd, actual: $actualAnd, expected: $expectedAnd")
}
}
+
+ test("SPARK-22696: CreateExternalRow should not use global variables") {
+ val ctx = new CodegenContext
+ val schema = new StructType().add("a", IntegerType).add("b", StringType)
+ CreateExternalRow(Seq(Literal(1), Literal("x")), schema).genCode(ctx)
+ assert(ctx.mutableStates.isEmpty)
+ }
+
+ test("SPARK-22696: InitializeJavaBean should not use global variables") {
+ val ctx = new CodegenContext
+ InitializeJavaBean(Literal.fromObject(new java.util.LinkedList[Int]),
+ Map("add" -> Literal(1))).genCode(ctx)
+ assert(ctx.mutableStates.isEmpty)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index b0eaad1c80f89..6dfca7d73a3df 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -299,4 +300,10 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
new StringToMap(Literal("a=1_b=2_c=3"), Literal("_"), NonFoldableLiteral("="))
.checkInputDataTypes().isFailure)
}
+
+ test("SPARK-22693: CreateNamedStruct should not use global variables") {
+ val ctx = new CodegenContext
+ CreateNamedStruct(Seq("a", "x", "b", 2.0)).genCode(ctx)
+ assert(ctx.mutableStates.isEmpty)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
index 3e11c3d2d4fe3..60d84aae1fa3f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types._
class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -145,4 +146,10 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
IndexedSeq((Literal(12) === Literal(1), Literal(42)),
(Literal(12) === Literal(42), Literal(1))))
}
+
+ test("SPARK-22705: case when should use less global variables") {
+ val ctx = new CodegenContext()
+ CaseWhen(Seq((Literal.create(false, BooleanType), Literal(1))), Literal(-1)).genCode(ctx)
+ assert(ctx.mutableStates.size == 1)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
index 40ef7770da33f..a23cd95632770 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.types._
@@ -155,6 +156,12 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Coalesce(inputs), "x_1")
}
+ test("SPARK-22705: Coalesce should use less global variables") {
+ val ctx = new CodegenContext()
+ Coalesce(Seq(Literal("a"), Literal("b"))).genCode(ctx)
+ assert(ctx.mutableStates.size == 1)
+ }
+
test("AtLeastNNonNulls should not throw 64kb exception") {
val inputs = (1 to 4000).map(x => Literal(s"x_$x"))
checkEvaluation(AtLeastNNonNulls(1, inputs), true)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index 0079e4e8d6f74..15cb0bea08f17 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -25,6 +25,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.RandomDataGenerator
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.types._
@@ -245,6 +246,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(In(Literal(1.0D), sets), true)
}
+ test("SPARK-22705: In should use less global variables") {
+ val ctx = new CodegenContext()
+ In(Literal(1.0D), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx)
+ assert(ctx.mutableStates.isEmpty)
+ }
+
test("INSET") {
val hS = HashSet[Any]() + 1 + 2
val nS = HashSet[Any]() + 1 + 2 + null
@@ -429,4 +436,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
val infinity = Literal(Double.PositiveInfinity)
checkEvaluation(EqualTo(infinity, infinity), true)
}
+
+ test("SPARK-22693: InSet should not use global variables") {
+ val ctx = new CodegenContext
+ InSet(Literal(1), Set(1, 2, 3, 4)).genCode(ctx)
+ assert(ctx.mutableStates.isEmpty)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
index 13bd363c8b692..70dea4b39d55d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.Locale
import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types.{IntegerType, StringType}
class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -47,4 +48,9 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(e2.getMessage.contains("Failed to execute user defined function"))
}
+ test("SPARK-22695: ScalaUDF should not use global variables") {
+ val ctx = new CodegenContext
+ ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil).genCode(ctx)
+ assert(ctx.mutableStates.isEmpty)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
index 0cd0d8859145f..6031bdf19e957 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
@@ -208,4 +208,15 @@ class GeneratedProjectionSuite extends SparkFunSuite {
unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("b"))))
assert(row.getStruct(0, 1).getString(0).toString == "a")
}
+
+ test("SPARK-22699: GenerateSafeProjection should not use global variables for struct") {
+ val safeProj = GenerateSafeProjection.generate(
+ Seq(BoundReference(0, new StructType().add("i", IntegerType), true)))
+ val globalVariables = safeProj.getClass.getDeclaredFields
+ // We need always 3 variables:
+ // - one is a reference to this
+ // - one is the references object
+ // - one is the mutableRow
+ assert(globalVariables.length == 3)
+ }
}
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 4db3fea008ee9..93010c606cf45 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -38,7 +38,7 @@
com.univocity
univocity-parsers
- 2.5.4
+ 2.5.9
jar
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java
index 9467435435d1f..24260b05194a7 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java
@@ -41,7 +41,7 @@
public class AggregateHashMap {
private OnHeapColumnVector[] columnVectors;
- private ColumnarBatch batch;
+ private MutableColumnarRow aggBufferRow;
private int[] buckets;
private int numBuckets;
private int numRows = 0;
@@ -63,7 +63,7 @@ public AggregateHashMap(StructType schema, int capacity, double loadFactor, int
this.maxSteps = maxSteps;
numBuckets = (int) (capacity / loadFactor);
columnVectors = OnHeapColumnVector.allocateColumns(capacity, schema);
- batch = new ColumnarBatch(schema, columnVectors, capacity);
+ aggBufferRow = new MutableColumnarRow(columnVectors);
buckets = new int[numBuckets];
Arrays.fill(buckets, -1);
}
@@ -72,14 +72,15 @@ public AggregateHashMap(StructType schema) {
this(schema, DEFAULT_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_MAX_STEPS);
}
- public ColumnarRow findOrInsert(long key) {
+ public MutableColumnarRow findOrInsert(long key) {
int idx = find(key);
if (idx != -1 && buckets[idx] == -1) {
columnVectors[0].putLong(numRows, key);
columnVectors[1].putLong(numRows, 0);
buckets[idx] = numRows++;
}
- return batch.getRow(buckets[idx]);
+ aggBufferRow.rowId = buckets[idx];
+ return aggBufferRow;
}
@VisibleForTesting
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java
index 0071bd66760be..1f1347ccd315e 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java
@@ -323,7 +323,6 @@ public ArrowColumnVector(ValueVector vector) {
for (int i = 0; i < childColumns.length; ++i) {
childColumns[i] = new ArrowColumnVector(mapVector.getVectorById(i));
}
- resultStruct = new ColumnarRow(childColumns);
} else {
throw new UnsupportedOperationException();
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
index cca14911fbb28..e6b87519239dd 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
@@ -157,18 +157,16 @@ public abstract class ColumnVector implements AutoCloseable {
/**
* Returns a utility object to get structs.
*/
- public ColumnarRow getStruct(int rowId) {
- resultStruct.rowId = rowId;
- return resultStruct;
+ public final ColumnarRow getStruct(int rowId) {
+ return new ColumnarRow(this, rowId);
}
/**
* Returns a utility object to get structs.
* provided to keep API compatibility with InternalRow for code generation
*/
- public ColumnarRow getStruct(int rowId, int size) {
- resultStruct.rowId = rowId;
- return resultStruct;
+ public final ColumnarRow getStruct(int rowId, int size) {
+ return getStruct(rowId);
}
/**
@@ -216,11 +214,6 @@ public MapData getMap(int ordinal) {
*/
protected DataType type;
- /**
- * Reusable Struct holder for getStruct().
- */
- protected ColumnarRow resultStruct;
-
/**
* The Dictionary for this column.
*
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
index 2f5fb360b226f..a9d09aa679726 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
@@ -18,6 +18,7 @@
import java.util.*;
+import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.StructType;
/**
@@ -40,10 +41,10 @@ public final class ColumnarBatch {
private final StructType schema;
private final int capacity;
private int numRows;
- final ColumnVector[] columns;
+ private final ColumnVector[] columns;
- // Staging row returned from getRow.
- final ColumnarRow row;
+ // Staging row returned from `getRow`.
+ private final MutableColumnarRow row;
/**
* Called to close all the columns in this batch. It is not valid to access the data after
@@ -58,10 +59,10 @@ public void close() {
/**
* Returns an iterator over the rows in this batch. This skips rows that are filtered out.
*/
- public Iterator rowIterator() {
+ public Iterator rowIterator() {
final int maxRows = numRows;
- final ColumnarRow row = new ColumnarRow(columns);
- return new Iterator() {
+ final MutableColumnarRow row = new MutableColumnarRow(columns);
+ return new Iterator() {
int rowId = 0;
@Override
@@ -70,7 +71,7 @@ public boolean hasNext() {
}
@Override
- public ColumnarRow next() {
+ public InternalRow next() {
if (rowId >= maxRows) {
throw new NoSuchElementException();
}
@@ -133,9 +134,8 @@ public void setNumRows(int numRows) {
/**
* Returns the row in this batch at `rowId`. Returned row is reused across calls.
*/
- public ColumnarRow getRow(int rowId) {
- assert(rowId >= 0);
- assert(rowId < numRows);
+ public InternalRow getRow(int rowId) {
+ assert(rowId >= 0 && rowId < numRows);
row.rowId = rowId;
return row;
}
@@ -144,6 +144,6 @@ public ColumnarBatch(StructType schema, ColumnVector[] columns, int capacity) {
this.schema = schema;
this.columns = columns;
this.capacity = capacity;
- this.row = new ColumnarRow(columns);
+ this.row = new MutableColumnarRow(columns);
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java
index cabb7479525d9..95c0d09873d67 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java
@@ -28,30 +28,32 @@
* to be reused, callers should copy the data out if it needs to be stored.
*/
public final class ColumnarRow extends InternalRow {
- protected int rowId;
- private final ColumnVector[] columns;
-
- // Ctor used if this is a struct.
- ColumnarRow(ColumnVector[] columns) {
- this.columns = columns;
+ // The data for this row. E.g. the value of 3rd int field is `data.getChildColumn(3).getInt(rowId)`.
+ private final ColumnVector data;
+ private final int rowId;
+ private final int numFields;
+
+ ColumnarRow(ColumnVector data, int rowId) {
+ assert (data.dataType() instanceof StructType);
+ this.data = data;
+ this.rowId = rowId;
+ this.numFields = ((StructType) data.dataType()).size();
}
- public ColumnVector[] columns() { return columns; }
-
@Override
- public int numFields() { return columns.length; }
+ public int numFields() { return numFields; }
/**
* Revisit this. This is expensive. This is currently only used in test paths.
*/
@Override
public InternalRow copy() {
- GenericInternalRow row = new GenericInternalRow(columns.length);
+ GenericInternalRow row = new GenericInternalRow(numFields);
for (int i = 0; i < numFields(); i++) {
if (isNullAt(i)) {
row.setNullAt(i);
} else {
- DataType dt = columns[i].dataType();
+ DataType dt = data.getChildColumn(i).dataType();
if (dt instanceof BooleanType) {
row.setBoolean(i, getBoolean(i));
} else if (dt instanceof ByteType) {
@@ -91,65 +93,65 @@ public boolean anyNull() {
}
@Override
- public boolean isNullAt(int ordinal) { return columns[ordinal].isNullAt(rowId); }
+ public boolean isNullAt(int ordinal) { return data.getChildColumn(ordinal).isNullAt(rowId); }
@Override
- public boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); }
+ public boolean getBoolean(int ordinal) { return data.getChildColumn(ordinal).getBoolean(rowId); }
@Override
- public byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); }
+ public byte getByte(int ordinal) { return data.getChildColumn(ordinal).getByte(rowId); }
@Override
- public short getShort(int ordinal) { return columns[ordinal].getShort(rowId); }
+ public short getShort(int ordinal) { return data.getChildColumn(ordinal).getShort(rowId); }
@Override
- public int getInt(int ordinal) { return columns[ordinal].getInt(rowId); }
+ public int getInt(int ordinal) { return data.getChildColumn(ordinal).getInt(rowId); }
@Override
- public long getLong(int ordinal) { return columns[ordinal].getLong(rowId); }
+ public long getLong(int ordinal) { return data.getChildColumn(ordinal).getLong(rowId); }
@Override
- public float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); }
+ public float getFloat(int ordinal) { return data.getChildColumn(ordinal).getFloat(rowId); }
@Override
- public double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); }
+ public double getDouble(int ordinal) { return data.getChildColumn(ordinal).getDouble(rowId); }
@Override
public Decimal getDecimal(int ordinal, int precision, int scale) {
- if (columns[ordinal].isNullAt(rowId)) return null;
- return columns[ordinal].getDecimal(rowId, precision, scale);
+ if (data.getChildColumn(ordinal).isNullAt(rowId)) return null;
+ return data.getChildColumn(ordinal).getDecimal(rowId, precision, scale);
}
@Override
public UTF8String getUTF8String(int ordinal) {
- if (columns[ordinal].isNullAt(rowId)) return null;
- return columns[ordinal].getUTF8String(rowId);
+ if (data.getChildColumn(ordinal).isNullAt(rowId)) return null;
+ return data.getChildColumn(ordinal).getUTF8String(rowId);
}
@Override
public byte[] getBinary(int ordinal) {
- if (columns[ordinal].isNullAt(rowId)) return null;
- return columns[ordinal].getBinary(rowId);
+ if (data.getChildColumn(ordinal).isNullAt(rowId)) return null;
+ return data.getChildColumn(ordinal).getBinary(rowId);
}
@Override
public CalendarInterval getInterval(int ordinal) {
- if (columns[ordinal].isNullAt(rowId)) return null;
- final int months = columns[ordinal].getChildColumn(0).getInt(rowId);
- final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId);
+ if (data.getChildColumn(ordinal).isNullAt(rowId)) return null;
+ final int months = data.getChildColumn(ordinal).getChildColumn(0).getInt(rowId);
+ final long microseconds = data.getChildColumn(ordinal).getChildColumn(1).getLong(rowId);
return new CalendarInterval(months, microseconds);
}
@Override
public ColumnarRow getStruct(int ordinal, int numFields) {
- if (columns[ordinal].isNullAt(rowId)) return null;
- return columns[ordinal].getStruct(rowId);
+ if (data.getChildColumn(ordinal).isNullAt(rowId)) return null;
+ return data.getChildColumn(ordinal).getStruct(rowId);
}
@Override
public ColumnarArray getArray(int ordinal) {
- if (columns[ordinal].isNullAt(rowId)) return null;
- return columns[ordinal].getArray(rowId);
+ if (data.getChildColumn(ordinal).isNullAt(rowId)) return null;
+ return data.getChildColumn(ordinal).getArray(rowId);
}
@Override
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
index f272cc163611b..06602c147dfe9 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
@@ -28,17 +28,24 @@
/**
* A mutable version of {@link ColumnarRow}, which is used in the vectorized hash map for hash
- * aggregate.
+ * aggregate, and {@link ColumnarBatch} to save object creation.
*
* Note that this class intentionally has a lot of duplicated code with {@link ColumnarRow}, to
* avoid java polymorphism overhead by keeping {@link ColumnarRow} and this class final classes.
*/
public final class MutableColumnarRow extends InternalRow {
public int rowId;
- private final WritableColumnVector[] columns;
+ private final ColumnVector[] columns;
+ private final WritableColumnVector[] writableColumns;
- public MutableColumnarRow(WritableColumnVector[] columns) {
+ public MutableColumnarRow(ColumnVector[] columns) {
this.columns = columns;
+ this.writableColumns = null;
+ }
+
+ public MutableColumnarRow(WritableColumnVector[] writableColumns) {
+ this.columns = writableColumns;
+ this.writableColumns = writableColumns;
}
@Override
@@ -225,54 +232,54 @@ public void update(int ordinal, Object value) {
@Override
public void setNullAt(int ordinal) {
- columns[ordinal].putNull(rowId);
+ writableColumns[ordinal].putNull(rowId);
}
@Override
public void setBoolean(int ordinal, boolean value) {
- columns[ordinal].putNotNull(rowId);
- columns[ordinal].putBoolean(rowId, value);
+ writableColumns[ordinal].putNotNull(rowId);
+ writableColumns[ordinal].putBoolean(rowId, value);
}
@Override
public void setByte(int ordinal, byte value) {
- columns[ordinal].putNotNull(rowId);
- columns[ordinal].putByte(rowId, value);
+ writableColumns[ordinal].putNotNull(rowId);
+ writableColumns[ordinal].putByte(rowId, value);
}
@Override
public void setShort(int ordinal, short value) {
- columns[ordinal].putNotNull(rowId);
- columns[ordinal].putShort(rowId, value);
+ writableColumns[ordinal].putNotNull(rowId);
+ writableColumns[ordinal].putShort(rowId, value);
}
@Override
public void setInt(int ordinal, int value) {
- columns[ordinal].putNotNull(rowId);
- columns[ordinal].putInt(rowId, value);
+ writableColumns[ordinal].putNotNull(rowId);
+ writableColumns[ordinal].putInt(rowId, value);
}
@Override
public void setLong(int ordinal, long value) {
- columns[ordinal].putNotNull(rowId);
- columns[ordinal].putLong(rowId, value);
+ writableColumns[ordinal].putNotNull(rowId);
+ writableColumns[ordinal].putLong(rowId, value);
}
@Override
public void setFloat(int ordinal, float value) {
- columns[ordinal].putNotNull(rowId);
- columns[ordinal].putFloat(rowId, value);
+ writableColumns[ordinal].putNotNull(rowId);
+ writableColumns[ordinal].putFloat(rowId, value);
}
@Override
public void setDouble(int ordinal, double value) {
- columns[ordinal].putNotNull(rowId);
- columns[ordinal].putDouble(rowId, value);
+ writableColumns[ordinal].putNotNull(rowId);
+ writableColumns[ordinal].putDouble(rowId, value);
}
@Override
public void setDecimal(int ordinal, Decimal value, int precision) {
- columns[ordinal].putNotNull(rowId);
- columns[ordinal].putDecimal(rowId, value, precision);
+ writableColumns[ordinal].putNotNull(rowId);
+ writableColumns[ordinal].putDecimal(rowId, value, precision);
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
index 806d0291a6c49..5f1b9885334b7 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
@@ -547,7 +547,7 @@ protected void reserveInternal(int newCapacity) {
} else if (type instanceof LongType || type instanceof DoubleType ||
DecimalType.is64BitDecimalType(type) || type instanceof TimestampType) {
this.data = Platform.reallocateMemory(data, oldCapacity * 8, newCapacity * 8);
- } else if (resultStruct != null) {
+ } else if (childColumns != null) {
// Nothing to store.
} else {
throw new RuntimeException("Unhandled " + type);
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
index 6e7f74ce12f16..f12772ede575d 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
@@ -558,7 +558,7 @@ protected void reserveInternal(int newCapacity) {
if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, capacity);
doubleData = newData;
}
- } else if (resultStruct != null) {
+ } else if (childColumns != null) {
// Nothing to store.
} else {
throw new RuntimeException("Unhandled " + type);
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
index 0bea4cc97142d..7c053b579442c 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
@@ -74,7 +74,6 @@ public void close() {
dictionaryIds = null;
}
dictionary = null;
- resultStruct = null;
}
public void reserve(int requiredCapacity) {
@@ -673,23 +672,19 @@ protected WritableColumnVector(int capacity, DataType type) {
}
this.childColumns = new WritableColumnVector[1];
this.childColumns[0] = reserveNewColumn(childCapacity, childType);
- this.resultStruct = null;
} else if (type instanceof StructType) {
StructType st = (StructType)type;
this.childColumns = new WritableColumnVector[st.fields().length];
for (int i = 0; i < childColumns.length; ++i) {
this.childColumns[i] = reserveNewColumn(capacity, st.fields()[i].dataType());
}
- this.resultStruct = new ColumnarRow(this.childColumns);
} else if (type instanceof CalendarIntervalType) {
// Two columns. Months as int. Microseconds as Long.
this.childColumns = new WritableColumnVector[2];
this.childColumns[0] = reserveNewColumn(capacity, DataTypes.IntegerType);
this.childColumns[1] = reserveNewColumn(capacity, DataTypes.LongType);
- this.resultStruct = new ColumnarRow(this.childColumns);
} else {
this.childColumns = null;
- this.resultStruct = null;
}
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java
index 9a89c8193dd6e..b2c908dc73a61 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java
@@ -49,4 +49,35 @@ public DataSourceV2Options(Map originalMap) {
public Optional get(String key) {
return Optional.ofNullable(keyLowerCasedMap.get(toLowerCase(key)));
}
+
+ /**
+ * Returns the boolean value to which the specified key is mapped,
+ * or defaultValue if there is no mapping for the key. The key match is case-insensitive
+ */
+ public boolean getBoolean(String key, boolean defaultValue) {
+ String lcaseKey = toLowerCase(key);
+ return keyLowerCasedMap.containsKey(lcaseKey) ?
+ Boolean.parseBoolean(keyLowerCasedMap.get(lcaseKey)) : defaultValue;
+ }
+
+ /**
+ * Returns the integer value to which the specified key is mapped,
+ * or defaultValue if there is no mapping for the key. The key match is case-insensitive
+ */
+ public int getInt(String key, int defaultValue) {
+ String lcaseKey = toLowerCase(key);
+ return keyLowerCasedMap.containsKey(lcaseKey) ?
+ Integer.parseInt(keyLowerCasedMap.get(lcaseKey)) : defaultValue;
+ }
+
+ /**
+ * Returns the long value to which the specified key is mapped,
+ * or defaultValue if there is no mapping for the key. The key match is case-insensitive
+ */
+ public long getLong(String key, long defaultValue) {
+ String lcaseKey = toLowerCase(key);
+ return keyLowerCasedMap.containsKey(lcaseKey) ?
+ Long.parseLong(keyLowerCasedMap.get(lcaseKey)) : defaultValue;
+ }
+
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 657b265260135..787c1cfbfb3d8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -23,7 +23,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext
import org.codehaus.commons.compiler.CompileException
-import org.codehaus.janino.JaninoRuntimeException
+import org.codehaus.janino.InternalCompilerException
import org.apache.spark.{broadcast, SparkEnv}
import org.apache.spark.internal.Logging
@@ -385,7 +385,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
try {
GeneratePredicate.generate(expression, inputSchema)
} catch {
- case _ @ (_: JaninoRuntimeException | _: CompileException) if codeGenFallBack =>
+ case _ @ (_: InternalCompilerException | _: CompileException) if codeGenFallBack =>
genInterpretedPredicate(expression, inputSchema)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 26d8cd7278353..9cadd13999e72 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -595,9 +595,7 @@ case class HashAggregateExec(
ctx.addMutableState(fastHashMapClassName, fastHashMapTerm,
s"$fastHashMapTerm = new $fastHashMapClassName();")
- ctx.addMutableState(
- s"java.util.Iterator<${classOf[ColumnarRow].getName}>",
- iterTermForFastHashMap)
+ ctx.addMutableState(s"java.util.Iterator", iterTermForFastHashMap)
} else {
val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions,
fastHashMapClassName, groupingKeySchema, bufferSchema).generate()
@@ -674,7 +672,7 @@ case class HashAggregateExec(
""".stripMargin
}
- // Iterate over the aggregate rows and convert them from ColumnarRow to UnsafeRow
+ // Iterate over the aggregate rows and convert them from InternalRow to UnsafeRow
def outputFromVectorizedMap: String = {
val row = ctx.freshName("fastHashMapRow")
ctx.currentVars = null
@@ -687,10 +685,9 @@ case class HashAggregateExec(
bufferSchema.toAttributes.zipWithIndex.map { case (attr, i) =>
BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable)
})
- val columnarRowCls = classOf[ColumnarRow].getName
s"""
|while ($iterTermForFastHashMap.hasNext()) {
- | $columnarRowCls $row = ($columnarRowCls) $iterTermForFastHashMap.next();
+ | InternalRow $row = (InternalRow) $iterTermForFastHashMap.next();
| ${generateKeyRow.code}
| ${generateBufferRow.code}
| $outputFunc(${generateKeyRow.value}, ${generateBufferRow.value});
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
index 44ba539ebf7c2..f04cd48072f17 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
@@ -17,9 +17,10 @@
package org.apache.spark.sql.execution.aggregate
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
-import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnarRow, MutableColumnarRow, OnHeapColumnVector}
+import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, MutableColumnarRow, OnHeapColumnVector}
import org.apache.spark.sql.types._
/**
@@ -231,7 +232,7 @@ class VectorizedHashMapGenerator(
protected def generateRowIterator(): String = {
s"""
- |public java.util.Iterator<${classOf[ColumnarRow].getName}> rowIterator() {
+ |public java.util.Iterator<${classOf[InternalRow].getName}> rowIterator() {
| batch.setNumRows(numRows);
| return batch.rowIterator();
|}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala
index 2f09757aa341c..341ade1a5c613 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala
@@ -35,7 +35,7 @@ private[columnar] trait NullableColumnAccessor extends ColumnAccessor {
nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1
pos = 0
- underlyingBuffer.position(underlyingBuffer.position + 4 + nullCount * 4)
+ underlyingBuffer.position(underlyingBuffer.position() + 4 + nullCount * 4)
super.initialize()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala
index bf00ad997c76e..79dcf3a6105ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala
@@ -112,7 +112,7 @@ private[columnar] case object PassThrough extends CompressionScheme {
var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else capacity
var pos = 0
var seenNulls = 0
- var bufferPos = buffer.position
+ var bufferPos = buffer.position()
while (pos < capacity) {
if (pos != nextNullIndex) {
val len = nextNullIndex - pos
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
index 75c42213db3c8..f7471cd7debce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
@@ -36,6 +36,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
@@ -124,7 +125,7 @@ class OrcFileFormat
true
}
- override def buildReader(
+ override def buildReaderWithPartitionValues(
sparkSession: SparkSession,
dataSchema: StructType,
partitionSchema: StructType,
@@ -167,9 +168,17 @@ class OrcFileFormat
val iter = new RecordReaderIterator[OrcStruct](orcRecordReader)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close()))
- val unsafeProjection = UnsafeProjection.create(requiredSchema)
+ val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes
+ val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema)
val deserializer = new OrcDeserializer(dataSchema, requiredSchema, requestedColIds)
- iter.map(value => unsafeProjection(deserializer.deserialize(value)))
+
+ if (partitionSchema.length == 0) {
+ iter.map(value => unsafeProjection(deserializer.deserialize(value)))
+ } else {
+ val joinedRow = new JoinedRow()
+ iter.map(value =>
+ unsafeProjection(joinedRow(deserializer.deserialize(value), file.partitionValues)))
+ }
}
}
}
diff --git a/sql/core/src/test/resources/test-data/comments.csv b/sql/core/src/test/resources/test-data/comments.csv
index 6275be7285b36..c0ace46db8c00 100644
--- a/sql/core/src/test/resources/test-data/comments.csv
+++ b/sql/core/src/test/resources/test-data/comments.csv
@@ -4,3 +4,4 @@
6,7,8,9,0,2015-08-21 16:58:01
~0,9,8,7,6,2015-08-22 17:59:02
1,2,3,4,5,2015-08-23 18:00:42
+~ comment in last line to test SPARK-22516 - do not add empty line at the end of this file!
\ No newline at end of file
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 8ddddbeee598f..5e077285ade55 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2775,32 +2775,4 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}
}
-
- test("SPARK-21791 ORC should support column names with dot") {
- val orc = classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName
- withTempDir { dir =>
- val path = new File(dir, "orc").getCanonicalPath
- Seq(Some(1), None).toDF("col.dots").write.format(orc).save(path)
- assert(spark.read.format(orc).load(path).collect().length == 2)
- }
- }
-
- test("SPARK-20728 Make ORCFileFormat configurable between sql/hive and sql/core") {
- withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> "hive") {
- val e = intercept[AnalysisException] {
- sql("CREATE TABLE spark_20728(a INT) USING ORC")
- }
- assert(e.message.contains("Hive built-in ORC data source must be used with Hive support"))
- }
-
- withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> "native") {
- withTable("spark_20728") {
- sql("CREATE TABLE spark_20728(a INT) USING ORC")
- val fileFormat = sql("SELECT * FROM spark_20728").queryExecution.analyzed.collectFirst {
- case l: LogicalRelation => l.relation.asInstanceOf[HadoopFsRelation].fileFormat.getClass
- }
- assert(fileFormat == Some(classOf[OrcFileFormat]))
- }
- }
- }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index e439699605abb..4fe45420b4e77 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -483,18 +483,21 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("commented lines in CSV data") {
- val results = spark.read
- .format("csv")
- .options(Map("comment" -> "~", "header" -> "false"))
- .load(testFile(commentsFile))
- .collect()
+ Seq("false", "true").foreach { multiLine =>
- val expected =
- Seq(Seq("1", "2", "3", "4", "5.01", "2015-08-20 15:57:00"),
- Seq("6", "7", "8", "9", "0", "2015-08-21 16:58:01"),
- Seq("1", "2", "3", "4", "5", "2015-08-23 18:00:42"))
+ val results = spark.read
+ .format("csv")
+ .options(Map("comment" -> "~", "header" -> "false", "multiLine" -> multiLine))
+ .load(testFile(commentsFile))
+ .collect()
- assert(results.toSeq.map(_.toSeq) === expected)
+ val expected =
+ Seq(Seq("1", "2", "3", "4", "5.01", "2015-08-20 15:57:00"),
+ Seq("6", "7", "8", "9", "0", "2015-08-21 16:58:01"),
+ Seq("1", "2", "3", "4", "5", "2015-08-23 18:00:42"))
+
+ assert(results.toSeq.map(_.toSeq) === expected)
+ }
}
test("inferring schema with commented lines in CSV data") {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala
similarity index 87%
rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala
rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala
index de6f0d67f1734..a5f6b68ee862e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala
@@ -15,25 +15,32 @@
* limitations under the License.
*/
-package org.apache.spark.sql.hive.orc
+package org.apache.spark.sql.execution.datasources.orc
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import scala.collection.JavaConverters._
-import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument}
+import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument}
-import org.apache.spark.sql.{Column, DataFrame, QueryTest}
+import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation}
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types._
/**
- * A test suite that tests ORC filter API based filter pushdown optimization.
+ * A test suite that tests Apache ORC filter API based filter pushdown optimization.
+ * OrcFilterSuite and HiveOrcFilterSuite is logically duplicated to provide the same test coverage.
+ * The difference are the packages containing 'Predicate' and 'SearchArgument' classes.
+ * - OrcFilterSuite uses 'org.apache.orc.storage.ql.io.sarg' package.
+ * - HiveOrcFilterSuite uses 'org.apache.hadoop.hive.ql.io.sarg' package.
*/
-class OrcFilterSuite extends QueryTest with OrcTest {
+class OrcFilterSuite extends OrcTest with SharedSQLContext {
+
private def checkFilterPredicate(
df: DataFrame,
predicate: Predicate,
@@ -55,7 +62,7 @@ class OrcFilterSuite extends QueryTest with OrcTest {
DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq)
assert(selectedFilters.nonEmpty, "No filter is pushed down")
- val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters.toArray)
+ val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters)
assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $selectedFilters")
checker(maybeFilter.get)
}
@@ -99,7 +106,7 @@ class OrcFilterSuite extends QueryTest with OrcTest {
DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq)
assert(selectedFilters.nonEmpty, "No filter is pushed down")
- val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters.toArray)
+ val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters)
assert(maybeFilter.isEmpty, s"Could generate filter predicate for $selectedFilters")
}
@@ -284,40 +291,27 @@ class OrcFilterSuite extends QueryTest with OrcTest {
test("filter pushdown - combinations with logical operators") {
withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df =>
- // Because `ExpressionTree` is not accessible at Hive 1.2.x, this should be checked
- // in string form in order to check filter creation including logical operators
- // such as `and`, `or` or `not`. So, this function uses `SearchArgument.toString()`
- // to produce string expression and then compare it to given string expression below.
- // This might have to be changed after Hive version is upgraded.
checkFilterPredicate(
'_1.isNotNull,
- """leaf-0 = (IS_NULL _1)
- |expr = (not leaf-0)""".stripMargin.trim
+ "leaf-0 = (IS_NULL _1), expr = (not leaf-0)"
)
checkFilterPredicate(
'_1 =!= 1,
- """leaf-0 = (IS_NULL _1)
- |leaf-1 = (EQUALS _1 1)
- |expr = (and (not leaf-0) (not leaf-1))""".stripMargin.trim
+ "leaf-0 = (IS_NULL _1), leaf-1 = (EQUALS _1 1), expr = (and (not leaf-0) (not leaf-1))"
)
checkFilterPredicate(
!('_1 < 4),
- """leaf-0 = (IS_NULL _1)
- |leaf-1 = (LESS_THAN _1 4)
- |expr = (and (not leaf-0) (not leaf-1))""".stripMargin.trim
+ "leaf-0 = (IS_NULL _1), leaf-1 = (LESS_THAN _1 4), expr = (and (not leaf-0) (not leaf-1))"
)
checkFilterPredicate(
'_1 < 2 || '_1 > 3,
- """leaf-0 = (LESS_THAN _1 2)
- |leaf-1 = (LESS_THAN_EQUALS _1 3)
- |expr = (or leaf-0 (not leaf-1))""".stripMargin.trim
+ "leaf-0 = (LESS_THAN _1 2), leaf-1 = (LESS_THAN_EQUALS _1 3), " +
+ "expr = (or leaf-0 (not leaf-1))"
)
checkFilterPredicate(
'_1 < 2 && '_1 > 3,
- """leaf-0 = (IS_NULL _1)
- |leaf-1 = (LESS_THAN _1 2)
- |leaf-2 = (LESS_THAN_EQUALS _1 3)
- |expr = (and (not leaf-0) leaf-1 (not leaf-2))""".stripMargin.trim
+ "leaf-0 = (IS_NULL _1), leaf-1 = (LESS_THAN _1 2), leaf-2 = (LESS_THAN_EQUALS _1 3), " +
+ "expr = (and (not leaf-0) leaf-1 (not leaf-2))"
)
}
}
@@ -344,4 +338,30 @@ class OrcFilterSuite extends QueryTest with OrcTest {
checkNoFilterPredicate('_1.isNotNull)
}
}
+
+ test("SPARK-12218 Converting conjunctions into ORC SearchArguments") {
+ import org.apache.spark.sql.sources._
+ // The `LessThan` should be converted while the `StringContains` shouldn't
+ val schema = new StructType(
+ Array(
+ StructField("a", IntegerType, nullable = true),
+ StructField("b", StringType, nullable = true)))
+ assertResult("leaf-0 = (LESS_THAN a 10), expr = leaf-0") {
+ OrcFilters.createFilter(schema, Array(
+ LessThan("a", 10),
+ StringContains("b", "prefix")
+ )).get.toString
+ }
+
+ // The `LessThan` should be converted while the whole inner `And` shouldn't
+ assertResult("leaf-0 = (LESS_THAN a 10), expr = leaf-0") {
+ OrcFilters.createFilter(schema, Array(
+ LessThan("a", 10),
+ Not(And(
+ GreaterThan("a", 1),
+ StringContains("b", "prefix")
+ ))
+ )).get.toString
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala
similarity index 82%
rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala
rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala
index d1ce3f1e2f058..d1911ea7f32a9 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala
@@ -15,19 +15,12 @@
* limitations under the License.
*/
-package org.apache.spark.sql.hive.orc
+package org.apache.spark.sql.execution.datasources.orc
import java.io.File
-import scala.reflect.ClassTag
-import scala.reflect.runtime.universe.TypeTag
-
-import org.apache.hadoop.hive.conf.HiveConf.ConfVars
-import org.scalatest.BeforeAndAfterAll
-
import org.apache.spark.sql._
-import org.apache.spark.sql.hive.test.TestHiveSingleton
-import org.apache.spark.util.Utils
+import org.apache.spark.sql.test.SharedSQLContext
// The data where the partitioning key exists only in the directory structure.
case class OrcParData(intField: Int, stringField: String)
@@ -35,28 +28,8 @@ case class OrcParData(intField: Int, stringField: String)
// The data that also includes the partitioning key
case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String)
-// TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot
-class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll {
- import spark._
- import spark.implicits._
-
- val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultStrVal
-
- def withTempDir(f: File => Unit): Unit = {
- val dir = Utils.createTempDir().getCanonicalFile
- try f(dir) finally Utils.deleteRecursively(dir)
- }
-
- def makeOrcFile[T <: Product: ClassTag: TypeTag](
- data: Seq[T], path: File): Unit = {
- data.toDF().write.mode("overwrite").orc(path.getCanonicalPath)
- }
-
-
- def makeOrcFile[T <: Product: ClassTag: TypeTag](
- df: DataFrame, path: File): Unit = {
- df.write.mode("overwrite").orc(path.getCanonicalPath)
- }
+abstract class OrcPartitionDiscoveryTest extends OrcTest {
+ val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__"
protected def withTempTable(tableName: String)(f: => Unit): Unit = {
try f finally spark.catalog.dropTempView(tableName)
@@ -90,7 +63,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with B
makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
}
- read.orc(base.getCanonicalPath).createOrReplaceTempView("t")
+ spark.read.orc(base.getCanonicalPath).createOrReplaceTempView("t")
withTempTable("t") {
checkAnswer(
@@ -137,7 +110,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with B
makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
}
- read.orc(base.getCanonicalPath).createOrReplaceTempView("t")
+ spark.read.orc(base.getCanonicalPath).createOrReplaceTempView("t")
withTempTable("t") {
checkAnswer(
@@ -186,8 +159,8 @@ class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with B
makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
}
- read
- .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName)
+ spark.read
+ .option("hive.exec.default.partition.name", defaultPartitionName)
.orc(base.getCanonicalPath)
.createOrReplaceTempView("t")
@@ -228,8 +201,8 @@ class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with B
makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
}
- read
- .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName)
+ spark.read
+ .option("hive.exec.default.partition.name", defaultPartitionName)
.orc(base.getCanonicalPath)
.createOrReplaceTempView("t")
@@ -253,3 +226,4 @@ class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with B
}
}
+class OrcPartitionDiscoverySuite extends OrcPartitionDiscoveryTest with SharedSQLContext
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala
similarity index 68%
rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala
index 1ffaf30311037..e00e057a18cc6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala
@@ -15,24 +15,27 @@
* limitations under the License.
*/
-package org.apache.spark.sql.hive.orc
+package org.apache.spark.sql.execution.datasources.orc
+import java.io.File
import java.nio.charset.StandardCharsets
import java.sql.Timestamp
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.hive.ql.io.orc.{OrcStruct, SparkOrcNewRecordReader}
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapreduce.{JobID, TaskAttemptID, TaskID, TaskType}
+import org.apache.hadoop.mapreduce.lib.input.FileSplit
+import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
+import org.apache.orc.{OrcConf, OrcFile}
import org.apache.orc.OrcConf.COMPRESS
-import org.scalatest.BeforeAndAfterAll
+import org.apache.orc.mapred.OrcStruct
+import org.apache.orc.mapreduce.OrcInputFormat
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, RecordReaderIterator}
-import org.apache.spark.sql.hive.HiveUtils
-import org.apache.spark.sql.hive.test.TestHive._
-import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.util.Utils
@@ -57,7 +60,8 @@ case class Contact(name: String, phone: String)
case class Person(name: String, age: Int, contacts: Seq[Contact])
-class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
+abstract class OrcQueryTest extends OrcTest {
+ import testImplicits._
test("Read/write All Types") {
val data = (0 to 255).map { i =>
@@ -73,7 +77,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
test("Read/write binary data") {
withOrcFile(BinaryData("test".getBytes(StandardCharsets.UTF_8)) :: Nil) { file =>
- val bytes = read.orc(file).head().getAs[Array[Byte]](0)
+ val bytes = spark.read.orc(file).head().getAs[Array[Byte]](0)
assert(new String(bytes, StandardCharsets.UTF_8) === "test")
}
}
@@ -91,7 +95,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
withOrcFile(data) { file =>
checkAnswer(
- read.orc(file),
+ spark.read.orc(file),
data.toDF().collect())
}
}
@@ -172,7 +176,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
withOrcFile(data) { file =>
checkAnswer(
- read.orc(file),
+ spark.read.orc(file),
Row(Seq.fill(5)(null): _*))
}
}
@@ -183,9 +187,13 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
spark.range(0, 10).write
.option(COMPRESS.getAttribute, "ZLIB")
.orc(file.getCanonicalPath)
- val expectedCompressionKind =
- OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression
- assert("ZLIB" === expectedCompressionKind.name())
+
+ val maybeOrcFile = file.listFiles().find(_.getName.endsWith(".zlib.orc"))
+ assert(maybeOrcFile.isDefined)
+
+ val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath)
+ val conf = OrcFile.readerOptions(new Configuration())
+ assert("ZLIB" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name)
}
// `compression` overrides `orc.compress`.
@@ -194,9 +202,13 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
.option("compression", "ZLIB")
.option(COMPRESS.getAttribute, "SNAPPY")
.orc(file.getCanonicalPath)
- val expectedCompressionKind =
- OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression
- assert("ZLIB" === expectedCompressionKind.name())
+
+ val maybeOrcFile = file.listFiles().find(_.getName.endsWith(".zlib.orc"))
+ assert(maybeOrcFile.isDefined)
+
+ val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath)
+ val conf = OrcFile.readerOptions(new Configuration())
+ assert("ZLIB" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name)
}
}
@@ -206,39 +218,39 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
spark.range(0, 10).write
.option("compression", "ZLIB")
.orc(file.getCanonicalPath)
- val expectedCompressionKind =
- OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression
- assert("ZLIB" === expectedCompressionKind.name())
+
+ val maybeOrcFile = file.listFiles().find(_.getName.endsWith(".zlib.orc"))
+ assert(maybeOrcFile.isDefined)
+
+ val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath)
+ val conf = OrcFile.readerOptions(new Configuration())
+ assert("ZLIB" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name)
}
withTempPath { file =>
spark.range(0, 10).write
.option("compression", "SNAPPY")
.orc(file.getCanonicalPath)
- val expectedCompressionKind =
- OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression
- assert("SNAPPY" === expectedCompressionKind.name())
+
+ val maybeOrcFile = file.listFiles().find(_.getName.endsWith(".snappy.orc"))
+ assert(maybeOrcFile.isDefined)
+
+ val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath)
+ val conf = OrcFile.readerOptions(new Configuration())
+ assert("SNAPPY" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name)
}
withTempPath { file =>
spark.range(0, 10).write
.option("compression", "NONE")
.orc(file.getCanonicalPath)
- val expectedCompressionKind =
- OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression
- assert("NONE" === expectedCompressionKind.name())
- }
- }
- // Following codec is not supported in Hive 1.2.1, ignore it now
- ignore("LZO compression options for writing to an ORC file not supported in Hive 1.2.1") {
- withTempPath { file =>
- spark.range(0, 10).write
- .option("compression", "LZO")
- .orc(file.getCanonicalPath)
- val expectedCompressionKind =
- OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression
- assert("LZO" === expectedCompressionKind.name())
+ val maybeOrcFile = file.listFiles().find(_.getName.endsWith(".orc"))
+ assert(maybeOrcFile.isDefined)
+
+ val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath)
+ val conf = OrcFile.readerOptions(new Configuration())
+ assert("NONE" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name)
}
}
@@ -256,22 +268,28 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
test("appending") {
val data = (0 until 10).map(i => (i, i.toString))
- createDataFrame(data).toDF("c1", "c2").createOrReplaceTempView("tmp")
+ spark.createDataFrame(data).toDF("c1", "c2").createOrReplaceTempView("tmp")
withOrcTable(data, "t") {
sql("INSERT INTO TABLE t SELECT * FROM tmp")
- checkAnswer(table("t"), (data ++ data).map(Row.fromTuple))
+ checkAnswer(spark.table("t"), (data ++ data).map(Row.fromTuple))
}
- sessionState.catalog.dropTable(TableIdentifier("tmp"), ignoreIfNotExists = true, purge = false)
+ spark.sessionState.catalog.dropTable(
+ TableIdentifier("tmp"),
+ ignoreIfNotExists = true,
+ purge = false)
}
test("overwriting") {
val data = (0 until 10).map(i => (i, i.toString))
- createDataFrame(data).toDF("c1", "c2").createOrReplaceTempView("tmp")
+ spark.createDataFrame(data).toDF("c1", "c2").createOrReplaceTempView("tmp")
withOrcTable(data, "t") {
sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp")
- checkAnswer(table("t"), data.map(Row.fromTuple))
+ checkAnswer(spark.table("t"), data.map(Row.fromTuple))
}
- sessionState.catalog.dropTable(TableIdentifier("tmp"), ignoreIfNotExists = true, purge = false)
+ spark.sessionState.catalog.dropTable(
+ TableIdentifier("tmp"),
+ ignoreIfNotExists = true,
+ purge = false)
}
test("self-join") {
@@ -334,60 +352,16 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
withTempPath { dir =>
val path = dir.getCanonicalPath
- spark.range(0, 10).select('id as "Acol").write.format("orc").save(path)
- spark.read.format("orc").load(path).schema("Acol")
+ spark.range(0, 10).select('id as "Acol").write.orc(path)
+ spark.read.orc(path).schema("Acol")
intercept[IllegalArgumentException] {
- spark.read.format("orc").load(path).schema("acol")
+ spark.read.orc(path).schema("acol")
}
- checkAnswer(spark.read.format("orc").load(path).select("acol").sort("acol"),
+ checkAnswer(spark.read.orc(path).select("acol").sort("acol"),
(0 until 10).map(Row(_)))
}
}
- test("SPARK-8501: Avoids discovery schema from empty ORC files") {
- withTempPath { dir =>
- val path = dir.getCanonicalPath
-
- withTable("empty_orc") {
- withTempView("empty", "single") {
- spark.sql(
- s"""CREATE TABLE empty_orc(key INT, value STRING)
- |STORED AS ORC
- |LOCATION '${dir.toURI}'
- """.stripMargin)
-
- val emptyDF = Seq.empty[(Int, String)].toDF("key", "value").coalesce(1)
- emptyDF.createOrReplaceTempView("empty")
-
- // This creates 1 empty ORC file with Hive ORC SerDe. We are using this trick because
- // Spark SQL ORC data source always avoids write empty ORC files.
- spark.sql(
- s"""INSERT INTO TABLE empty_orc
- |SELECT key, value FROM empty
- """.stripMargin)
-
- val errorMessage = intercept[AnalysisException] {
- spark.read.orc(path)
- }.getMessage
-
- assert(errorMessage.contains("Unable to infer schema for ORC"))
-
- val singleRowDF = Seq((0, "foo")).toDF("key", "value").coalesce(1)
- singleRowDF.createOrReplaceTempView("single")
-
- spark.sql(
- s"""INSERT INTO TABLE empty_orc
- |SELECT key, value FROM single
- """.stripMargin)
-
- val df = spark.read.orc(path)
- assert(df.schema === singleRowDF.schema.asNullable)
- checkAnswer(df, singleRowDF)
- }
- }
- }
- }
-
test("SPARK-10623 Enable ORC PPD") {
withTempPath { dir =>
withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") {
@@ -405,7 +379,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
}
// It needs to repartition data so that we can have several ORC files
// in order to skip stripes in ORC.
- createDataFrame(data).toDF("a", "b").repartition(10).write.orc(path)
+ spark.createDataFrame(data).toDF("a", "b").repartition(10).write.orc(path)
val df = spark.read.orc(path)
def checkPredicate(pred: Column, answer: Seq[Row]): Unit = {
@@ -440,77 +414,6 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
}
}
- test("Verify the ORC conversion parameter: CONVERT_METASTORE_ORC") {
- withTempView("single") {
- val singleRowDF = Seq((0, "foo")).toDF("key", "value")
- singleRowDF.createOrReplaceTempView("single")
-
- Seq("true", "false").foreach { orcConversion =>
- withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> orcConversion) {
- withTable("dummy_orc") {
- withTempPath { dir =>
- val path = dir.getCanonicalPath
- spark.sql(
- s"""
- |CREATE TABLE dummy_orc(key INT, value STRING)
- |STORED AS ORC
- |LOCATION '${dir.toURI}'
- """.stripMargin)
-
- spark.sql(
- s"""
- |INSERT INTO TABLE dummy_orc
- |SELECT key, value FROM single
- """.stripMargin)
-
- val df = spark.sql("SELECT * FROM dummy_orc WHERE key=0")
- checkAnswer(df, singleRowDF)
-
- val queryExecution = df.queryExecution
- if (orcConversion == "true") {
- queryExecution.analyzed.collectFirst {
- case _: LogicalRelation => ()
- }.getOrElse {
- fail(s"Expecting the query plan to convert orc to data sources, " +
- s"but got:\n$queryExecution")
- }
- } else {
- queryExecution.analyzed.collectFirst {
- case _: HiveTableRelation => ()
- }.getOrElse {
- fail(s"Expecting no conversion from orc to data sources, " +
- s"but got:\n$queryExecution")
- }
- }
- }
- }
- }
- }
- }
- }
-
- test("converted ORC table supports resolving mixed case field") {
- withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "true") {
- withTable("dummy_orc") {
- withTempPath { dir =>
- val df = spark.range(5).selectExpr("id", "id as valueField", "id as partitionValue")
- df.write
- .partitionBy("partitionValue")
- .mode("overwrite")
- .orc(dir.getAbsolutePath)
-
- spark.sql(s"""
- |create external table dummy_orc (id long, valueField long)
- |partitioned by (partitionValue int)
- |stored as orc
- |location "${dir.toURI}"""".stripMargin)
- spark.sql(s"msck repair table dummy_orc")
- checkAnswer(spark.sql("select * from dummy_orc"), df)
- }
- }
- }
- }
-
test("SPARK-14962 Produce correct results on array type with isnotnull") {
withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") {
val data = (0 until 10).map(i => Tuple1(Array(i)))
@@ -544,7 +447,8 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
withTempPath { file =>
// It needs to repartition data so that we can have several ORC files
// in order to skip stripes in ORC.
- createDataFrame(data).toDF("a").repartition(10).write.orc(file.getCanonicalPath)
+ spark.createDataFrame(data).toDF("a").repartition(10)
+ .write.orc(file.getCanonicalPath)
val df = spark.read.orc(file.getCanonicalPath).where("a == 2")
val actual = stripSparkFilter(df).count()
@@ -563,7 +467,8 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
withTempPath { file =>
// It needs to repartition data so that we can have several ORC files
// in order to skip stripes in ORC.
- createDataFrame(data).toDF("a").repartition(10).write.orc(file.getCanonicalPath)
+ spark.createDataFrame(data).toDF("a").repartition(10)
+ .write.orc(file.getCanonicalPath)
val df = spark.read.orc(file.getCanonicalPath).where(s"a == '$timeString'")
val actual = stripSparkFilter(df).count()
@@ -596,14 +501,18 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
test("Empty schema does not read data from ORC file") {
val data = Seq((1, 1), (2, 2))
withOrcFile(data) { path =>
- val requestedSchema = StructType(Nil)
val conf = new Configuration()
- val physicalSchema = OrcFileOperator.readSchema(Seq(path), Some(conf)).get
- OrcFileFormat.setRequiredColumns(conf, physicalSchema, requestedSchema)
- val maybeOrcReader = OrcFileOperator.getFileReader(path, Some(conf))
- assert(maybeOrcReader.isDefined)
- val orcRecordReader = new SparkOrcNewRecordReader(
- maybeOrcReader.get, conf, 0, maybeOrcReader.get.getContentLength)
+ conf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, "")
+ conf.setBoolean("hive.io.file.read.all.columns", false)
+
+ val orcRecordReader = {
+ val file = new File(path).listFiles().find(_.getName.endsWith(".snappy.orc")).head
+ val split = new FileSplit(new Path(file.toURI), 0, file.length, Array.empty[String])
+ val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
+ val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+ val oif = new OrcInputFormat[OrcStruct]
+ oif.createRecordReader(split, hadoopAttemptContext)
+ }
val recordsIterator = new RecordReaderIterator[OrcStruct](orcRecordReader)
try {
@@ -614,27 +523,88 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
}
}
- test("read from multiple orc input paths") {
- val path1 = Utils.createTempDir()
- val path2 = Utils.createTempDir()
- makeOrcFile((1 to 10).map(Tuple1.apply), path1)
- makeOrcFile((1 to 10).map(Tuple1.apply), path2)
- assertResult(20)(read.orc(path1.getCanonicalPath, path2.getCanonicalPath).count())
- }
+ test("read from multiple orc input paths") {
+ val path1 = Utils.createTempDir()
+ val path2 = Utils.createTempDir()
+ makeOrcFile((1 to 10).map(Tuple1.apply), path1)
+ makeOrcFile((1 to 10).map(Tuple1.apply), path2)
+ val df = spark.read.orc(path1.getCanonicalPath, path2.getCanonicalPath)
+ assert(df.count() == 20)
+ }
+}
+
+class OrcQuerySuite extends OrcQueryTest with SharedSQLContext {
+ import testImplicits._
+
+ test("LZO compression options for writing to an ORC file") {
+ withTempPath { file =>
+ spark.range(0, 10).write
+ .option("compression", "LZO")
+ .orc(file.getCanonicalPath)
+
+ val maybeOrcFile = file.listFiles().find(_.getName.endsWith(".lzo.orc"))
+ assert(maybeOrcFile.isDefined)
+
+ val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath)
+ val conf = OrcFile.readerOptions(new Configuration())
+ assert("LZO" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name)
+ }
+ }
+
+ test("Schema discovery on empty ORC files") {
+ // SPARK-8501 is fixed.
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+
+ withTable("empty_orc") {
+ withTempView("empty", "single") {
+ spark.sql(
+ s"""CREATE TABLE empty_orc(key INT, value STRING)
+ |USING ORC
+ |LOCATION '${dir.toURI}'
+ """.stripMargin)
+
+ val emptyDF = Seq.empty[(Int, String)].toDF("key", "value").coalesce(1)
+ emptyDF.createOrReplaceTempView("empty")
+
+ // This creates 1 empty ORC file with ORC SerDe. We are using this trick because
+ // Spark SQL ORC data source always avoids write empty ORC files.
+ spark.sql(
+ s"""INSERT INTO TABLE empty_orc
+ |SELECT key, value FROM empty
+ """.stripMargin)
+
+ val df = spark.read.orc(path)
+ assert(df.schema === emptyDF.schema.asNullable)
+ checkAnswer(df, emptyDF)
+ }
+ }
+ }
+ }
+
+ test("SPARK-21791 ORC should support column names with dot") {
+ withTempDir { dir =>
+ val path = new File(dir, "orc").getCanonicalPath
+ Seq(Some(1), None).toDF("col.dots").write.orc(path)
+ assert(spark.read.orc(path).collect().length == 2)
+ }
+ }
test("SPARK-20728 Make ORCFileFormat configurable between sql/hive and sql/core") {
- Seq(
- ("native", classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat]),
- ("hive", classOf[org.apache.spark.sql.hive.orc.OrcFileFormat])).foreach { case (i, format) =>
-
- withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> i) {
- withTable("spark_20728") {
- sql("CREATE TABLE spark_20728(a INT) USING ORC")
- val fileFormat = sql("SELECT * FROM spark_20728").queryExecution.analyzed.collectFirst {
- case l: LogicalRelation => l.relation.asInstanceOf[HadoopFsRelation].fileFormat.getClass
- }
- assert(fileFormat == Some(format))
+ withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> "hive") {
+ val e = intercept[AnalysisException] {
+ sql("CREATE TABLE spark_20728(a INT) USING ORC")
+ }
+ assert(e.message.contains("Hive built-in ORC data source must be used with Hive support"))
+ }
+
+ withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> "native") {
+ withTable("spark_20728") {
+ sql("CREATE TABLE spark_20728(a INT) USING ORC")
+ val fileFormat = sql("SELECT * FROM spark_20728").queryExecution.analyzed.collectFirst {
+ case l: LogicalRelation => l.relation.asInstanceOf[HadoopFsRelation].fileFormat.getClass
}
+ assert(fileFormat == Some(classOf[OrcFileFormat]))
}
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala
similarity index 63%
rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala
index 2a086be57f517..6f5f2fd795f74 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.hive.orc
+package org.apache.spark.sql.execution.datasources.orc
import java.io.File
import java.util.Locale
@@ -23,50 +23,30 @@ import java.util.Locale
import org.apache.orc.OrcConf.COMPRESS
import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.sql.{QueryTest, Row}
-import org.apache.spark.sql.execution.datasources.orc.OrcOptions
-import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.Row
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.sources._
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils
case class OrcData(intField: Int, stringField: String)
-abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll {
- import spark._
+abstract class OrcSuite extends OrcTest with BeforeAndAfterAll {
+ import testImplicits._
var orcTableDir: File = null
var orcTableAsDir: File = null
- override def beforeAll(): Unit = {
+ protected override def beforeAll(): Unit = {
super.beforeAll()
orcTableAsDir = Utils.createTempDir("orctests", "sparksql")
-
- // Hack: to prepare orc data files using hive external tables
orcTableDir = Utils.createTempDir("orctests", "sparksql")
- import org.apache.spark.sql.hive.test.TestHive.implicits._
sparkContext
.makeRDD(1 to 10)
.map(i => OrcData(i, s"part-$i"))
.toDF()
- .createOrReplaceTempView(s"orc_temp_table")
-
- sql(
- s"""CREATE EXTERNAL TABLE normal_orc(
- | intField INT,
- | stringField STRING
- |)
- |STORED AS ORC
- |LOCATION '${orcTableAsDir.toURI}'
- """.stripMargin)
-
- sql(
- s"""INSERT INTO TABLE normal_orc
- |SELECT intField, stringField FROM orc_temp_table
- """.stripMargin)
+ .createOrReplaceTempView("orc_temp_table")
}
test("create temporary orc table") {
@@ -152,56 +132,13 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA
}
test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") {
- val conf = sqlContext.sessionState.conf
+ val conf = spark.sessionState.conf
val option = new OrcOptions(Map(COMPRESS.getAttribute.toUpperCase(Locale.ROOT) -> "NONE"), conf)
assert(option.compressionCodec == "NONE")
}
- test("SPARK-19459/SPARK-18220: read char/varchar column written by Hive") {
- val location = Utils.createTempDir()
- val uri = location.toURI
- try {
- hiveClient.runSqlHive("USE default")
- hiveClient.runSqlHive(
- """
- |CREATE EXTERNAL TABLE hive_orc(
- | a STRING,
- | b CHAR(10),
- | c VARCHAR(10),
- | d ARRAY)
- |STORED AS orc""".stripMargin)
- // Hive throws an exception if I assign the location in the create table statement.
- hiveClient.runSqlHive(
- s"ALTER TABLE hive_orc SET LOCATION '$uri'")
- hiveClient.runSqlHive(
- """
- |INSERT INTO TABLE hive_orc
- |SELECT 'a', 'b', 'c', ARRAY(CAST('d' AS CHAR(3)))
- |FROM (SELECT 1) t""".stripMargin)
-
- // We create a different table in Spark using the same schema which points to
- // the same location.
- spark.sql(
- s"""
- |CREATE EXTERNAL TABLE spark_orc(
- | a STRING,
- | b CHAR(10),
- | c VARCHAR(10),
- | d ARRAY)
- |STORED AS orc
- |LOCATION '$uri'""".stripMargin)
- val result = Row("a", "b ", "c", Seq("d "))
- checkAnswer(spark.table("hive_orc"), result)
- checkAnswer(spark.table("spark_orc"), result)
- } finally {
- hiveClient.runSqlHive("DROP TABLE IF EXISTS hive_orc")
- hiveClient.runSqlHive("DROP TABLE IF EXISTS spark_orc")
- Utils.deleteRecursively(location)
- }
- }
-
test("SPARK-21839: Add SQL config for ORC compression") {
- val conf = sqlContext.sessionState.conf
+ val conf = spark.sessionState.conf
// Test if the default of spark.sql.orc.compression.codec is snappy
assert(new OrcOptions(Map.empty[String, String], conf).compressionCodec == "SNAPPY")
@@ -225,13 +162,28 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA
}
}
-class OrcSourceSuite extends OrcSuite {
- override def beforeAll(): Unit = {
+class OrcSourceSuite extends OrcSuite with SharedSQLContext {
+
+ protected override def beforeAll(): Unit = {
super.beforeAll()
+ sql(
+ s"""CREATE TABLE normal_orc(
+ | intField INT,
+ | stringField STRING
+ |)
+ |USING ORC
+ |LOCATION '${orcTableAsDir.toURI}'
+ """.stripMargin)
+
+ sql(
+ s"""INSERT INTO TABLE normal_orc
+ |SELECT intField, stringField FROM orc_temp_table
+ """.stripMargin)
+
spark.sql(
s"""CREATE TEMPORARY VIEW normal_orc_source
- |USING org.apache.spark.sql.hive.orc
+ |USING ORC
|OPTIONS (
| PATH '${new File(orcTableAsDir.getAbsolutePath).toURI}'
|)
@@ -239,43 +191,10 @@ class OrcSourceSuite extends OrcSuite {
spark.sql(
s"""CREATE TEMPORARY VIEW normal_orc_as_source
- |USING org.apache.spark.sql.hive.orc
+ |USING ORC
|OPTIONS (
| PATH '${new File(orcTableAsDir.getAbsolutePath).toURI}'
|)
""".stripMargin)
}
-
- test("SPARK-12218 Converting conjunctions into ORC SearchArguments") {
- // The `LessThan` should be converted while the `StringContains` shouldn't
- val schema = new StructType(
- Array(
- StructField("a", IntegerType, nullable = true),
- StructField("b", StringType, nullable = true)))
- assertResult(
- """leaf-0 = (LESS_THAN a 10)
- |expr = leaf-0
- """.stripMargin.trim
- ) {
- OrcFilters.createFilter(schema, Array(
- LessThan("a", 10),
- StringContains("b", "prefix")
- )).get.toString
- }
-
- // The `LessThan` should be converted while the whole inner `And` shouldn't
- assertResult(
- """leaf-0 = (LESS_THAN a 10)
- |expr = leaf-0
- """.stripMargin.trim
- ) {
- OrcFilters.createFilter(schema, Array(
- LessThan("a", 10),
- Not(And(
- GreaterThan("a", 1),
- StringContains("b", "prefix")
- ))
- )).get.toString
- }
- }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala
similarity index 72%
rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala
index a2f08c5ba72c6..d94cb850ed2a2 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala
@@ -15,20 +15,51 @@
* limitations under the License.
*/
-package org.apache.spark.sql.hive.orc
+package org.apache.spark.sql.execution.datasources.orc
import java.io.File
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
+import org.scalatest.BeforeAndAfterAll
+
import org.apache.spark.sql._
-import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.internal.SQLConf.ORC_IMPLEMENTATION
import org.apache.spark.sql.test.SQLTestUtils
-private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton {
+/**
+ * OrcTest
+ * -> OrcSuite
+ * -> OrcSourceSuite
+ * -> HiveOrcSourceSuite
+ * -> OrcQueryTests
+ * -> OrcQuerySuite
+ * -> HiveOrcQuerySuite
+ * -> OrcPartitionDiscoveryTest
+ * -> OrcPartitionDiscoverySuite
+ * -> HiveOrcPartitionDiscoverySuite
+ * -> OrcFilterSuite
+ * -> HiveOrcFilterSuite
+ */
+abstract class OrcTest extends QueryTest with SQLTestUtils with BeforeAndAfterAll {
import testImplicits._
+ val orcImp: String = "native"
+
+ private var originalConfORCImplementation = "native"
+
+ protected override def beforeAll(): Unit = {
+ super.beforeAll()
+ originalConfORCImplementation = conf.getConf(ORC_IMPLEMENTATION)
+ conf.setConf(ORC_IMPLEMENTATION, orcImp)
+ }
+
+ protected override def afterAll(): Unit = {
+ conf.setConf(ORC_IMPLEMENTATION, originalConfORCImplementation)
+ super.afterAll()
+ }
+
/**
* Writes `data` to a Orc file, which is then passed to `f` and will be deleted after `f`
* returns.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
index 0ae4f2d117609..c9c6bee513b53 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
@@ -751,11 +751,6 @@ class ColumnarBatchSuite extends SparkFunSuite {
c2.putDouble(1, 5.67)
val s = column.getStruct(0)
- assert(s.columns()(0).getInt(0) == 123)
- assert(s.columns()(0).getInt(1) == 456)
- assert(s.columns()(1).getDouble(0) == 3.45)
- assert(s.columns()(1).getDouble(1) == 5.67)
-
assert(s.getInt(0) == 123)
assert(s.getDouble(1) == 3.45)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala
index 933f4075bcc8a..752d3c193cc74 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala
@@ -37,4 +37,35 @@ class DataSourceV2OptionsSuite extends SparkFunSuite {
val options = new DataSourceV2Options(Map("foo" -> "bAr").asJava)
assert(options.get("foo").get == "bAr")
}
+
+ test("getInt") {
+ val options = new DataSourceV2Options(Map("numFOo" -> "1", "foo" -> "bar").asJava)
+ assert(options.getInt("numFOO", 10) == 1)
+ assert(options.getInt("numFOO2", 10) == 10)
+
+ intercept[NumberFormatException]{
+ options.getInt("foo", 1)
+ }
+ }
+
+ test("getBoolean") {
+ val options = new DataSourceV2Options(
+ Map("isFoo" -> "true", "isFOO2" -> "false", "foo" -> "bar").asJava)
+ assert(options.getBoolean("isFoo", false))
+ assert(!options.getBoolean("isFoo2", true))
+ assert(options.getBoolean("isBar", true))
+ assert(!options.getBoolean("isBar", false))
+ assert(!options.getBoolean("FOO", true))
+ }
+
+ test("getLong") {
+ val options = new DataSourceV2Options(Map("numFoo" -> "9223372036854775807",
+ "foo" -> "bar").asJava)
+ assert(options.getLong("numFOO", 0L) == 9223372036854775807L)
+ assert(options.getLong("numFoo2", -1L) == -1L)
+
+ intercept[NumberFormatException]{
+ options.getLong("foo", 0L)
+ }
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala
index d786a610f1535..3328400b214fb 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala
@@ -412,7 +412,9 @@ case class HiveScriptIOSchema (
propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames)
val properties = new Properties()
- properties.putAll(propsMap.asJava)
+ // Can not use properties.putAll(propsMap.asJava) in scala-2.12
+ // See https://github.com/scala/bug/issues/10418
+ propsMap.foreach { case (k, v) => properties.put(k, v) }
serde.initialize(null, properties)
serde
@@ -424,7 +426,9 @@ case class HiveScriptIOSchema (
recordReaderClass.map { klass =>
val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordReader]
val props = new Properties()
- props.putAll(outputSerdeProps.toMap.asJava)
+ // Can not use props.putAll(outputSerdeProps.toMap.asJava) in scala-2.12
+ // See https://github.com/scala/bug/issues/10418
+ outputSerdeProps.toMap.foreach { case (k, v) => props.put(k, v) }
instance.initialize(inputStream, conf, props)
instance
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcFilterSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcFilterSuite.scala
new file mode 100644
index 0000000000000..283037caf4a9b
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcFilterSuite.scala
@@ -0,0 +1,387 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.orc
+
+import java.nio.charset.StandardCharsets
+import java.sql.{Date, Timestamp}
+
+import scala.collection.JavaConverters._
+
+import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument}
+
+import org.apache.spark.sql.{Column, DataFrame}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.planning.PhysicalOperation
+import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation}
+import org.apache.spark.sql.execution.datasources.orc.OrcTest
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.types._
+
+/**
+ * A test suite that tests Hive ORC filter API based filter pushdown optimization.
+ */
+class HiveOrcFilterSuite extends OrcTest with TestHiveSingleton {
+
+ override val orcImp: String = "hive"
+
+ private def checkFilterPredicate(
+ df: DataFrame,
+ predicate: Predicate,
+ checker: (SearchArgument) => Unit): Unit = {
+ val output = predicate.collect { case a: Attribute => a }.distinct
+ val query = df
+ .select(output.map(e => Column(e)): _*)
+ .where(Column(predicate))
+
+ var maybeRelation: Option[HadoopFsRelation] = None
+ val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect {
+ case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _, _)) =>
+ maybeRelation = Some(orcRelation)
+ filters
+ }.flatten.reduceLeftOption(_ && _)
+ assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query")
+
+ val (_, selectedFilters, _) =
+ DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq)
+ assert(selectedFilters.nonEmpty, "No filter is pushed down")
+
+ val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters.toArray)
+ assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $selectedFilters")
+ checker(maybeFilter.get)
+ }
+
+ private def checkFilterPredicate
+ (predicate: Predicate, filterOperator: PredicateLeaf.Operator)
+ (implicit df: DataFrame): Unit = {
+ def checkComparisonOperator(filter: SearchArgument) = {
+ val operator = filter.getLeaves.asScala
+ assert(operator.map(_.getOperator).contains(filterOperator))
+ }
+ checkFilterPredicate(df, predicate, checkComparisonOperator)
+ }
+
+ private def checkFilterPredicate
+ (predicate: Predicate, stringExpr: String)
+ (implicit df: DataFrame): Unit = {
+ def checkLogicalOperator(filter: SearchArgument) = {
+ assert(filter.toString == stringExpr)
+ }
+ checkFilterPredicate(df, predicate, checkLogicalOperator)
+ }
+
+ private def checkNoFilterPredicate
+ (predicate: Predicate)
+ (implicit df: DataFrame): Unit = {
+ val output = predicate.collect { case a: Attribute => a }.distinct
+ val query = df
+ .select(output.map(e => Column(e)): _*)
+ .where(Column(predicate))
+
+ var maybeRelation: Option[HadoopFsRelation] = None
+ val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect {
+ case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _, _)) =>
+ maybeRelation = Some(orcRelation)
+ filters
+ }.flatten.reduceLeftOption(_ && _)
+ assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query")
+
+ val (_, selectedFilters, _) =
+ DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq)
+ assert(selectedFilters.nonEmpty, "No filter is pushed down")
+
+ val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters.toArray)
+ assert(maybeFilter.isEmpty, s"Could generate filter predicate for $selectedFilters")
+ }
+
+ test("filter pushdown - integer") {
+ withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df =>
+ checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL)
+
+ checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+
+ checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN)
+
+ checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+ checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN)
+ }
+ }
+
+ test("filter pushdown - long") {
+ withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit df =>
+ checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL)
+
+ checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+
+ checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN)
+
+ checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+ checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN)
+ }
+ }
+
+ test("filter pushdown - float") {
+ withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit df =>
+ checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL)
+
+ checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+
+ checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN)
+
+ checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+ checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN)
+ }
+ }
+
+ test("filter pushdown - double") {
+ withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit df =>
+ checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL)
+
+ checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+
+ checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN)
+
+ checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+ checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN)
+ }
+ }
+
+ test("filter pushdown - string") {
+ withOrcDataFrame((1 to 4).map(i => Tuple1(i.toString))) { implicit df =>
+ checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL)
+
+ checkFilterPredicate('_1 === "1", PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate('_1 <=> "1", PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+
+ checkFilterPredicate('_1 < "2", PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate('_1 > "3", PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 <= "1", PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 >= "4", PredicateLeaf.Operator.LESS_THAN)
+
+ checkFilterPredicate(Literal("1") === '_1, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate(Literal("1") <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+ checkFilterPredicate(Literal("2") > '_1, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate(Literal("3") < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal("1") >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal("4") <= '_1, PredicateLeaf.Operator.LESS_THAN)
+ }
+ }
+
+ test("filter pushdown - boolean") {
+ withOrcDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df =>
+ checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL)
+
+ checkFilterPredicate('_1 === true, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate('_1 <=> true, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+
+ checkFilterPredicate('_1 < true, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate('_1 > false, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 <= false, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 >= false, PredicateLeaf.Operator.LESS_THAN)
+
+ checkFilterPredicate(Literal(false) === '_1, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate(Literal(false) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+ checkFilterPredicate(Literal(false) > '_1, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate(Literal(true) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal(true) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal(true) <= '_1, PredicateLeaf.Operator.LESS_THAN)
+ }
+ }
+
+ test("filter pushdown - decimal") {
+ withOrcDataFrame((1 to 4).map(i => Tuple1.apply(BigDecimal.valueOf(i)))) { implicit df =>
+ checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL)
+
+ checkFilterPredicate('_1 === BigDecimal.valueOf(1), PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate('_1 <=> BigDecimal.valueOf(1), PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+
+ checkFilterPredicate('_1 < BigDecimal.valueOf(2), PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate('_1 > BigDecimal.valueOf(3), PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 <= BigDecimal.valueOf(1), PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 >= BigDecimal.valueOf(4), PredicateLeaf.Operator.LESS_THAN)
+
+ checkFilterPredicate(
+ Literal(BigDecimal.valueOf(1)) === '_1, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate(
+ Literal(BigDecimal.valueOf(1)) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+ checkFilterPredicate(
+ Literal(BigDecimal.valueOf(2)) > '_1, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate(
+ Literal(BigDecimal.valueOf(3)) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(
+ Literal(BigDecimal.valueOf(1)) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(
+ Literal(BigDecimal.valueOf(4)) <= '_1, PredicateLeaf.Operator.LESS_THAN)
+ }
+ }
+
+ test("filter pushdown - timestamp") {
+ val timeString = "2015-08-20 14:57:00"
+ val timestamps = (1 to 4).map { i =>
+ val milliseconds = Timestamp.valueOf(timeString).getTime + i * 3600
+ new Timestamp(milliseconds)
+ }
+ withOrcDataFrame(timestamps.map(Tuple1(_))) { implicit df =>
+ checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL)
+
+ checkFilterPredicate('_1 === timestamps(0), PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate('_1 <=> timestamps(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+
+ checkFilterPredicate('_1 < timestamps(1), PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate('_1 > timestamps(2), PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 <= timestamps(0), PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 >= timestamps(3), PredicateLeaf.Operator.LESS_THAN)
+
+ checkFilterPredicate(Literal(timestamps(0)) === '_1, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate(Literal(timestamps(0)) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+ checkFilterPredicate(Literal(timestamps(1)) > '_1, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate(Literal(timestamps(2)) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal(timestamps(0)) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal(timestamps(3)) <= '_1, PredicateLeaf.Operator.LESS_THAN)
+ }
+ }
+
+ test("filter pushdown - combinations with logical operators") {
+ withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df =>
+ // Because `ExpressionTree` is not accessible at Hive 1.2.x, this should be checked
+ // in string form in order to check filter creation including logical operators
+ // such as `and`, `or` or `not`. So, this function uses `SearchArgument.toString()`
+ // to produce string expression and then compare it to given string expression below.
+ // This might have to be changed after Hive version is upgraded.
+ checkFilterPredicate(
+ '_1.isNotNull,
+ """leaf-0 = (IS_NULL _1)
+ |expr = (not leaf-0)""".stripMargin.trim
+ )
+ checkFilterPredicate(
+ '_1 =!= 1,
+ """leaf-0 = (IS_NULL _1)
+ |leaf-1 = (EQUALS _1 1)
+ |expr = (and (not leaf-0) (not leaf-1))""".stripMargin.trim
+ )
+ checkFilterPredicate(
+ !('_1 < 4),
+ """leaf-0 = (IS_NULL _1)
+ |leaf-1 = (LESS_THAN _1 4)
+ |expr = (and (not leaf-0) (not leaf-1))""".stripMargin.trim
+ )
+ checkFilterPredicate(
+ '_1 < 2 || '_1 > 3,
+ """leaf-0 = (LESS_THAN _1 2)
+ |leaf-1 = (LESS_THAN_EQUALS _1 3)
+ |expr = (or leaf-0 (not leaf-1))""".stripMargin.trim
+ )
+ checkFilterPredicate(
+ '_1 < 2 && '_1 > 3,
+ """leaf-0 = (IS_NULL _1)
+ |leaf-1 = (LESS_THAN _1 2)
+ |leaf-2 = (LESS_THAN_EQUALS _1 3)
+ |expr = (and (not leaf-0) leaf-1 (not leaf-2))""".stripMargin.trim
+ )
+ }
+ }
+
+ test("no filter pushdown - non-supported types") {
+ implicit class IntToBinary(int: Int) {
+ def b: Array[Byte] = int.toString.getBytes(StandardCharsets.UTF_8)
+ }
+ // ArrayType
+ withOrcDataFrame((1 to 4).map(i => Tuple1(Array(i)))) { implicit df =>
+ checkNoFilterPredicate('_1.isNull)
+ }
+ // BinaryType
+ withOrcDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df =>
+ checkNoFilterPredicate('_1 <=> 1.b)
+ }
+ // DateType
+ val stringDate = "2015-01-01"
+ withOrcDataFrame(Seq(Tuple1(Date.valueOf(stringDate)))) { implicit df =>
+ checkNoFilterPredicate('_1 === Date.valueOf(stringDate))
+ }
+ // MapType
+ withOrcDataFrame((1 to 4).map(i => Tuple1(Map(i -> i)))) { implicit df =>
+ checkNoFilterPredicate('_1.isNotNull)
+ }
+ }
+
+ test("SPARK-12218 Converting conjunctions into ORC SearchArguments") {
+ import org.apache.spark.sql.sources._
+ // The `LessThan` should be converted while the `StringContains` shouldn't
+ val schema = new StructType(
+ Array(
+ StructField("a", IntegerType, nullable = true),
+ StructField("b", StringType, nullable = true)))
+ assertResult(
+ """leaf-0 = (LESS_THAN a 10)
+ |expr = leaf-0
+ """.stripMargin.trim
+ ) {
+ OrcFilters.createFilter(schema, Array(
+ LessThan("a", 10),
+ StringContains("b", "prefix")
+ )).get.toString
+ }
+
+ // The `LessThan` should be converted while the whole inner `And` shouldn't
+ assertResult(
+ """leaf-0 = (LESS_THAN a 10)
+ |expr = leaf-0
+ """.stripMargin.trim
+ ) {
+ OrcFilters.createFilter(schema, Array(
+ LessThan("a", 10),
+ Not(And(
+ GreaterThan("a", 1),
+ StringContains("b", "prefix")
+ ))
+ )).get.toString
+ }
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcPartitionDiscoverySuite.scala
new file mode 100644
index 0000000000000..ab9b492f347c6
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcPartitionDiscoverySuite.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.orc
+
+import org.apache.spark.sql.execution.datasources.orc.OrcPartitionDiscoveryTest
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+
+class HiveOrcPartitionDiscoverySuite extends OrcPartitionDiscoveryTest with TestHiveSingleton {
+ override val orcImp: String = "hive"
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala
new file mode 100644
index 0000000000000..7244c369bd3f4
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.orc
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
+import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
+import org.apache.spark.sql.execution.datasources.orc.OrcQueryTest
+import org.apache.spark.sql.hive.HiveUtils
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.internal.SQLConf
+
+class HiveOrcQuerySuite extends OrcQueryTest with TestHiveSingleton {
+ import testImplicits._
+
+ override val orcImp: String = "hive"
+
+ test("SPARK-8501: Avoids discovery schema from empty ORC files") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+
+ withTable("empty_orc") {
+ withTempView("empty", "single") {
+ spark.sql(
+ s"""CREATE TABLE empty_orc(key INT, value STRING)
+ |STORED AS ORC
+ |LOCATION '${dir.toURI}'
+ """.stripMargin)
+
+ val emptyDF = Seq.empty[(Int, String)].toDF("key", "value").coalesce(1)
+ emptyDF.createOrReplaceTempView("empty")
+
+ // This creates 1 empty ORC file with Hive ORC SerDe. We are using this trick because
+ // Spark SQL ORC data source always avoids write empty ORC files.
+ spark.sql(
+ s"""INSERT INTO TABLE empty_orc
+ |SELECT key, value FROM empty
+ """.stripMargin)
+
+ val errorMessage = intercept[AnalysisException] {
+ spark.read.orc(path)
+ }.getMessage
+
+ assert(errorMessage.contains("Unable to infer schema for ORC"))
+
+ val singleRowDF = Seq((0, "foo")).toDF("key", "value").coalesce(1)
+ singleRowDF.createOrReplaceTempView("single")
+
+ spark.sql(
+ s"""INSERT INTO TABLE empty_orc
+ |SELECT key, value FROM single
+ """.stripMargin)
+
+ val df = spark.read.orc(path)
+ assert(df.schema === singleRowDF.schema.asNullable)
+ checkAnswer(df, singleRowDF)
+ }
+ }
+ }
+ }
+
+ test("Verify the ORC conversion parameter: CONVERT_METASTORE_ORC") {
+ withTempView("single") {
+ val singleRowDF = Seq((0, "foo")).toDF("key", "value")
+ singleRowDF.createOrReplaceTempView("single")
+
+ Seq("true", "false").foreach { orcConversion =>
+ withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> orcConversion) {
+ withTable("dummy_orc") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ spark.sql(
+ s"""
+ |CREATE TABLE dummy_orc(key INT, value STRING)
+ |STORED AS ORC
+ |LOCATION '${dir.toURI}'
+ """.stripMargin)
+
+ spark.sql(
+ s"""
+ |INSERT INTO TABLE dummy_orc
+ |SELECT key, value FROM single
+ """.stripMargin)
+
+ val df = spark.sql("SELECT * FROM dummy_orc WHERE key=0")
+ checkAnswer(df, singleRowDF)
+
+ val queryExecution = df.queryExecution
+ if (orcConversion == "true") {
+ queryExecution.analyzed.collectFirst {
+ case _: LogicalRelation => ()
+ }.getOrElse {
+ fail(s"Expecting the query plan to convert orc to data sources, " +
+ s"but got:\n$queryExecution")
+ }
+ } else {
+ queryExecution.analyzed.collectFirst {
+ case _: HiveTableRelation => ()
+ }.getOrElse {
+ fail(s"Expecting no conversion from orc to data sources, " +
+ s"but got:\n$queryExecution")
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ test("converted ORC table supports resolving mixed case field") {
+ withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "true") {
+ withTable("dummy_orc") {
+ withTempPath { dir =>
+ val df = spark.range(5).selectExpr("id", "id as valueField", "id as partitionValue")
+ df.write
+ .partitionBy("partitionValue")
+ .mode("overwrite")
+ .orc(dir.getAbsolutePath)
+
+ spark.sql(s"""
+ |create external table dummy_orc (id long, valueField long)
+ |partitioned by (partitionValue int)
+ |stored as orc
+ |location "${dir.toURI}"""".stripMargin)
+ spark.sql(s"msck repair table dummy_orc")
+ checkAnswer(spark.sql("select * from dummy_orc"), df)
+ }
+ }
+ }
+ }
+
+ test("SPARK-20728 Make ORCFileFormat configurable between sql/hive and sql/core") {
+ Seq(
+ ("native", classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat]),
+ ("hive", classOf[org.apache.spark.sql.hive.orc.OrcFileFormat])).foreach {
+ case (orcImpl, format) =>
+ withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> orcImpl) {
+ withTable("spark_20728") {
+ sql("CREATE TABLE spark_20728(a INT) USING ORC")
+ val fileFormat = sql("SELECT * FROM spark_20728").queryExecution.analyzed.collectFirst {
+ case l: LogicalRelation =>
+ l.relation.asInstanceOf[HadoopFsRelation].fileFormat.getClass
+ }
+ assert(fileFormat == Some(format))
+ }
+ }
+ }
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala
new file mode 100644
index 0000000000000..17b7d8cfe127e
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.orc
+
+import java.io.File
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.execution.datasources.orc.OrcSuite
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.util.Utils
+
+class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton {
+
+ override val orcImp: String = "hive"
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ sql(
+ s"""CREATE EXTERNAL TABLE normal_orc(
+ | intField INT,
+ | stringField STRING
+ |)
+ |STORED AS ORC
+ |LOCATION '${orcTableAsDir.toURI}'
+ """.stripMargin)
+
+ sql(
+ s"""INSERT INTO TABLE normal_orc
+ |SELECT intField, stringField FROM orc_temp_table
+ """.stripMargin)
+
+ spark.sql(
+ s"""CREATE TEMPORARY VIEW normal_orc_source
+ |USING org.apache.spark.sql.hive.orc
+ |OPTIONS (
+ | PATH '${new File(orcTableAsDir.getAbsolutePath).toURI}'
+ |)
+ """.stripMargin)
+
+ spark.sql(
+ s"""CREATE TEMPORARY VIEW normal_orc_as_source
+ |USING org.apache.spark.sql.hive.orc
+ |OPTIONS (
+ | PATH '${new File(orcTableAsDir.getAbsolutePath).toURI}'
+ |)
+ """.stripMargin)
+ }
+
+ test("SPARK-19459/SPARK-18220: read char/varchar column written by Hive") {
+ val location = Utils.createTempDir()
+ val uri = location.toURI
+ try {
+ hiveClient.runSqlHive("USE default")
+ hiveClient.runSqlHive(
+ """
+ |CREATE EXTERNAL TABLE hive_orc(
+ | a STRING,
+ | b CHAR(10),
+ | c VARCHAR(10),
+ | d ARRAY)
+ |STORED AS orc""".stripMargin)
+ // Hive throws an exception if I assign the location in the create table statement.
+ hiveClient.runSqlHive(
+ s"ALTER TABLE hive_orc SET LOCATION '$uri'")
+ hiveClient.runSqlHive(
+ """
+ |INSERT INTO TABLE hive_orc
+ |SELECT 'a', 'b', 'c', ARRAY(CAST('d' AS CHAR(3)))
+ |FROM (SELECT 1) t""".stripMargin)
+
+ // We create a different table in Spark using the same schema which points to
+ // the same location.
+ spark.sql(
+ s"""
+ |CREATE EXTERNAL TABLE spark_orc(
+ | a STRING,
+ | b CHAR(10),
+ | c VARCHAR(10),
+ | d ARRAY)
+ |STORED AS orc
+ |LOCATION '$uri'""".stripMargin)
+ val result = Row("a", "b ", "c", Seq("d "))
+ checkAnswer(spark.table("hive_orc"), result)
+ checkAnswer(spark.table("spark_orc"), result)
+ } finally {
+ hiveClient.runSqlHive("DROP TABLE IF EXISTS hive_orc")
+ hiveClient.runSqlHive("DROP TABLE IF EXISTS spark_orc")
+ Utils.deleteRecursively(location)
+ }
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala
index ba0a7605da71c..f87162f94c01a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala
@@ -30,7 +30,8 @@ import org.apache.spark.sql.types._
class OrcHadoopFsRelationSuite extends HadoopFsRelationTest {
import testImplicits._
- override val dataSourceName: String = classOf[OrcFileFormat].getCanonicalName
+ override val dataSourceName: String =
+ classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName
// ORC does not play well with NullType and UDT.
override protected def supportsDataType(dataType: DataType): Boolean = dataType match {
@@ -116,3 +117,8 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest {
}
}
}
+
+class HiveOrcHadoopFsRelationSuite extends OrcHadoopFsRelationSuite {
+ override val dataSourceName: String =
+ classOf[org.apache.spark.sql.hive.orc.OrcFileFormat].getCanonicalName
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala
index b2ec33e82ddaa..b22bbb79a5cc9 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala
@@ -98,7 +98,7 @@ class RawNetworkReceiver(host: String, port: Int, storageLevel: StorageLevel)
/** Read a buffer fully from a given Channel */
private def readFully(channel: ReadableByteChannel, dest: ByteBuffer) {
- while (dest.position < dest.limit) {
+ while (dest.position() < dest.limit()) {
if (channel.read(dest) == -1) {
throw new EOFException("End of channel")
}