Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2304,6 +2304,26 @@ def map_values(col):
return Column(sc._jvm.functions.map_values(_to_java_column(col)))


@since(2.4)
def map_from_entries(col):
"""
Collection function: Returns a map created from the given array of entries.

:param col: name of column or expression

>>> from pyspark.sql.functions import map_from_entries
>>> df = spark.sql("SELECT array(struct(1, 'a'), struct(2, 'b')) as data")
>>> df.select(map_from_entries("data").alias("map")).show()
+----------------+
| map|
+----------------+
|[1 -> a, 2 -> b]|
+----------------+
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.map_from_entries(_to_java_column(col)))


# ---------------------------- User Defined Function ----------------------------------

class PandasUDFType(object):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ object FunctionRegistry {
expression[ElementAt]("element_at"),
expression[MapKeys]("map_keys"),
expression[MapValues]("map_values"),
expression[MapFromEntries]("map_from_entries"),
expression[Size]("size"),
expression[Slice]("slice"),
expression[Size]("cardinality"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
import org.apache.spark.util.collection.OpenHashSet

/**
* Given an array or map, returns its size. Returns -1 if null.
Expand Down Expand Up @@ -118,6 +120,229 @@ case class MapValues(child: Expression)
override def prettyName: String = "map_values"
}

/**
* Returns a map created from the given array of entries.
*/
@ExpressionDescription(
usage = "_FUNC_(arrayOfEntries) - Returns a map created from the given array of entries.",
examples = """
Examples:
> SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b')));
{1:"a",2:"b"}
""",
since = "2.4.0")
case class MapFromEntries(child: Expression) extends UnaryExpression
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: style

private lazy val resolvedDataType: Option[MapType] = child.dataType match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@transient?

case ArrayType(
StructType(Array(
StructField(_, keyType, false, _),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need key field to be nullable = false because we check the nullability when creating an array?

StructField(_, valueType, valueNullable, _))),
false) => Some(MapType(keyType, valueType, valueNullable))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reject an array with containsNull = true here? The array might not contain nulls.

case _ => None
}

override def dataType: MapType = resolvedDataType.get

override def checkInputDataTypes(): TypeCheckResult = resolvedDataType match {
case Some(_) => TypeCheckResult.TypeCheckSuccess
case None => TypeCheckResult.TypeCheckFailure(s"'${child.sql}' is of " +
s"${child.dataType.simpleString} type. $prettyName accepts only null-free arrays " +
"of pair structs. Values of the first struct field can't contain nulls and produce " +
"duplicates.")
}

override protected def nullSafeEval(input: Any): Any = {
val arrayData = input.asInstanceOf[ArrayData]
val length = arrayData.numElements()
val keyArray = new Array[AnyRef](length)
val keySet = new OpenHashSet[AnyRef]()
val valueArray = new Array[AnyRef](length)
var i = 0;
while (i < length) {
val entry = arrayData.getStruct(i, 2)
val key = entry.get(0, dataType.keyType)
if (key == null) {
throw new RuntimeException("The first field from a struct (key) can't be null.")
}
if (keySet.contains(key)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this check necessary for now? This is because other operations (e.g. CreateMap) allows us to create a map with duplicated key. Is it better to be consistent in Spark?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we've already touched this topic in your PR for SPARK-23933. I think if some hashing is added into maps in future, these duplicity checks will have to be introduced anyway. So if we add it now, we can avoid breaking changes in future. But I understand your point of view.

Presto also doesn't support duplicates:

presto:default> SELECT map_from_entries(ARRAY[(1, 'x'), (1, 'y')]);
Query 20180510_090536_00005_468a9 failed: Duplicate keys (1) are not allowed

WDYT @ueshin @gatorsmile

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry for the super delay.
Let's just ignore the duplicated key like CreateMap for now. We will need to discuss map-related topics, such as duplicate keys, equality or ordering, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, no problem. I've removed duplicity checks.

throw new RuntimeException("The first field from a struct (key) can't produce duplicates.")
}
keySet.add(key)
keyArray.update(i, key)
val value = entry.get(1, dataType.valueType)
valueArray.update(i, value)
i += 1
}
ArrayBasedMapData(keyArray, valueArray)
}

private def getHashSetDetails(): (String, String) = dataType.keyType match {
case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int")
case LongType => ("$mcJ$sp", "Long")
case _ => ("", "Object")
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => {
val numElements = ctx.freshName("numElements")
val keySet = ctx.freshName("keySet")
val hsClass = classOf[OpenHashSet[_]].getName
val tagPrefix = "scala.reflect.ClassTag$.MODULE$."
val (hsSuffix, tagSuffix) = getHashSetDetails()
val isKeyPrimitive = CodeGenerator.isPrimitiveType(dataType.keyType)
val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType)
val code = if (isKeyPrimitive && isValuePrimitive) {
genCodeForPrimitiveElements(ctx, c, ev.value, keySet, numElements)
} else {
genCodeForAnyElements(ctx, c, ev.value, keySet, numElements)
}
s"""
|final int $numElements = $c.numElements();
|final $hsClass$hsSuffix $keySet = new $hsClass$hsSuffix($tagPrefix$tagSuffix());
|$code
""".stripMargin
})
}

private def genCodeForAssignmentLoop(
ctx: CodegenContext,
childVariable: String,
numElements: String,
keySet: String,
keyAssignment: (String, String) => String,
valueAssignment: (String, String) => String): String = {
val entry = ctx.freshName("entry")
val key = ctx.freshName("key")
val idx = ctx.freshName("idx")
val keyType = CodeGenerator.javaType(dataType.keyType)

s"""
|for (int $idx = 0; $idx < $numElements; $idx++) {
| InternalRow $entry = $childVariable.getStruct($idx, 2);
| if ($entry.isNullAt(0)) {
| throw new RuntimeException("The first field from a struct (key) can't be null.");
| }
| $keyType $key = ${CodeGenerator.getValue(entry, dataType.keyType, "0")};
| if ($keySet.contains($key)) {
| throw new RuntimeException(
| "The first field from a struct (key) can't produce duplicates.");
| }
| $keySet.add($key);
| ${keyAssignment(key, idx)}
| ${valueAssignment(entry, idx)}
|}
""".stripMargin
}

private def genCodeForPrimitiveElements(
ctx: CodegenContext,
childVariable: String,
mapData: String,
keySet: String,
numElements: String): String = {
val byteArraySize = ctx.freshName("byteArraySize")
val keySectionSize = ctx.freshName("keySectionSize")
val valueSectionSize = ctx.freshName("valueSectionSize")
val data = ctx.freshName("byteArray")
val unsafeMapData = ctx.freshName("unsafeMapData")
val keyArrayData = ctx.freshName("keyArrayData")
val valueArrayData = ctx.freshName("valueArrayData")

val baseOffset = Platform.BYTE_ARRAY_OFFSET
val keySize = dataType.keyType.defaultSize
val valueSize = dataType.valueType.defaultSize
val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numElements, $keySize)"
val vByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numElements, $valueSize)"
val keyTypeName = CodeGenerator.primitiveTypeName(dataType.keyType)
val valueTypeName = CodeGenerator.primitiveTypeName(dataType.valueType)

val keyAssignment = (key: String, idx: String) => s"$keyArrayData.set$keyTypeName($idx, $key);"
val valueAssignment = (entry: String, idx: String) => {
val value = CodeGenerator.getValue(entry, dataType.valueType, "1")
val valueNullUnsafeAssignment = s"$valueArrayData.set$valueTypeName($idx, $value);"
if (dataType.valueContainsNull) {
s"""
|if ($entry.isNullAt(1)) {
| $valueArrayData.setNullAt($idx);
|} else {
| $valueNullUnsafeAssignment
|}
""".stripMargin
} else {
valueNullUnsafeAssignment
}
}
val assignmentLoop = genCodeForAssignmentLoop(
ctx,
childVariable,
numElements,
keySet,
keyAssignment,
valueAssignment
)

s"""
|final long $keySectionSize = $kByteSize;
|final long $valueSectionSize = $vByteSize;
|final long $byteArraySize = 8 + $keySectionSize + $valueSectionSize;
|if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| ${genCodeForAnyElements(ctx, childVariable, mapData, keySet, numElements)}
|} else {
| final byte[] $data = new byte[(int)$byteArraySize];
| UnsafeMapData $unsafeMapData = new UnsafeMapData();
| Platform.putLong($data, $baseOffset, $keySectionSize);
| Platform.putLong($data, ${baseOffset + 8}, $numElements);
| Platform.putLong($data, ${baseOffset + 8} + $keySectionSize, $numElements);
| $unsafeMapData.pointTo($data, $baseOffset, (int)$byteArraySize);
| ArrayData $keyArrayData = $unsafeMapData.keyArray();
| ArrayData $valueArrayData = $unsafeMapData.valueArray();
| $assignmentLoop
| $mapData = $unsafeMapData;
|}
""".stripMargin
}

private def genCodeForAnyElements(
ctx: CodegenContext,
childVariable: String,
mapData: String,
keySet: String,
numElements: String): String = {
val keys = ctx.freshName("keys")
val values = ctx.freshName("values")
val mapDataClass = classOf[ArrayBasedMapData].getName()

val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType)
val valueAssignment = (entry: String, idx: String) => {
val value = CodeGenerator.getValue(entry, dataType.valueType, "1")
if (dataType.valueContainsNull && isValuePrimitive) {
s"$values[$idx] = $entry.isNullAt(1) ? null : (Object)$value;"
} else {
s"$values[$idx] = $value;"
}
}
val keyAssignment = (key: String, idx: String) => s"$keys[$idx] = $key;"
val assignmentLoop = genCodeForAssignmentLoop(
ctx,
childVariable,
numElements,
keySet,
keyAssignment,
valueAssignment)

s"""
|final Object[] $keys = new Object[$numElements];
|final Object[] $values = new Object[$numElements];
|$assignmentLoop
|$mapData = $mapDataClass.apply($keys, $values);
""".stripMargin
}

override def prettyName: String = "map_from_entries"
}


/**
* Common base class for [[SortArray]] and [[ArraySort]].
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._

class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -56,6 +57,63 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(MapValues(m2), null)
}

test("MapFromEntries") {
def arrayType(keyType: DataType, valueType: DataType) : DataType = {
ArrayType(StructType(Seq(
StructField("a", keyType, false),
StructField("b", valueType))),
false)
}
def r(values: Any*): InternalRow = create_row(values: _*)

// Primitive-type keys and values
val aiType = arrayType(IntegerType, IntegerType)
val ai0 = Literal.create(Seq(r(1, 10), r(2, 20), r(3, 20)), aiType)
val ai1 = Literal.create(Seq(r(1, null), r(2, 20), r(3, null)), aiType)
val ai2 = Literal.create(Seq.empty, aiType)
val ai3 = Literal.create(null, aiType)
val ai4 = Literal.create(Seq(r(1, 10), r(1, 20)), aiType)
val ai5 = Literal.create(Seq(r(1, 10), r(null, 20)), aiType)
val aby = Literal.create(Seq(r(1.toByte, 10.toByte)), arrayType(ByteType, ByteType))
val ash = Literal.create(Seq(r(1.toShort, 10.toShort)), arrayType(ShortType, ShortType))
val alo = Literal.create(Seq(r(1L, 10L)), arrayType(LongType, LongType))

checkEvaluation(MapFromEntries(ai0), Map(1 -> 10, 2 -> 20, 3 -> 20))
checkEvaluation(MapFromEntries(ai1), Map(1 -> null, 2 -> 20, 3 -> null))
checkEvaluation(MapFromEntries(ai2), Map.empty)
checkEvaluation(MapFromEntries(ai3), null)
checkExceptionInExpression[RuntimeException](
MapFromEntries(ai4),
"The first field from a struct (key) can't produce duplicates.")
checkExceptionInExpression[RuntimeException](
MapFromEntries(ai5),
"The first field from a struct (key) can't be null.")
checkEvaluation(MapFromEntries(aby), Map(1.toByte -> 10.toByte))
checkEvaluation(MapFromEntries(ash), Map(1.toShort -> 10.toShort))
checkEvaluation(MapFromEntries(alo), Map(1L -> 10L))

// Non-primitive-type keys and values
val asType = arrayType(StringType, StringType)
val as0 = Literal.create(Seq(r("a", "aa"), r("b", "bb"), r("c", "bb")), asType)
val as1 = Literal.create(Seq(r("a", null), r("b", "bb"), r("c", null)), asType)
val as2 = Literal.create(Seq.empty, asType)
val as3 = Literal.create(null, asType)
val as4 = Literal.create(Seq(r("a", "aa"), r("a", "bb")), asType)
val as5 = Literal.create(Seq(r("a", "aa"), r(null, "bb")), asType)

checkEvaluation(MapFromEntries(as0), Map("a" -> "aa", "b" -> "bb", "c" -> "bb"))
checkEvaluation(MapFromEntries(as1), Map("a" -> null, "b" -> "bb", "c" -> null))
checkEvaluation(MapFromEntries(as2), Map.empty)
checkEvaluation(MapFromEntries(as3), null)
checkExceptionInExpression[RuntimeException](
MapFromEntries(as4),
"The first field from a struct (key) can't produce duplicates.")
checkExceptionInExpression[RuntimeException](
MapFromEntries(as5),
"The first field from a struct (key) can't be null.")

}

test("Sort Array") {
val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
Expand Down
7 changes: 7 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3414,6 +3414,13 @@ object functions {
*/
def map_values(e: Column): Column = withExpr { MapValues(e.expr) }

/**
* Returns a map created from the given array of entries.
* @group collection_funcs
* @since 2.4.0
*/
def map_from_entries(e: Column): Column = withExpr { MapFromEntries(e.expr) }

// scalastyle:off line.size.limit
// scalastyle:off parameter.number

Expand Down
Loading