Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val CACHE_CODEGEN = buildConf("spark.sql.inMemoryColumnarStorage.codegen")
.internal()
.doc("When true, use generated code to build column batches for caching. This is only " +
"supported for basic types and improves caching performance for such types.")
.booleanConf
.createWithDefault(true)

val PREFER_SORTMERGEJOIN = buildConf("spark.sql.join.preferSortMergeJoin")
.internal()
.doc("When true, prefer sort merge join over shuffle hash join.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import org.apache.spark.sql.types.DataType
*/
private[sql] trait ColumnarBatchScan extends CodegenSupport {

val columnIndexes: Array[Int] = null

val inMemoryTableScan: InMemoryTableScanExec = null

override lazy val metrics = Map(
Expand Down Expand Up @@ -89,7 +91,8 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
val colVars = output.indices.map(i => ctx.freshName("colInstance" + i))
val columnAssigns = colVars.zipWithIndex.map { case (name, i) =>
ctx.addMutableState(columnVectorClz, name, s"$name = null;")
s"$name = $batch.column($i);"
val index = if (columnIndexes == null) i else columnIndexes(i)
s"$name = $batch.column($index);"
}

val nextBatch = ctx.freshName("nextBatch")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeFormatter, CodeGenerator, UnsafeRowWriter}
import org.apache.spark.sql.execution.vectorized.ColumnarBatch
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -62,12 +63,16 @@ class MutableUnsafeRow(val writer: UnsafeRowWriter) extends BaseGenericInternalR
/**
* Generates bytecode for a [[ColumnarIterator]] for columnar cache.
*/
object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarIterator] with Logging {
class GenerateColumnAccessor(useColumnarBatch: Boolean)
extends CodeGenerator[Seq[DataType], ColumnarIterator] with Logging {

protected def canonicalize(in: Seq[DataType]): Seq[DataType] = in
protected def bind(in: Seq[DataType], inputSchema: Seq[Attribute]): Seq[DataType] = in

protected def create(columnTypes: Seq[DataType]): ColumnarIterator = {
if (useColumnarBatch) {
return createItrForCacheColumnarBatch(columnTypes)
}
val ctx = newCodeGenContext()
val numFields = columnTypes.size
val (initializeAccessors, extractors) = columnTypes.zipWithIndex.map { case (dt, index) =>
Expand Down Expand Up @@ -152,6 +157,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
(0 to groupedAccessorsLength - 1).map { i => s"extractors$i();" }.mkString("\n"))
}

val cachedBatchBytesCls = classOf[CachedBatchBytes].getName
val codeBody = s"""
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
Expand Down Expand Up @@ -205,9 +211,9 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
return false;
}

${classOf[CachedBatch].getName} batch = (${classOf[CachedBatch].getName}) input.next();
$cachedBatchBytesCls batch = ($cachedBatchBytesCls) input.next();
currentRow = 0;
numRowsInBatch = batch.numRows();
numRowsInBatch = batch.getNumRows();
for (int i = 0; i < columnIndexes.length; i ++) {
buffers[i] = batch.buffers()[columnIndexes[i]];
}
Expand All @@ -232,4 +238,110 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera

CodeGenerator.compile(code).generate(Array.empty).asInstanceOf[ColumnarIterator]
}

protected def createItrForCacheColumnarBatch(columnTypes: Seq[DataType])
: ColumnarIterator = {
val ctx = newCodeGenContext()
val numFields = columnTypes.size

val setters = ctx.splitExpressions(
columnTypes.zipWithIndex.map { case (dt, index) =>
val setter = dt match {
case IntegerType | DateType => s"setInt($index, colInstances[$index].getInt(rowIdx))"
case DoubleType => s"setDouble($index, colInstances[$index].getDouble(rowIdx))"
case _ => throw new UnsupportedOperationException(s"Unsupported type $dt")
}

s"""
if (colInstances[$index].isNullAt(rowIdx)) {
mutableRow.setNullAt($index);
} else {
mutableRow.$setter;
}
"""
},
"apply",
Seq.empty
)

val codeBody = s"""
import scala.collection.Iterator;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder;
import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter;
import org.apache.spark.sql.execution.columnar.MutableUnsafeRow;
import org.apache.spark.sql.execution.vectorized.ColumnVector;
import org.apache.spark.sql.execution.vectorized.OnHeapUnsafeColumnVector;

public SpecificColumnarIterator generate(Object[] references) {
return new SpecificColumnarIterator(references);
}

class SpecificColumnarIterator extends ${classOf[ColumnarIterator].getName} {
private ColumnVector[] colInstances;
private UnsafeRow unsafeRow = new UnsafeRow($numFields);
private BufferHolder bufferHolder = new BufferHolder(unsafeRow);
private UnsafeRowWriter rowWriter = new UnsafeRowWriter(bufferHolder, $numFields);
private MutableUnsafeRow mutableRow = null;

private int rowIdx = 0;
private int numRowsInBatch = 0;

private scala.collection.Iterator input = null;
private DataType[] columnTypes = null;
private int[] columnIndexes = null;

${ctx.declareMutableStates()}

public SpecificColumnarIterator(Object[] references) {
${ctx.initMutableStates()}
this.mutableRow = new MutableUnsafeRow(rowWriter);
}

public void initialize(Iterator input, DataType[] columnTypes, int[] columnIndexes) {
this.input = input;
this.columnTypes = columnTypes;
this.columnIndexes = columnIndexes;
}

${ctx.declareAddedFunctions()}

public boolean hasNext() {
if (rowIdx < numRowsInBatch) {
return true;
}
if (!input.hasNext()) {
return false;
}

${classOf[CachedColumnarBatch].getName} cachedBatch =
(${classOf[CachedColumnarBatch].getName}) input.next();
${classOf[ColumnarBatch].getName} batch = cachedBatch.columnarBatch();
rowIdx = 0;
numRowsInBatch = cachedBatch.getNumRows();
colInstances = new ColumnVector[columnIndexes.length];
for (int i = 0; i < columnIndexes.length; i ++) {
colInstances[i] = batch.column(columnIndexes[i]);
((OnHeapUnsafeColumnVector)colInstances[i]).decompress();
}

return hasNext();
}

public InternalRow next() {
bufferHolder.reset();
rowWriter.zeroOutNullBytes();
${setters}
unsafeRow.setTotalSize(bufferHolder.totalSize());
rowIdx += 1;
return unsafeRow;
}
}"""

val code = CodeFormatter.stripOverlappingComments(
new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()))
logDebug(s"Generated ColumnarIteratorForCachedColumnarBatch:\n${CodeFormatter.format(code)}")

CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[ColumnarIterator]
}
}
Loading