Skip to content
Closed
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
a081649
initial working version
stefankandic Feb 8, 2024
1441549
add golden files
stefankandic Feb 11, 2024
1be06e3
add map sort to other languages
stefankandic Feb 14, 2024
249e903
fix typoes
stefankandic Feb 28, 2024
aaae883
fix scalastyle issue
stefankandic Feb 28, 2024
acaf95e
add proto golden files
stefankandic Feb 28, 2024
5619fdb
fix python function call
stefankandic Feb 28, 2024
7754c14
fix ci errors
stefankandic Feb 29, 2024
f0ebf5d
fix ci checks
stefankandic Feb 29, 2024
1f78167
Optimized map-sort by switching to array sorting
stevomitric Mar 12, 2024
a5eb480
Potential tests fix
stevomitric Mar 13, 2024
9497f99
Potential tests fix 2
stevomitric Mar 13, 2024
5e7a033
Removed TODOs and changed parmIndex to ordinal
stevomitric Mar 17, 2024
ab70f1e
Shortened map sort function and added more docs
stevomitric Mar 18, 2024
e79d65c
updated map_sort test suite
stevomitric Mar 18, 2024
a435355
Update sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunction…
stevomitric Mar 18, 2024
c9901d0
Update sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunction…
stevomitric Mar 18, 2024
da6a710
docs fix
stevomitric Mar 18, 2024
81008c2
Updated codegen and removed once test-case
stevomitric Mar 19, 2024
86b29c5
Update python/pyspark/sql/functions/builtin.py
stevomitric Mar 19, 2024
c08ab6c
Updated 'select.show' to give more info in map_sort desc
stevomitric Mar 19, 2024
31a797c
Restructured docs, removed unused variable and refactored code
stevomitric Mar 19, 2024
69e3b48
Removed map_sort function but left the MapSort expression
stevomitric Mar 21, 2024
51ab204
Merge branch 'master' into stevomitric/map-expr
stevomitric Mar 21, 2024
8d9ac51
aditional erasions
stevomitric Mar 21, 2024
2951bcc
removed ExpressionDescription
stevomitric Mar 21, 2024
0fc3c6a
Moved ordering outside of comapre function
stevomitric Mar 21, 2024
0c7d21a
Removed oredering type
stevomitric Mar 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,157 @@ case class MapFromEntries(child: Expression)
copy(child = newChild)
}

case class MapSort(base: Expression, ascendingOrder: Expression)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
case class MapSort(base: Expression, ascendingOrder: Expression)
case class MapSort(base: Expression, ascendingOrder: Boolean)

Copy link
Contributor Author

@stevomitric stevomitric Mar 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't the BinaryExpression require two expressions here? Do we demote this to UnaryExpression?

EDIT: Expression for ascendingOrder in array sorting has been set as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's is the internal use-cases for the expression? Do we need this parameter at all?

Seems like you are going to pass true as ascendingOrder always at
https://github.com/apache/spark/pull/45549/files#diff-11264d807efa58054cca2d220aae8fba644ee0f0f2a4722c46d52828394846efR2488

   case a @ Aggregate(groupingExpr, x, b) =>
      val newGrouping = groupingExpr.map { expr =>
        (expr, expr.dataType) match {
          case (_: MapSort, _) => expr
          case (_, _: MapType) =>
            MapSort(expr, Literal.TrueLiteral)
          case _ => expr

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the point of internal use, we don't need it. Refactored expression as UnaryExpression and removed ordering altogether.

extends BinaryExpression with NullIntolerant with QueryErrorsBase {

def this(e: Expression) = this(e, Literal(true))

val keyType: DataType = base.dataType.asInstanceOf[MapType].keyType
val valueType: DataType = base.dataType.asInstanceOf[MapType].valueType

override def left: Expression = base
override def right: Expression = ascendingOrder
override def dataType: DataType = base.dataType

override def checkInputDataTypes(): TypeCheckResult = base.dataType match {
case m: MapType if RowOrdering.isOrderable(m.keyType) =>
ascendingOrder match {
case Literal(_: Boolean, BooleanType) =>
TypeCheckResult.TypeCheckSuccess
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have to be so strict on this argument? For example, could you imagine a case where you want to select the sort order based on the values in another column or the result of an expression? Is this needlessly restrictive?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, could you imagine a case where you want to select the sort order based on the values in another column or the result of an expression?

I can imagine the case but so far we are going to use the expression internally for one case only. Support of ascendingOrder = false or even an arbitrary boolean expression just overcomplicates the code.

case _ =>
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(1),
"requiredType" -> toSQLType(BooleanType),
"inputSql" -> toSQLExpr(ascendingOrder),
"inputType" -> toSQLType(ascendingOrder.dataType))
)
}
case _: MapType =>
DataTypeMismatch(
errorSubClass = "INVALID_ORDERING_TYPE",
messageParameters = Map(
"functionName" -> toSQLId(prettyName),
"dataType" -> toSQLType(base.dataType)
)
)
case _ =>
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(0),
"requiredType" -> toSQLType(MapType),
"inputSql" -> toSQLExpr(base),
"inputType" -> toSQLType(base.dataType))
)
}

