Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,17 @@ class EquivalentExpressions {

// There are some special expressions that we should not recurse into all of its children.
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
// 2. If: common subexpressions will always be evaluated at the beginning, but the true and
// 2. LambdaFunction: it's children operate in the context of local lambdas and can't be split
// 3. If: common subexpressions will always be evaluated at the beginning, but the true and
// false expressions in `If` may not get accessed, according to the predicate
// expression. We should only recurse into the predicate expression.
// 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain
// 4. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain
// condition. We should only recurse into the first condition expression as it
// will always get accessed.
// 4. Coalesce: it's also a conditional expression, we should only recurse into the first
// 5. Coalesce: it's also a conditional expression, we should only recurse into the first
// children, because others may not get accessed.
private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match {
case _: CodegenFallback => Nil
case _: CodegenFallback | _: LambdaFunction => Nil
case i: If => i.predicate :: Nil
case c: CaseWhen => c.children.head :: Nil
case c: Coalesce => c.children.head :: Nil
Expand All @@ -122,7 +123,7 @@ class EquivalentExpressions {
// For some special expressions we cannot just recurse into all of its children, but we can
// recursively add the common expressions shared between all of its children.
private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] = expr match {
case _: CodegenFallback => Nil
case _: CodegenFallback | _: LambdaFunction => Nil
case i: If => Seq(Seq(i.trueValue, i.falseValue))
case c: CaseWhen =>
// We look at subexpressions in conditions and values of `CaseWhen` separately. It is
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ abstract class Expression extends TreeNode[Expression] {
private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = {
// TODO: support whole stage codegen too
val splitThreshold = SQLConf.get.methodSplitThreshold
if (eval.code.length > splitThreshold && ctx.INPUT_ROW != null && ctx.currentVars == null) {
if (eval.code.length > splitThreshold && ctx.INPUT_ROW != null && ctx.currentVars == null
&& ctx.currentLambdaVars.isEmpty) {
val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) {
val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull")
val localIsNull = eval.isNull
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,45 @@ class CodegenContext extends Logging {
*/
var currentVars: Seq[ExprCode] = null

/**
* Holding a map of current lambda variables.
*/
var currentLambdaVars: mutable.Map[String, ExprCode] = mutable.HashMap.empty

def withLambdaVar(namedLambda: NamedLambdaVariable, f: ExprCode => ExprCode): ExprCode = {
val name = namedLambda.variableName
if (currentLambdaVars.get(name).nonEmpty) {
throw QueryExecutionErrors.lambdaVariableAlreadyDefinedError(name)
}
val isNull = if (namedLambda.nullable) {
JavaCode.isNullVariable(s"lambda_${name}_isNull")
} else {
FalseLiteral
}
val lambdaVar = ExprCode(isNull, JavaCode.variable(s"lambda_$name", namedLambda.dataType))

currentLambdaVars.put(name, lambdaVar)
val result = f(lambdaVar)
currentLambdaVars.remove(name)
result
}

def withOptionalLambdaVar(namedLambda: Option[NamedLambdaVariable],
f: Option[ExprCode] => ExprCode): ExprCode = {
namedLambda.map { lambdaVar =>
def wrapperFunc(ev: ExprCode): ExprCode = f(Some(ev))
withLambdaVar(lambdaVar, wrapperFunc)
}.getOrElse {
f(None)
}
}

def getLambdaVar(name: String): ExprCode = {
currentLambdaVars.getOrElse(name, {
throw QueryExecutionErrors.lambdaVariableNotDefinedError(name)
})
}

/**
* Holding expressions' inlined mutable states like `MonotonicallyIncreasingID.count` as a
* 2-tuple: java type, variable name.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}

import scala.collection.mutable

import org.apache.spark.sql.catalyst.CatalystTypeConverters.isPrimitive
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedException}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.{BinaryLike, QuaternaryLike, TernaryLike}
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util._
Expand Down Expand Up @@ -76,8 +78,7 @@ case class NamedLambdaVariable(
exprId: ExprId = NamedExpression.newExprId,
value: AtomicReference[Any] = new AtomicReference())
extends LeafExpression
with NamedExpression
with CodegenFallback {
with NamedExpression {

override def qualifier: Seq[String] = Seq.empty

Expand All @@ -98,6 +99,17 @@ case class NamedLambdaVariable(
override def simpleString(maxFields: Int): String = {
s"lambda $name#${exprId.id}: ${dataType.simpleString(maxFields)}"
}

// We need to include the Expr ID in the Codegen variable name since several tests bypass
// `UnresolvedNamedLambdaVariable.freshVarName`
lazy val variableName = s"${name}_${exprId.id}"

override def genCode(ctx: CodegenContext): ExprCode = {
ctx.getLambdaVar(variableName)
}

// This won't be called as `genCode` is overridden, just overriding it to make non-abstract.
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev
}

/**
Expand All @@ -109,7 +121,7 @@ case class LambdaFunction(
function: Expression,
arguments: Seq[NamedExpression],
hidden: Boolean = false)
extends Expression with CodegenFallback {
extends Expression {

override def children: Seq[Expression] = function +: arguments
override def dataType: DataType = function.dataType
Expand All @@ -127,6 +139,23 @@ case class LambdaFunction(

override def eval(input: InternalRow): Any = function.eval(input)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val functionCode = function.genCode(ctx)

if (nullable) {
ev.copy(code = code"""
|${functionCode.code}
|boolean ${ev.isNull} = ${functionCode.isNull};
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${functionCode.value};
""".stripMargin)
} else {
ev.copy(code = code"""
|${functionCode.code}
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${functionCode.value};
""".stripMargin, isNull = FalseLiteral)
}
}

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): LambdaFunction =
copy(
Expand Down Expand Up @@ -269,6 +298,29 @@ trait SimpleHigherOrderFunction extends HigherOrderFunction with BinaryLike[Expr
}
}

protected def nullSafeCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: String => String): ExprCode = {
val argumentGen = argument.genCode(ctx)
val resultCode = f(argumentGen.value)

if (nullable) {
val nullSafeEval = ctx.nullSafeExec(argument.nullable, argumentGen.isNull)(resultCode)
ev.copy(code = code"""
|${argumentGen.code}
|boolean ${ev.isNull} = ${argumentGen.isNull};
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|$nullSafeEval
""")
} else {
ev.copy(code = code"""
|${argumentGen.code}
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|$resultCode
""", isNull = FalseLiteral)
}
}
}

trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction {
Expand Down Expand Up @@ -297,7 +349,7 @@ trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction {
case class ArrayTransform(
argument: Expression,
function: Expression)
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
extends ArrayBasedSimpleHigherOrderFunction {

override def dataType: ArrayType = ArrayType(function.dataType, function.nullable)

Expand Down Expand Up @@ -338,6 +390,68 @@ case class ArrayTransform(
result
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of this can probably be abstracted out into the parent traits, but I figured that will be easier to do when implementing a second function

ctx.withLambdaVar(elementVar, elementCode => {
ctx.withOptionalLambdaVar(indexVar, indexCode => {
nullSafeCodeGen(ctx, ev, arg => {
val numElements = ctx.freshName("numElements")
val arrayData = ctx.freshName("arrayData")
val i = ctx.freshName("i")

val argumentType = argument.dataType.asInstanceOf[ArrayType]
val argumentElementJavaType = CodeGenerator.javaType(argumentType.elementType)
val initialization = CodeGenerator.createArrayData(
arrayData, dataType.elementType, numElements, s" $prettyName failed.")
val extractElement = CodeGenerator.getValue(arg, argumentType.elementType, i)

val functionCode = function.genCode(ctx)

val elementAtomic = ctx.addReferenceObj(elementVar.variableName, elementVar.value)
val elementAssignment = if (elementVar.nullable) {
s"""
$argumentElementJavaType ${elementCode.value} = $extractElement;
boolean ${elementCode.isNull} = ${arg}.isNullAt($i);
$elementAtomic.set(${elementCode.value});
"""
} else {
s"""
$argumentElementJavaType ${elementCode.value} = $extractElement;
$elementAtomic.set(${elementCode.value});
"""
}
val indexAssignment = indexCode.map(c => {
val indexAtomic = ctx.addReferenceObj(indexVar.get.variableName, indexVar.get.value)
s"""
int ${c.value} = $i;
$indexAtomic.set(${c.value});
"""
})
val varAssignments = (Seq(elementAssignment) ++: indexAssignment).mkString("\n")

// Some expressions return internal buffers that we have to copy
val copy = if (isPrimitive(function.dataType)) {
s"${functionCode.value}"
} else {
s"InternalRow.copyValue(${functionCode.value})"
}
val resultAssignment = CodeGenerator.setArrayElement(arrayData, dataType.elementType,
i, copy, isNull = Some(functionCode.isNull))

s"""
|final int $numElements = ${arg}.numElements();
|$initialization
|for (int $i = 0; $i < $numElements; $i++) {
| $varAssignments
| ${functionCode.code}
| $resultAssignment
|}
|${ev.value} = $arrayData;
""".stripMargin
})
})
})
}

override def prettyName: String = "transform"

override protected def withNewChildrenInternal(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,15 @@ object QueryExecutionErrors {
new IllegalArgumentException(s"$funcName is not matched at addNewFunction")
}

def lambdaVariableAlreadyDefinedError(name: String): Throwable = {
new IllegalArgumentException(s"Lambda variable $name cannot be redefined")
}

def lambdaVariableNotDefinedError(name: String): Throwable = {
new IllegalArgumentException(
s"Lambda variable $name is not defined in the current codegen scope")
}

def cannotGenerateCodeForUncomparableTypeError(
codeType: String, dataType: DataType): Throwable = {
new IllegalArgumentException(
Expand Down