Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2380,6 +2380,26 @@ def map_entries(col):
return Column(sc._jvm.functions.map_entries(_to_java_column(col)))


@since(2.4)
def map_from_entries(col):
"""
Collection function: Returns a map created from the given array of entries.

:param col: name of column or expression

>>> from pyspark.sql.functions import map_from_entries
>>> df = spark.sql("SELECT array(struct(1, 'a'), struct(2, 'b')) as data")
>>> df.select(map_from_entries("data").alias("map")).show()
+----------------+
| map|
+----------------+
|[1 -> a, 2 -> b]|
+----------------+
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.map_from_entries(_to_java_column(col)))


@ignore_unicode_prefix
@since(2.4)
def array_repeat(col, count):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ object FunctionRegistry {
expression[MapKeys]("map_keys"),
expression[MapValues]("map_values"),
expression[MapEntries]("map_entries"),
expression[MapFromEntries]("map_from_entries"),
expression[Size]("size"),
expression[Slice]("slice"),
expression[Size]("cardinality"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
import org.apache.spark.util.collection.OpenHashSet

/**
* Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit
Expand Down Expand Up @@ -308,6 +309,234 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp
override def prettyName: String = "map_entries"
}

/**
* Returns a map created from the given array of entries.
*/
@ExpressionDescription(
usage = "_FUNC_(arrayOfEntries) - Returns a map created from the given array of entries.",
examples = """
Examples:
> SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b')));
{1:"a",2:"b"}
""",
since = "2.4.0")
case class MapFromEntries(child: Expression) extends UnaryExpression {

@transient
private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match {
case ArrayType(
StructType(Array(
StructField(_, keyType, keyNullable, _),
StructField(_, valueType, valueNullable, _))),
containsNull) => Some((MapType(keyType, valueType, valueNullable), keyNullable, containsNull))
case _ => None
}

private def nullEntries: Boolean = dataTypeDetails.get._3

override def dataType: MapType = dataTypeDetails.get._1

override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match {
case Some(_) => TypeCheckResult.TypeCheckSuccess
case None => TypeCheckResult.TypeCheckFailure(s"'${child.sql}' is of " +
s"${child.dataType.simpleString} type. $prettyName accepts only arrays of pair structs.")
}

override protected def nullSafeEval(input: Any): Any = {
val arrayData = input.asInstanceOf[ArrayData]
val length = arrayData.numElements()
val numEntries = if (nullEntries) (0 until length).count(!arrayData.isNullAt(_)) else length
val keyArray = new Array[AnyRef](numEntries)
val valueArray = new Array[AnyRef](numEntries)
var i = 0
var j = 0
while (i < length) {
if (!arrayData.isNullAt(i)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should throw an exception if arrayData.isNullAt(i)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ueshin,
wouldn't it be better return null in this case? And follow null handling of other functions like flatten?

flatten(array(array(1,2), null, array(3,4))) => null

WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that sounds reasonable. Thanks.

val entry = arrayData.getStruct(i, 2)
val key = entry.get(0, dataType.keyType)
if (key == null) {
throw new RuntimeException("The first field from a struct (key) can't be null.")
}
keyArray.update(j, key)
val value = entry.get(1, dataType.valueType)
valueArray.update(j, value)
j += 1
}
i += 1
}
ArrayBasedMapData(keyArray, valueArray)
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => {
val length = ctx.freshName("length")
val numEntries = ctx.freshName("numEntries")
val isKeyPrimitive = CodeGenerator.isPrimitiveType(dataType.keyType)
val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType)
val code = if (isKeyPrimitive && isValuePrimitive) {
genCodeForPrimitiveElements(ctx, c, ev.value, length, numEntries)
} else {
genCodeForAnyElements(ctx, c, ev.value, length, numEntries)
}
val numEntriesAssignment = if (nullEntries) {
val idx = ctx.freshName("idx")
s"""
|int $numEntries = 0;
|for (int $idx = 0; $idx < $length; $idx++) {
| if (!$c.isNullAt($idx)) $numEntries++;
|}
""".stripMargin
} else {
s"final int $numEntries = $length;"
}

s"""
|final int $length = $c.numElements();
|$numEntriesAssignment
|$code
""".stripMargin
})
}

private def genCodeForAssignmentLoop(
ctx: CodegenContext,
childVariable: String,
length: String,
keyAssignment: (String, String) => String,
valueAssignment: (String, String) => String): String = {
val entry = ctx.freshName("entry")
val i = ctx.freshName("idx")
val j = ctx.freshName("idx")

val nullEntryCheck = if (nullEntries) s"if ($childVariable.isNullAt($i)) continue;" else ""
val nullKeyCheck = if (dataTypeDetails.get._2) {
s"""
|if ($entry.isNullAt(0)) {
| throw new RuntimeException("The first field from a struct (key) can't be null.");
|}
""".stripMargin
} else {
""
}

s"""
|for (int $i = 0, $j = 0; $i < $length; $i++) {
| $nullEntryCheck
| InternalRow $entry = $childVariable.getStruct($i, 2);
| $nullKeyCheck
| ${keyAssignment(CodeGenerator.getValue(entry, dataType.keyType, "0"), j)}
| ${valueAssignment(entry, j)}
| $j++;
|}
""".stripMargin
}

private def genCodeForPrimitiveElements(
ctx: CodegenContext,
childVariable: String,
mapData: String,
length: String,
numEntries: String): String = {
val byteArraySize = ctx.freshName("byteArraySize")
val keySectionSize = ctx.freshName("keySectionSize")
val valueSectionSize = ctx.freshName("valueSectionSize")
val data = ctx.freshName("byteArray")
val unsafeMapData = ctx.freshName("unsafeMapData")
val keyArrayData = ctx.freshName("keyArrayData")
val valueArrayData = ctx.freshName("valueArrayData")

val baseOffset = Platform.BYTE_ARRAY_OFFSET
val keySize = dataType.keyType.defaultSize
val valueSize = dataType.valueType.defaultSize
val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $keySize)"
val vByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $valueSize)"
val keyTypeName = CodeGenerator.primitiveTypeName(dataType.keyType)
val valueTypeName = CodeGenerator.primitiveTypeName(dataType.valueType)

val keyAssignment = (key: String, idx: String) => s"$keyArrayData.set$keyTypeName($idx, $key);"
val valueAssignment = (entry: String, idx: String) => {
val value = CodeGenerator.getValue(entry, dataType.valueType, "1")
val valueNullUnsafeAssignment = s"$valueArrayData.set$valueTypeName($idx, $value);"
if (dataType.valueContainsNull) {
s"""
|if ($entry.isNullAt(1)) {
| $valueArrayData.setNullAt($idx);
|} else {
| $valueNullUnsafeAssignment
|}
""".stripMargin
} else {
valueNullUnsafeAssignment
}
}
val assignmentLoop = genCodeForAssignmentLoop(
ctx,
childVariable,
length,
keyAssignment,
valueAssignment
)

s"""
|final long $keySectionSize = $kByteSize;
|final long $valueSectionSize = $vByteSize;
|final long $byteArraySize = 8 + $keySectionSize + $valueSectionSize;
|if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| ${genCodeForAnyElements(ctx, childVariable, mapData, length, numEntries)}
|} else {
| final byte[] $data = new byte[(int)$byteArraySize];
| UnsafeMapData $unsafeMapData = new UnsafeMapData();
| Platform.putLong($data, $baseOffset, $keySectionSize);
| Platform.putLong($data, ${baseOffset + 8}, $numEntries);
| Platform.putLong($data, ${baseOffset + 8} + $keySectionSize, $numEntries);
| $unsafeMapData.pointTo($data, $baseOffset, (int)$byteArraySize);
| ArrayData $keyArrayData = $unsafeMapData.keyArray();
| ArrayData $valueArrayData = $unsafeMapData.valueArray();
| $assignmentLoop
| $mapData = $unsafeMapData;
|}
""".stripMargin
}

private def genCodeForAnyElements(
ctx: CodegenContext,
childVariable: String,
mapData: String,
length: String,
numEntries: String): String = {
val keys = ctx.freshName("keys")
val values = ctx.freshName("values")
val mapDataClass = classOf[ArrayBasedMapData].getName()

val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType)
val valueAssignment = (entry: String, idx: String) => {
val value = CodeGenerator.getValue(entry, dataType.valueType, "1")
if (dataType.valueContainsNull && isValuePrimitive) {
s"$values[$idx] = $entry.isNullAt(1) ? null : (Object)$value;"
} else {
s"$values[$idx] = $value;"
}
}
val keyAssignment = (key: String, idx: String) => s"$keys[$idx] = $key;"
val assignmentLoop = genCodeForAssignmentLoop(
ctx,
childVariable,
length,
keyAssignment,
valueAssignment)

s"""
|final Object[] $keys = new Object[$numEntries];
|final Object[] $values = new Object[$numEntries];
|$assignmentLoop
|$mapData = $mapDataClass.apply($keys, $values);
""".stripMargin
}

override def prettyName: String = "map_from_entries"
}


/**
* Common base class for [[SortArray]] and [[ArraySort]].
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,63 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(MapEntries(ms2), null)
}

test("MapFromEntries") {
def arrayType(keyType: DataType, valueType: DataType) : DataType = {
ArrayType(
StructType(Seq(
StructField("a", keyType),
StructField("b", valueType))),
true)
}
def r(values: Any*): InternalRow = create_row(values: _*)

// Primitive-type keys and values
val aiType = arrayType(IntegerType, IntegerType)
val ai0 = Literal.create(Seq(r(1, 10), r(2, 20), r(3, 20)), aiType)
val ai1 = Literal.create(Seq(r(1, null), r(2, 20), r(3, null)), aiType)
val ai2 = Literal.create(Seq.empty, aiType)
val ai3 = Literal.create(null, aiType)
val ai4 = Literal.create(Seq(r(1, 10), r(1, 20)), aiType)
val ai5 = Literal.create(Seq(r(1, 10), r(null, 20)), aiType)
val ai6 = Literal.create(Seq(null, r(2, 20), null), aiType)
val aby = Literal.create(Seq(r(1.toByte, 10.toByte)), arrayType(ByteType, ByteType))
val ash = Literal.create(Seq(r(1.toShort, 10.toShort)), arrayType(ShortType, ShortType))
val alo = Literal.create(Seq(r(1L, 10L)), arrayType(LongType, LongType))

checkEvaluation(MapFromEntries(ai0), Map(1 -> 10, 2 -> 20, 3 -> 20))
checkEvaluation(MapFromEntries(ai1), Map(1 -> null, 2 -> 20, 3 -> null))
checkEvaluation(MapFromEntries(ai2), Map.empty)
checkEvaluation(MapFromEntries(ai3), null)
checkEvaluation(MapKeys(MapFromEntries(ai4)), Seq(1, 1))
checkExceptionInExpression[RuntimeException](
MapFromEntries(ai5),
"The first field from a struct (key) can't be null.")
checkEvaluation(MapFromEntries(ai6), Map(2 -> 20))
checkEvaluation(MapFromEntries(aby), Map(1.toByte -> 10.toByte))
checkEvaluation(MapFromEntries(ash), Map(1.toShort -> 10.toShort))
checkEvaluation(MapFromEntries(alo), Map(1L -> 10L))

// Non-primitive-type keys and values
val asType = arrayType(StringType, StringType)
val as0 = Literal.create(Seq(r("a", "aa"), r("b", "bb"), r("c", "bb")), asType)
val as1 = Literal.create(Seq(r("a", null), r("b", "bb"), r("c", null)), asType)
val as2 = Literal.create(Seq.empty, asType)
val as3 = Literal.create(null, asType)
val as4 = Literal.create(Seq(r("a", "aa"), r("a", "bb")), asType)
val as5 = Literal.create(Seq(r("a", "aa"), r(null, "bb")), asType)
val as6 = Literal.create(Seq(null, r("b", "bb"), null), asType)

checkEvaluation(MapFromEntries(as0), Map("a" -> "aa", "b" -> "bb", "c" -> "bb"))
checkEvaluation(MapFromEntries(as1), Map("a" -> null, "b" -> "bb", "c" -> null))
checkEvaluation(MapFromEntries(as2), Map.empty)
checkEvaluation(MapFromEntries(as3), null)
checkEvaluation(MapKeys(MapFromEntries(as4)), Seq("a", "a"))
checkExceptionInExpression[RuntimeException](
MapFromEntries(as5),
"The first field from a struct (key) can't be null.")
checkEvaluation(MapFromEntries(as6), Map("b" -> "bb"))
}

test("Sort Array") {
val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
Expand Down
7 changes: 7 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3508,6 +3508,13 @@ object functions {
*/
def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) }

/**
* Returns a map created from the given array of entries.
* @group collection_funcs
* @since 2.4.0
*/
def map_from_entries(e: Column): Column = withExpr { MapFromEntries(e.expr) }

//////////////////////////////////////////////////////////////////////////////////////////////
// Mask functions
//////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Loading