Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
a081649
initial working version
stefankandic Feb 8, 2024
1441549
add golden files
stefankandic Feb 11, 2024
1be06e3
add map sort to other languages
stefankandic Feb 14, 2024
249e903
fix typoes
stefankandic Feb 28, 2024
aaae883
fix scalastyle issue
stefankandic Feb 28, 2024
acaf95e
add proto golden files
stefankandic Feb 28, 2024
5619fdb
fix python function call
stefankandic Feb 28, 2024
7754c14
fix ci errors
stefankandic Feb 29, 2024
f0ebf5d
fix ci checks
stefankandic Feb 29, 2024
1f78167
Optimized map-sort by switching to array sorting
stevomitric Mar 12, 2024
a5eb480
Potential tests fix
stevomitric Mar 13, 2024
9497f99
Potential tests fix 2
stevomitric Mar 13, 2024
5e7a033
Removed TODOs and changed parmIndex to ordinal
stevomitric Mar 17, 2024
ab70f1e
Shortened map sort function and added more docs
stevomitric Mar 18, 2024
e79d65c
updated map_sort test suite
stevomitric Mar 18, 2024
a435355
Update sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunction…
stevomitric Mar 18, 2024
c9901d0
Update sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunction…
stevomitric Mar 18, 2024
da6a710
docs fix
stevomitric Mar 18, 2024
81008c2
Updated codegen and removed once test-case
stevomitric Mar 19, 2024
86b29c5
Update python/pyspark/sql/functions/builtin.py
stevomitric Mar 19, 2024
c08ab6c
Updated 'select.show' to give more info in map_sort desc
stevomitric Mar 19, 2024
31a797c
Restructured docs, removed unused variable and refactored code
stevomitric Mar 19, 2024
69e3b48
Removed map_sort function but left the MapSort expression
stevomitric Mar 21, 2024
51ab204
Merge branch 'master' into stevomitric/map-expr
stevomitric Mar 21, 2024
8d9ac51
aditional erasions
stevomitric Mar 21, 2024
2951bcc
removed ExpressionDescription
stevomitric Mar 21, 2024
0fc3c6a
Moved ordering outside of comapre function
stevomitric Mar 21, 2024
0c7d21a
Removed oredering type
stevomitric Mar 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,137 @@ case class MapFromEntries(child: Expression)
copy(child = newChild)
}

case class MapSort(base: Expression)
extends UnaryExpression with NullIntolerant with QueryErrorsBase {

val keyType: DataType = base.dataType.asInstanceOf[MapType].keyType
val valueType: DataType = base.dataType.asInstanceOf[MapType].valueType

override def child: Expression = base

override def dataType: DataType = base.dataType

override def checkInputDataTypes(): TypeCheckResult = base.dataType match {
case m: MapType if RowOrdering.isOrderable(m.keyType) =>
TypeCheckResult.TypeCheckSuccess
case _: MapType =>
DataTypeMismatch(
errorSubClass = "INVALID_ORDERING_TYPE",
messageParameters = Map(
"functionName" -> toSQLId(prettyName),
"dataType" -> toSQLType(base.dataType)
)
)
case _ =>
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(0),
"requiredType" -> toSQLType(MapType),
"inputSql" -> toSQLExpr(base),
"inputType" -> toSQLType(base.dataType))
)
}

override def nullSafeEval(array: Any): Any = {
// put keys and their respective values inside a tuple and sort them
// according to the key ordering. Extract the new sorted k/v pairs to form a sorted map

val mapData = array.asInstanceOf[MapData]
val numElements = mapData.numElements()
val keys = mapData.keyArray()
val values = mapData.valueArray()

val ordering = PhysicalDataType.ordering(keyType)

val sortedMap = Array
.tabulate(numElements)(i => (keys.get(i, keyType).asInstanceOf[Any],
values.get(i, valueType).asInstanceOf[Any]))
.sortBy(_._1)(ordering)

new ArrayBasedMapData(new GenericArrayData(sortedMap.map(_._1)),
new GenericArrayData(sortedMap.map(_._2)))
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, b => sortCodegen(ctx, ev, b))
}

