Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.language.existentials

import com.google.common.cache.{CacheBuilder, CacheLoader}
Expand Down Expand Up @@ -265,6 +266,43 @@ class CodeGenContext {
def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt)

def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt))

/**
* Splits the generated code of expressions into multiple functions, because function has
* 64kb code size limit in JVM
*/
def splitExpressions(input: String, expressions: Seq[String]): String = {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you document what "input" is here?

val blocks = new ArrayBuffer[String]()
val blockBuilder = new StringBuilder()
for (code <- expressions) {
// We can't know how many byte code will be generated, so use the number of bytes as limit
if (blockBuilder.length > 64 * 1000) {
blocks.append(blockBuilder.toString())
blockBuilder.clear()
}
blockBuilder.append(code)
}
blocks.append(blockBuilder.toString())

if (blocks.length == 1) {
// inline execution if only one block
blocks.head
} else {
val apply = freshName("apply")
val functions = blocks.zipWithIndex.map { case (body, i) =>
val name = s"${apply}_$i"
val code = s"""
|private void $name(InternalRow $input) {
| $body
|}
""".stripMargin
addNewFunction(name, code)
name
}

functions.map(name => s"$name($input);").mkString("\n")
}
}
}

/**
Expand All @@ -289,15 +327,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
protected def declareMutableStates(ctx: CodeGenContext): String = {
ctx.mutableStates.map { case (javaType, variableName, _) =>
s"private $javaType $variableName;"
}.mkString
}.mkString("\n")
}

protected def initMutableStates(ctx: CodeGenContext): String = {
ctx.mutableStates.map(_._3).mkString
ctx.mutableStates.map(_._3).mkString("\n")
}

protected def declareAddedFunctions(ctx: CodeGenContext): String = {
ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString
ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString("\n")
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu

protected def create(expressions: Seq[Expression]): (() => MutableProjection) = {
val ctx = newCodeGenContext()
val projectionCode = expressions.zipWithIndex.map {
val projectionCodes = expressions.zipWithIndex.map {
case (NoOp, _) => ""
case (e, i) =>
val evaluationCode = e.gen(ctx)
Expand All @@ -65,35 +65,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
"""
}
}
// collect projections into blocks as function has 64kb codesize limit in JVM
val projectionBlocks = new ArrayBuffer[String]()
val blockBuilder = new StringBuilder()
for (projection <- projectionCode) {
if (blockBuilder.length > 16 * 1000) {
projectionBlocks.append(blockBuilder.toString())
blockBuilder.clear()
}
blockBuilder.append(projection)
}
projectionBlocks.append(blockBuilder.toString())

val (projectionFuns, projectionCalls) = {
// inline execution if codesize limit was not broken
if (projectionBlocks.length == 1) {
("", projectionBlocks.head)
} else {
(
projectionBlocks.zipWithIndex.map { case (body, i) =>
s"""
|private void apply$i(InternalRow i) {
| $body
|}
""".stripMargin
}.mkString,
projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n")
)
}
}
val allProjections = ctx.splitExpressions("i", projectionCodes)

val code = s"""
public Object generate($exprType[] expr) {
Expand Down Expand Up @@ -123,12 +95,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
return (InternalRow) mutableRow;
}

$projectionFuns

public Object apply(Object _i) {
InternalRow i = (InternalRow) _i;
$projectionCalls

$allProjections
return mutableRow;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,11 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
${genUpdater(ctx, rowTerm, dt, i, colTerm)};
}
"""
}.mkString("\n")
}
val allUpdates = ctx.splitExpressions(value, updates)
s"""
$genericMutableRowType $rowTerm = new $genericMutableRowType(${struct.fields.length});
$updates
$allUpdates
$setter.update($ordinal, $rowTerm.copy());
"""
case _ =>
Expand All @@ -68,7 +69,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]

protected def create(expressions: Seq[Expression]): Projection = {
val ctx = newCodeGenContext()
val projectionCode = expressions.zipWithIndex.map {
val expressionCodes = expressions.zipWithIndex.map {
case (NoOp, _) => ""
case (e, i) =>
val evaluationCode = e.gen(ctx)
Expand All @@ -81,36 +82,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
}
"""
}
// collect projections into blocks as function has 64kb codesize limit in JVM
val projectionBlocks = new ArrayBuffer[String]()
val blockBuilder = new StringBuilder()
for (projection <- projectionCode) {
if (blockBuilder.length > 16 * 1000) {
projectionBlocks.append(blockBuilder.toString())
blockBuilder.clear()
}
blockBuilder.append(projection)
}
projectionBlocks.append(blockBuilder.toString())

val (projectionFuns, projectionCalls) = {
// inline it if we have only one block
if (projectionBlocks.length == 1) {
("", projectionBlocks.head)
} else {
(
projectionBlocks.zipWithIndex.map { case (body, i) =>
s"""
|private void apply$i(InternalRow i) {
| $body
|}
""".stripMargin
}.mkString,
projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n")
)
}
}

val allExpressions = ctx.splitExpressions("i", expressionCodes)
val code = s"""
public Object generate($exprType[] expr) {
return new SpecificSafeProjection(expr);
Expand All @@ -121,19 +93,17 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
private $exprType[] expressions;
private $mutableRowType mutableRow;
${declareMutableStates(ctx)}
${declareAddedFunctions(ctx)}

public SpecificSafeProjection($exprType[] expr) {
expressions = expr;
mutableRow = new $genericMutableRowType(${expressions.size});
${initMutableStates(ctx)}
}

$projectionFuns

public Object apply(Object _i) {
InternalRow i = (InternalRow) _i;
$projectionCalls

$allExpressions
return mutableRow;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,19 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro

def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match {
case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
s" + $DecimalWriter.getSize(${ev.primitive})"
s"$DecimalWriter.getSize(${ev.primitive})"
case StringType =>
s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))"
s"${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive})"
case BinaryType =>
s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))"
s"${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive})"
case CalendarIntervalType =>
s" + (${ev.isNull} ? 0 : 16)"
s"${ev.isNull} ? 0 : 16"
case _: StructType =>
s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))"
s"${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive})"
case _: ArrayType =>
s" + (${ev.isNull} ? 0 : $ArrayWriter.getSize(${ev.primitive}))"
s"${ev.isNull} ? 0 : $ArrayWriter.getSize(${ev.primitive})"
case _: MapType =>
s" + (${ev.isNull} ? 0 : $MapWriter.getSize(${ev.primitive}))"
s"${ev.isNull} ? 0 : $MapWriter.getSize(${ev.primitive})"
case _ => ""
}

Expand Down Expand Up @@ -125,64 +125,68 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
*/
private def createCodeForStruct(
ctx: CodeGenContext,
row: String,
inputs: Seq[GeneratedExpressionCode],
inputTypes: Seq[DataType]): GeneratedExpressionCode = {

val fixedSize = 8 * inputTypes.length + UnsafeRow.calculateBitSetWidthInBytes(inputTypes.length)

val output = ctx.freshName("convertedStruct")
ctx.addMutableState("UnsafeRow", output, s"$output = new UnsafeRow();")
ctx.addMutableState("UnsafeRow", output, s"this.$output = new UnsafeRow();")
val buffer = ctx.freshName("buffer")
ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];")
val numBytes = ctx.freshName("numBytes")
ctx.addMutableState("byte[]", buffer, s"this.$buffer = new byte[$fixedSize];")
val cursor = ctx.freshName("cursor")
ctx.addMutableState("int", cursor, s"this.$cursor = 0;")
val tmp = ctx.freshName("tmpBuffer")

