Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ object FunctionRegistry {
expression[ArrayRemove]("array_remove"),
expression[ArrayDistinct]("array_distinct"),
expression[ArrayTransform]("transform"),
expression[MapFilter]("map_filter"),
expression[ArrayFilter]("filter"),
expression[ArrayAggregate]("aggregate"),
CreateStruct.registryEntry,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -133,7 +133,10 @@ trait HigherOrderFunction extends Expression {
}
}

trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes {
/**
* Trait for functions having as input one argument and one function.
*/
trait SimpleHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes {

def input: Expression

Expand All @@ -145,9 +148,38 @@ trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInpu

def expectingFunctionType: AbstractDataType = AnyDataType

@transient lazy val functionForEval: Expression = functionsForEval.head

/**
* Called by [[eval]]. If a subclass keeps the default nullability, it can override this method
* in order to save null-check code.
*/
protected def nullSafeEval(inputRow: InternalRow, input: Any): Any =
sys.error(s"UnaryHigherOrderFunction must override either eval or nullSafeEval")

override def eval(inputRow: InternalRow): Any = {
val value = input.eval(inputRow)
if (value == null) {
null
} else {
nullSafeEval(inputRow, value)
}
}
}

trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction {
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, expectingFunctionType)
}

@transient lazy val functionForEval: Expression = functionsForEval.head
trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction {
override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType)

def keyValueArgumentType(dt: DataType): (DataType, DataType, Boolean) = dt match {
case MapType(kType, vType, vContainsNull) => (kType, vType, vContainsNull)
case _ =>
val MapType(kType, vType, vContainsNull) = MapType.defaultConcreteType
(kType, vType, vContainsNull)
}
}