private def sortCodegen(ctx: CodegenContext, ev: ExprCode,
base: String): String = {

val arrayBasedMapData = classOf[ArrayBasedMapData].getName
val genericArrayData = classOf[GenericArrayData].getName

val numElements = ctx.freshName("numElements")
val keys = ctx.freshName("keys")
val values = ctx.freshName("values")
val sortArray = ctx.freshName("sortArray")
val i = ctx.freshName("i")
val o1 = ctx.freshName("o1")
val o1entry = ctx.freshName("o1entry")
val o2 = ctx.freshName("o2")
val o2entry = ctx.freshName("o2entry")
val c = ctx.freshName("c")
val newKeys = ctx.freshName("newKeys")
val newValues = ctx.freshName("newValues")

val boxedKeyType = CodeGenerator.boxedType(keyType)
val boxedValueType = CodeGenerator.boxedType(valueType)
val javaKeyType = CodeGenerator.javaType(keyType)

val simpleEntryType = s"java.util.AbstractMap.SimpleEntry<$boxedKeyType, $boxedValueType>"

val comp = if (CodeGenerator.isPrimitiveType(keyType)) {
val v1 = ctx.freshName("v1")
val v2 = ctx.freshName("v2")
s"""
|$javaKeyType $v1 = (($boxedKeyType) $o1).${javaKeyType}Value();
|$javaKeyType $v2 = (($boxedKeyType) $o2).${javaKeyType}Value();
|int $c = ${ctx.genComp(keyType, v1, v2)};
""".stripMargin
} else {
s"int $c = ${ctx.genComp(keyType, s"(($javaKeyType) $o1)", s"(($javaKeyType) $o2)")};"
}

s"""
|final int $numElements = $base.numElements();
|ArrayData $keys = $base.keyArray();
|ArrayData $values = $base.valueArray();
|
|Object[] $sortArray = new Object[$numElements];
|
|for (int $i = 0; $i < $numElements; $i++) {
| $sortArray[$i] = new $simpleEntryType(
| ${CodeGenerator.getValue(keys, keyType, i)},
| ${CodeGenerator.getValue(values, valueType, i)});
|}
|
|java.util.Arrays.sort($sortArray, new java.util.Comparator<Object>() {
| @Override public int compare(Object $o1entry, Object $o2entry) {
| Object $o1 = (($simpleEntryType) $o1entry).getKey();
| Object $o2 = (($simpleEntryType) $o2entry).getKey();
| $comp;
| return $c;
| }
|});
|
|Object[] $newKeys = new Object[$numElements];
|Object[] $newValues = new Object[$numElements];
|
|for (int $i = 0; $i < $numElements; $i++) {
| $newKeys[$i] = (($simpleEntryType) $sortArray[$i]).getKey();
| $newValues[$i] = (($simpleEntryType) $sortArray[$i]).getValue();
|}
|
|${ev.value} = new $arrayBasedMapData(
| new $genericArrayData($newKeys), new $genericArrayData($newValues));
|""".stripMargin
}

override protected def withNewChildInternal(newChild: Expression)
: MapSort = copy(base = newChild)
}

/**
* Common base class for [[SortArray]] and [[ArraySort]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,29 @@ class CollectionExpressionsSuite
)
}

test("Sort Map") {
val intKey = Literal.create(Map(2 -> 2, 1 -> 1, 3 -> 3), MapType(IntegerType, IntegerType))
val boolKey = Literal.create(Map(true -> 2, false -> 1), MapType(BooleanType, IntegerType))
val stringKey = Literal.create(Map("2" -> 2, "1" -> 1, "3" -> 3),
MapType(StringType, IntegerType))
val arrayKey = Literal.create(Map(Seq(2) -> 2, Seq(1) -> 1, Seq(3) -> 3),
MapType(ArrayType(IntegerType), IntegerType))
val nestedArrayKey = Literal.create(Map(Seq(Seq(2)) -> 2, Seq(Seq(1)) -> 1, Seq(Seq(3)) -> 3),
MapType(ArrayType(ArrayType(IntegerType)), IntegerType))
val structKey = Literal.create(
Map(create_row(2) -> 2, create_row(1) -> 1, create_row(3) -> 3),
MapType(StructType(Seq(StructField("a", IntegerType))), IntegerType))

checkEvaluation(MapSort(intKey), Map(1 -> 1, 2 -> 2, 3 -> 3))
checkEvaluation(MapSort(boolKey), Map(false -> 1, true -> 2))
checkEvaluation(MapSort(stringKey), Map("1" -> 1, "2" -> 2, "3" -> 3))
checkEvaluation(MapSort(arrayKey), Map(Seq(1) -> 1, Seq(2) -> 2, Seq(3) -> 3))
checkEvaluation(MapSort(nestedArrayKey),
Map(Seq(Seq(1)) -> 1, Seq(Seq(2)) -> 2, Seq(Seq(3)) -> 3))
checkEvaluation(MapSort(structKey),
Map(create_row(1) -> 1, create_row(2) -> 2, create_row(3) -> 3))
}

test("Sort Array") {
val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
Expand Down