Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,27 @@ class CodegenContext {
// Cannot split these expressions because they are not created from a row object.
return expressions.mkString("\n")
}
splitExpressions(expressions, "apply", ("InternalRow", row) :: Nil)
splitExpressions(row, expressions, Seq.empty)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the check in lines 735-738? Now, we do the same check at lines 754-757, too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed, thanks.

}

/**
* Splits the generated code of expressions into multiple functions, because function has
* 64kb code size limit in JVM
*
* @param row the variable name of row that is used by expressions
* @param expressions the codes to evaluate expressions.
* @param arguments the additional arguments to the functions.
*/
def splitExpressions(
row: String,
expressions: Seq[String],
arguments: Seq[(String, String)]): String = {
if (row == null || currentVars != null) {
// Cannot split these expressions because they are not created from a row object.
return expressions.mkString("\n")
}
val params = arguments ++ Seq(("InternalRow", row))
splitExpressions(expressions, "apply", params)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
Expand Down Expand Up @@ -342,19 +343,38 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc
val values = ctx.freshName("values")
ctx.addMutableState("Object[]", values, s"$values = null;")

// `splitExpressions` might split codes to multiple functions. The local variables of
// `LambdaVariable` can't be accessed in the functions. We need to add them into the parameters
// of the functions.
val (valExprCodes, valExprParams) = valExprs.map { expr =>
val exprCode = expr.genCode(ctx)
val lambdaVars = expr.collect {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks hacky, and why you only do it for CreateNamedStruct?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, CreateArray is another expression using splitExpressions. However, if the external type is an array, MapObjects will be used to construct the internal array. CreateMap is the same, I think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC there are a lot of places we use splitExpressions, shall we check all of them?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very sure about other places. I move the collecting of lambda variable to splitExpressions and let the possible places to do this check.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach doesn't work well after rethinking about this and more experiments.

A simpler approach is letting those lambda variables global as MapObjectsandCollectObjectsToMap` do.

case l: LambdaVariable => l
}.flatMap { lambda =>
val valueParam = ctx.javaType(lambda.dataType) -> lambda.value
if (lambda.isNull == "false") {
Seq(valueParam)
} else {
Seq(valueParam, "boolean" -> lambda.isNull)
}
}
(exprCode, lambdaVars)
}.unzip

val splitFuncsParams = valExprParams.flatten.distinct

ev.copy(code = s"""
$values = new Object[${valExprs.size}];""" +
ctx.splitExpressions(
ctx.INPUT_ROW,
valExprs.zipWithIndex.map { case (e, i) =>
val eval = e.genCode(ctx)
valExprCodes.zipWithIndex.map { case (eval, i) =>
eval.code + s"""
if (${eval.isNull}) {
$values[$i] = null;
} else {
$values[$i] = ${eval.value};
}"""
}) +
}, splitFuncsParams) +
s"""
final InternalRow ${ev.value} = new $rowClass($values);
$values = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass)

case class ComplexMapClass(map: MapClass, lhmap: LHMapClass)

case class InnerData(name: String, value: Int)
case class NestedData(id: Int, param: Map[String, InnerData])

package object packageobject {
case class PackageClass(value: Int)
}
Expand Down Expand Up @@ -354,4 +357,9 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1))
}

test("SPARK-19104: lambda variables should work when parent expression splits generated codes") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should also update the test name...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

miss it. updated.

val data = Seq.tabulate(10)(i => NestedData(1, Map("key" -> InnerData("name", i + 100))))
val ds = spark.createDataset(data)
checkDataset(ds, data: _*)
}
}