override def nullSafeEval(array: Any, ascending: Any): Any = {
// put keys and their respective values inside a tuple and sort them
// according to the key ordering. Extract the new sorted k/v pairs to form a sorted map

val mapData = array.asInstanceOf[MapData]
val numElements = mapData.numElements()
val keys = mapData.keyArray()
val values = mapData.valueArray()

val ordering = if (ascending.asInstanceOf[Boolean]) {
PhysicalDataType.ordering(keyType)
} else {
PhysicalDataType.ordering(keyType).reverse
}

val sortedMap = Array
.tabulate(numElements)(i => (keys.get(i, keyType).asInstanceOf[Any],
values.get(i, valueType).asInstanceOf[Any]))
.sortBy(_._1)(ordering)

new ArrayBasedMapData(new GenericArrayData(sortedMap.map(_._1)),
new GenericArrayData(sortedMap.map(_._2)))
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (b, order) => sortCodegen(ctx, ev, b, order))
}

private def sortCodegen(ctx: CodegenContext, ev: ExprCode,
base: String, order: String): String = {

val arrayBasedMapData = classOf[ArrayBasedMapData].getName
val genericArrayData = classOf[GenericArrayData].getName

val numElements = ctx.freshName("numElements")
val keys = ctx.freshName("keys")
val values = ctx.freshName("values")
val sortArray = ctx.freshName("sortArray")
val i = ctx.freshName("i")
val o1 = ctx.freshName("o1")
val o1entry = ctx.freshName("o1entry")
val o2 = ctx.freshName("o2")
val o2entry = ctx.freshName("o2entry")
val c = ctx.freshName("c")
val newKeys = ctx.freshName("newKeys")
val newValues = ctx.freshName("newValues")

val boxedKeyType = CodeGenerator.boxedType(keyType)
val boxedValueType = CodeGenerator.boxedType(valueType)
val javaKeyType = CodeGenerator.javaType(keyType)

val simpleEntryType = s"java.util.AbstractMap.SimpleEntry<$boxedKeyType, $boxedValueType>"

val comp = if (CodeGenerator.isPrimitiveType(keyType)) {
val v1 = ctx.freshName("v1")
val v2 = ctx.freshName("v2")
s"""
|$javaKeyType $v1 = (($boxedKeyType) $o1).${javaKeyType}Value();
|$javaKeyType $v2 = (($boxedKeyType) $o2).${javaKeyType}Value();
|int $c = ${ctx.genComp(keyType, v1, v2)};
""".stripMargin
} else {
s"int $c = ${ctx.genComp(keyType, s"(($javaKeyType) $o1)", s"(($javaKeyType) $o2)")};"
}

s"""
|final int $numElements = $base.numElements();
|ArrayData $keys = $base.keyArray();
|ArrayData $values = $base.valueArray();
|
|Object[] $sortArray = new Object[$numElements];
|
|for (int $i = 0; $i < $numElements; $i++) {
| $sortArray[$i] = new $simpleEntryType(
| ${CodeGenerator.getValue(keys, keyType, i)},
| ${CodeGenerator.getValue(values, valueType, i)});
|}
|
|java.util.Arrays.sort($sortArray, new java.util.Comparator<Object>() {
| @Override public int compare(Object $o1entry, Object $o2entry) {
| Object $o1 = (($simpleEntryType) $o1entry).getKey();
| Object $o2 = (($simpleEntryType) $o2entry).getKey();
| $comp;
| return $order ? $c : -$c;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for just seeing this, but maybe we should do here the same thing that ArraySort does which is put the ordering into a variable outside of compare and just multiply it with the result?

this way we avoid branching in every comparison

| }
|});
|
|Object[] $newKeys = new Object[$numElements];
|Object[] $newValues = new Object[$numElements];
|
|for (int $i = 0; $i < $numElements; $i++) {
| $newKeys[$i] = (($simpleEntryType) $sortArray[$i]).getKey();
| $newValues[$i] = (($simpleEntryType) $sortArray[$i]).getValue();
|}
|
|${ev.value} = new $arrayBasedMapData(
| new $genericArrayData($newKeys), new $genericArrayData($newValues));
|""".stripMargin
}

