Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
134e9e4
[SPARK-29020] Improving array_sort behaviour
Sep 9, 2019
aeee71c
[SPARK-29020] [SQL] Keep array_sort original behaviour
Sep 10, 2019
ebde544
[SPARK-29020] [SQL] ascending parameter is now a Column
Sep 10, 2019
3f4c328
Array sorting as HOF with Integers Asc and Desc
Sep 15, 2019
7264fc5
4/6 cases working
Sep 15, 2019
210baf7
undo array_sort asc flag
Sep 15, 2019
f7e7d39
Name and indent refactor
Sep 16, 2019
e651094
Added null handling
Sep 16, 2019
b32a9fa
fix import for scalastyle
Sep 16, 2019
cc44b90
added checkInputDataTypes function
Sep 16, 2019
2285678
rename HOF array_sort to sort
Sep 17, 2019
d35d6cb
Remove ArraySort and unifiyng it in new ArraySort HOF
Sep 18, 2019
ad04599
added sortEval in ArraySortLike
Sep 18, 2019
7ad574a
add array_sort in functions for scala API
Sep 18, 2019
d75a671
Refactor changes
Sep 18, 2019
0c65ff8
Consistency with array_sort constructor
Sep 23, 2019
83ddb8b
fix comment in ArraySort
Sep 24, 2019
adf0e0e
Indentation changes
Sep 27, 2019
2224354
change dataType from ArraySort
Sep 28, 2019
92586cc
Unregistered Udfs, null handle in comparators
Oct 22, 2019
c836dae
Remove ArraySortLike from array_sort
Oct 25, 2019
f7a93c5
Fix ArraySort to support comparator function.
ueshin Nov 4, 2019
6b08cdc
Merge pull request #1 from ueshin/pr/25728/array_sort
Gschiavon Nov 7, 2019
a345142
fixed test after changing comparator order
Nov 7, 2019
3fe059a
fixing ArraySort Query Example
Nov 8, 2019
53b563e
Fix lambda function names and readabilty of query examples
Nov 13, 2019
7c38b8a
revert indent
Nov 14, 2019
ef28d4f
Changed ArraySort description
Nov 17, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -900,54 +900,6 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
override def prettyName: String = "sort_array"
}


