diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java
index 14caaeaedbe2b..f18d00359c90c 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java
@@ -287,7 +287,7 @@ public final CalendarInterval getInterval(int rowId) {
/**
* @return child [[ColumnVector]] at the given ordinal.
*/
- protected abstract ColumnVector getChild(int ordinal);
+ public abstract ColumnVector getChild(int ordinal);
/**
* Data type for this column.
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java
index 9f917ea11d72a..a2feac869ece6 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java
@@ -31,7 +31,7 @@
* the entire data loading process.
*/
@Evolving
-public final class ColumnarBatch {
+public final class ColumnarBatch implements AutoCloseable {
private int numRows;
private final ColumnVector[] columns;
@@ -42,6 +42,7 @@ public final class ColumnarBatch {
* Called to close all the columns in this batch. It is not valid to access the data after
* calling this. This must be called at the end to clean up memory allocations.
*/
+ @Override
public void close() {
for (ColumnVector c: columns) {
c.close();
@@ -110,7 +111,17 @@ public InternalRow getRow(int rowId) {
}
public ColumnarBatch(ColumnVector[] columns) {
+ this(columns, 0);
+ }
+
+ /**
+ * Create a new batch from existing column vectors.
+ * @param columns The columns of this batch
+ * @param numRows The number of rows in this batch
+ */
+ public ColumnarBatch(ColumnVector[] columns, int numRows) {
this.columns = columns;
+ this.numRows = numRows;
this.row = new ColumnarBatchRow(columns);
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
index 4f5e72c1326ac..14fac72847af2 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
@@ -604,7 +604,10 @@ public final int appendArray(int length) {
*/
public final int appendStruct(boolean isNull) {
if (isNull) {
- appendNull();
+ // This is the same as appendNull but without the assertion for struct types
+ reserve(elementsAppended + 1);
+ putNull(elementsAppended);
+ elementsAppended++;
for (WritableColumnVector c: childColumns) {
if (c.type instanceof StructType) {
c.appendStruct(true);
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
index 66becf37cb6a4..1c2bf9e7c2a57 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.ColumnarRule
/**
* :: Experimental ::
@@ -42,6 +43,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
*
Planning Strategies.
* Customized Parser.
* (External) Catalog listeners.
+ * Columnar Rules.
*
*
* The extensions can be used by calling `withExtensions` on the [[SparkSession.Builder]], for
@@ -93,6 +95,23 @@ class SparkSessionExtensions {
type StrategyBuilder = SparkSession => Strategy
type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface
type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder)
+ type ColumnarRuleBuilder = SparkSession => ColumnarRule
+
+ private[this] val columnarRuleBuilders = mutable.Buffer.empty[ColumnarRuleBuilder]
+
+ /**
+ * Build the override rules for columnar execution.
+ */
+ private[sql] def buildColumnarRules(session: SparkSession): Seq[ColumnarRule] = {
+ columnarRuleBuilders.map(_.apply(session))
+ }
+
+ /**
+ * Inject a rule that can override the columnar execution of an executor.
+ */
+ def injectColumnar(builder: ColumnarRuleBuilder): Unit = {
+ columnarRuleBuilders += builder
+ }
private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
new file mode 100644
index 0000000000000..315eba6635aac
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
@@ -0,0 +1,534 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.{broadcast, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder, SpecializedGetters, UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector, WritableColumnVector}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
+
+/**
+ * Holds a user defined rule that can be used to inject columnar implementations of various
+ * operators in the plan. The [[preColumnarTransitions]] [[Rule]] can be used to replace
+ * [[SparkPlan]] instances with versions that support a columnar implementation. After this
+ * Spark will insert any transitions necessary. This includes transitions from row to columnar
+ * [[RowToColumnarExec]] and from columnar to row [[ColumnarToRowExec]]. At this point the
+ * [[postColumnarTransitions]] [[Rule]] is called to allow replacing any of the implementations
+ * of the transitions or doing cleanup of the plan, like inserting stages to build larger batches
+ * for more efficient processing, or stages that transition the data to/from an accelerator's
+ * memory.
+ */
+class ColumnarRule {
+ def preColumnarTransitions: Rule[SparkPlan] = plan => plan
+ def postColumnarTransitions: Rule[SparkPlan] = plan => plan
+}
+
+/**
+ * Provides a common executor to translate an [[RDD]] of [[ColumnarBatch]] into an [[RDD]] of
+ * [[InternalRow]]. This is inserted whenever such a transition is determined to be needed.
+ *
+ * The implementation is based off of similar implementations in [[ColumnarBatchScan]],
+ * [[org.apache.spark.sql.execution.python.ArrowEvalPythonExec]], and
+ * [[MapPartitionsInRWithArrowExec]]. Eventually this should replace those implementations.
+ */
+case class ColumnarToRowExec(child: SparkPlan)
+ extends UnaryExecNode with CodegenSupport {
+
+ override def output: Seq[Attribute] = child.output
+
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+
+ override lazy val metrics: Map[String, SQLMetric] = Map(
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
+ "numInputBatches" -> SQLMetrics.createMetric(sparkContext, "number of input batches"),
+ "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")
+ )
+
+ override def doExecute(): RDD[InternalRow] = {
+ val numOutputRows = longMetric("numOutputRows")
+ val numInputBatches = longMetric("numInputBatches")
+ val scanTime = longMetric("scanTime")
+ // UnsafeProjection is not serializable so do it on the executor side, which is why it is lazy
+ @transient lazy val outputProject = UnsafeProjection.create(output, output)
+ val batches = child.executeColumnar()
+ batches.flatMap(batch => {
+ val batchStartNs = System.nanoTime()
+ numInputBatches += 1
+ // In order to match the numOutputRows metric in the generated code we update
+ // numOutputRows for each batch. This is less accurate than doing it at output
+ // because it will over count the number of rows output in the case of a limit,
+ // but it is more efficient.
+ numOutputRows += batch.numRows()
+ val ret = batch.rowIterator().asScala
+ scanTime += ((System.nanoTime() - batchStartNs) / (1000 * 1000))
+ ret.map(outputProject)
+ })
+ }
+
+ /**
+ * Generate [[ColumnVector]] expressions for our parent to consume as rows.
+ * This is called once per [[ColumnVector]] in the batch.
+ *
+ * This code came unchanged from [[ColumnarBatchScan]] and will hopefully replace it
+ * at some point.
+ */
+ private def genCodeColumnVector(
+ ctx: CodegenContext,
+ columnVar: String,
+ ordinal: String,
+ dataType: DataType,
+ nullable: Boolean): ExprCode = {
+ val javaType = CodeGenerator.javaType(dataType)
+ val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal)
+ val isNullVar = if (nullable) {
+ JavaCode.isNullVariable(ctx.freshName("isNull"))
+ } else {
+ FalseLiteral
+ }
+ val valueVar = ctx.freshName("value")
+ val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]"
+ val code = code"${ctx.registerComment(str)}" + (if (nullable) {
+ code"""
+ boolean $isNullVar = $columnVar.isNullAt($ordinal);
+ $javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value);
+ """
+ } else {
+ code"$javaType $valueVar = $value;"
+ })
+ ExprCode(code, isNullVar, JavaCode.variable(valueVar, dataType))
+ }
+
+ /**
+ * Produce code to process the input iterator as [[ColumnarBatch]]es.
+ * This produces an [[org.apache.spark.sql.catalyst.expressions.UnsafeRow]] for each row in
+ * each batch.
+ *
+ * This code came almost completely unchanged from [[ColumnarBatchScan]] and will
+ * hopefully replace it at some point.
+ */
+ override protected def doProduce(ctx: CodegenContext): String = {
+ // PhysicalRDD always just has one input
+ val input = ctx.addMutableState("scala.collection.Iterator", "input",
+ v => s"$v = inputs[0];")
+
+ // metrics
+ val numOutputRows = metricTerm(ctx, "numOutputRows")
+ val numInputBatches = metricTerm(ctx, "numInputBatches")
+ val scanTimeMetric = metricTerm(ctx, "scanTime")
+ val scanTimeTotalNs =
+ ctx.addMutableState(CodeGenerator.JAVA_LONG, "scanTime") // init as scanTime = 0
+
+ val columnarBatchClz = classOf[ColumnarBatch].getName
+ val batch = ctx.addMutableState(columnarBatchClz, "batch")
+
+ val idx = ctx.addMutableState(CodeGenerator.JAVA_INT, "batchIdx") // init as batchIdx = 0
+ val columnVectorClzs = child.vectorTypes.getOrElse(
+ Seq.fill(output.indices.size)(classOf[ColumnVector].getName))
+ val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map {
+ case (columnVectorClz, i) =>
+ val name = ctx.addMutableState(columnVectorClz, s"colInstance$i")
+ (name, s"$name = ($columnVectorClz) $batch.column($i);")
+ }.unzip
+
+ val nextBatch = ctx.freshName("nextBatch")
+ val nextBatchFuncName = ctx.addNewFunction(nextBatch,
+ s"""
+ |private void $nextBatch() throws java.io.IOException {
+ | long getBatchStart = System.nanoTime();
+ | if ($input.hasNext()) {
+ | $batch = ($columnarBatchClz)$input.next();
+ | $numOutputRows.add($batch.numRows());
+ | $idx = 0;
+ | ${columnAssigns.mkString("", "\n", "\n")}
+ | ${numInputBatches}.add(1);
+ | }
+ | $scanTimeTotalNs += System.nanoTime() - getBatchStart;
+ |}""".stripMargin)
+
+ ctx.currentVars = null
+ val rowidx = ctx.freshName("rowIdx")
+ val columnsBatchInput = (output zip colVars).map { case (attr, colVar) =>
+ genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable)
+ }
+ val localIdx = ctx.freshName("localIdx")
+ val localEnd = ctx.freshName("localEnd")
+ val numRows = ctx.freshName("numRows")
+ val shouldStop = if (parent.needStopCheck) {
+ s"if (shouldStop()) { $idx = $rowidx + 1; return; }"
+ } else {
+ "// shouldStop check is eliminated"
+ }
+ s"""
+ |if ($batch == null) {
+ | $nextBatchFuncName();
+ |}
+ |while ($batch != null) {
+ | int $numRows = $batch.numRows();
+ | int $localEnd = $numRows - $idx;
+ | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
+ | int $rowidx = $idx + $localIdx;
+ | ${consume(ctx, columnsBatchInput).trim}
+ | $shouldStop
+ | }
+ | $idx = $numRows;
+ | $batch = null;
+ | $nextBatchFuncName();
+ |}
+ |$scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000));
+ |$scanTimeTotalNs = 0;
+ """.stripMargin
+ }
+
+ override def inputRDDs(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].inputRDDs()
+ }
+}
+
+/**
+ * Provides an optimized set of APIs to append row based data to an array of
+ * [[WritableColumnVector]].
+ */
+private[execution] class RowToColumnConverter(schema: StructType) extends Serializable {
+ private val converters = schema.fields.map {
+ f => RowToColumnConverter.getConverterForType(f.dataType, f.nullable)
+ }
+
+ final def convert(row: InternalRow, vectors: Array[WritableColumnVector]): Unit = {
+ var idx = 0
+ while (idx < row.numFields) {
+ converters(idx).append(row, idx, vectors(idx))
+ idx += 1
+ }
+ }
+}
+
+/**
+ * Provides an optimized set of APIs to extract a column from a row and append it to a
+ * [[WritableColumnVector]].
+ */
+private object RowToColumnConverter {
+ private abstract class TypeConverter extends Serializable {
+ def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit
+ }
+
+ private final case class BasicNullableTypeConverter(base: TypeConverter) extends TypeConverter {
+ override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = {
+ if (row.isNullAt(column)) {
+ cv.appendNull
+ } else {
+ base.append(row, column, cv)
+ }
+ }
+ }
+
+ private final case class StructNullableTypeConverter(base: TypeConverter) extends TypeConverter {
+ override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = {
+ if (row.isNullAt(column)) {
+ cv.appendStruct(true)
+ } else {
+ base.append(row, column, cv)
+ }
+ }
+ }
+
+ private def getConverterForType(dataType: DataType, nullable: Boolean): TypeConverter = {
+ val core = dataType match {
+ case BooleanType => BooleanConverter
+ case ByteType => ByteConverter
+ case ShortType => ShortConverter
+ case IntegerType | DateType => IntConverter
+ case FloatType => FloatConverter
+ case LongType | TimestampType => LongConverter
+ case DoubleType => DoubleConverter
+ case StringType => StringConverter
+ case CalendarIntervalType => CalendarConverter
+ case at: ArrayType => new ArrayConverter(getConverterForType(at.elementType, nullable))
+ case st: StructType => new StructConverter(st.fields.map(
+ (f) => getConverterForType(f.dataType, f.nullable)))
+ case dt: DecimalType => new DecimalConverter(dt)
+ case mt: MapType => new MapConverter(getConverterForType(mt.keyType, nullable),
+ getConverterForType(mt.valueType, nullable))
+ case unknown => throw new UnsupportedOperationException(
+ s"Type $unknown not supported")
+ }
+
+ if (nullable) {
+ dataType match {
+ case CalendarIntervalType => new StructNullableTypeConverter(core)
+ case st: StructType => new StructNullableTypeConverter(core)
+ case _ => new BasicNullableTypeConverter(core)
+ }
+ } else {
+ core
+ }
+ }
+
+ private object BooleanConverter extends TypeConverter {
+ override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit =
+ cv.appendBoolean(row.getBoolean(column))
+ }
+
+ private object ByteConverter extends TypeConverter {
+ override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit =
+ cv.appendByte(row.getByte(column))
+ }
+
+ private object ShortConverter extends TypeConverter {
+ override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit =
+ cv.appendShort(row.getShort(column))
+ }
+
+ private object IntConverter extends TypeConverter {
+ override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit =
+ cv.appendInt(row.getInt(column))
+ }
+
+ private object FloatConverter extends TypeConverter {
+ override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit =
+ cv.appendFloat(row.getFloat(column))
+ }
+
+ private object LongConverter extends TypeConverter {
+ override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit =
+ cv.appendLong(row.getLong(column))
+ }
+
+ private object DoubleConverter extends TypeConverter {
+ override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit =
+ cv.appendDouble(row.getDouble(column))
+ }
+
+ private object StringConverter extends TypeConverter {
+ override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = {
+ val data = row.getUTF8String(column).getBytes
+ cv.appendByteArray(data, 0, data.length)
+ }
+ }
+
+ private object CalendarConverter extends TypeConverter {
+ override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = {
+ val c = row.getInterval(column)
+ cv.appendStruct(false)
+ cv.getChild(0).appendInt(c.months)
+ cv.getChild(1).appendLong(c.microseconds)
+ }
+ }
+
+ private case class ArrayConverter(childConverter: TypeConverter) extends TypeConverter {
+ override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = {
+ val values = row.getArray(column)
+ val numElements = values.numElements()
+ cv.appendArray(numElements)
+ val arrData = cv.arrayData()
+ for (i <- 0 until numElements) {
+ childConverter.append(values, i, arrData)
+ }
+ }
+ }
+
+ private case class StructConverter(childConverters: Array[TypeConverter]) extends TypeConverter {
+ override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = {
+ cv.appendStruct(false)
+ val data = row.getStruct(column, childConverters.length)
+ for (i <- 0 until childConverters.length) {
+ childConverters(i).append(data, i, cv.getChild(i))
+ }
+ }
+ }
+
+ private case class DecimalConverter(dt: DecimalType) extends TypeConverter {
+ override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = {
+ val d = row.getDecimal(column, dt.precision, dt.scale)
+ if (dt.precision <= Decimal.MAX_INT_DIGITS) {
+ cv.appendInt(d.toUnscaledLong.toInt)
+ } else if (dt.precision <= Decimal.MAX_LONG_DIGITS) {
+ cv.appendLong(d.toUnscaledLong)
+ } else {
+ val integer = d.toJavaBigDecimal.unscaledValue
+ val bytes = integer.toByteArray
+ cv.appendByteArray(bytes, 0, bytes.length)
+ }
+ }
+ }
+
+ private case class MapConverter(keyConverter: TypeConverter, valueConverter: TypeConverter)
+ extends TypeConverter {
+ override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = {
+ val m = row.getMap(column)
+ val keys = cv.getChild(0)
+ val values = cv.getChild(1)
+ val numElements = m.numElements()
+ cv.appendArray(numElements)
+
+ val srcKeys = m.keyArray()
+ val srcValues = m.valueArray()
+
+ for (i <- 0 until numElements) {
+ keyConverter.append(srcKeys, i, keys)
+ valueConverter.append(srcValues, i, values)
+ }
+ }
+ }
+}
+
+/**
+ * Provides a common executor to translate an [[RDD]] of [[InternalRow]] into an [[RDD]] of
+ * [[ColumnarBatch]]. This is inserted whenever such a transition is determined to be needed.
+ *
+ * This is similar to some of the code in ArrowConverters.scala and
+ * [[org.apache.spark.sql.execution.arrow.ArrowWriter]]. That code is more specialized
+ * to convert [[InternalRow]] to Arrow formatted data, but in the future if we make
+ * [[OffHeapColumnVector]] internally Arrow formatted we may be able to replace much of that code.
+ *
+ * This is also similar to
+ * [[org.apache.spark.sql.execution.vectorized.ColumnVectorUtils.populate()]] and
+ * [[org.apache.spark.sql.execution.vectorized.ColumnVectorUtils.toBatch()]] toBatch is only ever
+ * called from tests and can probably be removed, but populate is used by both Orc and Parquet
+ * to initialize partition and missing columns. There is some chance that we could replace
+ * populate with [[RowToColumnConverter]], but the performance requirements are different and it
+ * would only be to reduce code.
+ */
+case class RowToColumnarExec(child: SparkPlan) extends UnaryExecNode {
+ override def output: Seq[Attribute] = child.output
+
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+
+ override def doExecute(): RDD[InternalRow] = {
+ child.execute()
+ }
+
+ override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
+ child.doExecuteBroadcast()
+ }
+
+ override def supportsColumnar: Boolean = true
+
+ override lazy val metrics: Map[String, SQLMetric] = Map(
+ "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"),
+ "numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "number of output batches")
+ )
+
+ override def doExecuteColumnar(): RDD[ColumnarBatch] = {
+ val enableOffHeapColumnVector = sqlContext.conf.offHeapColumnVectorEnabled
+ val numInputRows = longMetric("numInputRows")
+ val numOutputBatches = longMetric("numOutputBatches")
+ // Instead of creating a new config we are reusing columnBatchSize. In the future if we do
+ // combine with some of the Arrow conversion tools we will need to unify some of the configs.
+ val numRows = conf.columnBatchSize
+ val converters = new RowToColumnConverter(schema)
+ val rowBased = child.execute()
+ rowBased.mapPartitions(rowIterator => {
+ new Iterator[ColumnarBatch] {
+ var cb: ColumnarBatch = null
+
+ TaskContext.get().addTaskCompletionListener[Unit] { _ =>
+ if (cb != null) {
+ cb.close()
+ cb = null
+ }
+ }
+
+ override def hasNext: Boolean = {
+ rowIterator.hasNext
+ }
+
+ override def next(): ColumnarBatch = {
+ if (cb != null) {
+ cb.close()
+ cb = null
+ }
+ val columnVectors : Array[WritableColumnVector] =
+ if (enableOffHeapColumnVector) {
+ OffHeapColumnVector.allocateColumns(numRows, schema).toArray
+ } else {
+ OnHeapColumnVector.allocateColumns(numRows, schema).toArray
+ }
+ var rowCount = 0
+ while (rowCount < numRows && rowIterator.hasNext) {
+ val row = rowIterator.next()
+ converters.convert(row, columnVectors)
+ rowCount += 1
+ }
+ cb = new ColumnarBatch(columnVectors.toArray, rowCount)
+ numInputRows += rowCount
+ numOutputBatches += 1
+ cb
+ }
+ }
+ })
+ }
+}
+
+/**
+ * Apply any user defined [[ColumnarRule]]s and find the correct place to insert transitions
+ * to/from columnar formatted data.
+ */
+case class ApplyColumnarRulesAndInsertTransitions(conf: SQLConf, columnarRules: Seq[ColumnarRule])
+ extends Rule[SparkPlan] {
+
+ /**
+ * Inserts an transition to columnar formatted data.
+ */
+ private def insertRowToColumnar(plan: SparkPlan): SparkPlan = {
+ if (!plan.supportsColumnar) {
+ // The tree feels kind of backwards
+ // Columnar Processing will start here, so transition from row to columnar
+ RowToColumnarExec(insertTransitions(plan))
+ } else {
+ plan.withNewChildren(plan.children.map(insertRowToColumnar))
+ }
+ }
+
+ /**
+ * Inserts RowToColumnarExecs and ColumnarToRowExecs where needed.
+ */
+ private def insertTransitions(plan: SparkPlan): SparkPlan = {
+ if (plan.supportsColumnar) {
+ // The tree feels kind of backwards
+ // This is the end of the columnar processing so go back to rows
+ ColumnarToRowExec(insertRowToColumnar(plan))
+ } else {
+ plan.withNewChildren(plan.children.map(insertTransitions))
+ }
+ }
+
+ def apply(plan: SparkPlan): SparkPlan = {
+ var preInsertPlan: SparkPlan = plan
+ columnarRules.foreach((r : ColumnarRule) =>
+ preInsertPlan = r.preColumnarTransitions(preInsertPlan))
+ var postInsertPlan = insertTransitions(preInsertPlan)
+ columnarRules.reverse.foreach((r : ColumnarRule) =>
+ postInsertPlan = r.postColumnarTransitions(postInsertPlan))
+ postInsertPlan
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
index 7caff69f23499..b2e9f760d27ca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
@@ -30,8 +30,6 @@ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
*/
private[sql] trait ColumnarBatchScan extends CodegenSupport {
- def vectorTypes: Option[Seq[String]] = None
-
protected def supportsBatch: Boolean = true
override lazy val metrics = Map(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 6f0b489af2784..9fcffac53c999 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -119,6 +119,8 @@ class QueryExecution(
InsertAdaptiveSparkPlan(sparkSession),
PlanSubqueries(sparkSession),
EnsureRequirements(sparkSession.sessionState.conf),
+ ApplyColumnarRulesAndInsertTransitions(sparkSession.sessionState.conf,
+ sparkSession.sessionState.columnarRules),
CollapseCodegenStages(sparkSession.sessionState.conf),
ReuseExchange(sparkSession.sessionState.conf),
ReuseSubquery(sparkSession.sessionState.conf))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index fbe8e5055a25c..6deb90c7e4ff2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.vectorized.ColumnarBatch
object SparkPlan {
/** The original [[LogicalPlan]] from which this [[SparkPlan]] is converted. */
@@ -73,6 +74,17 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
// whether we should fallback when hitting compilation errors caused by codegen
private val codeGenFallBack = (sqlContext == null) || sqlContext.conf.codegenFallback
+ /**
+ * Return true if this stage of the plan supports columnar execution.
+ */
+ def supportsColumnar: Boolean = false
+
+ /**
+ * The exact java types of the columns that are output in columnar processing mode. This
+ * is a performance optimization for code generation and is optional.
+ */
+ def vectorTypes: Option[Seq[String]] = None
+
/** Overridden make copy also propagates sqlContext to copied plan. */
override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = {
if (sqlContext != null) {
@@ -181,6 +193,20 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
doExecuteBroadcast()
}
+ /**
+ * Returns the result of this query as an RDD[ColumnarBatch] by delegating to `doColumnarExecute`
+ * after preparations.
+ *
+ * Concrete implementations of SparkPlan should override `doColumnarExecute` if `supportsColumnar`
+ * returns true.
+ */
+ final def executeColumnar(): RDD[ColumnarBatch] = executeQuery {
+ if (isCanonicalizedPlan) {
+ throw new IllegalStateException("A canonicalized plan is not supposed to be executed.")
+ }
+ doExecuteColumnar()
+ }
+
/**
* Executes a query after preparing the query and adding query plan information to created RDDs
* for visualization.
@@ -272,6 +298,16 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
throw new UnsupportedOperationException(s"$nodeName does not implement doExecuteBroadcast")
}
+ /**
+ * Produces the result of the query as an `RDD[ColumnarBatch]` if [[supportsColumnar]] returns
+ * true. By convention the executor that creates a ColumnarBatch is responsible for closing it
+ * when it is no longer needed. This allows input formats to be able to reuse batches if needed.
+ */
+ protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
+ throw new IllegalStateException(s"Internal Error ${this.getClass} has column support" +
+ s" mismatch:\n${this}")
+ }
+
/**
* Packing the UnsafeRows into byte array for faster serialization.
* The byte arrays are in the following format:
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 92e80dcf90e58..94a5ede751456 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -38,6 +38,7 @@ import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoi
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.Utils
/**
@@ -490,8 +491,12 @@ trait InputRDDCodegen extends CodegenSupport {
*
* This is the leaf node of a tree with WholeStageCodegen that is used to generate code
* that consumes an RDD iterator of InternalRow.
+ *
+ * @param isChildColumnar true if the inputRDD is really columnar data hidden by type erasure,
+ * false if inputRDD is really an RDD[InternalRow]
*/
-case class InputAdapter(child: SparkPlan) extends UnaryExecNode with InputRDDCodegen {
+case class InputAdapter(child: SparkPlan, isChildColumnar: Boolean)
+ extends UnaryExecNode with InputRDDCodegen {
override def output: Seq[Attribute] = child.output
@@ -499,6 +504,12 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with InputRDDCod
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+ override def vectorTypes: Option[Seq[String]] = child.vectorTypes
+
+ // This is not strictly needed because the codegen transformation happens after the columnar
+ // transformation but just for consistency
+ override def supportsColumnar: Boolean = child.supportsColumnar
+
override def doExecute(): RDD[InternalRow] = {
child.execute()
}
@@ -507,7 +518,17 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with InputRDDCod
child.doExecuteBroadcast()
}
- override def inputRDD: RDD[InternalRow] = child.execute()
+ override def doExecuteColumnar(): RDD[ColumnarBatch] = {
+ child.executeColumnar()
+ }
+
+ override def inputRDD: RDD[InternalRow] = {
+ if (isChildColumnar) {
+ child.executeColumnar().asInstanceOf[RDD[InternalRow]] // Hack because of type erasure
+ } else {
+ child.execute()
+ }
+ }
// This is a leaf node so the node can produce limit not reached checks.
override protected def canCheckLimitNotReached: Boolean = true
@@ -589,6 +610,10 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int)
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+ // This is not strictly needed because the codegen transformation happens after the columnar
+ // transformation but just for consistency
+ override def supportsColumnar: Boolean = child.supportsColumnar
+
override lazy val metrics = Map(
"pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext,
WholeStageCodegenExec.PIPELINE_DURATION_METRIC))
@@ -659,6 +684,12 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int)
(ctx, cleanedSource)
}
+ override def doExecuteColumnar(): RDD[ColumnarBatch] = {
+ // Code generation is not currently supported for columnar output, so just fall back to
+ // the interpreted path
+ child.executeColumnar()
+ }
+
override def doExecute(): RDD[InternalRow] = {
val (ctx, cleanedSource) = doCodeGen()
// try to compile and fallback if it failed
@@ -689,6 +720,9 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int)
val durationMs = longMetric("pipelineTime")
+ // Even though rdds is an RDD[InternalRow] it may actually be an RDD[ColumnarBatch] with
+ // type erasure hiding that. This allows for the input to a code gen stage to be columnar,
+ // but the output must be rows.
val rdds = child.asInstanceOf[CodegenSupport].inputRDDs()
assert(rdds.size <= 2, "Up to two input RDDs can be supported")
if (rdds.length == 1) {
@@ -840,34 +874,55 @@ case class CollapseCodegenStages(
/**
* Inserts an InputAdapter on top of those that do not support codegen.
*/
- private def insertInputAdapter(plan: SparkPlan): SparkPlan = plan match {
- case p if !supportCodegen(p) =>
- // collapse them recursively
- InputAdapter(insertWholeStageCodegen(p))
- case j: SortMergeJoinExec =>
- // The children of SortMergeJoin should do codegen separately.
- j.withNewChildren(j.children.map(child => InputAdapter(insertWholeStageCodegen(child))))
- case p =>
- p.withNewChildren(p.children.map(insertInputAdapter))
+ private def insertInputAdapter(plan: SparkPlan, isColumnarInput: Boolean): SparkPlan = {
+ val isColumnar = adjustColumnar(plan, isColumnarInput)
+ plan match {
+ case p if !supportCodegen(p) =>
+ // collapse them recursively
+ InputAdapter(insertWholeStageCodegen(p, isColumnar), isColumnar)
+ case j: SortMergeJoinExec =>
+ // The children of SortMergeJoin should do codegen separately.
+ j.withNewChildren(j.children.map(
+ child => InputAdapter(insertWholeStageCodegen(child, isColumnar), isColumnar)))
+ case p =>
+ p.withNewChildren(p.children.map(insertInputAdapter(_, isColumnar)))
+ }
}
/**
* Inserts a WholeStageCodegen on top of those that support codegen.
*/
- private def insertWholeStageCodegen(plan: SparkPlan): SparkPlan = plan match {
- // For operators that will output domain object, do not insert WholeStageCodegen for it as
- // domain object can not be written into unsafe row.
- case plan if plan.output.length == 1 && plan.output.head.dataType.isInstanceOf[ObjectType] =>
- plan.withNewChildren(plan.children.map(insertWholeStageCodegen))
- case plan: CodegenSupport if supportCodegen(plan) =>
- WholeStageCodegenExec(insertInputAdapter(plan))(codegenStageCounter.incrementAndGet())
- case other =>
- other.withNewChildren(other.children.map(insertWholeStageCodegen))
+ private def insertWholeStageCodegen(plan: SparkPlan, isColumnarInput: Boolean): SparkPlan = {
+ val isColumnar = adjustColumnar(plan, isColumnarInput)
+ plan match {
+ // For operators that will output domain object, do not insert WholeStageCodegen for it as
+ // domain object can not be written into unsafe row.
+ case plan if plan.output.length == 1 && plan.output.head.dataType.isInstanceOf[ObjectType] =>
+ plan.withNewChildren(plan.children.map(insertWholeStageCodegen(_, isColumnar)))
+ case plan: CodegenSupport if supportCodegen(plan) =>
+ WholeStageCodegenExec(
+ insertInputAdapter(plan, isColumnar))(codegenStageCounter.incrementAndGet())
+ case other =>
+ other.withNewChildren(other.children.map(insertWholeStageCodegen(_, isColumnar)))
+ }
+ }
+
+ /**
+ * Depending on the stage in the plan and if we currently are columnar or not
+ * return if we are still columnar or not.
+ */
+ private def adjustColumnar(plan: SparkPlan, isColumnar: Boolean): Boolean =
+ // We are walking up the plan, so columnar starts when we transition to rows
+ // and ends when we transition to columns
+ plan match {
+ case c2r: ColumnarToRowExec => true
+ case r2c: RowToColumnarExec => false
+ case _ => isColumnar
}
def apply(plan: SparkPlan): SparkPlan = {
if (conf.wholeStageEnabled) {
- insertWholeStageCodegen(plan)
+ insertWholeStageCodegen(plan, false)
} else {
plan
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 8d73449c3533d..8dc30eaa3a318 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.{QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser}
+import org.apache.spark.sql.execution.{ColumnarRule, QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.v2.{V2StreamingScanSupportCheck, V2WriteSupportCheck}
import org.apache.spark.sql.streaming.StreamingQueryManager
@@ -264,6 +264,10 @@ abstract class BaseSessionStateBuilder(
extensions.buildPlannerStrategies(session)
}
+ protected def columnarRules: Seq[ColumnarRule] = {
+ extensions.buildColumnarRules(session)
+ }
+
/**
* Create a query execution object.
*/
@@ -314,7 +318,8 @@ abstract class BaseSessionStateBuilder(
listenerManager,
() => resourceLoader,
createQueryExecution,
- createClone)
+ createClone,
+ columnarRules)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index b34db581ca2c1..b962ab6feabcb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -71,7 +71,8 @@ private[sql] class SessionState(
val listenerManager: ExecutionListenerManager,
resourceLoaderBuilder: () => SessionResourceLoader,
createQueryExecution: LogicalPlan => QueryExecution,
- createClone: (SparkSession, SessionState) => SessionState) {
+ createClone: (SparkSession, SessionState) => SessionState,
+ val columnarRules: Seq[ColumnarRule]) {
// The following fields are lazy to avoid creating the Hive client when creating SessionState.
lazy val catalog: SessionCatalog = catalogBuilder()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 881268440ccd7..2e2e61b438963 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -16,14 +16,19 @@
*/
package org.apache.spark.sql
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkFunSuite, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
-import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, Literal}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy}
-import org.apache.spark.sql.types.{DataType, IntegerType, StructType}
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
+import org.apache.spark.sql.types.{DataType, Decimal, IntegerType, LongType, Metadata, StructType}
+import org.apache.spark.sql.vectorized.{ColumnarArray, ColumnarBatch, ColumnarMap, ColumnVector}
+import org.apache.spark.unsafe.types.UTF8String
/**
* Test cases for the [[SparkSessionExtensions]].
@@ -116,6 +121,34 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
}
}
+ test("inject columnar") {
+ val extensions = create { extensions =>
+ extensions.injectColumnar(session =>
+ MyColumarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule()))
+ }
+ withSession(extensions) { session =>
+ assert(session.sessionState.columnarRules.contains(
+ MyColumarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())))
+ import session.sqlContext.implicits._
+ // repartitioning avoids having the add operation pushed up into the LocalTableScan
+ val data = Seq((100L), (200L), (300L)).toDF("vals").repartition(1)
+ val df = data.selectExpr("vals + 1")
+ // Verify that both pre and post processing of the plan worked.
+ val found = df.queryExecution.executedPlan.collect {
+ case rep: ReplacedRowToColumnarExec => 1
+ case proj: ColumnarProjectExec => 10
+ case c2r: ColumnarToRowExec => 100
+ }.sum
+ assert(found == 111)
+
+ // Verify that we get back the expected, wrong, result
+ val result = df.collect()
+ assert(result(0).getLong(0) == 102L) // Check that broken columnar Add was used.
+ assert(result(1).getLong(0) == 202L)
+ assert(result(2).getLong(0) == 302L)
+ }
+ }
+
test("use custom class for extensions") {
val session = SparkSession.builder()
.master("local[1]")
@@ -130,6 +163,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
assert(session.sessionState.sqlParser.isInstanceOf[MyParser])
assert(session.sessionState.functionRegistry
.lookupFunction(MyExtensions.myFunction._1).isDefined)
+ assert(session.sessionState.columnarRules.contains(
+ MyColumarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())))
} finally {
stop(session)
}
@@ -251,6 +286,371 @@ object MyExtensions {
(_: Seq[Expression]) => Literal(5, IntegerType))
}
+case class CloseableColumnBatchIterator(itr: Iterator[ColumnarBatch],
+ f: ColumnarBatch => ColumnarBatch) extends Iterator[ColumnarBatch] {
+ var cb: ColumnarBatch = null
+
+ private def closeCurrentBatch(): Unit = {
+ if (cb != null) {
+ cb.close
+ cb = null
+ }
+ }
+
+ TaskContext.get().addTaskCompletionListener[Unit]((tc: TaskContext) => {
+ closeCurrentBatch()
+ })
+
+ override def hasNext: Boolean = {
+ closeCurrentBatch()
+ itr.hasNext
+ }
+
+ override def next(): ColumnarBatch = {
+ closeCurrentBatch()
+ cb = f(itr.next())
+ cb
+ }
+}
+
+object NoCloseColumnVector extends Logging {
+ def wrapIfNeeded(cv: ColumnVector): NoCloseColumnVector = cv match {
+ case ref: NoCloseColumnVector =>
+ ref
+ case vec => NoCloseColumnVector(vec)
+ }
+}
+
+/**
+ * Provide a ColumnVector so ColumnarExpression can close temporary values without
+ * having to guess what type it really is.
+ */
+case class NoCloseColumnVector(wrapped: ColumnVector) extends ColumnVector(wrapped.dataType) {
+ private var refCount = 1
+
+ /**
+ * Don't actually close the ColumnVector this wraps. The producer of the vector will take
+ * care of that.
+ */
+ override def close(): Unit = {
+ // Empty
+ }
+
+ override def hasNull: Boolean = wrapped.hasNull
+
+ override def numNulls(): Int = wrapped.numNulls
+
+ override def isNullAt(rowId: Int): Boolean = wrapped.isNullAt(rowId)
+
+ override def getBoolean(rowId: Int): Boolean = wrapped.getBoolean(rowId)
+
+ override def getByte(rowId: Int): Byte = wrapped.getByte(rowId)
+
+ override def getShort(rowId: Int): Short = wrapped.getShort(rowId)
+
+ override def getInt(rowId: Int): Int = wrapped.getInt(rowId)
+
+ override def getLong(rowId: Int): Long = wrapped.getLong(rowId)
+
+ override def getFloat(rowId: Int): Float = wrapped.getFloat(rowId)
+
+ override def getDouble(rowId: Int): Double = wrapped.getDouble(rowId)
+
+ override def getArray(rowId: Int): ColumnarArray = wrapped.getArray(rowId)
+
+ override def getMap(ordinal: Int): ColumnarMap = wrapped.getMap(ordinal)
+
+ override def getDecimal(rowId: Int, precision: Int, scale: Int): Decimal =
+ wrapped.getDecimal(rowId, precision, scale)
+
+ override def getUTF8String(rowId: Int): UTF8String = wrapped.getUTF8String(rowId)
+
+ override def getBinary(rowId: Int): Array[Byte] = wrapped.getBinary(rowId)
+
+ override protected def getChild(ordinal: Int): ColumnVector = wrapped.getChild(ordinal)
+}
+
+trait ColumnarExpression extends Expression with Serializable {
+ /**
+ * Returns true if this expression supports columnar processing through [[columnarEval]].
+ */
+ def supportsColumnar: Boolean = true
+
+ /**
+ * Returns the result of evaluating this expression on the entire
+ * [[org.apache.spark.sql.vectorized.ColumnarBatch]]. The result of
+ * calling this may be a single [[org.apache.spark.sql.vectorized.ColumnVector]] or a scalar
+ * value. Scalar values typically happen if they are a part of the expression i.e. col("a") + 100.
+ * In this case the 100 is a [[org.apache.spark.sql.catalyst.expressions.Literal]] that
+ * [[org.apache.spark.sql.catalyst.expressions.Add]] would have to be able to handle.
+ *
+ * By convention any [[org.apache.spark.sql.vectorized.ColumnVector]] returned by [[columnarEval]]
+ * is owned by the caller and will need to be closed by them. This can happen by putting it into
+ * a [[org.apache.spark.sql.vectorized.ColumnarBatch]] and closing the batch or by closing the
+ * vector directly if it is a temporary value.
+ */
+ def columnarEval(batch: ColumnarBatch): Any = {
+ throw new IllegalStateException(s"Internal Error ${this.getClass} has column support mismatch")
+ }
+
+ // We need to override equals because we are subclassing a case class
+ override def equals(other: Any): Boolean = {
+ if (!super.equals(other)) {
+ return false
+ }
+ return other.isInstanceOf[ColumnarExpression]
+ }
+
+ override def hashCode(): Int = super.hashCode()
+}
+
+object ColumnarBindReferences extends Logging {
+
+ // Mostly copied from BoundAttribute.scala so we can do columnar processing
+ def bindReference[A <: ColumnarExpression](
+ expression: A,
+ input: AttributeSeq,
+ allowFailures: Boolean = false): A = {
+ expression.transform { case a: AttributeReference =>
+ val ordinal = input.indexOf(a.exprId)
+ if (ordinal == -1) {
+ if (allowFailures) {
+ a
+ } else {
+ sys.error(s"Couldn't find $a in ${input.attrs.mkString("[", ",", "]")}")
+ }
+ } else {
+ new ColumnarBoundReference(ordinal, a.dataType, input(ordinal).nullable)
+ }
+ }.asInstanceOf[A]
+ }
+
+ /**
+ * A helper function to bind given expressions to an input schema.
+ */
+ def bindReferences[A <: ColumnarExpression](
+ expressions: Seq[A],
+ input: AttributeSeq): Seq[A] = {
+ expressions.map(ColumnarBindReferences.bindReference(_, input))
+ }
+}
+
+class ColumnarBoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
+ extends BoundReference(ordinal, dataType, nullable) with ColumnarExpression {
+
+ override def columnarEval(batch: ColumnarBatch): Any = {
+ // Because of the convention that the returned ColumnVector must be closed by the
+ // caller we wrap this column vector so a close is a NOOP, and let the original source
+ // of the vector close it.
+ NoCloseColumnVector.wrapIfNeeded(batch.column(ordinal))
+ }
+}
+
+class ColumnarAlias(child: ColumnarExpression, name: String)(
+ override val exprId: ExprId = NamedExpression.newExprId,
+ override val qualifier: Seq[String] = Seq.empty,
+ override val explicitMetadata: Option[Metadata] = None)
+ extends Alias(child, name)(exprId, qualifier, explicitMetadata)
+ with ColumnarExpression {
+
+ override def columnarEval(batch: ColumnarBatch): Any = child.columnarEval(batch)
+}
+
+class ColumnarAttributeReference(
+ name: String,
+ dataType: DataType,
+ nullable: Boolean = true,
+ override val metadata: Metadata = Metadata.empty)(
+ override val exprId: ExprId = NamedExpression.newExprId,
+ override val qualifier: Seq[String] = Seq.empty[String])
+ extends AttributeReference(name, dataType, nullable, metadata)(exprId, qualifier)
+ with ColumnarExpression {
+
+ // No columnar eval is needed because this must be bound before it is evaluated
+}
+
+class ColumnarLiteral (value: Any, dataType: DataType) extends Literal(value, dataType)
+ with ColumnarExpression {
+ override def columnarEval(batch: ColumnarBatch): Any = value
+}
+
+/**
+ * A version of ProjectExec that adds in columnar support.
+ */
+class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
+ extends ProjectExec(projectList, child) {
+
+ override def supportsColumnar: Boolean =
+ projectList.forall(_.asInstanceOf[ColumnarExpression].supportsColumnar)
+
+ // Disable code generation
+ override def supportCodegen: Boolean = false
+
+ override def doExecuteColumnar() : RDD[ColumnarBatch] = {
+ val boundProjectList: Seq[Any] =
+ ColumnarBindReferences.bindReferences(
+ projectList.asInstanceOf[Seq[ColumnarExpression]], child.output)
+ val rdd = child.executeColumnar()
+ rdd.mapPartitions((itr) => CloseableColumnBatchIterator(itr,
+ (cb) => {
+ val newColumns = boundProjectList.map(
+ expr => expr.asInstanceOf[ColumnarExpression].columnarEval(cb).asInstanceOf[ColumnVector]
+ ).toArray
+ new ColumnarBatch(newColumns, cb.numRows())
+ })
+ )
+ }
+
+ // We have to override equals because subclassing a case class like ProjectExec is not that clean
+ // One of the issues is that the generated equals will see ColumnarProjectExec and ProjectExec
+ // as being equal and this can result in the withNewChildren method not actually replacing
+ // anything
+ override def equals(other: Any): Boolean = {
+ if (!super.equals(other)) {
+ return false
+ }
+ return other.isInstanceOf[ColumnarProjectExec]
+ }
+
+ override def hashCode(): Int = super.hashCode()
+}
+
+/**
+ * A version of add that supports columnar processing for longs. This version is broken
+ * on purpose so it adds the numbers plus 1 so that the tests can show that it was replaced.
+ */
+class BrokenColumnarAdd(left: ColumnarExpression, right: ColumnarExpression)
+ extends Add(left, right) with ColumnarExpression {
+
+ override def supportsColumnar(): Boolean = left.supportsColumnar && right.supportsColumnar
+
+ override def columnarEval(batch: ColumnarBatch): Any = {
+ var lhs: Any = null
+ var rhs: Any = null
+ var ret: Any = null
+ try {
+ lhs = left.columnarEval(batch)
+ rhs = right.columnarEval(batch)
+
+ if (lhs == null || rhs == null) {
+ ret = null
+ } else if (lhs.isInstanceOf[ColumnVector] && rhs.isInstanceOf[ColumnVector]) {
+ val l = lhs.asInstanceOf[ColumnVector]
+ val r = rhs.asInstanceOf[ColumnVector]
+ val result = new OnHeapColumnVector(batch.numRows(), dataType)
+ ret = result
+
+ for (i <- 0 until batch.numRows()) {
+ result.appendLong(l.getLong(i) + r.getLong(i) + 1) // BUG to show we replaced Add
+ }
+ } else if (rhs.isInstanceOf[ColumnVector]) {
+ val l = lhs.asInstanceOf[Long]
+ val r = rhs.asInstanceOf[ColumnVector]
+ val result = new OnHeapColumnVector(batch.numRows(), dataType)
+ ret = result
+
+ for (i <- 0 until batch.numRows()) {
+ result.appendLong(l + r.getLong(i) + 1) // BUG to show we replaced Add
+ }
+ } else if (lhs.isInstanceOf[ColumnVector]) {
+ val l = lhs.asInstanceOf[ColumnVector]
+ val r = rhs.asInstanceOf[Long]
+ val result = new OnHeapColumnVector(batch.numRows(), dataType)
+ ret = result
+
+ for (i <- 0 until batch.numRows()) {
+ result.appendLong(l.getLong(i) + r + 1) // BUG to show we replaced Add
+ }
+ } else {
+ ret = nullSafeEval(lhs, rhs)
+ }
+ } finally {
+ if (lhs != null && lhs.isInstanceOf[ColumnVector]) {
+ lhs.asInstanceOf[ColumnVector].close()
+ }
+ if (rhs != null && rhs.isInstanceOf[ColumnVector]) {
+ rhs.asInstanceOf[ColumnVector].close()
+ }
+ }
+ ret
+ }
+}
+
+class CannotReplaceException(str: String) extends RuntimeException(str) {
+
+}
+
+case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] {
+ def replaceWithColumnarExpression(exp: Expression): ColumnarExpression = exp match {
+ case a: Alias =>
+ new ColumnarAlias(replaceWithColumnarExpression(a.child),
+ a.name)(a.exprId, a.qualifier, a.explicitMetadata)
+ case att: AttributeReference =>
+ new ColumnarAttributeReference(att.name, att.dataType, att.nullable,
+ att.metadata)(att.exprId, att.qualifier)
+ case lit: Literal =>
+ new ColumnarLiteral(lit.value, lit.dataType)
+ case add: Add if (add.dataType == LongType) &&
+ (add.left.dataType == LongType) &&
+ (add.right.dataType == LongType) =>
+ // Add only supports Longs for now.
+ new BrokenColumnarAdd(replaceWithColumnarExpression(add.left),
+ replaceWithColumnarExpression(add.right))
+ case exp =>
+ throw new CannotReplaceException(s"expression " +
+ s"${exp.getClass} ${exp} is not currently supported.")
+ }
+
+ def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan =
+ try {
+ plan match {
+ case plan: ProjectExec =>
+ new ColumnarProjectExec(plan.projectList.map((exp) =>
+ replaceWithColumnarExpression(exp).asInstanceOf[NamedExpression]),
+ replaceWithColumnarPlan(plan.child))
+ case p =>
+ logWarning(s"Columnar processing for ${p.getClass} is not currently supported.")
+ p.withNewChildren(p.children.map(replaceWithColumnarPlan))
+ }
+ } catch {
+ case exp: CannotReplaceException =>
+ logWarning(s"Columnar processing for ${plan.getClass} is not currently supported" +
+ s"because ${exp.getMessage}")
+ plan
+ }
+
+ override def apply(plan: SparkPlan): SparkPlan = replaceWithColumnarPlan(plan)
+}
+
+class ReplacedRowToColumnarExec(override val child: SparkPlan)
+ extends RowToColumnarExec(child) {
+
+ // We have to override equals because subclassing a case class like ProjectExec is not that clean
+ // One of the issues is that the generated equals will see ColumnarProjectExec and ProjectExec
+ // as being equal and this can result in the withNewChildren method not actually replacing
+ // anything
+ override def equals(other: Any): Boolean = {
+ if (!super.equals(other)) {
+ return false
+ }
+ return other.isInstanceOf[ReplacedRowToColumnarExec]
+ }
+
+ override def hashCode(): Int = super.hashCode()
+}
+
+case class MyPostRule() extends Rule[SparkPlan] {
+ override def apply(plan: SparkPlan): SparkPlan = plan match {
+ case rc: RowToColumnarExec => new ReplacedRowToColumnarExec(rc.child)
+ case plan => plan.withNewChildren(plan.children.map(apply))
+ }
+}
+
+case class MyColumarRule(pre: Rule[SparkPlan], post: Rule[SparkPlan]) extends ColumnarRule {
+ override def preColumnarTransitions: Rule[SparkPlan] = pre
+ override def postColumnarTransitions: Rule[SparkPlan] = post
+}
+
class MyExtensions extends (SparkSessionExtensions => Unit) {
def apply(e: SparkSessionExtensions): Unit = {
e.injectPlannerStrategy(MySparkStrategy)
@@ -260,6 +660,7 @@ class MyExtensions extends (SparkSessionExtensions => Unit) {
e.injectOptimizerRule(MyRule)
e.injectParser(MyParser)
e.injectFunction(MyExtensions.myFunction)
+ e.injectColumnar(session => MyColumarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule()))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
index 289cc667a1c66..8a18a1ab5406f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
@@ -50,7 +50,7 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext {
val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
case f @ FilterExec(
And(_: AttributeReference, _: AttributeReference),
- InputAdapter(_: BatchEvalPythonExec)) => f
+ InputAdapter(_: BatchEvalPythonExec, _)) => f
case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b
}
assert(qualifiedPlanNodes.size == 2)
@@ -60,7 +60,7 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext {
val df = Seq(("Hello", 4)).toDF("a", "b")
.where("dummyPythonUDF(a, dummyPythonUDF(a, b)) and a in (3, 4)")
val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
- case f @ FilterExec(_: AttributeReference, InputAdapter(_: BatchEvalPythonExec)) => f
+ case f @ FilterExec(_: AttributeReference, InputAdapter(_: BatchEvalPythonExec, _)) => f
case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b
}
assert(qualifiedPlanNodes.size == 2)
@@ -72,7 +72,7 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext {
val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
case f @ FilterExec(
And(_: AttributeReference, _: GreaterThan),
- InputAdapter(_: BatchEvalPythonExec)) => f
+ InputAdapter(_: BatchEvalPythonExec, _)) => f
case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b
}
assert(qualifiedPlanNodes.size == 2)
@@ -85,7 +85,7 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext {
val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
case f @ FilterExec(
And(_: AttributeReference, _: GreaterThan),
- InputAdapter(_: BatchEvalPythonExec)) => f
+ InputAdapter(_: BatchEvalPythonExec, _)) => f
case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b
}
assert(qualifiedPlanNodes.size == 2)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
index bee20227ce67d..758780c80b284 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
@@ -31,11 +31,14 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.memory.MemoryMode
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapBuilder, DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.execution.RowToColumnConverter
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
import org.apache.spark.unsafe.Platform
-import org.apache.spark.unsafe.types.CalendarInterval
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
class ColumnarBatchSuite extends SparkFunSuite {
@@ -1270,6 +1273,211 @@ class ColumnarBatchSuite extends SparkFunSuite {
allocator.close()
}
+ test("RowToColumnConverter") {
+ val schema = StructType(
+ StructField("str", StringType) ::
+ StructField("bool", BooleanType) ::
+ StructField("byte", ByteType) ::
+ StructField("short", ShortType) ::
+ StructField("int", IntegerType) ::
+ StructField("long", LongType) ::
+ StructField("float", FloatType) ::
+ StructField("double", DoubleType) ::
+ StructField("decimal", DecimalType(25, 5)) ::
+ StructField("date", DateType) ::
+ StructField("ts", TimestampType) ::
+ StructField("cal", CalendarIntervalType) ::
+ StructField("arr_of_int", ArrayType(IntegerType)) ::
+ StructField("int_and_int", StructType(
+ StructField("int1", IntegerType, false) ::
+ StructField("int2", IntegerType) ::
+ Nil
+ )) ::
+ StructField("int_to_int", MapType(IntegerType, IntegerType)) ::
+ Nil)
+ var mapBuilder = new ArrayBasedMapBuilder(IntegerType, IntegerType)
+ mapBuilder.put(1, 10)
+ mapBuilder.put(20, null)
+ val row1 = new GenericInternalRow(Array[Any](
+ UTF8String.fromString("a string"),
+ true,
+ 1.toByte,
+ 2.toShort,
+ 3,
+ Long.MaxValue,
+ 0.25.toFloat,
+ 0.75D,
+ Decimal("1234.23456"),
+ DateTimeUtils.fromJavaDate(java.sql.Date.valueOf("2015-01-01")),
+ DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf("2015-01-01 23:50:59.123")),
+ new CalendarInterval(1, 0),
+ new GenericArrayData(Array(1, 2, 3, 4, null)),
+ new GenericInternalRow(Array[Any](5.asInstanceOf[Any], 10)),
+ mapBuilder.build()
+ ))
+
+ mapBuilder = new ArrayBasedMapBuilder(IntegerType, IntegerType)
+ mapBuilder.put(30, null)
+ mapBuilder.put(40, 50)
+ val row2 = new GenericInternalRow(Array[Any](
+ UTF8String.fromString("second string"),
+ false,
+ -1.toByte,
+ 17.toShort,
+ Int.MinValue,
+ 987654321L,
+ Float.NaN,
+ Double.PositiveInfinity,
+ Decimal("0.01000"),
+ DateTimeUtils.fromJavaDate(java.sql.Date.valueOf("1875-12-12")),
+ DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf("1880-01-05 12:45:21.321")),
+ new CalendarInterval(-10, -100),
+ new GenericArrayData(Array(5, 10, -100)),
+ new GenericInternalRow(Array[Any](20.asInstanceOf[Any], null)),
+ mapBuilder.build()
+ ))
+
+ val row3 = new GenericInternalRow(Array[Any](
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null
+ ))
+
+ val converter = new RowToColumnConverter(schema)
+ val columns = OnHeapColumnVector.allocateColumns(3, schema)
+ val batch = new ColumnarBatch(columns.toArray, 3)
+ try {
+ converter.convert(row1, columns.toArray)
+ converter.convert(row2, columns.toArray)
+ converter.convert(row3, columns.toArray)
+
+ assert(columns(0).dataType() == StringType)
+ assert(columns(0).getUTF8String(0).toString == "a string")
+ assert(columns(0).getUTF8String(1).toString == "second string")
+ assert(columns(0).isNullAt(2))
+
+ assert(columns(1).dataType() == BooleanType)
+ assert(columns(1).getBoolean(0) == true)
+ assert(columns(1).getBoolean(1) == false)
+ assert(columns(1).isNullAt(2))
+
+ assert(columns(2).dataType() == ByteType)
+ assert(columns(2).getByte(0) == 1.toByte)
+ assert(columns(2).getByte(1) == -1.toByte)
+ assert(columns(2).isNullAt(2))
+
+ assert(columns(3).dataType() == ShortType)
+ assert(columns(3).getShort(0) == 2.toShort)
+ assert(columns(3).getShort(1) == 17.toShort)
+ assert(columns(3).isNullAt(2))
+
+ assert(columns(4).dataType() == IntegerType)
+ assert(columns(4).getInt(0) == 3)
+ assert(columns(4).getInt(1) == Int.MinValue)
+ assert(columns(4).isNullAt(2))
+
+ assert(columns(5).dataType() == LongType)
+ assert(columns(5).getLong(0) == Long.MaxValue)
+ assert(columns(5).getLong(1) == 987654321L)
+ assert(columns(5).isNullAt(2))
+
+ assert(columns(6).dataType() == FloatType)
+ assert(columns(6).getFloat(0) == 0.25.toFloat)
+ assert(columns(6).getFloat(1).isNaN)
+ assert(columns(6).isNullAt(2))
+
+ assert(columns(7).dataType() == DoubleType)
+ assert(columns(7).getDouble(0) == 0.75D)
+ assert(columns(7).getDouble(1) == Double.PositiveInfinity)
+ assert(columns(7).isNullAt(2))
+
+ assert(columns(8).dataType() == DecimalType(25, 5))
+ assert(columns(8).getDecimal(0, 25, 5) == Decimal("1234.23456"))
+ assert(columns(8).getDecimal(1, 25, 5) == Decimal("0.01000"))
+ assert(columns(8).isNullAt(2))
+
+ assert(columns(9).dataType() == DateType)
+ assert(columns(9).getInt(0) ==
+ DateTimeUtils.fromJavaDate(java.sql.Date.valueOf("2015-01-01")))
+ assert(columns(9).getInt(1) ==
+ DateTimeUtils.fromJavaDate(java.sql.Date.valueOf("1875-12-12")))
+ assert(columns(9).isNullAt(2))
+
+ assert(columns(10).dataType() == TimestampType)
+ assert(columns(10).getLong(0) ==
+ DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf("2015-01-01 23:50:59.123")))
+ assert(columns(10).getLong(1) ==
+ DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf("1880-01-05 12:45:21.321")))
+ assert(columns(10).isNullAt(2))
+
+ assert(columns(11).dataType() == CalendarIntervalType)
+ assert(columns(11).getInterval(0) == new CalendarInterval(1, 0))
+ assert(columns(11).getInterval(1) == new CalendarInterval(-10, -100))
+ assert(columns(11).isNullAt(2))
+
+ assert(columns(12).dataType() == ArrayType(IntegerType))
+ val arr1 = columns(12).getArray(0)
+ assert(arr1.numElements() == 5)
+ assert(arr1.getInt(0) == 1)
+ assert(arr1.getInt(1) == 2)
+ assert(arr1.getInt(2) == 3)
+ assert(arr1.getInt(3) == 4)
+ assert(arr1.isNullAt(4))
+
+ val arr2 = columns(12).getArray(1)
+ assert(arr2.numElements() == 3)
+ assert(arr2.getInt(0) == 5)
+ assert(arr2.getInt(1) == 10)
+ assert(arr2.getInt(2) == -100)
+
+ assert(columns(12).isNullAt(2))
+
+ assert(columns(13).dataType() == StructType(
+ StructField("int1", IntegerType, false) ::
+ StructField("int2", IntegerType) ::
+ Nil
+ ))
+ val struct1 = columns(13).getStruct(0)
+ assert(struct1.getInt(0) == 5)
+ assert(struct1.getInt(1) == 10)
+ val struct2 = columns(13).getStruct(1)
+ assert(struct2.getInt(0) == 20)
+ assert(struct2.isNullAt(1))
+ assert(columns(13).isNullAt(2))
+
+ assert(columns(14).dataType() == MapType(IntegerType, IntegerType))
+ val map1 = columns(14).getMap(0)
+ assert(map1.numElements() == 2)
+ assert(map1.keyArray().getInt(0) == 1)
+ assert(map1.valueArray().getInt(0) == 10)
+ assert(map1.keyArray().getInt(1) == 20)
+ assert(map1.valueArray().isNullAt(1))
+
+ val map2 = columns(14).getMap(1)
+ assert(map2.numElements() == 2)
+ assert(map2.keyArray().getInt(0) == 30)
+ assert(map2.valueArray().isNullAt(0))
+ assert(map2.keyArray().getInt(1) == 40)
+ assert(map2.valueArray().getInt(1) == 50)
+
+ assert(columns(14).isNullAt(2))
+ } finally {
+ batch.close()
+ }
+ }
+
testVector("Decimal API", 4, DecimalType.IntDecimal) {
column =>