override def prettyName: String = "map_sort"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this since the expression hasn't been bound to the function name.


override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression)
: MapSort = copy(base = newLeft, ascendingOrder = newRight)
}

/**
* Common base class for [[SortArray]] and [[ArraySort]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,46 @@ class CollectionExpressionsSuite
)
}

test("Sort Map") {
val intKey = Literal.create(Map(2 -> 2, 1 -> 1, 3 -> 3), MapType(IntegerType, IntegerType))
val boolKey = Literal.create(Map(true -> 2, false -> 1), MapType(BooleanType, IntegerType))
val stringKey = Literal.create(Map("2" -> 2, "1" -> 1, "3" -> 3),
MapType(StringType, IntegerType))
val arrayKey = Literal.create(Map(Seq(2) -> 2, Seq(1) -> 1, Seq(3) -> 3),
MapType(ArrayType(IntegerType), IntegerType))
val nestedArrayKey = Literal.create(Map(Seq(Seq(2)) -> 2, Seq(Seq(1)) -> 1, Seq(Seq(3)) -> 3),
MapType(ArrayType(ArrayType(IntegerType)), IntegerType))
val structKey = Literal.create(
Map(create_row(2) -> 2, create_row(1) -> 1, create_row(3) -> 3),
MapType(StructType(Seq(StructField("a", IntegerType))), IntegerType))

checkEvaluation(new MapSort(intKey), Map(1 -> 1, 2 -> 2, 3 -> 3))
checkEvaluation(MapSort(intKey, Literal.create(false, BooleanType)),
Map(3 -> 3, 2 -> 2, 1 -> 1))

checkEvaluation(new MapSort(boolKey), Map(false -> 1, true -> 2))
checkEvaluation(MapSort(boolKey, Literal.create(false, BooleanType)),
Map(true -> 2, false -> 1))

checkEvaluation(new MapSort(stringKey), Map("1" -> 1, "2" -> 2, "3" -> 3))
checkEvaluation(MapSort(stringKey, Literal.create(false, BooleanType)),
Map("3" -> 3, "2" -> 2, "1" -> 1))

checkEvaluation(new MapSort(arrayKey), Map(Seq(1) -> 1, Seq(2) -> 2, Seq(3) -> 3))
checkEvaluation(MapSort(arrayKey, Literal.create(false, BooleanType)),
Map(Seq(3) -> 3, Seq(2) -> 2, Seq(1) -> 1))

checkEvaluation(new MapSort(nestedArrayKey),
Map(Seq(Seq(1)) -> 1, Seq(Seq(2)) -> 2, Seq(Seq(3)) -> 3))
checkEvaluation(MapSort(nestedArrayKey, Literal.create(false, BooleanType)),
Map(Seq(Seq(3)) -> 3, Seq(Seq(2)) -> 2, Seq(Seq(1)) -> 1))

checkEvaluation(new MapSort(structKey),
Map(create_row(1) -> 1, create_row(2) -> 2, create_row(3) -> 3))
checkEvaluation(MapSort(structKey, Literal.create(false, BooleanType)),
Map(create_row(3) -> 3, create_row(2) -> 2, create_row(1) -> 1))
}

test("Sort Array") {
val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
Expand Down