Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.language.existentials

import com.google.common.cache.{CacheBuilder, CacheLoader}
Expand Down Expand Up @@ -265,6 +266,45 @@ class CodeGenContext {
def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt)

def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt))

/**
* Splits the generated code of expressions into multiple functions, because function has
* 64kb code size limit in JVM
*
* @param row the variable name of row that is used by expressions
*/
def splitExpressions(row: String, expressions: Seq[String]): String = {
val blocks = new ArrayBuffer[String]()
val blockBuilder = new StringBuilder()
for (code <- expressions) {
// We can't know how many byte code will be generated, so use the number of bytes as limit
if (blockBuilder.length > 64 * 1000) {
blocks.append(blockBuilder.toString())
blockBuilder.clear()
}
blockBuilder.append(code)
}
blocks.append(blockBuilder.toString())

if (blocks.length == 1) {
// inline execution if only one block
blocks.head
} else {
val apply = freshName("apply")
val functions = blocks.zipWithIndex.map { case (body, i) =>
val name = s"${apply}_$i"
val code = s"""
|private void $name(InternalRow $row) {
| $body
|}
""".stripMargin
addNewFunction(name, code)
name
}

functions.map(name => s"$name($row);").mkString("\n")
}
}
}

/**
Expand All @@ -289,15 +329,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
protected def declareMutableStates(ctx: CodeGenContext): String = {
ctx.mutableStates.map { case (javaType, variableName, _) =>
s"private $javaType $variableName;"
}.mkString
}.mkString("\n")
}

protected def initMutableStates(ctx: CodeGenContext): String = {
ctx.mutableStates.map(_._3).mkString
ctx.mutableStates.map(_._3).mkString("\n")
}

protected def declareAddedFunctions(ctx: CodeGenContext): String = {
ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString
ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString("\n")
}

/**
Expand Down Expand Up @@ -328,6 +368,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
private[this] def doCompile(code: String): GeneratedClass = {
val evaluator = new ClassBodyEvaluator()
evaluator.setParentClassLoader(getClass.getClassLoader)
// Cannot be under package codegen, or fail with java.lang.InstantiationException
evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass")
evaluator.setDefaultImports(Array(
classOf[PlatformDependent].getName,
classOf[InternalRow].getName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu

protected def create(expressions: Seq[Expression]): (() => MutableProjection) = {
val ctx = newCodeGenContext()
val projectionCode = expressions.zipWithIndex.map {
val projectionCodes = expressions.zipWithIndex.map {
case (NoOp, _) => ""
case (e, i) =>
val evaluationCode = e.gen(ctx)
Expand All @@ -65,49 +65,21 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
"""
}
}
// collect projections into blocks as function has 64kb codesize limit in JVM
val projectionBlocks = new ArrayBuffer[String]()
val blockBuilder = new StringBuilder()
for (projection <- projectionCode) {
if (blockBuilder.length > 16 * 1000) {
projectionBlocks.append(blockBuilder.toString())
blockBuilder.clear()
}
blockBuilder.append(projection)
}
projectionBlocks.append(blockBuilder.toString())

val (projectionFuns, projectionCalls) = {
// inline execution if codesize limit was not broken
if (projectionBlocks.length == 1) {
("", projectionBlocks.head)
} else {
(
projectionBlocks.zipWithIndex.map { case (body, i) =>
s"""
|private void apply$i(InternalRow i) {
| $body
|}
""".stripMargin
}.mkString,
projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n")
)
}
}
val allProjections = ctx.splitExpressions("i", projectionCodes)

val code = s"""
public Object generate($exprType[] expr) {
return new SpecificProjection(expr);
return new SpecificMutableProjection(expr);
}

class SpecificProjection extends ${classOf[BaseMutableProjection].getName} {
class SpecificMutableProjection extends ${classOf[BaseMutableProjection].getName} {

private $exprType[] expressions;
private $mutableRowType mutableRow;
${declareMutableStates(ctx)}
${declareAddedFunctions(ctx)}

public SpecificProjection($exprType[] expr) {
public SpecificMutableProjection($exprType[] expr) {
expressions = expr;
mutableRow = new $genericMutableRowType(${expressions.size});
${initMutableStates(ctx)}
Expand All @@ -123,12 +95,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
return (InternalRow) mutableRow;
}

$projectionFuns

public Object apply(Object _i) {
InternalRow i = (InternalRow) _i;
$projectionCalls

$allProjections
return mutableRow;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql.catalyst.expressions.codegen

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
import org.apache.spark.sql.types._
Expand All @@ -43,6 +41,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val tmp = ctx.freshName("tmp")
val output = ctx.freshName("safeRow")
val values = ctx.freshName("values")
// These expressions could be splitted into multiple functions
ctx.addMutableState("Object[]", values, s"this.$values = null;")

val rowClass = classOf[GenericInternalRow].getName

val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
Expand All @@ -53,12 +54,12 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
$values[$i] = ${converter.primitive};
}
"""
}.mkString("\n")

}
val allFields = ctx.splitExpressions(tmp, fieldWriters)
val code = s"""
final InternalRow $tmp = $input;
final Object[] $values = new Object[${schema.length}];
$fieldWriters
this.$values = new Object[${schema.length}];
$allFields
final InternalRow $output = new $rowClass($values);
"""

Expand Down Expand Up @@ -128,7 +129,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]

protected def create(expressions: Seq[Expression]): Projection = {
val ctx = newCodeGenContext()
val projectionCode = expressions.zipWithIndex.map {
val expressionCodes = expressions.zipWithIndex.map {
case (NoOp, _) => ""
case (e, i) =>
val evaluationCode = e.gen(ctx)
Expand All @@ -143,36 +144,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
}
"""
}
// collect projections into blocks as function has 64kb codesize limit in JVM
val projectionBlocks = new ArrayBuffer[String]()
val blockBuilder = new StringBuilder()
for (projection <- projectionCode) {
if (blockBuilder.length > 16 * 1000) {
projectionBlocks.append(blockBuilder.toString())
blockBuilder.clear()
}
blockBuilder.append(projection)
}
projectionBlocks.append(blockBuilder.toString())

val (projectionFuns, projectionCalls) = {
// inline it if we have only one block
if (projectionBlocks.length == 1) {
("", projectionBlocks.head)
} else {
(
projectionBlocks.zipWithIndex.map { case (body, i) =>
s"""
|private void apply$i(InternalRow i) {
| $body
|}
""".stripMargin
}.mkString,
projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n")
)
}
}

val allExpressions = ctx.splitExpressions("i", expressionCodes)
val code = s"""
public Object generate($exprType[] expr) {
return new SpecificSafeProjection(expr);
Expand All @@ -183,19 +155,17 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
private $exprType[] expressions;
private $mutableRowType mutableRow;
${declareMutableStates(ctx)}
${declareAddedFunctions(ctx)}

public SpecificSafeProjection($exprType[] expr) {
expressions = expr;
mutableRow = new $genericMutableRowType(${expressions.size});
${initMutableStates(ctx)}
}

$projectionFuns

public Object apply(Object _i) {
InternalRow i = (InternalRow) _i;
$projectionCalls

$allExpressions
return mutableRow;
}
}
Expand Down
Loading