diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index f3c1e4150017..bea0de4d90c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -335,7 +335,7 @@ object ScalaReflection extends ScalaReflection { // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t - CollectObjectsToMap( + CatalystToExternalMap( p => deserializerFor(keyType, Some(p), walkedTypePath), p => deserializerFor(valueType, Some(p), walkedTypePath), getPath, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index d6d06aecc077..ce07f4a25c18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -465,7 +465,11 @@ object MapObjects { customCollectionCls: Option[Class[_]] = None): MapObjects = { val id = curId.getAndIncrement() val loopValue = s"MapObjects_loopValue$id" - val loopIsNull = s"MapObjects_loopIsNull$id" + val loopIsNull = if (elementNullable) { + s"MapObjects_loopIsNull$id" + } else { + "false" + } val loopVar = LambdaVariable(loopValue, loopIsNull, elementType, elementNullable) MapObjects( loopValue, loopIsNull, elementType, function(loopVar), inputData, customCollectionCls) @@ -517,7 +521,6 @@ case class MapObjects private( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val elementJavaType = ctx.javaType(loopVarDataType) - ctx.addMutableState("boolean", loopIsNull, "") ctx.addMutableState(elementJavaType, loopValue, "") val genInputData = inputData.genCode(ctx) val genFunction = lambdaFunction.genCode(ctx) @@ -588,12 +591,14 @@ case class MapObjects private( case _ => genFunction.value } - val loopNullCheck = inputDataType match { - case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" - // The element of primitive array will never be null. - case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive => - s"$loopIsNull = false" - case _ => s"$loopIsNull = $loopValue == null;" + val loopNullCheck = if (loopIsNull != "false") { + ctx.addMutableState("boolean", loopIsNull, "") + inputDataType match { + case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" + case _ => s"$loopIsNull = $loopValue == null;" + } + } else { + "" } val (initCollection, addElement, getResult): (String, String => String, String) = @@ -667,11 +672,11 @@ case class MapObjects private( } } -object CollectObjectsToMap { +object CatalystToExternalMap { private val curId = new java.util.concurrent.atomic.AtomicInteger() /** - * Construct an instance of CollectObjectsToMap case class. + * Construct an instance of CatalystToExternalMap case class. * * @param keyFunction The function applied on the key collection elements. * @param valueFunction The function applied on the value collection elements. @@ -682,15 +687,19 @@ object CollectObjectsToMap { keyFunction: Expression => Expression, valueFunction: Expression => Expression, inputData: Expression, - collClass: Class[_]): CollectObjectsToMap = { + collClass: Class[_]): CatalystToExternalMap = { val id = curId.getAndIncrement() - val keyLoopValue = s"CollectObjectsToMap_keyLoopValue$id" + val keyLoopValue = s"CatalystToExternalMap_keyLoopValue$id" val mapType = inputData.dataType.asInstanceOf[MapType] val keyLoopVar = LambdaVariable(keyLoopValue, "", mapType.keyType, nullable = false) - val valueLoopValue = s"CollectObjectsToMap_valueLoopValue$id" - val valueLoopIsNull = s"CollectObjectsToMap_valueLoopIsNull$id" + val valueLoopValue = s"CatalystToExternalMap_valueLoopValue$id" + val valueLoopIsNull = if (mapType.valueContainsNull) { + s"CatalystToExternalMap_valueLoopIsNull$id" + } else { + "false" + } val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, mapType.valueType) - CollectObjectsToMap( + CatalystToExternalMap( keyLoopValue, keyFunction(keyLoopVar), valueLoopValue, valueLoopIsNull, valueFunction(valueLoopVar), inputData, collClass) @@ -716,7 +725,7 @@ object CollectObjectsToMap { * @param inputData An expression that when evaluated returns a map object. * @param collClass The type of the resulting collection. */ -case class CollectObjectsToMap private( +case class CatalystToExternalMap private( keyLoopValue: String, keyLambdaFunction: Expression, valueLoopValue: String, @@ -748,7 +757,6 @@ case class CollectObjectsToMap private( ctx.addMutableState(keyElementJavaType, keyLoopValue, "") val genKeyFunction = keyLambdaFunction.genCode(ctx) val valueElementJavaType = ctx.javaType(mapType.valueType) - ctx.addMutableState("boolean", valueLoopIsNull, "") ctx.addMutableState(valueElementJavaType, valueLoopValue, "") val genValueFunction = valueLambdaFunction.genCode(ctx) val genInputData = inputData.genCode(ctx) @@ -781,7 +789,12 @@ case class CollectObjectsToMap private( val genKeyFunctionValue = genFunctionValue(keyLambdaFunction, genKeyFunction) val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) - val valueLoopNullCheck = s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" + val valueLoopNullCheck = if (valueLoopIsNull != "false") { + ctx.addMutableState("boolean", valueLoopIsNull, "") + s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" + } else { + "" + } val builderClass = classOf[Builder[_, _]].getName val constructBuilder = s""" @@ -847,9 +860,17 @@ object ExternalMapToCatalyst { valueNullable: Boolean): ExternalMapToCatalyst = { val id = curId.getAndIncrement() val keyName = "ExternalMapToCatalyst_key" + id - val keyIsNull = "ExternalMapToCatalyst_key_isNull" + id + val keyIsNull = if (keyNullable) { + "ExternalMapToCatalyst_key_isNull" + id + } else { + "false" + } val valueName = "ExternalMapToCatalyst_value" + id - val valueIsNull = "ExternalMapToCatalyst_value_isNull" + id + val valueIsNull = if (valueNullable) { + "ExternalMapToCatalyst_value_isNull" + id + } else { + "false" + } ExternalMapToCatalyst( keyName, @@ -919,9 +940,7 @@ case class ExternalMapToCatalyst private( val keyElementJavaType = ctx.javaType(keyType) val valueElementJavaType = ctx.javaType(valueType) - ctx.addMutableState("boolean", keyIsNull, "") ctx.addMutableState(keyElementJavaType, key, "") - ctx.addMutableState("boolean", valueIsNull, "") ctx.addMutableState(valueElementJavaType, value, "") val (defineEntries, defineKeyValue) = child.dataType match { @@ -957,16 +976,18 @@ case class ExternalMapToCatalyst private( defineEntries -> defineKeyValue } - val keyNullCheck = if (ctx.isPrimitiveType(keyType)) { - s"$keyIsNull = false;" - } else { + val keyNullCheck = if (keyIsNull != "false") { + ctx.addMutableState("boolean", keyIsNull, "") s"$keyIsNull = $key == null;" + } else { + "" } - val valueNullCheck = if (ctx.isPrimitiveType(valueType)) { - s"$valueIsNull = false;" - } else { + val valueNullCheck = if (valueIsNull != "false") { + ctx.addMutableState("boolean", valueIsNull, "") s"$valueIsNull = $value == null;" + } else { + "" } val arrayCls = classOf[GenericArrayData].getName