-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-19088][SQL] Optimize sequence type deserialization codegen #16541
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
4aaef15
a330d5f
1092375
85edddd
b5f87bd
d04e043
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects | |
|
|
||
| import java.lang.reflect.Modifier | ||
|
|
||
| import scala.collection.mutable.Builder | ||
| import scala.language.existentials | ||
| import scala.reflect.ClassTag | ||
|
|
||
|
|
@@ -429,24 +430,33 @@ object MapObjects { | |
| * @param function The function applied on the collection elements. | ||
| * @param inputData An expression that when evaluated returns a collection object. | ||
| * @param elementType The data type of elements in the collection. | ||
| * @param collClass The class of the resulting collection | ||
| */ | ||
| def apply( | ||
| function: Expression => Expression, | ||
| inputData: Expression, | ||
| elementType: DataType): MapObjects = { | ||
| val loopValue = "MapObjects_loopValue" + curId.getAndIncrement() | ||
| val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement() | ||
| elementType: DataType, | ||
| collClass: Class[_] = classOf[Array[_]]): MapObjects = { | ||
|
||
| val id = curId.getAndIncrement() | ||
| val loopValue = s"MapObjects_loopValue$id" | ||
| val loopIsNull = s"MapObjects_loopIsNull$id" | ||
| val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) | ||
| MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData) | ||
| val builderValue = s"MapObjects_builderValue$id" | ||
| MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData, | ||
| collClass, builderValue) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Applies the given expression to every element of a collection of items, returning the result | ||
| * as an ArrayType. This is similar to a typical map operation, but where the lambda function | ||
| * is expressed using catalyst expressions. | ||
| * as an ArrayType or ObjectType. This is similar to a typical map operation, but where the lambda | ||
| * function is expressed using catalyst expressions. | ||
| * | ||
| * The type of the result is determined as follows: | ||
| * - ArrayType - when collClass is an array class | ||
| * - ObjectType(collClass) - when collClass is a collection class | ||
| * | ||
| * The following collection ObjectTypes are currently supported: | ||
| * The following collection ObjectTypes are currently supported on input: | ||
| * Seq, Array, ArrayData, java.util.List | ||
| * | ||
| * @param loopValue the name of the loop variable that used when iterate the collection, and used | ||
|
|
@@ -458,13 +468,18 @@ object MapObjects { | |
| * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function | ||
| * to handle collection elements. | ||
| * @param inputData An expression that when evaluated returns a collection object. | ||
| * @param collClass The class of the resulting collection | ||
| * @param builderValue The name of the builder variable used to construct the resulting collection | ||
| * (used only when returning ObjectType) | ||
| */ | ||
| case class MapObjects private( | ||
| loopValue: String, | ||
| loopIsNull: String, | ||
| loopVarDataType: DataType, | ||
| lambdaFunction: Expression, | ||
| inputData: Expression) extends Expression with NonSQLExpression { | ||
| inputData: Expression, | ||
| collClass: Class[_], | ||
| builderValue: String) extends Expression with NonSQLExpression { | ||
|
|
||
| override def nullable: Boolean = inputData.nullable | ||
|
|
||
|
|
@@ -474,7 +489,8 @@ case class MapObjects private( | |
| throw new UnsupportedOperationException("Only code-generated evaluation is supported") | ||
|
|
||
| override def dataType: DataType = | ||
| ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable) | ||
| if (!collClass.isArray) ObjectType(collClass) | ||
| else ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable) | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| val elementJavaType = ctx.javaType(loopVarDataType) | ||
|
|
@@ -557,15 +573,32 @@ case class MapObjects private( | |
| case _ => s"$loopIsNull = $loopValue == null;" | ||
| } | ||
|
|
||
| val (genInit, genAssign, genResult): (String, String => String, String) = | ||
|
||
| if (collClass.isArray) { | ||
| // array | ||
| (s"""$convertedType[] $convertedArray = null; | ||
| $convertedArray = $arrayConstructor;""", | ||
| genValue => s"$convertedArray[$loopIndex] = $genValue;", | ||
| s"new ${classOf[GenericArrayData].getName}($convertedArray);") | ||
| } else { | ||
| // collection | ||
| val collObjectName = s"${collClass.getName}$$.MODULE$$" | ||
| val getBuilderVar = s"$collObjectName.newBuilder()" | ||
|
|
||
| (s"""${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar; | ||
| $builderValue.sizeHint($dataLength);""", | ||
| genValue => s"$builderValue.$$plus$$eq($genValue);", | ||
| s"(${collClass.getName}) $builderValue.result();") | ||
| } | ||
|
|
||
| val code = s""" | ||
| ${genInputData.code} | ||
| ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; | ||
|
|
||
| if (!${genInputData.isNull}) { | ||
| $determineCollectionType | ||
| $convertedType[] $convertedArray = null; | ||
| int $dataLength = $getLength; | ||
| $convertedArray = $arrayConstructor; | ||
| $genInit | ||
|
|
||
| int $loopIndex = 0; | ||
| while ($loopIndex < $dataLength) { | ||
|
|
@@ -574,15 +607,15 @@ case class MapObjects private( | |
|
|
||
| ${genFunction.code} | ||
| if (${genFunction.isNull}) { | ||
| $convertedArray[$loopIndex] = null; | ||
| ${genAssign("null")} | ||
| } else { | ||
| $convertedArray[$loopIndex] = $genFunctionValue; | ||
| ${genAssign(genFunctionValue)} | ||
| } | ||
|
|
||
| $loopIndex += 1; | ||
| } | ||
|
|
||
| ${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray); | ||
| ${ev.value} = $genResult | ||
| } | ||
| """ | ||
| ev.copy(code = code, isNull = genInputData.isNull) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked that the builder for
Seq[_]is amutable.ListBufferand its build result will beimmutable.List. But the original deserialized type ofSeq[_]is aWrappedArray.Can we keep the original expression if we can't find
newBuilderhere, i.e., returning theWrappedArray?Or use
WrappedArrayas thecollClassinstead ofSeq?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rethink about this. It may not be a problem. Let me try to test it first.