Skip to content

Commit 8c6039c

Browse files
committed
[SPARK-23934][SQL] Adding map_from_entries function
1 parent e35ad3c commit 8c6039c

File tree

6 files changed

+364
-1
lines changed

6 files changed

+364
-1
lines changed

python/pyspark/sql/functions.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2304,6 +2304,26 @@ def map_values(col):
23042304
return Column(sc._jvm.functions.map_values(_to_java_column(col)))
23052305

23062306

2307+
@since(2.4)
2308+
def map_from_entries(col):
2309+
"""
2310+
Collection function: Returns a map created from the given array of entries.
2311+
2312+
:param col: name of column or expression
2313+
2314+
>>> from pyspark.sql.functions import map_from_entries
2315+
>>> df = spark.sql("SELECT array(struct(1, 'a'), struct(2, 'b')) as data")
2316+
>>> df.select(map_from_entries("data").alias("map")).show()
2317+
+----------------+
2318+
| map|
2319+
+----------------+
2320+
|[1 -> a, 2 -> b]|
2321+
+----------------+
2322+
"""
2323+
sc = SparkContext._active_spark_context
2324+
return Column(sc._jvm.functions.map_from_entries(_to_java_column(col)))
2325+
2326+
23072327
# ---------------------------- User Defined Function ----------------------------------
23082328