/**
* Sorts the input array in ascending order according to the natural ordering of
* the array elements and returns it.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(array) - Sorts the input array in ascending order. The elements of the input array must
be orderable. Null elements will be placed at the end of the returned array.
""",
examples = """
Examples:
> SELECT _FUNC_(array('b', 'd', null, 'c', 'a'));
["a","b","c","d",null]
""",
since = "2.4.0")
// scalastyle:on line.size.limit
case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLike {

override def dataType: DataType = child.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)

override def arrayExpression: Expression = child
override def nullOrder: NullOrder = NullOrder.Greatest

override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
case ArrayType(dt, _) if RowOrdering.isOrderable(dt) =>
TypeCheckResult.TypeCheckSuccess
case ArrayType(dt, _) =>
val dtSimple = dt.catalogString
TypeCheckResult.TypeCheckFailure(
s"$prettyName does not support sorting array of type $dtSimple which is not orderable")
case _ =>
TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.")
}

override def nullSafeEval(array: Any): Any = {
sortEval(array, true)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => sortCodegen(ctx, ev, c, "true"))
}

override def prettyName: String = "array_sort"
}

/**
* Returns a random permutation of the given array.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import java.util.Comparator
import java.util.concurrent.atomic.AtomicReference

import scala.collection.mutable
Expand Down Expand Up @@ -285,6 +286,113 @@ case class ArrayTransform(
override def prettyName: String = "transform"
}

/**
* Sorts elements in an array using a comparator function.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """_FUNC_(expr, func) - Sorts the input array in ascending order. The elements of the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this phrase in ascending order always true?

input array must be orderable. Null elements will be placed at the end of the returned
array. Since 3.0.0 this function also sorts and returns the array based on the given
comparator function. The comparator will take two arguments
representing two elements of the array.
It returns -1, 0, or 1 as the first element is less than, equal to, or greater
than the second element. If the comparator function returns other
values (including null), the function will fail and raise an error.
""",
examples = """
Examples:
> SELECT _FUNC_(array(5, 6, 1), (left, right) -> case when left < right then -1 when left > right then 1 else 0 end);
[1,5,6]
> SELECT _FUNC_(array('bc', 'ab', 'dc'), (left, right) -> case when left is null and right is null then 0 when left is null then -1 when right is null then 1 when left < right then 1 when left > right then -1 else 0 end);
["dc","bc","ab"]
> SELECT _FUNC_(array('b', 'd', null, 'c', 'a'));
["a","b","c","d",null]
""",
since = "2.4.0")
// scalastyle:on line.size.limit
case class ArraySort(
argument: Expression,
function: Expression)
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {

def this(argument: Expression) = this(argument, ArraySort.defaultComparator)

@transient lazy val elementType: DataType =
argument.dataType.asInstanceOf[ArrayType].elementType

override def dataType: ArrayType = argument.dataType.asInstanceOf[ArrayType]
override def checkInputDataTypes(): TypeCheckResult = {
checkArgumentDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
argument.dataType match {
case ArrayType(dt, _) if RowOrdering.isOrderable(dt) =>
if (function.dataType == IntegerType) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure("Return type of the given function has to be " +
"IntegerType")
}
case ArrayType(dt, _) =>
val dtSimple = dt.catalogString
TypeCheckResult.TypeCheckFailure(
s"$prettyName does not support sorting array of type $dtSimple which is not " +
"orderable")
case _ =>
TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.")
}
case failure => failure
}
}

override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArraySort = {
val ArrayType(elementType, containsNull) = argument.dataType
copy(function =
f(function, (elementType, containsNull) :: (elementType, containsNull) :: Nil))
}

@transient lazy val LambdaFunction(_,
Seq(firstElemVar: NamedLambdaVariable, secondElemVar: NamedLambdaVariable), _) = function

def comparator(inputRow: InternalRow): Comparator[Any] = {
val f = functionForEval
(o1: Any, o2: Any) => {
firstElemVar.value.set(o1)
secondElemVar.value.set(o2)
f.eval(inputRow).asInstanceOf[Int]
}
}

override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
val arr = argumentValue.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
if (elementType != NullType) {
java.util.Arrays.sort(arr, comparator(inputRow))
}
new GenericArrayData(arr.asInstanceOf[Array[Any]])
}

override def prettyName: String = "array_sort"
}

object ArraySort {

def comparator(left: Expression, right: Expression): Expression = {
val lit0 = Literal(0)
val lit1 = Literal(1)
val litm1 = Literal(-1)

If(And(IsNull(left), IsNull(right)), lit0,
If(IsNull(left), lit1, If(IsNull(right), litm1,
If(LessThan(left, right), litm1, If(GreaterThan(left, right), lit1, lit0)))))
}

val defaultComparator: LambdaFunction = {
val left = UnresolvedNamedLambdaVariable(Seq("left"))
val right = UnresolvedNamedLambdaVariable(Seq("right"))
LambdaFunction(comparator(left, right), Seq(left, right))
}
}

/**
* Filters entries in a map using the provided function.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,16 +363,6 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
val arrayArrayStruct = Literal.create(Seq(aas2, aas1), typeAAS)

checkEvaluation(new SortArray(arrayArrayStruct), Seq(aas1, aas2))

checkEvaluation(ArraySort(a0), Seq(1, 2, 3))
checkEvaluation(ArraySort(a1), Seq[Integer]())
checkEvaluation(ArraySort(a2), Seq("a", "b"))
checkEvaluation(ArraySort(a3), Seq("a", "b", null))
checkEvaluation(ArraySort(a4), Seq(d1, d2))
checkEvaluation(ArraySort(a5), Seq(null, null))
checkEvaluation(ArraySort(arrayStruct), Seq(create_row(1), create_row(2)))
checkEvaluation(ArraySort(arrayArray), Seq(aa1, aa2))
checkEvaluation(ArraySort(arrayArrayStruct), Seq(aas1, aas2))
}

test("Array contains") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
ArrayTransform(expr, createLambda(et, cn, IntegerType, false, f)).bind(validateBinding)
}

def arraySort(expr: Expression): Expression = {
arraySort(expr, ArraySort.comparator)
}

def arraySort(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val ArrayType(et, cn) = expr.dataType
ArraySort(expr, createLambda(et, cn, et, cn, f)).bind(validateBinding)
}

def filter(expr: Expression, f: Expression => Expression): Expression = {
val ArrayType(et, cn) = expr.dataType
ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding)
Expand Down Expand Up @@ -162,6 +171,47 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
Seq("[1, 3, 5]", null, "[4, 6]"))
}

test("ArraySort") {
val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType))
val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType))
val d1 = new Decimal().set(10)
val d2 = new Decimal().set(100)
val a4 = Literal.create(Seq(d2, d1), ArrayType(DecimalType(10, 0)))
val a5 = Literal.create(Seq(null, null), ArrayType(NullType))

val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS)

val typeAA = ArrayType(ArrayType(IntegerType))
val aa1 = Array[java.lang.Integer](1, 2)
val aa2 = Array[java.lang.Integer](3, null, 4)
val arrayArray = Literal.create(Seq(aa2, aa1), typeAA)

val typeAAS = ArrayType(ArrayType(StructType(StructField("a", IntegerType) :: Nil)))
val aas1 = Array(create_row(1))
val aas2 = Array(create_row(2))
val arrayArrayStruct = Literal.create(Seq(aas2, aas1), typeAAS)

checkEvaluation(arraySort(a0), Seq(1, 2, 3))
checkEvaluation(arraySort(a1), Seq[Integer]())
checkEvaluation(arraySort(a2), Seq("a", "b"))
checkEvaluation(arraySort(a3), Seq("a", "b", null))
checkEvaluation(arraySort(a4), Seq(d1, d2))
checkEvaluation(arraySort(a5), Seq(null, null))
checkEvaluation(arraySort(arrayStruct), Seq(create_row(1), create_row(2)))
checkEvaluation(arraySort(arrayArray), Seq(aa1, aa2))
checkEvaluation(arraySort(arrayArrayStruct), Seq(aas1, aas2))

checkEvaluation(arraySort(a0, (left, right) => UnaryMinus(ArraySort.comparator(left, right))),
Seq(3, 2, 1))
checkEvaluation(arraySort(a3, (left, right) => UnaryMinus(ArraySort.comparator(left, right))),
Seq(null, "b", "a"))
checkEvaluation(arraySort(a4, (left, right) => UnaryMinus(ArraySort.comparator(left, right))),
Seq(d2, d1))
}

test("MapFilter") {
def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val MapType(kt, vt, vcn) = expr.dataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3334,7 +3334,7 @@ object functions {
* @group collection_funcs
* @since 2.4.0
*/
def array_sort(e: Column): Column = withExpr { ArraySort(e.expr) }
def array_sort(e: Column): Column = withExpr { new ArraySort(e.expr) }

