From bd0221a6b745be938ade7596658e788dbddbab91 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 26 Jun 2017 04:26:00 +0000 Subject: [PATCH 1/3] Add lambda variables into the parameters of functions generated by splitExpressions. --- .../expressions/codegen/CodeGenerator.scala | 22 +++++++++++++++- .../expressions/complexTypeCreator.scala | 26 ++++++++++++++++--- .../spark/sql/DatasetPrimitiveSuite.scala | 8 ++++++ 3 files changed, 52 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 5158949b95629..3aa7887c63fa2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -736,7 +736,27 @@ class CodegenContext { // Cannot split these expressions because they are not created from a row object. return expressions.mkString("\n") } - splitExpressions(expressions, "apply", ("InternalRow", row) :: Nil) + splitExpressions(row, expressions, Seq.empty) + } + + /** + * Splits the generated code of expressions into multiple functions, because function has + * 64kb code size limit in JVM + * + * @param row the variable name of row that is used by expressions + * @param expressions the codes to evaluate expressions. + * @param arguments the additional arguments to the functions. + */ + def splitExpressions( + row: String, + expressions: Seq[String], + arguments: Seq[(String, String)]): String = { + if (row == null || currentVars != null) { + // Cannot split these expressions because they are not created from a row object. + return expressions.mkString("\n") + } + val params = arguments ++ Seq(("InternalRow", row)) + splitExpressions(expressions, "apply", params) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 98c4cbee38dee..dabddded7b7d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -342,19 +343,38 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc val values = ctx.freshName("values") ctx.addMutableState("Object[]", values, s"$values = null;") + // `splitExpressions` might split codes to multiple functions. The local variables of + // `LambdaVariable` can't be accessed in the functions. We need to add them into the parameters + // of the functions. + val (valExprCodes, valExprParams) = valExprs.map { expr => + val exprCode = expr.genCode(ctx) + val lambdaVars = expr.collect { + case l: LambdaVariable => l + }.flatMap { lambda => + val valueParam = ctx.javaType(lambda.dataType) -> lambda.value + if (lambda.isNull == "false") { + Seq(valueParam) + } else { + Seq(valueParam, "boolean" -> lambda.isNull) + } + } + (exprCode, lambdaVars) + }.unzip + + val splitFuncsParams = valExprParams.flatten.distinct + ev.copy(code = s""" $values = new Object[${valExprs.size}];""" + ctx.splitExpressions( ctx.INPUT_ROW, - valExprs.zipWithIndex.map { case (e, i) => - val eval = e.genCode(ctx) + valExprCodes.zipWithIndex.map { case (eval, i) => eval.code + s""" if (${eval.isNull}) { $values[$i] = null; } else { $values[$i] = ${eval.value}; }""" - }) + + }, splitFuncsParams) + s""" final InternalRow ${ev.value} = new $rowClass($values); $values = null; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 4126660b5d102..cc4391e786713 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -39,6 +39,9 @@ case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass) case class ComplexMapClass(map: MapClass, lhmap: LHMapClass) +case class InnerData(name: String, value: Int) +case class NestedData(id: Int, param: Map[String, InnerData]) + package object packageobject { case class PackageClass(value: Int) } @@ -354,4 +357,9 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1)) } + test("SPARK-19104: lambda variables should work when parent expression splits generated codes") { + val data = Seq.tabulate(10)(i => NestedData(1, Map("key" -> InnerData("name", i + 100)))) + val ds = spark.createDataset(data) + checkDataset(ds, data: _*) + } } From f4cb190d57fc5763b85159623ea062641744b0b5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 27 Jun 2017 00:11:01 +0000 Subject: [PATCH 2/3] Address comment. --- .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 3aa7887c63fa2..201005f2e80f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -732,10 +732,6 @@ class CodegenContext { * @param expressions the codes to evaluate expressions. */ def splitExpressions(row: String, expressions: Seq[String]): String = { - if (row == null || currentVars != null) { - // Cannot split these expressions because they are not created from a row object. - return expressions.mkString("\n") - } splitExpressions(row, expressions, Seq.empty) } From cffb09bf7e7b2efc0b56bbebbaecbb21676127b2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 27 Jun 2017 09:16:26 +0000 Subject: [PATCH 3/3] Simpler approach. --- .../expressions/codegen/CodeGenerator.scala | 18 +------------ .../expressions/complexTypeCreator.scala | 26 +++---------------- .../expressions/objects/objects.scala | 18 ++++++++----- .../spark/sql/DatasetPrimitiveSuite.scala | 2 +- 4 files changed, 17 insertions(+), 47 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 201005f2e80f2..5158949b95629 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -732,27 +732,11 @@ class CodegenContext { * @param expressions the codes to evaluate expressions. */ def splitExpressions(row: String, expressions: Seq[String]): String = { - splitExpressions(row, expressions, Seq.empty) - } - - /** - * Splits the generated code of expressions into multiple functions, because function has - * 64kb code size limit in JVM - * - * @param row the variable name of row that is used by expressions - * @param expressions the codes to evaluate expressions. - * @param arguments the additional arguments to the functions. - */ - def splitExpressions( - row: String, - expressions: Seq[String], - arguments: Seq[(String, String)]): String = { if (row == null || currentVars != null) { // Cannot split these expressions because they are not created from a row object. return expressions.mkString("\n") } - val params = arguments ++ Seq(("InternalRow", row)) - splitExpressions(expressions, "apply", params) + splitExpressions(expressions, "apply", ("InternalRow", row) :: Nil) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index dabddded7b7d5..98c4cbee38dee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -343,38 +342,19 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc val values = ctx.freshName("values") ctx.addMutableState("Object[]", values, s"$values = null;") - // `splitExpressions` might split codes to multiple functions. The local variables of - // `LambdaVariable` can't be accessed in the functions. We need to add them into the parameters - // of the functions. - val (valExprCodes, valExprParams) = valExprs.map { expr => - val exprCode = expr.genCode(ctx) - val lambdaVars = expr.collect { - case l: LambdaVariable => l - }.flatMap { lambda => - val valueParam = ctx.javaType(lambda.dataType) -> lambda.value - if (lambda.isNull == "false") { - Seq(valueParam) - } else { - Seq(valueParam, "boolean" -> lambda.isNull) - } - } - (exprCode, lambdaVars) - }.unzip - - val splitFuncsParams = valExprParams.flatten.distinct - ev.copy(code = s""" $values = new Object[${valExprs.size}];""" + ctx.splitExpressions( ctx.INPUT_ROW, - valExprCodes.zipWithIndex.map { case (eval, i) => + valExprs.zipWithIndex.map { case (e, i) => + val eval = e.genCode(ctx) eval.code + s""" if (${eval.isNull}) { $values[$i] = null; } else { $values[$i] = ${eval.value}; }""" - }, splitFuncsParams) + + }) + s""" final InternalRow ${ev.value} = new $rowClass($values); $values = null; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 073993cccdf8a..4b651836ff4d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -911,6 +911,12 @@ case class ExternalMapToCatalyst private( val entry = ctx.freshName("entry") val entries = ctx.freshName("entries") + val keyElementJavaType = ctx.javaType(keyType) + val valueElementJavaType = ctx.javaType(valueType) + ctx.addMutableState(keyElementJavaType, key, "") + ctx.addMutableState("boolean", valueIsNull, "") + ctx.addMutableState(valueElementJavaType, value, "") + val (defineEntries, defineKeyValue) = child.dataType match { case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => val javaIteratorCls = classOf[java.util.Iterator[_]].getName @@ -922,8 +928,8 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next(); - ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) $entry.getKey(); - ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) $entry.getValue(); + $key = (${ctx.boxedType(keyType)}) $entry.getKey(); + $value = (${ctx.boxedType(valueType)}) $entry.getValue(); """ defineEntries -> defineKeyValue @@ -937,17 +943,17 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next(); - ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) $entry._1(); - ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) $entry._2(); + $key = (${ctx.boxedType(keyType)}) $entry._1(); + $value = (${ctx.boxedType(valueType)}) $entry._2(); """ defineEntries -> defineKeyValue } val valueNullCheck = if (ctx.isPrimitiveType(valueType)) { - s"boolean $valueIsNull = false;" + s"$valueIsNull = false;" } else { - s"boolean $valueIsNull = $value == null;" + s"$valueIsNull = $value == null;" } val arrayCls = classOf[GenericArrayData].getName diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index cc4391e786713..a6847dcfbffc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -357,7 +357,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1)) } - test("SPARK-19104: lambda variables should work when parent expression splits generated codes") { + test("SPARK-19104: Lambda variables in ExternalMapToCatalyst should be global") { val data = Seq.tabulate(10)(i => NestedData(1, Map("key" -> InnerData("name", i + 100)))) val ds = spark.createDataset(data) checkDataset(ds, data: _*)