diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 279972052a48..c97303be1d27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -21,13 +21,13 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} import org.apache.spark.sql.Encoder -import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection, WalkedTypePath} +import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, InitializeJavaBean, Invoke, NewInstance} -import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts -import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation} +import org.apache.spark.sql.catalyst.optimizer.{ReassignLambdaVariableID, SimplifyCasts} +import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LeafNode, LocalRelation} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ObjectType, StringType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -301,13 +301,25 @@ case class ExpressionEncoder[T]( } @transient - private lazy val extractProjection = GenerateUnsafeProjection.generate(serializer) + private lazy val extractProjection = GenerateUnsafeProjection.generate({ + // When using `ExpressionEncoder` directly, we will skip the normal query processing steps + // (analyzer, optimizer, etc.). Here we apply the ReassignLambdaVariableID rule, as it's + // important to codegen performance. + val optimizedPlan = ReassignLambdaVariableID.apply(DummyExpressionHolder(serializer)) + optimizedPlan.asInstanceOf[DummyExpressionHolder].exprs + }) @transient private lazy val inputRow = new GenericInternalRow(1) @transient - private lazy val constructProjection = SafeProjection.create(deserializer :: Nil) + private lazy val constructProjection = SafeProjection.create({ + // When using `ExpressionEncoder` directly, we will skip the normal query processing steps + // (analyzer, optimizer, etc.). Here we apply the ReassignLambdaVariableID rule, as it's + // important to codegen performance. + val optimizedPlan = ReassignLambdaVariableID.apply(DummyExpressionHolder(Seq(deserializer))) + optimizedPlan.asInstanceOf[DummyExpressionHolder].exprs + }) /** * Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form @@ -371,3 +383,9 @@ case class ExpressionEncoder[T]( override def toString: String = s"class[$schemaString]" } + +// A dummy logical plan that can hold expressions and go through optimizer rules. +case class DummyExpressionHolder(exprs: Seq[Expression]) extends LeafNode { + override lazy val resolved = true + override def output: Seq[Attribute] = Nil +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index a6a48b6be03f..871aba67cf13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -575,15 +575,43 @@ case class WrapOption(child: Expression, optType: DataType) } } +object LambdaVariable { + private val curId = new java.util.concurrent.atomic.AtomicLong() + + // Returns the codegen-ed `LambdaVariable` and add it to mutable states, so that it can be + // accessed anywhere in the generated code. + def prepareLambdaVariable(ctx: CodegenContext, variable: LambdaVariable): ExprCode = { + val variableCode = variable.genCode(ctx) + assert(variableCode.code.isEmpty) + + ctx.addMutableState( + CodeGenerator.javaType(variable.dataType), + variableCode.value, + forceInline = true, + useFreshName = false) + + if (variable.nullable) { + ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, + variableCode.isNull, + forceInline = true, + useFreshName = false) + } + + variableCode + } +} + /** - * A placeholder for the loop variable used in [[MapObjects]]. This should never be constructed + * A placeholder for the loop variable used in [[MapObjects]]. This should never be constructed * manually, but will instead be passed into the provided lambda function. */ +// TODO: Merge this and `NamedLambdaVariable`. case class LambdaVariable( - value: String, - isNull: String, + name: String, dataType: DataType, - nullable: Boolean = true) extends LeafExpression with NonSQLExpression { + nullable: Boolean, + id: Long = LambdaVariable.curId.incrementAndGet) extends LeafExpression with NonSQLExpression { private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType, nullable) @@ -595,12 +623,16 @@ case class LambdaVariable( } override def genCode(ctx: CodegenContext): ExprCode = { - val isNullValue = if (nullable) { - JavaCode.isNullVariable(isNull) + // If `LambdaVariable` IDs are reassigned by the `ReassignLambdaVariableID` rule, the IDs will + // all be negative. + val suffix = "lambda_variable_" + math.abs(id) + val isNull = if (nullable) { + JavaCode.isNullVariable(s"isNull_${name}_$suffix") } else { FalseLiteral } - ExprCode(value = JavaCode.variable(value, dataType), isNull = isNullValue) + val value = JavaCode.variable(s"value_${name}_$suffix", dataType) + ExprCode(isNull, value) } // This won't be called as `genCode` is overrided, just overriding it to make @@ -630,8 +662,6 @@ case class UnresolvedMapObjects( } object MapObjects { - private val curId = new java.util.concurrent.atomic.AtomicInteger() - /** * Construct an instance of MapObjects case class. * @@ -649,16 +679,8 @@ object MapObjects { elementType: DataType, elementNullable: Boolean = true, customCollectionCls: Option[Class[_]] = None): MapObjects = { - val id = curId.getAndIncrement() - val loopValue = s"MapObjects_loopValue$id" - val loopIsNull = if (elementNullable) { - s"MapObjects_loopIsNull$id" - } else { - "false" - } - val loopVar = LambdaVariable(loopValue, loopIsNull, elementType, elementNullable) - MapObjects( - loopValue, loopIsNull, elementType, function(loopVar), inputData, customCollectionCls) + val loopVar = LambdaVariable("MapObject", elementType, elementNullable) + MapObjects(loopVar, function(loopVar), inputData, customCollectionCls) } } @@ -674,12 +696,8 @@ object MapObjects { * The following collection ObjectTypes are currently supported on input: * Seq, Array, ArrayData, java.util.List * - * @param loopValue the name of the loop variable that used when iterate the collection, and used - * as input for the `lambdaFunction` - * @param loopIsNull the nullity of the loop variable that used when iterate the collection, and - * used as input for the `lambdaFunction` - * @param loopVarDataType the data type of the loop variable that used when iterate the collection, - * and used as input for the `lambdaFunction` + * @param loopVar the [[LambdaVariable]] expression representing the loop variable that used to + * iterate the collection, and used as input for the `lambdaFunction`. * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function * to handle collection elements. * @param inputData An expression that when evaluated returns a collection object. @@ -687,16 +705,14 @@ object MapObjects { * or None (returning ArrayType) */ case class MapObjects private( - loopValue: String, - loopIsNull: String, - loopVarDataType: DataType, + loopVar: LambdaVariable, lambdaFunction: Expression, inputData: Expression, customCollectionCls: Option[Class[_]]) extends Expression with NonSQLExpression { override def nullable: Boolean = inputData.nullable - override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil + override def children: Seq[Expression] = Seq(loopVar, lambdaFunction, inputData) // The data with UserDefinedType are actually stored with the data type of its sqlType. // When we want to apply MapObjects on it, we have to use it. @@ -790,8 +806,8 @@ case class MapObjects private( ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val elementJavaType = CodeGenerator.javaType(loopVarDataType) - ctx.addMutableState(elementJavaType, loopValue, forceInline = true, useFreshName = false) + val elementJavaType = CodeGenerator.javaType(loopVar.dataType) + val loopVarCode = LambdaVariable.prepareLambdaVariable(ctx, loopVar) val genInputData = inputData.genCode(ctx) val genFunction = lambdaFunction.genCode(ctx) val dataLength = ctx.freshName("dataLength") @@ -879,12 +895,10 @@ case class MapObjects private( case _ => genFunction.value } - val loopNullCheck = if (loopIsNull != "false") { - ctx.addMutableState( - CodeGenerator.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false) + val loopNullCheck = if (loopVar.nullable) { inputDataType match { - case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" - case _ => s"$loopIsNull = $loopValue == null;" + case _: ArrayType => s"${loopVarCode.isNull} = ${genInputData.value}.isNullAt($loopIndex);" + case _ => s"${loopVarCode.isNull} = ${loopVarCode.value} == null;" } } else { "" @@ -942,7 +956,7 @@ case class MapObjects private( int $loopIndex = 0; $prepareLoop while ($loopIndex < $dataLength) { - $loopValue = ($elementJavaType) ($getLoopVar); + ${loopVarCode.value} = ($elementJavaType) ($getLoopVar); $loopNullCheck ${genFunction.code} @@ -982,23 +996,15 @@ case class UnresolvedCatalystToExternalMap( } object CatalystToExternalMap { - private val curId = new java.util.concurrent.atomic.AtomicInteger() - def apply(u: UnresolvedCatalystToExternalMap): CatalystToExternalMap = { - val id = curId.getAndIncrement() - val keyLoopValue = s"CatalystToExternalMap_keyLoopValue$id" val mapType = u.child.dataType.asInstanceOf[MapType] - val keyLoopVar = LambdaVariable(keyLoopValue, "", mapType.keyType, nullable = false) - val valueLoopValue = s"CatalystToExternalMap_valueLoopValue$id" - val valueLoopIsNull = if (mapType.valueContainsNull) { - s"CatalystToExternalMap_valueLoopIsNull$id" - } else { - "false" - } - val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, mapType.valueType) + val keyLoopVar = LambdaVariable( + "CatalystToExternalMap_key", mapType.keyType, nullable = false) + val valueLoopVar = LambdaVariable( + "CatalystToExternalMap_value", mapType.valueType, mapType.valueContainsNull) CatalystToExternalMap( - keyLoopValue, u.keyFunction(keyLoopVar), - valueLoopValue, valueLoopIsNull, u.valueFunction(valueLoopVar), + keyLoopVar, u.keyFunction(keyLoopVar), + valueLoopVar, u.valueFunction(valueLoopVar), u.child, u.collClass) } } @@ -1008,33 +1014,31 @@ object CatalystToExternalMap { * The collection is constructed using the associated builder, obtained by calling `newBuilder` * on the collection's companion object. * - * @param keyLoopValue the name of the loop variable that is used when iterating over the key - * collection, and which is used as input for the `keyLambdaFunction` + * @param keyLoopVar the [[LambdaVariable]] expression representing the loop variable that is used + * when iterating over the key collection, and which is used as input for the + * `keyLambdaFunction`. * @param keyLambdaFunction A function that takes the `keyLoopVar` as input, and is used as * a lambda function to handle collection elements. - * @param valueLoopValue the name of the loop variable that is used when iterating over the value - * collection, and which is used as input for the `valueLambdaFunction` - * @param valueLoopIsNull the nullability of the loop variable that is used when iterating over - * the value collection, and which is used as input for the - * `valueLambdaFunction` + * @param valueLoopVar the [[LambdaVariable]] expression representing the loop variable that is used + * when iterating over the value collection, and which is used as input for the + * `valueLambdaFunction`. * @param valueLambdaFunction A function that takes the `valueLoopVar` as input, and is used as * a lambda function to handle collection elements. * @param inputData An expression that when evaluated returns a map object. * @param collClass The type of the resulting collection. */ case class CatalystToExternalMap private( - keyLoopValue: String, + keyLoopVar: LambdaVariable, keyLambdaFunction: Expression, - valueLoopValue: String, - valueLoopIsNull: String, + valueLoopVar: LambdaVariable, valueLambdaFunction: Expression, inputData: Expression, collClass: Class[_]) extends Expression with NonSQLExpression { override def nullable: Boolean = inputData.nullable - override def children: Seq[Expression] = - keyLambdaFunction :: valueLambdaFunction :: inputData :: Nil + override def children: Seq[Expression] = Seq( + keyLoopVar, keyLambdaFunction, valueLoopVar, valueLambdaFunction, inputData) private lazy val inputMapType = inputData.dataType.asInstanceOf[MapType] @@ -1075,20 +1079,11 @@ case class CatalystToExternalMap private( override def dataType: DataType = ObjectType(collClass) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. - // When we want to apply MapObjects on it, we have to use it. - def inputDataType(dataType: DataType) = dataType match { - case p: PythonUserDefinedType => p.sqlType - case _ => dataType - } - - val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType] - val keyElementJavaType = CodeGenerator.javaType(mapType.keyType) - ctx.addMutableState(keyElementJavaType, keyLoopValue, forceInline = true, useFreshName = false) + val keyCode = LambdaVariable.prepareLambdaVariable(ctx, keyLoopVar) + val valueCode = LambdaVariable.prepareLambdaVariable(ctx, valueLoopVar) + val keyElementJavaType = CodeGenerator.javaType(keyLoopVar.dataType) val genKeyFunction = keyLambdaFunction.genCode(ctx) - val valueElementJavaType = CodeGenerator.javaType(mapType.valueType) - ctx.addMutableState(valueElementJavaType, valueLoopValue, forceInline = true, - useFreshName = false) + val valueElementJavaType = CodeGenerator.javaType(valueLoopVar.dataType) val genValueFunction = valueLambdaFunction.genCode(ctx) val genInputData = inputData.genCode(ctx) val dataLength = ctx.freshName("dataLength") @@ -1098,9 +1093,8 @@ case class CatalystToExternalMap private( val keyArray = ctx.freshName("keyArray") val valueArray = ctx.freshName("valueArray") - val getKeyLoopVar = CodeGenerator.getValue(keyArray, inputDataType(mapType.keyType), loopIndex) - val getValueLoopVar = CodeGenerator.getValue( - valueArray, inputDataType(mapType.valueType), loopIndex) + val getKeyLoopVar = CodeGenerator.getValue(keyArray, keyLoopVar.dataType, loopIndex) + val getValueLoopVar = CodeGenerator.getValue(valueArray, valueLoopVar.dataType, loopIndex) // Make a copy of the data if it's unsafe-backed def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = @@ -1115,10 +1109,8 @@ case class CatalystToExternalMap private( val genKeyFunctionValue = genFunctionValue(keyLambdaFunction, genKeyFunction) val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) - val valueLoopNullCheck = if (valueLoopIsNull != "false") { - ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true, - useFreshName = false) - s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" + val valueLoopNullCheck = if (valueLoopVar.nullable) { + s"${valueCode.isNull} = $valueArray.isNullAt($loopIndex);" } else { "" } @@ -1154,8 +1146,8 @@ case class CatalystToExternalMap private( int $loopIndex = 0; while ($loopIndex < $dataLength) { - $keyLoopValue = ($keyElementJavaType) ($getKeyLoopVar); - $valueLoopValue = ($valueElementJavaType) ($getValueLoopVar); + ${keyCode.value} = ($keyElementJavaType) ($getKeyLoopVar); + ${valueCode.value} = ($valueElementJavaType) ($getValueLoopVar); $valueLoopNullCheck ${genKeyFunction.code} @@ -1174,8 +1166,6 @@ case class CatalystToExternalMap private( } object ExternalMapToCatalyst { - private val curId = new java.util.concurrent.atomic.AtomicInteger() - def apply( inputMap: Expression, keyType: DataType, @@ -1184,31 +1174,14 @@ object ExternalMapToCatalyst { valueType: DataType, valueConverter: Expression => Expression, valueNullable: Boolean): ExternalMapToCatalyst = { - val id = curId.getAndIncrement() - val keyName = "ExternalMapToCatalyst_key" + id - val keyIsNull = if (keyNullable) { - "ExternalMapToCatalyst_key_isNull" + id - } else { - "false" - } - val valueName = "ExternalMapToCatalyst_value" + id - val valueIsNull = if (valueNullable) { - "ExternalMapToCatalyst_value_isNull" + id - } else { - "false" - } - + val keyLoopVar = LambdaVariable("ExternalMapToCatalyst_key", keyType, keyNullable) + val valueLoopVar = LambdaVariable("ExternalMapToCatalyst_value", valueType, valueNullable) ExternalMapToCatalyst( - keyName, - keyIsNull, - keyType, - keyConverter(LambdaVariable(keyName, keyIsNull, keyType, keyNullable)), - valueName, - valueIsNull, - valueType, - valueConverter(LambdaVariable(valueName, valueIsNull, valueType, valueNullable)), - inputMap - ) + keyLoopVar, + keyConverter(keyLoopVar), + valueLoopVar, + valueConverter(valueLoopVar), + inputMap) } } @@ -1216,37 +1189,32 @@ object ExternalMapToCatalyst { * Converts a Scala/Java map object into catalyst format, by applying the key/value converter when * iterate the map. * - * @param key the name of the map key variable that used when iterate the map, and used as input for - * the `keyConverter` - * @param keyIsNull the nullability of the map key variable that used when iterate the map, and - * used as input for the `keyConverter` - * @param keyType the data type of the map key variable that used when iterate the map, and used as - * input for the `keyConverter` + * @param keyLoopVar the [[LambdaVariable]] expression representing the loop variable that is used + * when iterating over the key collection, and which is used as input for the + * `keyConverter`. * @param keyConverter A function that take the `key` as input, and converts it to catalyst format. - * @param value the name of the map value variable that used when iterate the map, and used as input - * for the `valueConverter` - * @param valueIsNull the nullability of the map value variable that used when iterate the map, and - * used as input for the `valueConverter` - * @param valueType the data type of the map value variable that used when iterate the map, and - * used as input for the `valueConverter` + * @param valueLoopVar the [[LambdaVariable]] expression representing the loop variable that is used + * when iterating over the value collection, and which is used as input for the + * `valueConverter`. * @param valueConverter A function that take the `value` as input, and converts it to catalyst * format. - * @param child An expression that when evaluated returns the input map object. + * @param inputData An expression that when evaluated returns the input map object. */ case class ExternalMapToCatalyst private( - key: String, - keyIsNull: String, - keyType: DataType, + keyLoopVar: LambdaVariable, keyConverter: Expression, - value: String, - valueIsNull: String, - valueType: DataType, + valueLoopVar: LambdaVariable, valueConverter: Expression, - child: Expression) - extends UnaryExpression with NonSQLExpression { + inputData: Expression) + extends Expression with NonSQLExpression { override def foldable: Boolean = false + override def nullable: Boolean = inputData.nullable + + override def children: Seq[Expression] = Seq( + keyLoopVar, keyConverter, valueLoopVar, valueConverter, inputData) + override def dataType: MapType = MapType( keyConverter.dataType, valueConverter.dataType, valueContainsNull = valueConverter.nullable) @@ -1257,7 +1225,7 @@ case class ExternalMapToCatalyst private( rowBuffer } - child.dataType match { + inputData.dataType match { case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => (input: Any) => { val data = input.asInstanceOf[java.util.Map[Any, Any]] @@ -1308,7 +1276,7 @@ case class ExternalMapToCatalyst private( } override def eval(input: InternalRow): Any = { - val result = child.eval(input) + val result = inputData.eval(input) if (result != null) { val (keys, values) = mapCatalystConverter(result) new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values)) @@ -1318,7 +1286,7 @@ case class ExternalMapToCatalyst private( } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val inputMap = child.genCode(ctx) + val inputMap = inputData.genCode(ctx) val genKeyConverter = keyConverter.genCode(ctx) val genValueConverter = valueConverter.genCode(ctx) val length = ctx.freshName("length") @@ -1328,12 +1296,12 @@ case class ExternalMapToCatalyst private( val entry = ctx.freshName("entry") val entries = ctx.freshName("entries") - val keyElementJavaType = CodeGenerator.javaType(keyType) - val valueElementJavaType = CodeGenerator.javaType(valueType) - ctx.addMutableState(keyElementJavaType, key, forceInline = true, useFreshName = false) - ctx.addMutableState(valueElementJavaType, value, forceInline = true, useFreshName = false) + val keyJavaType = CodeGenerator.javaType(keyLoopVar.dataType) + val valueJavaType = CodeGenerator.javaType(valueLoopVar.dataType) + val keyCode = LambdaVariable.prepareLambdaVariable(ctx, keyLoopVar) + val valueCode = LambdaVariable.prepareLambdaVariable(ctx, valueLoopVar) - val (defineEntries, defineKeyValue) = child.dataType match { + val (defineEntries, defineKeyValue) = inputData.dataType match { case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => val javaIteratorCls = classOf[java.util.Iterator[_]].getName val javaMapEntryCls = classOf[java.util.Map.Entry[_, _]].getName @@ -1344,8 +1312,8 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next(); - $key = (${CodeGenerator.boxedType(keyType)}) $entry.getKey(); - $value = (${CodeGenerator.boxedType(valueType)}) $entry.getValue(); + ${keyCode.value} = (${CodeGenerator.boxedType(keyJavaType)}) $entry.getKey(); + ${valueCode.value} = (${CodeGenerator.boxedType(valueJavaType)}) $entry.getValue(); """ defineEntries -> defineKeyValue @@ -1359,25 +1327,21 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next(); - $key = (${CodeGenerator.boxedType(keyType)}) $entry._1(); - $value = (${CodeGenerator.boxedType(valueType)}) $entry._2(); + ${keyCode.value} = (${CodeGenerator.boxedType(keyJavaType)}) $entry._1(); + ${valueCode.value} = (${CodeGenerator.boxedType(valueJavaType)}) $entry._2(); """ defineEntries -> defineKeyValue } - val keyNullCheck = if (keyIsNull != "false") { - ctx.addMutableState( - CodeGenerator.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false) - s"$keyIsNull = $key == null;" + val keyNullCheck = if (keyLoopVar.nullable) { + s"${keyCode.isNull} = ${keyCode.value} == null;" } else { "" } - val valueNullCheck = if (valueIsNull != "false") { - ctx.addMutableState( - CodeGenerator.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false) - s"$valueIsNull = $value == null;" + val valueNullCheck = if (valueLoopVar.nullable) { + s"${valueCode.isNull} = ${valueCode.value} == null;" } else { "" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 8c52ff9e9be7..17b4ff758c9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -172,7 +172,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) Batch("Object Expressions Optimization", fixedPoint, EliminateMapObjects, CombineTypedFilters, - ObjectSerializerPruning) :+ + ObjectSerializerPruning, + ReassignLambdaVariableID) :+ Batch("LocalRelation", fixedPoint, ConvertToLocalRelation, PropagateEmptyRelation) :+ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 216c125bee87..ad93ef347a97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -110,7 +110,7 @@ object CombineTypedFilters extends Rule[LogicalPlan] { */ object EliminateMapObjects extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case MapObjects(_, _, _, LambdaVariable(_, _, _, false), inputData, None) => inputData + case MapObjects(_, LambdaVariable(_, _, false, _), inputData, None) => inputData } } @@ -228,3 +228,54 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] { } } } + +/** + * Reassigns per-query unique IDs to `LambdaVariable`s, whose original IDs are globally unique. This + * can help Spark to hit codegen cache more often and improve performance. + */ +object ReassignLambdaVariableID extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!SQLConf.get.getConf(SQLConf.OPTIMIZER_REASSIGN_LAMBDA_VARIABLE_ID)) return plan + + // The original LambdaVariable IDs are all positive. To avoid conflicts, the new IDs are all + // negative and starts from -1. + var newId = 0L + val oldIdToNewId = scala.collection.mutable.Map.empty[Long, Long] + + // The `LambdaVariable` IDs in a query should be all positive or negative. Otherwise it's a bug + // and we should fail earlier. + var hasNegativeIds = false + var hasPositiveIds = false + + plan.transformAllExpressions { + case lr: LambdaVariable if lr.id == 0 => + throw new IllegalStateException("LambdaVariable should never has 0 as its ID.") + + case lr: LambdaVariable if lr.id < 0 => + hasNegativeIds = true + if (hasPositiveIds) { + throw new IllegalStateException( + "LambdaVariable IDs in a query should be all positive or negative.") + + } + lr + + case lr: LambdaVariable if lr.id > 0 => + hasPositiveIds = true + if (hasNegativeIds) { + throw new IllegalStateException( + "LambdaVariable IDs in a query should be all positive or negative.") + } + + if (oldIdToNewId.contains(lr.id)) { + // This `LambdaVariable` has appeared before, reuse the newly generated ID. + lr.copy(id = oldIdToNewId(lr.id)) + } else { + // This is the first appearance of this `LambdaVariable`, generate a new ID. + newId -= 1 + oldIdToNewId(lr.id) = newId + lr.copy(id = newId) + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d26cd2ca7343..bb6894a198bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -208,6 +208,13 @@ object SQLConf { .stringConf .createOptional + val OPTIMIZER_REASSIGN_LAMBDA_VARIABLE_ID = + buildConf("spark.sql.optimizer.reassignLambdaVariableID") + .doc("When true, Spark optimizer reassigns per-query unique IDs to LambdaVariable, so that " + + "it's more likely to hit codegen cache.") + .booleanConf + .createWithDefault(true) + val COMPRESS_CACHED = buildConf("spark.sql.inMemoryColumnarStorage.compressed") .doc("When set to true Spark SQL will automatically select a compression codec for each " + "column based on statistics of the data.") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 8520300ca59c..b6ca52f1d967 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -448,7 +448,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val row = RandomDataGenerator.randomRow(random, schema) val rowConverter = RowEncoder(schema) val internalRow = rowConverter.toRow(row) - val lambda = LambdaVariable("dummy", "dummuIsNull", schema(0).dataType, schema(0).nullable) + val lambda = LambdaVariable("dummy", schema(0).dataType, schema(0).nullable, id = 0) checkEvaluationWithoutCodegen(lambda, internalRow.get(0, schema(0).dataType), internalRow) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReassignLambdaVariableIDSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReassignLambdaVariableIDSuite.scala new file mode 100644 index 000000000000..06a32c77ac5e --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReassignLambdaVariableIDSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.BooleanType + +class ReassignLambdaVariableIDSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("Optimizer Batch", FixedPoint(100), ReassignLambdaVariableID) :: Nil + } + + test("basic: replace positive IDs with unique negative IDs") { + val testRelation = LocalRelation('col.int) + val var1 = LambdaVariable("a", BooleanType, true, id = 2) + val var2 = LambdaVariable("b", BooleanType, true, id = 4) + val query = testRelation.where(var1 && var2) + val optimized = Optimize.execute(query) + val expected = testRelation.where(var1.copy(id = -1) && var2.copy(id = -2)) + comparePlans(optimized, expected) + } + + test("ignore LambdaVariable with negative IDs") { + val testRelation = LocalRelation('col.int) + val var1 = LambdaVariable("a", BooleanType, true, id = -2) + val var2 = LambdaVariable("b", BooleanType, true, id = -4) + val query = testRelation.where(var1 && var2) + val optimized = Optimize.execute(query) + comparePlans(optimized, query) + } + + test("fail if positive ID LambdaVariable and negative LambdaVariable both exist") { + val testRelation = LocalRelation('col.int) + val var1 = LambdaVariable("a", BooleanType, true, id = -2) + val var2 = LambdaVariable("b", BooleanType, true, id = 4) + val query = testRelation.where(var1 && var2) + val e = intercept[IllegalStateException](Optimize.execute(query)) + assert(e.getMessage.contains("should be all positive or negative")) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a80aadebe353..7b210ecefb7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -39,7 +39,6 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} @@ -70,7 +69,7 @@ private[sql] object Dataset { // do not do this check in that case. this check can be expensive since it requires running // the whole [[Analyzer]] to resolve the deserializer if (dataset.exprEnc.clsTag.runtimeClass != classOf[Row]) { - dataset.deserializer + dataset.resolvedEnc } dataset } @@ -217,10 +216,11 @@ class Dataset[T] private[sql]( */ private[sql] implicit val exprEnc: ExpressionEncoder[T] = encoderFor(encoder) - // The deserializer expression which can be used to build a projection and turn rows to objects - // of type T, after collecting rows to the driver side. - private lazy val deserializer = - exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer).deserializer + // The resolved `ExpressionEncoder` which can be used to turn rows to objects of type T, after + // collecting rows to the driver side. + private lazy val resolvedEnc = { + exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer) + } private implicit def classTag = exprEnc.clsTag @@ -2776,15 +2776,9 @@ class Dataset[T] private[sql]( */ def toLocalIterator(): java.util.Iterator[T] = { withAction("toLocalIterator", queryExecution) { plan => - // This projection writes output to a `InternalRow`, which means applying this projection is - // not thread-safe. Here we create the projection inside this method to make `Dataset` - // thread-safe. - val objProj = GenerateSafeProjection.generate(deserializer :: Nil) - plan.executeToIterator().map { row => - // The row returned by SafeProjection is `SpecificInternalRow`, which ignore the data type - // parameter of its `get` method, so it's safe to use null here. - objProj(row).get(0, null).asInstanceOf[T] - }.asJava + // `ExpressionEncoder` is not thread-safe, here we create a new encoder. + val enc = resolvedEnc.copy() + plan.executeToIterator().map(enc.fromRow).asJava } } @@ -3403,14 +3397,9 @@ class Dataset[T] private[sql]( * Collect all elements from a spark plan. */ private def collectFromPlan(plan: SparkPlan): Array[T] = { - // This projection writes output to a `InternalRow`, which means applying this projection is not - // thread-safe. Here we create the projection inside this method to make `Dataset` thread-safe. - val objProj = GenerateSafeProjection.generate(deserializer :: Nil) - plan.executeCollect().map { row => - // The row returned by SafeProjection is `SpecificInternalRow`, which ignore the data type - // parameter of its `get` method, so it's safe to use null here. - objProj(row).get(0, null).asInstanceOf[T] - } + // `ExpressionEncoder` is not thread-safe, here we create a new encoder. + val enc = resolvedEnc.copy() + plan.executeCollect().map(enc.fromRow) } private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index a2def6b510eb..ea44c6013b7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -51,9 +51,8 @@ object TypedAggregateExpression { // If the buffer object is simple, use `SimpleTypedAggregateExpression`, which supports whole // stage codegen. if (isSimpleBuffer) { - val bufferDeserializer = UnresolvedDeserializer( - bufferEncoder.deserializer, - bufferSerializer.map(_.toAttribute)) + val bufferAttrs = bufferSerializer.map(_.toAttribute) + val bufferDeserializer = UnresolvedDeserializer(bufferEncoder.deserializer, bufferAttrs) SimpleTypedAggregateExpression( aggregator.asInstanceOf[Aggregator[Any, Any, Any]], @@ -61,6 +60,7 @@ object TypedAggregateExpression { None, None, bufferSerializer, + bufferAttrs.map(_.asInstanceOf[AttributeReference]), bufferDeserializer, outputEncoder.serializer, outputEncoder.deserializer.dataType, @@ -116,7 +116,8 @@ case class SimpleTypedAggregateExpression( inputDeserializer: Option[Expression], inputClass: Option[Class[_]], inputSchema: Option[StructType], - bufferSerializer: Seq[NamedExpression], + bufferSerializer: Seq[Expression], + aggBufferAttributes: Seq[AttributeReference], bufferDeserializer: Expression, outputSerializer: Seq[Expression], outputExternalType: DataType, @@ -126,7 +127,10 @@ case class SimpleTypedAggregateExpression( override lazy val deterministic: Boolean = true - override def children: Seq[Expression] = inputDeserializer.toSeq :+ bufferDeserializer + override def children: Seq[Expression] = { + inputDeserializer.toSeq ++ bufferSerializer ++ aggBufferAttributes ++ + Seq(bufferDeserializer) ++ outputSerializer + } override lazy val resolved: Boolean = inputDeserializer.isDefined && childrenResolved @@ -137,9 +141,6 @@ case class SimpleTypedAggregateExpression( private def bufferExternalType = bufferDeserializer.dataType - override lazy val aggBufferAttributes: Seq[AttributeReference] = - bufferSerializer.map(_.toAttribute.asInstanceOf[AttributeReference]) - private def serializeToBuffer(expr: Expression): Seq[Expression] = { bufferSerializer.map(_.transform { case _: BoundReference => expr @@ -209,7 +210,7 @@ case class ComplexTypedAggregateExpression( inputDeserializer: Option[Expression], inputClass: Option[Class[_]], inputSchema: Option[StructType], - bufferSerializer: Seq[NamedExpression], + bufferSerializer: Seq[Expression], bufferDeserializer: Expression, outputSerializer: Expression, dataType: DataType, @@ -220,7 +221,9 @@ case class ComplexTypedAggregateExpression( override lazy val deterministic: Boolean = true - override def children: Seq[Expression] = inputDeserializer.toSeq + override def children: Seq[Expression] = { + inputDeserializer.toSeq ++ bufferSerializer :+ bufferDeserializer :+ outputSerializer + } override lazy val resolved: Boolean = inputDeserializer.isDefined && childrenResolved diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetOptimizationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetOptimizationSuite.scala index b9242541abcb..ae051e43fbcb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetOptimizationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetOptimizationSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression} import org.apache.spark.sql.catalyst.expressions.objects.ExternalMapToCatalyst import org.apache.spark.sql.catalyst.plans.logical.SerializeFromObject @@ -49,13 +50,10 @@ class DatasetOptimizationSuite extends QueryTest with SharedSQLContext { def collectNamedStruct: PartialFunction[Expression, Seq[CreateNamedStruct]] = { case c: CreateNamedStruct => Seq(c) - case m: ExternalMapToCatalyst => - m.keyConverter.collect(collectNamedStruct).flatten ++ - m.valueConverter.collect(collectNamedStruct).flatten } - serializer.serializer.zip(structFields).foreach { case (serializer, fields) => - val structs: Seq[CreateNamedStruct] = serializer.collect(collectNamedStruct).flatten + serializer.serializer.zip(structFields).foreach { case (ser, fields) => + val structs: Seq[CreateNamedStruct] = ser.collect(collectNamedStruct).flatten assert(structs.size == fields.size) structs.zip(fields).foreach { case (struct, fieldNames) => assert(struct.names.map(_.toString) == fieldNames) @@ -166,4 +164,44 @@ class DatasetOptimizationSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Seq(Row("1"), Row("2"), Row("3"))) } } + + test("SPARK-27871: Dataset encoder should benefit from codegen cache") { + def checkCodegenCache(createDataset: () => Dataset[_]): Unit = { + def getCodegenCount(): Long = CodegenMetrics.METRIC_COMPILATION_TIME.getCount() + + val count1 = getCodegenCount() + // trigger codegen for Dataset + createDataset().collect() + val count2 = getCodegenCount() + // codegen happens + assert(count2 > count1) + + // trigger codegen for another Dataset of same type + createDataset().collect() + // codegen cache should work for Datasets of same type. + val count3 = getCodegenCount() + assert(count3 == count2) + + withSQLConf(SQLConf.OPTIMIZER_REASSIGN_LAMBDA_VARIABLE_ID.key -> "false") { + // trigger codegen for another Dataset of same type + createDataset().collect() + // with the rule disabled, codegen happens again for encoder serializer and encoder + // deserializer + val count4 = getCodegenCount() + assert(count4 == (count3 + 2)) + } + } + + withClue("array type") { + checkCodegenCache(() => Seq(Seq("abc")).toDS()) + } + + withClue("map type") { + checkCodegenCache(() => Seq(Map("abc" -> 1)).toDS()) + } + + withClue("array of map") { + checkCodegenCache(() => Seq(Seq(Map("abc" -> 1))).toDS()) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 721ce65bc61d..4b08a4b0d1a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1571,17 +1571,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("SPARK-22472: add null check for top-level primitive values") { // If the primitive values are from Option, we need to do runtime null check. val ds = Seq(Some(1), None).toDS().as[Int] - intercept[NullPointerException](ds.collect()) - val e = intercept[SparkException](ds.map(_ * 2).collect()) - assert(e.getCause.isInstanceOf[NullPointerException]) + val e1 = intercept[RuntimeException](ds.collect()) + assert(e1.getCause.isInstanceOf[NullPointerException]) + val e2 = intercept[SparkException](ds.map(_ * 2).collect()) + assert(e2.getCause.isInstanceOf[NullPointerException]) withTempPath { path => Seq(Integer.valueOf(1), null).toDF("i").write.parquet(path.getCanonicalPath) // If the primitive values are from files, we need to do runtime null check. val ds = spark.read.parquet(path.getCanonicalPath).as[Int] - intercept[NullPointerException](ds.collect()) - val e = intercept[SparkException](ds.map(_ * 2).collect()) - assert(e.getCause.isInstanceOf[NullPointerException]) + val e1 = intercept[RuntimeException](ds.collect()) + assert(e1.getCause.isInstanceOf[NullPointerException]) + val e2 = intercept[SparkException](ds.map(_ * 2).collect()) + assert(e2.getCause.isInstanceOf[NullPointerException]) } } @@ -1599,7 +1601,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("SPARK-23835: null primitive data type should throw NullPointerException") { val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS() - intercept[NullPointerException](ds.as[(Int, Int)].collect()) + val e = intercept[RuntimeException](ds.as[(Int, Int)].collect()) + assert(e.getCause.isInstanceOf[NullPointerException]) } test("SPARK-24569: Option of primitive types are mistakenly mapped to struct type") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 3c9a0908147a..9462ee190a31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -295,7 +295,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { import testImplicits._ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME.key -> "true") { - val bytecodeSizeHisto = CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE + val bytecodeSizeHisto = CodegenMetrics.METRIC_COMPILATION_TIME // the same query run twice should hit the codegen cache spark.range(3).select('id + 2).collect