23092329
class PandasUDFType(object):

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ object FunctionRegistry {
409409
expression[ElementAt]("element_at"),
410410
expression[MapKeys]("map_keys"),
411411
expression[MapValues]("map_values"),
412+
expression[MapFromEntries]("map_from_entries"),
412413
expression[Size]("size"),
413414
expression[Slice]("slice"),
414415
expression[Size]("cardinality"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 226 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@ import org.apache.spark.sql.catalyst.InternalRow
2222
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2323
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
2424
import org.apache.spark.sql.catalyst.expressions.codegen._
25-
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
25+
import org.apache.spark.sql.catalyst.util._
2626
import org.apache.spark.sql.types._
27+
import org.apache.spark.unsafe.Platform
2728
import org.apache.spark.unsafe.array.ByteArrayMethods
2829
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
30+
import org.apache.spark.util.collection.OpenHashSet
2931

3032
/**
3133
* Given an array or map, returns its size. Returns -1 if null.
@@ -118,6 +120,229 @@ case class MapValues(child: Expression)
118120
override def prettyName: String = "map_values"
119121
}
120122

123+
/**
124+
* Returns a map created from the given array of entries.
125+
*/
126+
@ExpressionDescription(
127+
usage = "_FUNC_(arrayOfEntries) - Returns a map created from the given array of entries.",
128+
examples = """
129+
Examples:
130+
> SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b')));
131+
{1:"a",2:"b"}
132+
""",
133+
since = "2.4.0")
134+
case class MapFromEntries(child: Expression) extends UnaryExpression
135+
{
136+
private lazy val resolvedDataType: Option[MapType] = child.dataType match {
137+
case ArrayType(
138+
StructType(Array(
139+
StructField(_, keyType, false, _),
140+
StructField(_, valueType, valueNullable, _))),
141+
false) => Some(MapType(keyType, valueType, valueNullable))
142+
case _ => None
143+
}
144+
145+
override def dataType: MapType = resolvedDataType.get
146+
147+
override def checkInputDataTypes(): TypeCheckResult = resolvedDataType match {
148+
case Some(_) => TypeCheckResult.TypeCheckSuccess
149+
case None => TypeCheckResult.TypeCheckFailure(s"'${child.sql}' is of " +
150+
s"${child.dataType.simpleString} type. $prettyName accepts only null-free arrays " +
151+
"of pair structs. Values of the first struct field can't contain nulls and produce " +
152+
"duplicates.")
153+
}
154+
155+
override protected def nullSafeEval(input: Any): Any = {
156+
val arrayData = input.asInstanceOf[ArrayData]
157+
val length = arrayData.numElements()
158+
val keyArray = new Array[AnyRef](length)
159+
val keySet = new OpenHashSet[AnyRef]()
160+
val valueArray = new Array[AnyRef](length)
161+
var i = 0;
162+
while (i < length) {
163+
val entry = arrayData.getStruct(i, 2)
164+
val key = entry.get(0, dataType.keyType)
165+
if (key == null) {
166+
throw new RuntimeException("The first field from a struct (key) can't be null.")
167+
}
168+
if (keySet.contains(key)) {
169+
throw new RuntimeException("The first field from a struct (key) can't produce duplicates.")
170+
}
171+
keySet.add(key)
172+
keyArray.update(i, key)
173+
val value = entry.get(1, dataType.valueType)
174+
valueArray.update(i, value)
175+
i += 1
176+
}
177+
ArrayBasedMapData(keyArray, valueArray)
178+
}
179+
180+
private def getHashSetDetails(): (String, String) = dataType.keyType match {
181+
case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int")
182+
case LongType => ("$mcJ$sp", "Long")
183+
case _ => ("", "Object")
184+
}
185+
186+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
187+
nullSafeCodeGen(ctx, ev, c => {
188+
val numElements = ctx.freshName("numElements")
189+
val keySet = ctx.freshName("keySet")
190+
val hsClass = classOf[OpenHashSet[_]].getName
191+
val tagPrefix = "scala.reflect.ClassTag$.MODULE$."
192+
val (hsSuffix, tagSuffix) = getHashSetDetails()
193+
val isKeyPrimitive = CodeGenerator.isPrimitiveType(dataType.keyType)
194+
val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType)
195+
val code = if (isKeyPrimitive && isValuePrimitive) {
196+
genCodeForPrimitiveElements(ctx, c, ev.value, keySet, numElements)
197+
} else {
198+
genCodeForAnyElements(ctx, c, ev.value, keySet, numElements)
199+
}
200+
s"""
201+
|final int $numElements = $c.numElements();
202+
|final $hsClass$hsSuffix $keySet = new $hsClass$hsSuffix($tagPrefix$tagSuffix());
203+
|$code
204+
""".stripMargin
205+
})
206+
}
207+
208+
private def genCodeForAssignmentLoop(
209+
ctx: CodegenContext,
210+
childVariable: String,
211+
numElements: String,
212+
keySet: String,
213+
keyAssignment: (String, String) => String,
214+
valueAssignment: (String, String) => String): String = {
215+
val entry = ctx.freshName("entry")
216+
val key = ctx.freshName("key")
217+
val idx = ctx.freshName("idx")
218+
val keyType = CodeGenerator.javaType(dataType.keyType)
219+
220+
s"""
221+
|for (int $idx = 0; $idx < $numElements; $idx++) {
222+
| InternalRow $entry = $childVariable.getStruct($idx, 2);
223+
| if ($entry.isNullAt(0)) {
224+
| throw new RuntimeException("The first field from a struct (key) can't be null.");
225+
| }
226+
| $keyType $key = ${CodeGenerator.getValue(entry, dataType.keyType, "0")};
227+
| if ($keySet.contains($key)) {
228+
| throw new RuntimeException(
229+
| "The first field from a struct (key) can't produce duplicates.");
230+
| }
231+
| $keySet.add($key);
232+
| ${keyAssignment(key, idx)}
233+
| ${valueAssignment(entry, idx)}
234+
|}
235+
""".stripMargin
236+
}
237+
238+
private def genCodeForPrimitiveElements(
239+
ctx: CodegenContext,
240+
childVariable: String,
241+
mapData: String,
242+
keySet: String,
243+
numElements: String): String = {
244+
val byteArraySize = ctx.freshName("byteArraySize")
245+
val keySectionSize = ctx.freshName("keySectionSize")
246+
val valueSectionSize = ctx.freshName("valueSectionSize")
247+
val data = ctx.freshName("byteArray")
248+
val unsafeMapData = ctx.freshName("unsafeMapData")
249+
val keyArrayData = ctx.freshName("keyArrayData")
250+
val valueArrayData = ctx.freshName("valueArrayData")
251+
252+
val baseOffset = Platform.BYTE_ARRAY_OFFSET
253+
val keySize = dataType.keyType.defaultSize
254+
val valueSize = dataType.valueType.defaultSize
255+
val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numElements, $keySize)"
256+
val vByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numElements, $valueSize)"
257+
val keyTypeName = CodeGenerator.primitiveTypeName(dataType.keyType)
258+
val valueTypeName = CodeGenerator.primitiveTypeName(dataType.valueType)
259+
260+
val keyAssignment = (key: String, idx: String) => s"$keyArrayData.set$keyTypeName($idx, $key);"
261+
val valueAssignment = (entry: String, idx: String) => {
262+
val value = CodeGenerator.getValue(entry, dataType.valueType, "1")
263+
val valueNullUnsafeAssignment = s"$valueArrayData.set$valueTypeName($idx, $value);"
264+
if (dataType.valueContainsNull) {
265+
s"""
266+
|if ($entry.isNullAt(1)) {
267+
| $valueArrayData.setNullAt($idx);
268+
|} else {
269+
| $valueNullUnsafeAssignment
270+
|}
271+
""".stripMargin
272+
} else {
273+
valueNullUnsafeAssignment
274+
}
275+
}
276+
val assignmentLoop = genCodeForAssignmentLoop(
277+
ctx,
278+
childVariable,
279+
numElements,
280+
keySet,
281+
keyAssignment,
282+
valueAssignment
283+
)
284+
285+
s"""
286+
|final long $keySectionSize = $kByteSize;
287+
|final long $valueSectionSize = $vByteSize;
288+
|final long $byteArraySize = 8 + $keySectionSize + $valueSectionSize;
289+
|if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
290+
| ${genCodeForAnyElements(ctx, childVariable, mapData, keySet, numElements)}
291+
|} else {
292+
| final byte[] $data = new byte[(int)$byteArraySize];
293+
| UnsafeMapData $unsafeMapData = new UnsafeMapData();
294+
| Platform.putLong($data, $baseOffset, $keySectionSize);
295+
| Platform.putLong($data, ${baseOffset + 8}, $numElements);
296+
| Platform.putLong($data, ${baseOffset + 8} + $keySectionSize, $numElements);
297+
| $unsafeMapData.pointTo($data, $baseOffset, (int)$byteArraySize);
298+
| ArrayData $keyArrayData = $unsafeMapData.keyArray();
299+
| ArrayData $valueArrayData = $unsafeMapData.valueArray();
300+
| $assignmentLoop
301+
| $mapData = $unsafeMapData;
302+
|}
303+
""".stripMargin
304+
}
305+
306+
private def genCodeForAnyElements(
307+
ctx: CodegenContext,
308+
childVariable: String,
309+
mapData: String,
310+
keySet: String,
311+
numElements: String): String = {
312+
val keys = ctx.freshName("keys")
313+
val values = ctx.freshName("values")
314+
val mapDataClass = classOf[ArrayBasedMapData].getName()
315+
316+
val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType)
317+
val valueAssignment = (entry: String, idx: String) => {
318+
val value = CodeGenerator.getValue(entry, dataType.valueType, "1")
319+
if (dataType.valueContainsNull && isValuePrimitive) {
320+
s"$values[$idx] = $entry.isNullAt(1) ? null : (Object)$value;"
321+
} else {
322+
s"$values[$idx] = $value;"
323+
}
324+
}
325+
val keyAssignment = (key: String, idx: String) => s"$keys[$idx] = $key;"
326+
val assignmentLoop = genCodeForAssignmentLoop(
327+
ctx,
328+
childVariable,
329+
numElements,
330+
keySet,
331+
keyAssignment,
332+
valueAssignment)
333+
334+
s"""
335+
|final Object[] $keys = new Object[$numElements];
336+
|final Object[] $values = new Object[$numElements];
337+
|$assignmentLoop
338+
|$mapData = $mapDataClass.apply($keys, $values);
339+
""".stripMargin
340+
}
341+
342+
override def prettyName: String = "map_from_entries"
343+
}
344+
345+
121346
/**
122347
* Common base class for [[SortArray]] and [[ArraySort]].
123348
*/

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.catalyst.InternalRow
2122
import org.apache.spark.sql.types._
2223