/**
* Remove all elements that equal to element from the given array.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,86 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
)
}

test("array_sort with lambda functions") {

spark.udf.register("fAsc", (x: Int, y: Int) => {
if (x < y) -1
else if (x == y) 0
else 1
})

spark.udf.register("fDesc", (x: Int, y: Int) => {
if (x < y) 1
else if (x == y) 0
else -1
})

spark.udf.register("fString", (x: String, y: String) => {
if (x == null && y == null) 0
else if (x == null) 1
else if (y == null) -1
else if (x < y) 1
else if (x == y) 0
else -1
})

spark.udf.register("fStringLength", (x: String, y: String) => {
if (x == null && y == null) 0
else if (x == null) 1
else if (y == null) -1
else if (x.length < y.length) -1
else if (x.length == y.length) 0
else 1
})

val df1 = Seq(Array[Int](3, 2, 5, 1, 2)).toDF("a")
checkAnswer(
df1.selectExpr("array_sort(a, (x, y) -> fAsc(x, y))"),
Seq(
Row(Seq(1, 2, 2, 3, 5)))
)

checkAnswer(
df1.selectExpr("array_sort(a, (x, y) -> fDesc(x, y))"),
Seq(
Row(Seq(5, 3, 2, 2, 1)))
)

val df2 = Seq(Array[String]("bc", "ab", "dc")).toDF("a")
checkAnswer(
df2.selectExpr("array_sort(a, (x, y) -> fString(x, y))"),
Seq(
Row(Seq("dc", "bc", "ab")))
)

val df3 = Seq(Array[String]("a", "abcd", "abc")).toDF("a")
checkAnswer(
df3.selectExpr("array_sort(a, (x, y) -> fStringLength(x, y))"),
Seq(
Row(Seq("a", "abc", "abcd")))
)

val df4 = Seq((Array[Array[Int]](Array(2, 3, 1), Array(4, 2, 1, 4),
Array(1, 2)), "x")).toDF("a", "b")
checkAnswer(
df4.selectExpr("array_sort(a, (x, y) -> fAsc(cardinality(x), cardinality(y)))"),
Seq(
Row(Seq[Seq[Int]](Seq(1, 2), Seq(2, 3, 1), Seq(4, 2, 1, 4))))
)

val df5 = Seq(Array[String]("bc", null, "ab", "dc")).toDF("a")
checkAnswer(
df5.selectExpr("array_sort(a, (x, y) -> fString(x, y))"),
Seq(
Row(Seq("dc", "bc", "ab", null)))
)

spark.sql("drop temporary function fAsc")
spark.sql("drop temporary function fDesc")
spark.sql("drop temporary function fString")
spark.sql("drop temporary function fStringLength")
}

test("sort_array/array_sort functions") {
val df = Seq(
(Array[Int](2, 1, 3), Array("b", "c", "a")),
Expand Down Expand Up @@ -383,7 +463,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {

assert(intercept[AnalysisException] {
df3.selectExpr("array_sort(a)").collect()
}.getMessage().contains("only supports array input"))
}.getMessage().contains("argument 1 requires array type, however, '`a`' is of string type"))
}

def testSizeOfArray(sizeOfNull: Any): Unit = {
Expand Down