-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23934][SQL] Adding map_from_entries function #21282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
8c6039c
25aa879
8d12d9f
7fd824e
45e4633
83165e0
10ace84
6cca713
44c513c
599656e
4eaedc5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,10 +22,12 @@ import org.apache.spark.sql.catalyst.InternalRow | |
| import org.apache.spark.sql.catalyst.analysis.TypeCheckResult | ||
| import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder | ||
| import org.apache.spark.sql.catalyst.expressions.codegen._ | ||
| import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} | ||
| import org.apache.spark.sql.catalyst.util._ | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.unsafe.Platform | ||
| import org.apache.spark.unsafe.array.ByteArrayMethods | ||
| import org.apache.spark.unsafe.types.{ByteArray, UTF8String} | ||
| import org.apache.spark.util.collection.OpenHashSet | ||
|
|
||
| /** | ||
| * Given an array or map, returns its size. Returns -1 if null. | ||
|
|
@@ -118,6 +120,229 @@ case class MapValues(child: Expression) | |
| override def prettyName: String = "map_values" | ||
| } | ||
|
|
||
| /** | ||
| * Returns a map created from the given array of entries. | ||
| */ | ||
| @ExpressionDescription( | ||
| usage = "_FUNC_(arrayOfEntries) - Returns a map created from the given array of entries.", | ||
| examples = """ | ||
| Examples: | ||
| > SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b'))); | ||
| {1:"a",2:"b"} | ||
| """, | ||
| since = "2.4.0") | ||
| case class MapFromEntries(child: Expression) extends UnaryExpression | ||
| { | ||
| private lazy val resolvedDataType: Option[MapType] = child.dataType match { | ||
|
||
| case ArrayType( | ||
| StructType(Array( | ||
| StructField(_, keyType, false, _), | ||
|
||
| StructField(_, valueType, valueNullable, _))), | ||
| false) => Some(MapType(keyType, valueType, valueNullable)) | ||
|
||
| case _ => None | ||
| } | ||
|
|
||
| override def dataType: MapType = resolvedDataType.get | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = resolvedDataType match { | ||
| case Some(_) => TypeCheckResult.TypeCheckSuccess | ||
| case None => TypeCheckResult.TypeCheckFailure(s"'${child.sql}' is of " + | ||
| s"${child.dataType.simpleString} type. $prettyName accepts only null-free arrays " + | ||
| "of pair structs. Values of the first struct field can't contain nulls and produce " + | ||
| "duplicates.") | ||
| } | ||
|
|
||
| override protected def nullSafeEval(input: Any): Any = { | ||
| val arrayData = input.asInstanceOf[ArrayData] | ||
| val length = arrayData.numElements() | ||
| val keyArray = new Array[AnyRef](length) | ||
| val keySet = new OpenHashSet[AnyRef]() | ||
| val valueArray = new Array[AnyRef](length) | ||
| var i = 0; | ||
| while (i < length) { | ||
| val entry = arrayData.getStruct(i, 2) | ||
| val key = entry.get(0, dataType.keyType) | ||
| if (key == null) { | ||
| throw new RuntimeException("The first field from a struct (key) can't be null.") | ||
| } | ||
| if (keySet.contains(key)) { | ||
|
||
| throw new RuntimeException("The first field from a struct (key) can't produce duplicates.") | ||
| } | ||
| keySet.add(key) | ||
| keyArray.update(i, key) | ||
| val value = entry.get(1, dataType.valueType) | ||
| valueArray.update(i, value) | ||
| i += 1 | ||
| } | ||
| ArrayBasedMapData(keyArray, valueArray) | ||
| } | ||
|
|
||
| private def getHashSetDetails(): (String, String) = dataType.keyType match { | ||
| case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int") | ||
| case LongType => ("$mcJ$sp", "Long") | ||
| case _ => ("", "Object") | ||
| } | ||
|
|
||
| override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| nullSafeCodeGen(ctx, ev, c => { | ||
| val numElements = ctx.freshName("numElements") | ||
| val keySet = ctx.freshName("keySet") | ||
| val hsClass = classOf[OpenHashSet[_]].getName | ||
| val tagPrefix = "scala.reflect.ClassTag$.MODULE$." | ||
| val (hsSuffix, tagSuffix) = getHashSetDetails() | ||
| val isKeyPrimitive = CodeGenerator.isPrimitiveType(dataType.keyType) | ||
| val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) | ||
| val code = if (isKeyPrimitive && isValuePrimitive) { | ||
| genCodeForPrimitiveElements(ctx, c, ev.value, keySet, numElements) | ||
| } else { | ||
| genCodeForAnyElements(ctx, c, ev.value, keySet, numElements) | ||
| } | ||
| s""" | ||
| |final int $numElements = $c.numElements(); | ||
| |final $hsClass$hsSuffix $keySet = new $hsClass$hsSuffix($tagPrefix$tagSuffix()); | ||
| |$code | ||
| """.stripMargin | ||
| }) | ||
| } | ||
|
|
||
| private def genCodeForAssignmentLoop( | ||
| ctx: CodegenContext, | ||
| childVariable: String, | ||
| numElements: String, | ||
| keySet: String, | ||
| keyAssignment: (String, String) => String, | ||
| valueAssignment: (String, String) => String): String = { | ||
| val entry = ctx.freshName("entry") | ||
| val key = ctx.freshName("key") | ||
| val idx = ctx.freshName("idx") | ||
| val keyType = CodeGenerator.javaType(dataType.keyType) | ||
|
|
||
| s""" | ||
| |for (int $idx = 0; $idx < $numElements; $idx++) { | ||
| | InternalRow $entry = $childVariable.getStruct($idx, 2); | ||
| | if ($entry.isNullAt(0)) { | ||
| | throw new RuntimeException("The first field from a struct (key) can't be null."); | ||
| | } | ||
| | $keyType $key = ${CodeGenerator.getValue(entry, dataType.keyType, "0")}; | ||
| | if ($keySet.contains($key)) { | ||
| | throw new RuntimeException( | ||
| | "The first field from a struct (key) can't produce duplicates."); | ||
| | } | ||
| | $keySet.add($key); | ||
| | ${keyAssignment(key, idx)} | ||
| | ${valueAssignment(entry, idx)} | ||
| |} | ||
| """.stripMargin | ||
| } | ||
|
|
||
| private def genCodeForPrimitiveElements( | ||
| ctx: CodegenContext, | ||
| childVariable: String, | ||
| mapData: String, | ||
| keySet: String, | ||
| numElements: String): String = { | ||
| val byteArraySize = ctx.freshName("byteArraySize") | ||
| val keySectionSize = ctx.freshName("keySectionSize") | ||
| val valueSectionSize = ctx.freshName("valueSectionSize") | ||
| val data = ctx.freshName("byteArray") | ||
| val unsafeMapData = ctx.freshName("unsafeMapData") | ||
| val keyArrayData = ctx.freshName("keyArrayData") | ||
| val valueArrayData = ctx.freshName("valueArrayData") | ||
|
|
||
| val baseOffset = Platform.BYTE_ARRAY_OFFSET | ||
| val keySize = dataType.keyType.defaultSize | ||
| val valueSize = dataType.valueType.defaultSize | ||
| val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numElements, $keySize)" | ||
| val vByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numElements, $valueSize)" | ||
| val keyTypeName = CodeGenerator.primitiveTypeName(dataType.keyType) | ||
| val valueTypeName = CodeGenerator.primitiveTypeName(dataType.valueType) | ||
|
|
||
| val keyAssignment = (key: String, idx: String) => s"$keyArrayData.set$keyTypeName($idx, $key);" | ||
| val valueAssignment = (entry: String, idx: String) => { | ||
| val value = CodeGenerator.getValue(entry, dataType.valueType, "1") | ||
| val valueNullUnsafeAssignment = s"$valueArrayData.set$valueTypeName($idx, $value);" | ||
| if (dataType.valueContainsNull) { | ||
| s""" | ||
| |if ($entry.isNullAt(1)) { | ||
| | $valueArrayData.setNullAt($idx); | ||
| |} else { | ||
| | $valueNullUnsafeAssignment | ||
| |} | ||
| """.stripMargin | ||
| } else { | ||
| valueNullUnsafeAssignment | ||
| } | ||
| } | ||
| val assignmentLoop = genCodeForAssignmentLoop( | ||
| ctx, | ||
| childVariable, | ||
| numElements, | ||
| keySet, | ||
| keyAssignment, | ||
| valueAssignment | ||
| ) | ||
|
|
||
| s""" | ||
| |final long $keySectionSize = $kByteSize; | ||
| |final long $valueSectionSize = $vByteSize; | ||
| |final long $byteArraySize = 8 + $keySectionSize + $valueSectionSize; | ||
| |if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | ||
| | ${genCodeForAnyElements(ctx, childVariable, mapData, keySet, numElements)} | ||
| |} else { | ||
| | final byte[] $data = new byte[(int)$byteArraySize]; | ||
| | UnsafeMapData $unsafeMapData = new UnsafeMapData(); | ||
| | Platform.putLong($data, $baseOffset, $keySectionSize); | ||
| | Platform.putLong($data, ${baseOffset + 8}, $numElements); | ||
| | Platform.putLong($data, ${baseOffset + 8} + $keySectionSize, $numElements); | ||
| | $unsafeMapData.pointTo($data, $baseOffset, (int)$byteArraySize); | ||
| | ArrayData $keyArrayData = $unsafeMapData.keyArray(); | ||
| | ArrayData $valueArrayData = $unsafeMapData.valueArray(); | ||
| | $assignmentLoop | ||
| | $mapData = $unsafeMapData; | ||
| |} | ||
| """.stripMargin | ||
| } | ||
|
|
||
| private def genCodeForAnyElements( | ||
| ctx: CodegenContext, | ||
| childVariable: String, | ||
| mapData: String, | ||
| keySet: String, | ||
| numElements: String): String = { | ||
| val keys = ctx.freshName("keys") | ||
| val values = ctx.freshName("values") | ||
| val mapDataClass = classOf[ArrayBasedMapData].getName() | ||
|
|
||
| val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) | ||
| val valueAssignment = (entry: String, idx: String) => { | ||
| val value = CodeGenerator.getValue(entry, dataType.valueType, "1") | ||
| if (dataType.valueContainsNull && isValuePrimitive) { | ||
| s"$values[$idx] = $entry.isNullAt(1) ? null : (Object)$value;" | ||
| } else { | ||
| s"$values[$idx] = $value;" | ||
| } | ||
| } | ||
| val keyAssignment = (key: String, idx: String) => s"$keys[$idx] = $key;" | ||
| val assignmentLoop = genCodeForAssignmentLoop( | ||
| ctx, | ||
| childVariable, | ||
| numElements, | ||
| keySet, | ||
| keyAssignment, | ||
| valueAssignment) | ||
|
|
||
| s""" | ||
| |final Object[] $keys = new Object[$numElements]; | ||
| |final Object[] $values = new Object[$numElements]; | ||
| |$assignmentLoop | ||
| |$mapData = $mapDataClass.apply($keys, $values); | ||
| """.stripMargin | ||
| } | ||
|
|
||
| override def prettyName: String = "map_from_entries" | ||
| } | ||
|
|
||
|
|
||
| /** | ||
| * Common base class for [[SortArray]] and [[ArraySort]]. | ||
| */ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: style