2324
class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -56,6 +57,63 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
5657
checkEvaluation(MapValues(m2), null)
5758
}
5859

60+
test("MapFromEntries") {
61+
def arrayType(keyType: DataType, valueType: DataType) : DataType = {
62+
ArrayType(StructType(Seq(
63+
StructField("a", keyType, false),
64+
StructField("b", valueType))),
65+
false)
66+
}
67+
def r(values: Any*): InternalRow = create_row(values: _*)
68+
69+
// Primitive-type keys and values
70+
val aiType = arrayType(IntegerType, IntegerType)
71+
val ai0 = Literal.create(Seq(r(1, 10), r(2, 20), r(3, 20)), aiType)
72+
val ai1 = Literal.create(Seq(r(1, null), r(2, 20), r(3, null)), aiType)
73+
val ai2 = Literal.create(Seq.empty, aiType)
74+
val ai3 = Literal.create(null, aiType)
75+
val ai4 = Literal.create(Seq(r(1, 10), r(1, 20)), aiType)
76+
val ai5 = Literal.create(Seq(r(1, 10), r(null, 20)), aiType)
77+
val aby = Literal.create(Seq(r(1.toByte, 10.toByte)), arrayType(ByteType, ByteType))
78+
val ash = Literal.create(Seq(r(1.toShort, 10.toShort)), arrayType(ShortType, ShortType))
79+
val alo = Literal.create(Seq(r(1L, 10L)), arrayType(LongType, LongType))
80+
81+
checkEvaluation(MapFromEntries(ai0), Map(1 -> 10, 2 -> 20, 3 -> 20))
82+
checkEvaluation(MapFromEntries(ai1), Map(1 -> null, 2 -> 20, 3 -> null))
83+
checkEvaluation(MapFromEntries(ai2), Map.empty)
84+
checkEvaluation(MapFromEntries(ai3), null)
85+
checkExceptionInExpression[RuntimeException](
86+
MapFromEntries(ai4),
87+
"The first field from a struct (key) can't produce duplicates.")
88+
checkExceptionInExpression[RuntimeException](
89+
MapFromEntries(ai5),
90+
"The first field from a struct (key) can't be null.")
91+
checkEvaluation(MapFromEntries(aby), Map(1.toByte -> 10.toByte))
92+
checkEvaluation(MapFromEntries(ash), Map(1.toShort -> 10.toShort))
93+
checkEvaluation(MapFromEntries(alo), Map(1L -> 10L))
94+
95+
// Non-primitive-type keys and values
96+
val asType = arrayType(StringType, StringType)
97+
val as0 = Literal.create(Seq(r("a", "aa"), r("b", "bb"), r("c", "bb")), asType)
98+
val as1 = Literal.create(Seq(r("a", null), r("b", "bb"), r("c", null)), asType)
99+
val as2 = Literal.create(Seq.empty, asType)
100+
val as3 = Literal.create(null, asType)
101+
val as4 = Literal.create(Seq(r("a", "aa"), r("a", "bb")), asType)
102+
val as5 = Literal.create(Seq(r("a", "aa"), r(null, "bb")), asType)
103+
104+
checkEvaluation(MapFromEntries(as0), Map("a" -> "aa", "b" -> "bb", "c" -> "bb"))
105+
checkEvaluation(MapFromEntries(as1), Map("a" -> null, "b" -> "bb", "c" -> null))
106+
checkEvaluation(MapFromEntries(as2), Map.empty)
107+
checkEvaluation(MapFromEntries(as3), null)
108+
checkExceptionInExpression[RuntimeException](
109+
MapFromEntries(as4),
110+
"The first field from a struct (key) can't produce duplicates.")
111+
checkExceptionInExpression[RuntimeException](
112+
MapFromEntries(as5),
113+
"The first field from a struct (key) can't be null.")
114+
115+
}
116+
59117
test("Sort Array") {
60118
val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
61119
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3414,6 +3414,13 @@ object functions {
34143414
*/
34153415
def map_values(e: Column): Column = withExpr { MapValues(e.expr) }
34163416

3417+
/**
3418+
* Returns a map created from the given array of entries.
3419+
* @group collection_funcs
3420+
* @since 2.4.0
3421+
*/
3422+
def map_from_entries(e: Column): Column = withExpr { MapFromEntries(e.expr) }
3423+
34173424
// scalastyle:off line.size.limit
34183425
// scalastyle:off parameter.number
34193426

0 commit comments

Comments
 (0)