diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java index 9467435435d1f..24260b05194a7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java @@ -41,7 +41,7 @@ public class AggregateHashMap { private OnHeapColumnVector[] columnVectors; - private ColumnarBatch batch; + private MutableColumnarRow aggBufferRow; private int[] buckets; private int numBuckets; private int numRows = 0; @@ -63,7 +63,7 @@ public AggregateHashMap(StructType schema, int capacity, double loadFactor, int this.maxSteps = maxSteps; numBuckets = (int) (capacity / loadFactor); columnVectors = OnHeapColumnVector.allocateColumns(capacity, schema); - batch = new ColumnarBatch(schema, columnVectors, capacity); + aggBufferRow = new MutableColumnarRow(columnVectors); buckets = new int[numBuckets]; Arrays.fill(buckets, -1); } @@ -72,14 +72,15 @@ public AggregateHashMap(StructType schema) { this(schema, DEFAULT_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_MAX_STEPS); } - public ColumnarRow findOrInsert(long key) { + public MutableColumnarRow findOrInsert(long key) { int idx = find(key); if (idx != -1 && buckets[idx] == -1) { columnVectors[0].putLong(numRows, key); columnVectors[1].putLong(numRows, 0); buckets[idx] = numRows++; } - return batch.getRow(buckets[idx]); + aggBufferRow.rowId = buckets[idx]; + return aggBufferRow; } @VisibleForTesting diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 0071bd66760be..1f1347ccd315e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -323,7 +323,6 @@ public ArrowColumnVector(ValueVector vector) { for (int i = 0; i < childColumns.length; ++i) { childColumns[i] = new ArrowColumnVector(mapVector.getVectorById(i)); } - resultStruct = new ColumnarRow(childColumns); } else { throw new UnsupportedOperationException(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index cca14911fbb28..e6b87519239dd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -157,18 +157,16 @@ public abstract class ColumnVector implements AutoCloseable { /** * Returns a utility object to get structs. */ - public ColumnarRow getStruct(int rowId) { - resultStruct.rowId = rowId; - return resultStruct; + public final ColumnarRow getStruct(int rowId) { + return new ColumnarRow(this, rowId); } /** * Returns a utility object to get structs. * provided to keep API compatibility with InternalRow for code generation */ - public ColumnarRow getStruct(int rowId, int size) { - resultStruct.rowId = rowId; - return resultStruct; + public final ColumnarRow getStruct(int rowId, int size) { + return getStruct(rowId); } /** @@ -216,11 +214,6 @@ public MapData getMap(int ordinal) { */ protected DataType type; - /** - * Reusable Struct holder for getStruct(). - */ - protected ColumnarRow resultStruct; - /** * The Dictionary for this column. * diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 2f5fb360b226f..a9d09aa679726 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -18,6 +18,7 @@ import java.util.*; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.StructType; /** @@ -40,10 +41,10 @@ public final class ColumnarBatch { private final StructType schema; private final int capacity; private int numRows; - final ColumnVector[] columns; + private final ColumnVector[] columns; - // Staging row returned from getRow. - final ColumnarRow row; + // Staging row returned from `getRow`. + private final MutableColumnarRow row; /** * Called to close all the columns in this batch. It is not valid to access the data after @@ -58,10 +59,10 @@ public void close() { /** * Returns an iterator over the rows in this batch. This skips rows that are filtered out. */ - public Iterator rowIterator() { + public Iterator rowIterator() { final int maxRows = numRows; - final ColumnarRow row = new ColumnarRow(columns); - return new Iterator() { + final MutableColumnarRow row = new MutableColumnarRow(columns); + return new Iterator() { int rowId = 0; @Override @@ -70,7 +71,7 @@ public boolean hasNext() { } @Override - public ColumnarRow next() { + public InternalRow next() { if (rowId >= maxRows) { throw new NoSuchElementException(); } @@ -133,9 +134,8 @@ public void setNumRows(int numRows) { /** * Returns the row in this batch at `rowId`. Returned row is reused across calls. */ - public ColumnarRow getRow(int rowId) { - assert(rowId >= 0); - assert(rowId < numRows); + public InternalRow getRow(int rowId) { + assert(rowId >= 0 && rowId < numRows); row.rowId = rowId; return row; } @@ -144,6 +144,6 @@ public ColumnarBatch(StructType schema, ColumnVector[] columns, int capacity) { this.schema = schema; this.columns = columns; this.capacity = capacity; - this.row = new ColumnarRow(columns); + this.row = new MutableColumnarRow(columns); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java index cabb7479525d9..95c0d09873d67 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java @@ -28,30 +28,32 @@ * to be reused, callers should copy the data out if it needs to be stored. */ public final class ColumnarRow extends InternalRow { - protected int rowId; - private final ColumnVector[] columns; - - // Ctor used if this is a struct. - ColumnarRow(ColumnVector[] columns) { - this.columns = columns; + // The data for this row. E.g. the value of 3rd int field is `data.getChildColumn(3).getInt(rowId)`. + private final ColumnVector data; + private final int rowId; + private final int numFields; + + ColumnarRow(ColumnVector data, int rowId) { + assert (data.dataType() instanceof StructType); + this.data = data; + this.rowId = rowId; + this.numFields = ((StructType) data.dataType()).size(); } - public ColumnVector[] columns() { return columns; } - @Override - public int numFields() { return columns.length; } + public int numFields() { return numFields; } /** * Revisit this. This is expensive. This is currently only used in test paths. */ @Override public InternalRow copy() { - GenericInternalRow row = new GenericInternalRow(columns.length); + GenericInternalRow row = new GenericInternalRow(numFields); for (int i = 0; i < numFields(); i++) { if (isNullAt(i)) { row.setNullAt(i); } else { - DataType dt = columns[i].dataType(); + DataType dt = data.getChildColumn(i).dataType(); if (dt instanceof BooleanType) { row.setBoolean(i, getBoolean(i)); } else if (dt instanceof ByteType) { @@ -91,65 +93,65 @@ public boolean anyNull() { } @Override - public boolean isNullAt(int ordinal) { return columns[ordinal].isNullAt(rowId); } + public boolean isNullAt(int ordinal) { return data.getChildColumn(ordinal).isNullAt(rowId); } @Override - public boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); } + public boolean getBoolean(int ordinal) { return data.getChildColumn(ordinal).getBoolean(rowId); } @Override - public byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); } + public byte getByte(int ordinal) { return data.getChildColumn(ordinal).getByte(rowId); } @Override - public short getShort(int ordinal) { return columns[ordinal].getShort(rowId); } + public short getShort(int ordinal) { return data.getChildColumn(ordinal).getShort(rowId); } @Override - public int getInt(int ordinal) { return columns[ordinal].getInt(rowId); } + public int getInt(int ordinal) { return data.getChildColumn(ordinal).getInt(rowId); } @Override - public long getLong(int ordinal) { return columns[ordinal].getLong(rowId); } + public long getLong(int ordinal) { return data.getChildColumn(ordinal).getLong(rowId); } @Override - public float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); } + public float getFloat(int ordinal) { return data.getChildColumn(ordinal).getFloat(rowId); } @Override - public double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); } + public double getDouble(int ordinal) { return data.getChildColumn(ordinal).getDouble(rowId); } @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getDecimal(rowId, precision, scale); + if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; + return data.getChildColumn(ordinal).getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getUTF8String(rowId); + if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; + return data.getChildColumn(ordinal).getUTF8String(rowId); } @Override public byte[] getBinary(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getBinary(rowId); + if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; + return data.getChildColumn(ordinal).getBinary(rowId); } @Override public CalendarInterval getInterval(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - final int months = columns[ordinal].getChildColumn(0).getInt(rowId); - final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId); + if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; + final int months = data.getChildColumn(ordinal).getChildColumn(0).getInt(rowId); + final long microseconds = data.getChildColumn(ordinal).getChildColumn(1).getLong(rowId); return new CalendarInterval(months, microseconds); } @Override public ColumnarRow getStruct(int ordinal, int numFields) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getStruct(rowId); + if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; + return data.getChildColumn(ordinal).getStruct(rowId); } @Override public ColumnarArray getArray(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getArray(rowId); + if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; + return data.getChildColumn(ordinal).getArray(rowId); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index f272cc163611b..06602c147dfe9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -28,17 +28,24 @@ /** * A mutable version of {@link ColumnarRow}, which is used in the vectorized hash map for hash - * aggregate. + * aggregate, and {@link ColumnarBatch} to save object creation. * * Note that this class intentionally has a lot of duplicated code with {@link ColumnarRow}, to * avoid java polymorphism overhead by keeping {@link ColumnarRow} and this class final classes. */ public final class MutableColumnarRow extends InternalRow { public int rowId; - private final WritableColumnVector[] columns; + private final ColumnVector[] columns; + private final WritableColumnVector[] writableColumns; - public MutableColumnarRow(WritableColumnVector[] columns) { + public MutableColumnarRow(ColumnVector[] columns) { this.columns = columns; + this.writableColumns = null; + } + + public MutableColumnarRow(WritableColumnVector[] writableColumns) { + this.columns = writableColumns; + this.writableColumns = writableColumns; } @Override @@ -225,54 +232,54 @@ public void update(int ordinal, Object value) { @Override public void setNullAt(int ordinal) { - columns[ordinal].putNull(rowId); + writableColumns[ordinal].putNull(rowId); } @Override public void setBoolean(int ordinal, boolean value) { - columns[ordinal].putNotNull(rowId); - columns[ordinal].putBoolean(rowId, value); + writableColumns[ordinal].putNotNull(rowId); + writableColumns[ordinal].putBoolean(rowId, value); } @Override public void setByte(int ordinal, byte value) { - columns[ordinal].putNotNull(rowId); - columns[ordinal].putByte(rowId, value); + writableColumns[ordinal].putNotNull(rowId); + writableColumns[ordinal].putByte(rowId, value); } @Override public void setShort(int ordinal, short value) { - columns[ordinal].putNotNull(rowId); - columns[ordinal].putShort(rowId, value); + writableColumns[ordinal].putNotNull(rowId); + writableColumns[ordinal].putShort(rowId, value); } @Override public void setInt(int ordinal, int value) { - columns[ordinal].putNotNull(rowId); - columns[ordinal].putInt(rowId, value); + writableColumns[ordinal].putNotNull(rowId); + writableColumns[ordinal].putInt(rowId, value); } @Override public void setLong(int ordinal, long value) { - columns[ordinal].putNotNull(rowId); - columns[ordinal].putLong(rowId, value); + writableColumns[ordinal].putNotNull(rowId); + writableColumns[ordinal].putLong(rowId, value); } @Override public void setFloat(int ordinal, float value) { - columns[ordinal].putNotNull(rowId); - columns[ordinal].putFloat(rowId, value); + writableColumns[ordinal].putNotNull(rowId); + writableColumns[ordinal].putFloat(rowId, value); } @Override public void setDouble(int ordinal, double value) { - columns[ordinal].putNotNull(rowId); - columns[ordinal].putDouble(rowId, value); + writableColumns[ordinal].putNotNull(rowId); + writableColumns[ordinal].putDouble(rowId, value); } @Override public void setDecimal(int ordinal, Decimal value, int precision) { - columns[ordinal].putNotNull(rowId); - columns[ordinal].putDecimal(rowId, value, precision); + writableColumns[ordinal].putNotNull(rowId); + writableColumns[ordinal].putDecimal(rowId, value, precision); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 806d0291a6c49..5f1b9885334b7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -547,7 +547,7 @@ protected void reserveInternal(int newCapacity) { } else if (type instanceof LongType || type instanceof DoubleType || DecimalType.is64BitDecimalType(type) || type instanceof TimestampType) { this.data = Platform.reallocateMemory(data, oldCapacity * 8, newCapacity * 8); - } else if (resultStruct != null) { + } else if (childColumns != null) { // Nothing to store. } else { throw new RuntimeException("Unhandled " + type); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 6e7f74ce12f16..f12772ede575d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -558,7 +558,7 @@ protected void reserveInternal(int newCapacity) { if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, capacity); doubleData = newData; } - } else if (resultStruct != null) { + } else if (childColumns != null) { // Nothing to store. } else { throw new RuntimeException("Unhandled " + type); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 0bea4cc97142d..7c053b579442c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -74,7 +74,6 @@ public void close() { dictionaryIds = null; } dictionary = null; - resultStruct = null; } public void reserve(int requiredCapacity) { @@ -673,23 +672,19 @@ protected WritableColumnVector(int capacity, DataType type) { } this.childColumns = new WritableColumnVector[1]; this.childColumns[0] = reserveNewColumn(childCapacity, childType); - this.resultStruct = null; } else if (type instanceof StructType) { StructType st = (StructType)type; this.childColumns = new WritableColumnVector[st.fields().length]; for (int i = 0; i < childColumns.length; ++i) { this.childColumns[i] = reserveNewColumn(capacity, st.fields()[i].dataType()); } - this.resultStruct = new ColumnarRow(this.childColumns); } else if (type instanceof CalendarIntervalType) { // Two columns. Months as int. Microseconds as Long. this.childColumns = new WritableColumnVector[2]; this.childColumns[0] = reserveNewColumn(capacity, DataTypes.IntegerType); this.childColumns[1] = reserveNewColumn(capacity, DataTypes.LongType); - this.resultStruct = new ColumnarRow(this.childColumns); } else { this.childColumns = null; - this.resultStruct = null; } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 26d8cd7278353..9cadd13999e72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -595,9 +595,7 @@ case class HashAggregateExec( ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, s"$fastHashMapTerm = new $fastHashMapClassName();") - ctx.addMutableState( - s"java.util.Iterator<${classOf[ColumnarRow].getName}>", - iterTermForFastHashMap) + ctx.addMutableState(s"java.util.Iterator", iterTermForFastHashMap) } else { val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions, fastHashMapClassName, groupingKeySchema, bufferSchema).generate() @@ -674,7 +672,7 @@ case class HashAggregateExec( """.stripMargin } - // Iterate over the aggregate rows and convert them from ColumnarRow to UnsafeRow + // Iterate over the aggregate rows and convert them from InternalRow to UnsafeRow def outputFromVectorizedMap: String = { val row = ctx.freshName("fastHashMapRow") ctx.currentVars = null @@ -687,10 +685,9 @@ case class HashAggregateExec( bufferSchema.toAttributes.zipWithIndex.map { case (attr, i) => BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable) }) - val columnarRowCls = classOf[ColumnarRow].getName s""" |while ($iterTermForFastHashMap.hasNext()) { - | $columnarRowCls $row = ($columnarRowCls) $iterTermForFastHashMap.next(); + | InternalRow $row = (InternalRow) $iterTermForFastHashMap.next(); | ${generateKeyRow.code} | ${generateBufferRow.code} | $outputFunc(${generateKeyRow.value}, ${generateBufferRow.value}); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 44ba539ebf7c2..f04cd48072f17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext -import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnarRow, MutableColumnarRow, OnHeapColumnVector} +import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, MutableColumnarRow, OnHeapColumnVector} import org.apache.spark.sql.types._ /** @@ -231,7 +232,7 @@ class VectorizedHashMapGenerator( protected def generateRowIterator(): String = { s""" - |public java.util.Iterator<${classOf[ColumnarRow].getName}> rowIterator() { + |public java.util.Iterator<${classOf[InternalRow].getName}> rowIterator() { | batch.setNumRows(numRows); | return batch.rowIterator(); |} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 0ae4f2d117609..c9c6bee513b53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -751,11 +751,6 @@ class ColumnarBatchSuite extends SparkFunSuite { c2.putDouble(1, 5.67) val s = column.getStruct(0) - assert(s.columns()(0).getInt(0) == 123) - assert(s.columns()(0).getInt(1) == 456) - assert(s.columns()(1).getDouble(0) == 3.45) - assert(s.columns()(1).getDouble(1) == 5.67) - assert(s.getInt(0) == 123) assert(s.getDouble(1) == 3.45)