object ArrayBasedHigherOrderFunction {
Expand Down Expand Up @@ -179,7 +211,7 @@ object ArrayBasedHigherOrderFunction {
case class ArrayTransform(
input: Expression,
function: Expression)
extends ArrayBasedHigherOrderFunction with CodegenFallback {
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {

override def nullable: Boolean = input.nullable

Expand All @@ -205,29 +237,77 @@ case class ArrayTransform(
(elementVar, indexVar)
}

override def eval(input: InternalRow): Any = {
val arr = this.input.eval(input).asInstanceOf[ArrayData]
if (arr == null) {
null
} else {
val f = functionForEval
val result = new GenericArrayData(new Array[Any](arr.numElements))
var i = 0
while (i < arr.numElements) {
elementVar.value.set(arr.get(i, elementVar.dataType))
if (indexVar.isDefined) {
indexVar.get.value.set(i)
}
result.update(i, f.eval(input))
i += 1
override def nullSafeEval(inputRow: InternalRow, inputValue: Any): Any = {
val arr = inputValue.asInstanceOf[ArrayData]
val f = functionForEval
val result = new GenericArrayData(new Array[Any](arr.numElements))
var i = 0
while (i < arr.numElements) {
elementVar.value.set(arr.get(i, elementVar.dataType))
if (indexVar.isDefined) {
indexVar.get.value.set(i)
}
result
result.update(i, f.eval(inputRow))
i += 1
}
result
}

override def prettyName: String = "transform"
}

/**
* Filters entries in a map using the provided function.
*/
@ExpressionDescription(
usage = "_FUNC_(expr, func) - Filters entries in a map using the function.",
examples = """
Examples:
> SELECT _FUNC_(map(1, 0, 2, 2, 3, -1), (k, v) -> k > v);
[1 -> 0, 3 -> -1]
""",
since = "2.4.0")
case class MapFilter(
input: Expression,
function: Expression)
extends MapBasedSimpleHigherOrderFunction with CodegenFallback {

@transient lazy val (keyVar, valueVar) = {
val args = function.asInstanceOf[LambdaFunction].arguments
(args.head.asInstanceOf[NamedLambdaVariable], args.tail.head.asInstanceOf[NamedLambdaVariable])
}

@transient val (keyType, valueType, valueContainsNull) = keyValueArgumentType(input.dataType)

override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapFilter = {
copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil))
}

override def nullable: Boolean = input.nullable

override def nullSafeEval(inputRow: InternalRow, value: Any): Any = {
val m = value.asInstanceOf[MapData]
val f = functionForEval
val retKeys = new mutable.ListBuffer[Any]
val retValues = new mutable.ListBuffer[Any]
m.foreach(keyType, valueType, (k, v) => {
keyVar.value.set(k)
valueVar.value.set(v)
if (f.eval(inputRow).asInstanceOf[Boolean]) {
retKeys += k
retValues += v
}
})
ArrayBasedMapData(retKeys.toArray, retValues.toArray)
}

override def dataType: DataType = input.dataType

override def expectingFunctionType: AbstractDataType = BooleanType

override def prettyName: String = "map_filter"
}

/**
* Filters the input array using the given lambda function.
*/
Expand All @@ -242,7 +322,7 @@ case class ArrayTransform(
case class ArrayFilter(
input: Expression,
function: Expression)
extends ArrayBasedHigherOrderFunction with CodegenFallback {
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {

override def nullable: Boolean = input.nullable

Expand All @@ -257,23 +337,19 @@ case class ArrayFilter(

@transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function

override def eval(input: InternalRow): Any = {
val arr = this.input.eval(input).asInstanceOf[ArrayData]
if (arr == null) {
null
} else {
val f = functionForEval
val buffer = new mutable.ArrayBuffer[Any](arr.numElements)
var i = 0
while (i < arr.numElements) {
elementVar.value.set(arr.get(i, elementVar.dataType))
if (f.eval(input).asInstanceOf[Boolean]) {
buffer += elementVar.value.get
}
i += 1
override def nullSafeEval(inputRow: InternalRow, value: Any): Any = {
val arr = value.asInstanceOf[ArrayData]
val f = functionForEval
val buffer = new mutable.ArrayBuffer[Any](arr.numElements)
var i = 0
while (i < arr.numElements) {
elementVar.value.set(arr.get(i, elementVar.dataType))
if (f.eval(inputRow).asInstanceOf[Boolean]) {
buffer += elementVar.value.get
}
new GenericArrayData(buffer)
i += 1
}
new GenericArrayData(buffer)
}

override def prettyName: String = "filter"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,55 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
Seq("[1, 3, 5]", null, "[4, 6]"))
}

test("MapFilter") {
def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val mt = expr.dataType.asInstanceOf[MapType]
MapFilter(expr, createLambda(mt.keyType, false, mt.valueType, mt.valueContainsNull, f))
}
val mii0 = Literal.create(Map(1 -> 0, 2 -> 10, 3 -> -1),
MapType(IntegerType, IntegerType, valueContainsNull = false))
val mii1 = Literal.create(Map(1 -> null, 2 -> 10, 3 -> null),
MapType(IntegerType, IntegerType, valueContainsNull = true))
val miin = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false))

val kGreaterThanV: (Expression, Expression) => Expression = (k, v) => k > v

checkEvaluation(mapFilter(mii0, kGreaterThanV), Map(1 -> 0, 3 -> -1))
checkEvaluation(mapFilter(mii1, kGreaterThanV), Map())
checkEvaluation(mapFilter(miin, kGreaterThanV), null)

val valueIsNull: (Expression, Expression) => Expression = (_, v) => v.isNull

checkEvaluation(mapFilter(mii0, valueIsNull), Map())
checkEvaluation(mapFilter(mii1, valueIsNull), Map(1 -> null, 3 -> null))
checkEvaluation(mapFilter(miin, valueIsNull), null)

val msi0 = Literal.create(Map("abcdf" -> 5, "abc" -> 10, "" -> 0),
MapType(StringType, IntegerType, valueContainsNull = false))
val msi1 = Literal.create(Map("abcdf" -> 5, "abc" -> 10, "" -> null),
MapType(StringType, IntegerType, valueContainsNull = true))
val msin = Literal.create(null, MapType(StringType, IntegerType, valueContainsNull = false))

val isLengthOfKey: (Expression, Expression) => Expression = (k, v) => Length(k) === v

checkEvaluation(mapFilter(msi0, isLengthOfKey), Map("abcdf" -> 5, "" -> 0))
checkEvaluation(mapFilter(msi1, isLengthOfKey), Map("abcdf" -> 5))
checkEvaluation(mapFilter(msin, isLengthOfKey), null)

val mia0 = Literal.create(Map(1 -> Seq(0, 1, 2), 2 -> Seq(10), -3 -> Seq(-1, 0, -2, 3)),
MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = false))
val mia1 = Literal.create(Map(1 -> Seq(0, 1, 2), 2 -> null, -3 -> Seq(-1, 0, -2, 3)),
MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = true))
val mian = Literal.create(
null, MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = false))

val customFunc: (Expression, Expression) => Expression = (k, v) => Size(v) + k > 3

checkEvaluation(mapFilter(mia0, customFunc), Map(1 -> Seq(0, 1, 2)))
checkEvaluation(mapFilter(mia1, customFunc), Map(1 -> Seq(0, 1, 2)))
checkEvaluation(mapFilter(mian, customFunc), null)
}

test("ArrayFilter") {
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1800,6 +1800,52 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type"))
}

test("map_filter") {
val dfInts = Seq(
Map(1 -> 10, 2 -> 20, 3 -> 30),
Map(1 -> -1, 2 -> -2, 3 -> -3),
Map(1 -> 10, 2 -> 5, 3 -> -3)).toDF("m")

checkAnswer(dfInts.selectExpr(
"map_filter(m, (k, v) -> k * 10 = v)", "map_filter(m, (k, v) -> k = -v)"),
Seq(
Row(Map(1 -> 10, 2 -> 20, 3 -> 30), Map()),
Row(Map(), Map(1 -> -1, 2 -> -2, 3 -> -3)),
Row(Map(1 -> 10), Map(3 -> -3))))

val dfComplex = Seq(
Map(1 -> Seq(Some(1)), 2 -> Seq(Some(1), Some(2)), 3 -> Seq(Some(1), Some(2), Some(3))),
Map(1 -> null, 2 -> Seq(Some(-2), Some(-2)), 3 -> Seq[Option[Int]](None))).toDF("m")

checkAnswer(dfComplex.selectExpr(
"map_filter(m, (k, v) -> k = v[0])", "map_filter(m, (k, v) -> k = size(v))"),
Seq(
Row(Map(1 -> Seq(1)), Map(1 -> Seq(1), 2 -> Seq(1, 2), 3 -> Seq(1, 2, 3))),
Row(Map(), Map(2 -> Seq(-2, -2)))))

// Invalid use cases
val df = Seq(
(Map(1 -> "a"), 1),
(Map.empty[Int, String], 2),
(null, 3)
).toDF("s", "i")

val ex1 = intercept[AnalysisException] {
df.selectExpr("map_filter(s, (x, y, z) -> x + y + z)")
}
assert(ex1.getMessage.contains("The number of lambda function arguments '3' does not match"))

val ex2 = intercept[AnalysisException] {
df.selectExpr("map_filter(s, x -> x)")
}
assert(ex2.getMessage.contains("The number of lambda function arguments '1' does not match"))

val ex3 = intercept[AnalysisException] {
df.selectExpr("map_filter(i, (k, v) -> k > v)")
}
assert(ex3.getMessage.contains("data type mismatch: argument 1 requires map type"))
}

test("filter function - array for primitive type not containing null") {
val df = Seq(
Seq(1, 9, 8, 7),
Expand Down