-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-19088][SQL] Optimize sequence type deserialization codegen #16541
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
4aaef15
a330d5f
1092375
85edddd
b5f87bd
d04e043
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects | |
|
|
||
| import java.lang.reflect.Modifier | ||
|
|
||
| import scala.collection.mutable.Builder | ||
| import scala.language.existentials | ||
| import scala.reflect.ClassTag | ||
|
|
||
|
|
@@ -589,6 +590,171 @@ case class MapObjects private( | |
| } | ||
| } | ||
|
|
||
| object CollectObjects { | ||
| private val curId = new java.util.concurrent.atomic.AtomicInteger() | ||
|
|
||
| /** | ||
| * Construct an instance of CollectObjects case class. | ||
| * | ||
| * @param function The function applied on the collection elements. | ||
| * @param inputData An expression that when evaluated returns a collection object. | ||
| * @param elementType The data type of elements in the collection. | ||
| * @param collClass The type of the resulting collection. | ||
| */ | ||
| def apply( | ||
| function: Expression => Expression, | ||
| inputData: Expression, | ||
| elementType: DataType, | ||
| collClass: Class[_]): CollectObjects = { | ||
| val loopValue = "CollectObjects_loopValue" + curId.getAndIncrement() | ||
| val loopIsNull = "CollectObjects_loopIsNull" + curId.getAndIncrement() | ||
| val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) | ||
| val builderValue = "CollectObjects_builderValue" + curId.getAndIncrement() | ||
|
||
| CollectObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData, | ||
| collClass, builderValue) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * An equivalent to the [[MapObjects]] case class but returning an ObjectType containing | ||
| * a Scala collection constructed using the associated builder, obtained by calling `newBuilder` | ||
| * on the collection's companion object. | ||
| * | ||
| * @param loopValue the name of the loop variable that used when iterate the collection, and used | ||
| * as input for the `lambdaFunction` | ||
| * @param loopIsNull the nullity of the loop variable that used when iterate the collection, and | ||
| * used as input for the `lambdaFunction` | ||
| * @param loopVarDataType the data type of the loop variable that used when iterate the collection, | ||
| * and used as input for the `lambdaFunction` | ||
| * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function | ||
| * to handle collection elements. | ||
| * @param inputData An expression that when evaluated returns a collection object. | ||
| * @param collClass The type of the resulting collection. | ||
| * @param builderValue The name of the builder variable used to construct the resulting collection. | ||
| */ | ||
| case class CollectObjects private( | ||
|
||
| loopValue: String, | ||
| loopIsNull: String, | ||
| loopVarDataType: DataType, | ||
| lambdaFunction: Expression, | ||
| inputData: Expression, | ||
| collClass: Class[_], | ||
| builderValue: String) extends Expression with NonSQLExpression { | ||
|
|
||
| override def nullable: Boolean = inputData.nullable | ||
|
|
||
| override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil | ||
|
|
||
| override def eval(input: InternalRow): Any = | ||
| throw new UnsupportedOperationException("Only code-generated evaluation is supported") | ||
|
|
||
| override def dataType: DataType = ObjectType(collClass) | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| val collObjectName = s"${collClass.getName}$$.MODULE$$" | ||
| val getBuilderVar = s"$collObjectName.newBuilder()" | ||
|
||
| val elementJavaType = ctx.javaType(loopVarDataType) | ||
| ctx.addMutableState("boolean", loopIsNull, "") | ||
| ctx.addMutableState(elementJavaType, loopValue, "") | ||
| val genInputData = inputData.genCode(ctx) | ||
| val genFunction = lambdaFunction.genCode(ctx) | ||
| val dataLength = ctx.freshName("dataLength") | ||
| val convertedArray = ctx.freshName("convertedArray") | ||
| val loopIndex = ctx.freshName("loopIndex") | ||
|
|
||
| val convertedType = ctx.boxedType(lambdaFunction.dataType) | ||
|
|
||
| // In RowEncoder, we use `Object` to represent Array or Seq, so we need to determine the type | ||
| // of input collection at runtime for this case. | ||
| val seq = ctx.freshName("seq") | ||
| val array = ctx.freshName("array") | ||
| val determineCollectionType = inputData.dataType match { | ||
| case ObjectType(cls) if cls == classOf[Object] => | ||
| val seqClass = classOf[Seq[_]].getName | ||
| s""" | ||
| $seqClass $seq = null; | ||
| $elementJavaType[] $array = null; | ||
| if (${genInputData.value}.getClass().isArray()) { | ||
| $array = ($elementJavaType[]) ${genInputData.value}; | ||
| } else { | ||
| $seq = ($seqClass) ${genInputData.value}; | ||
| } | ||
| """ | ||
| case _ => "" | ||
| } | ||
|
|
||
| // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. | ||
| // When we want to apply MapObjects on it, we have to use it. | ||
| val inputDataType = inputData.dataType match { | ||
| case p: PythonUserDefinedType => p.sqlType | ||
| case _ => inputData.dataType | ||
| } | ||
|
|
||
| val (getLength, getLoopVar) = inputDataType match { | ||
| case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => | ||
| s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)" | ||
| case ObjectType(cls) if cls.isArray => | ||
| s"${genInputData.value}.length" -> s"${genInputData.value}[$loopIndex]" | ||
| case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => | ||
| s"${genInputData.value}.size()" -> s"${genInputData.value}.get($loopIndex)" | ||
| case ArrayType(et, _) => | ||
| s"${genInputData.value}.numElements()" -> ctx.getValue(genInputData.value, et, loopIndex) | ||
| case ObjectType(cls) if cls == classOf[Object] => | ||
| s"$seq == null ? $array.length : $seq.size()" -> | ||
| s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)" | ||
| } | ||
|
|
||
| // Make a copy of the data if it's unsafe-backed | ||
| def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = | ||
| s"$value instanceof ${clazz.getSimpleName}? ${value}.copy() : $value" | ||
| val genFunctionValue = lambdaFunction.dataType match { | ||
| case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) | ||
| case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) | ||
| case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) | ||
| case _ => genFunction.value | ||
| } | ||
|
|
||
| val loopNullCheck = inputDataType match { | ||
| case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" | ||
| // The element of primitive array will never be null. | ||
| case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive => | ||
| s"$loopIsNull = false" | ||
| case _ => s"$loopIsNull = $loopValue == null;" | ||
| } | ||
|
|
||
| val code = s""" | ||
| ${genInputData.code} | ||
| ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; | ||
|
|
||
| if (!${genInputData.isNull}) { | ||
| $determineCollectionType | ||
| $convertedType[] $convertedArray = null; | ||
| int $dataLength = $getLength; | ||
| ${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar; | ||
| $builderValue.sizeHint($dataLength); | ||
|
|
||
| int $loopIndex = 0; | ||
| while ($loopIndex < $dataLength) { | ||
| $loopValue = ($elementJavaType) ($getLoopVar); | ||
| $loopNullCheck | ||
|
|
||
| ${genFunction.code} | ||
| if (${genFunction.isNull}) { | ||
| $builderValue.$$plus$$eq(null); | ||
| } else { | ||
| $builderValue.$$plus$$eq($genFunctionValue); | ||
| } | ||
|
|
||
| $loopIndex += 1; | ||
| } | ||
|
|
||
| ${ev.value} = (${collClass.getName}) $builderValue.result(); | ||
| } | ||
| """ | ||
| ev.copy(code = code, isNull = genInputData.isNull) | ||
| } | ||
| } | ||
|
|
||
| object ExternalMapToCatalyst { | ||
| private val curId = new java.util.concurrent.atomic.AtomicInteger() | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 space indention please