val convertedFields = inputTypes.zip(inputs).map { case (dt, input) =>
createConvertCode(ctx, input, dt)
}

val fixedSize = 8 * inputTypes.length + UnsafeRow.calculateBitSetWidthInBytes(inputTypes.length)
val additionalSize = inputTypes.zip(convertedFields).map { case (dt, ev) =>
genAdditionalSize(dt, ev)
}.mkString("")

val fieldWriters = inputTypes.zip(convertedFields).zipWithIndex.map { case ((dt, ev), i) =>
val update = genFieldWriter(ctx, dt, ev, output, i, cursor)
if (dt.isInstanceOf[DecimalType]) {
// Can't call setNullAt() for DecimalType
val convertedFields = inputTypes.zip(inputs).zipWithIndex.map { case ((dt, input), i) =>
val ev = createConvertCode(ctx, input, dt)
val growBuffer = if (!UnsafeRow.isFixedLength(dt)) {
val numBytes = ctx.freshName("numBytes")
s"""
int $numBytes = $cursor + (${genAdditionalSize(dt, ev)});
if ($buffer.length < $numBytes) {
// This will not happen frequently, because the buffer is re-used.
byte[] $tmp = new byte[$numBytes * 3 / 2];

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should do 8 byte alignment here

System.arraycopy($buffer, 0, $tmp, 0, $buffer.length);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and use PlatformDependent.copyMemory

$buffer = $tmp;
}
$output.pointTo($buffer, $PlatformDependent.BYTE_ARRAY_OFFSET,
${inputTypes.length}, $numBytes);
"""
} else {
""
}
val update = dt match {
case dt: DecimalType if dt.precision > Decimal.MAX_LONG_DIGITS =>
// Can't call setNullAt() for DecimalType
s"""
if (${ev.isNull}) {
$cursor += $DecimalWriter.write($output, $i, $cursor, null);
$cursor += $DecimalWriter.write($output, $i, $cursor, null);
} else {
$update;
${genFieldWriter(ctx, dt, ev, output, i, cursor)};
}
"""
} else {
s"""
case _ =>
s"""
if (${ev.isNull}) {
$output.setNullAt($i);
} else {
$update;
${genFieldWriter(ctx, dt, ev, output, i, cursor)};
}
"""
}
}.mkString("\n")
s"""
${ev.code}
$growBuffer
$update
"""
}

val code = s"""
${convertedFields.map(_.code).mkString("\n")}

final int $numBytes = $fixedSize $additionalSize;
if ($numBytes > $buffer.length) {
$buffer = new byte[$numBytes];
}

$output.pointTo(
$buffer,
$PlatformDependent.BYTE_ARRAY_OFFSET,
${inputTypes.length},
$numBytes);

int $cursor = $fixedSize;

$fieldWriters
$cursor = $fixedSize;
$output.pointTo($buffer, $PlatformDependent.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $cursor);
${ctx.splitExpressions(row, convertedFields)}
"""
GeneratedExpressionCode(code, "false", output)
}
Expand Down Expand Up @@ -400,7 +404,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val fieldIsNull = s"$tmp.isNullAt($i)"
GeneratedExpressionCode("", fieldIsNull, getFieldCode)
}
val converter = createCodeForStruct(ctx, fieldEvals, fieldTypes)
val converter = createCodeForStruct(ctx, tmp, fieldEvals, fieldTypes)
val code = s"""
${input.code}
UnsafeRow $output = null;
Expand All @@ -427,7 +431,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = {
val exprEvals = expressions.map(e => e.gen(ctx))
val exprTypes = expressions.map(_.dataType)
createCodeForStruct(ctx, exprEvals, exprTypes)
createCodeForStruct(ctx, "i", exprEvals, exprTypes)
}

protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
|$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32));
""".stripMargin
}
}.mkString
}.mkString("\n")

// ------------------------ Finally, put everything together --------------------------- //
val code = s"""
Expand Down
Loading