Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2080,6 +2080,21 @@ def size(col):
return Column(sc._jvm.functions.size(_to_java_column(col)))


@since(2.4)
def array_max(col):
"""
Collection function: returns the maximum value of the array.

:param col: name of column or expression

>>> df = spark.createDataFrame([([2, 1, 3],),([None, 10, -1],)], ['data'])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick nit ,( ->, (

>>> df.select(array_max(df.data).alias('max')).collect()
[Row(max=3), Row(max=10)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.array_max(_to_java_column(col)))


@since(1.5)
def sort_array(col, asc=True):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ object FunctionRegistry {
expression[MapValues]("map_values"),
expression[Size]("size"),
expression[SortArray]("sort_array"),
expression[ArrayMax]("array_max"),
CreateStruct.registryEntry,

// misc functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -676,11 +676,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
val evals = evalChildren.map(eval =>
s"""
|${eval.code}
|if (!${eval.isNull} && (${ev.isNull} ||
| ${ctx.genGreater(dataType, eval.value, ev.value)})) {
| ${ev.isNull} = false;
| ${ev.value} = ${eval.value};
|}
|${ctx.reassignIfGreater(dataType, ev, eval)}
""".stripMargin
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,23 @@ class CodegenContext {
case _ => s"(${genComp(dataType, c1, c2)}) > 0"
}

/**
* Generates code for updating `partialResult` if `item` is greater than it.
*
* @param dataType data type of the expressions
* @param partialResult `ExprCode` representing the partial result which has to be updated
* @param item `ExprCode` representing the new expression to evaluate for the result
*/
def reassignIfGreater(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = {
s"""
|if (!${item.isNull} && (${partialResult.isNull} ||
| ${genGreater(dataType, item.value, partialResult.value)})) {
| ${partialResult.isNull} = false;
| ${partialResult.value} = ${item.value};
|}
""".stripMargin
}

/**
* Generates code to do null safe execution, i.e. only execute the code when the input is not
* null by adding null check if necessary.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.util.Comparator
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -287,3 +287,61 @@ case class ArrayContains(left: Expression, right: Expression)

override def prettyName: String = "array_contains"
}


/**
* Returns the maximum value in the array.
*/
@ExpressionDescription(
usage = "_FUNC_(array) - Returns the maximum value in the array.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 20, null, 3));
20
""", since = "2.4.0")
case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {

override def nullable: Boolean =
child.nullable || child.dataType.asInstanceOf[ArrayType].containsNull
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should always be true because the array might be empty?


override def foldable: Boolean = child.foldable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same line of code is in UnaryExpression.


override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)

private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childGen = child.genCode(ctx)
val javaType = CodeGenerator.javaType(dataType)
val i = ctx.freshName("i")
val item = ExprCode("",
isNull = StatementValue(s"${childGen.value}.isNullAt($i)", "boolean"),
value = StatementValue(CodeGenerator.getValue(childGen.value, dataType, i), javaType))
ev.copy(code =
s"""
|${childGen.code}
|boolean ${ev.isNull} = true;
|$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to use MIN value for each data type instead of default value?
If we perform this operation against (-10, -100, -1000), I think that we would get -1 as a result.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, isNull is used for assigning the initial value.

|if (!${childGen.isNull}) {
| for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) {
| ${ctx.reassignIfGreater(dataType, ev, item)}
| }
|}
""".stripMargin)
}

override protected def nullSafeEval(input: Any): Any = {
var max: Any = null
input.asInstanceOf[ArrayData].foreach(dataType, (_, item) =>
if (item != null && (max == null || ordering.gt(item, max))) {
max = item
}
)
max
}

override def dataType: DataType = child.dataType match {
case ArrayType(dt, _) => dt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also check if dt is orderable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the check in the checkInputDataTypes method, thanks.

case _ => throw new IllegalStateException("array_max accepts only arrays.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayContains(a3, Literal("")), null)
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
}

test("Array max") {
checkEvaluation(ArrayMax(Literal.create(Seq(1, 10, 2), ArrayType(IntegerType))), 10)
checkEvaluation(
ArrayMax(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "abc")
checkEvaluation(ArrayMax(Literal.create(Seq(null), ArrayType(LongType))), null)
checkEvaluation(ArrayMax(Literal.create(null, ArrayType(StringType))), null)
checkEvaluation(
ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123)
}
}
8 changes: 8 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3300,6 +3300,14 @@ object functions {
*/
def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) }

/**
* Returns the maximum value in the array.
*
* @group collection_funcs
* @since 2.4.0
*/
def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) }

/**
* Returns an unordered array containing the keys of the map.
* @group collection_funcs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
)
}

test("array_max function") {
val df = Seq(
Seq[Option[Int]](Some(1), Some(3), Some(2)),
Seq.empty[Option[Int]],
Seq[Option[Int]](None),
Seq[Option[Int]](None, Some(1), Some(-100))
).toDF("a")

val answer = Seq(Row(3), Row(null), Row(null), Row(1))

checkAnswer(df.select(array_max(df("a"))), answer)
checkAnswer(df.selectExpr("array_max(a)"), answer)
}

private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
Expand Down