Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ object FunctionRegistry {
expression[ArrayRemove]("array_remove"),
expression[ArrayDistinct]("array_distinct"),
expression[ArrayTransform]("transform"),
expression[MapFilter]("map_filter"),
expression[ArrayFilter]("filter"),
expression[ArrayAggregate]("aggregate"),
CreateStruct.registryEntry,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -133,7 +133,10 @@ trait HigherOrderFunction extends Expression {
}
}

trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes {
/**
* Trait for functions having as input one argument and one function.
*/
trait UnaryHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this trait but I'm not sure whether we can say "Unary"HigherOrderFunction for this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, how about defining nullSafeEval for input in this trait like UnaryExpression? (nullInputSafeEval?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I called it Unary as it gets one input and one function. Honestly I can't think of a better name without becoming very verbose. if you have a better suggestion I am happy to follow it. I will add the nullSafeEval, thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @hvanhovell for the naming?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use the term Unary a lot and this is different from the other uses. The name should convey a HigherOrderFunction that only uses a single (lambda) function right? The only thing I can come up with is SingleHigherOrderFunction. Simple would probably also be fine.


def input: Expression

Expand All @@ -145,9 +148,31 @@ trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInpu

def expectingFunctionType: AbstractDataType = AnyDataType

@transient lazy val functionForEval: Expression = functionsForEval.head

/**
* Called by [[eval]]. If a subclass keeps the default nullability, it can override this method
* in order to save null-check code.
*/
protected def nullSafeEval(inputRow: InternalRow, input: Any): Any =
sys.error(s"UnaryHigherOrderFunction must override either eval or nullSafeEval")

override def eval(inputRow: InternalRow): Any = {
val value = input.eval(inputRow)
if (value == null) {
null
} else {
nullSafeEval(inputRow, value)
}
}
}

trait ArrayBasedUnaryHigherOrderFunction extends UnaryHigherOrderFunction {
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, expectingFunctionType)
}

@transient lazy val functionForEval: Expression = functionsForEval.head
trait MapBasedUnaryHigherOrderFunction extends UnaryHigherOrderFunction {
override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType)
}

object ArrayBasedHigherOrderFunction {
Expand Down Expand Up @@ -179,7 +204,7 @@ object ArrayBasedHigherOrderFunction {
case class ArrayTransform(
input: Expression,
function: Expression)
extends ArrayBasedHigherOrderFunction with CodegenFallback {
extends ArrayBasedUnaryHigherOrderFunction with CodegenFallback {

override def nullable: Boolean = input.nullable

Expand All @@ -205,29 +230,82 @@ case class ArrayTransform(
(elementVar, indexVar)
}

override def eval(input: InternalRow): Any = {
val arr = this.input.eval(input).asInstanceOf[ArrayData]
if (arr == null) {
null
} else {
val f = functionForEval
val result = new GenericArrayData(new Array[Any](arr.numElements))
var i = 0
while (i < arr.numElements) {
elementVar.value.set(arr.get(i, elementVar.dataType))
if (indexVar.isDefined) {
indexVar.get.value.set(i)
}
result.update(i, f.eval(input))
i += 1
override def nullSafeEval(inputRow: InternalRow, inputValue: Any): Any = {
val arr = inputValue.asInstanceOf[ArrayData]
val f = functionForEval
val result = new GenericArrayData(new Array[Any](arr.numElements))
var i = 0
while (i < arr.numElements) {
elementVar.value.set(arr.get(i, elementVar.dataType))
if (indexVar.isDefined) {
indexVar.get.value.set(i)
}
result
result.update(i, f.eval(inputRow))
i += 1
}
result
}

override def prettyName: String = "transform"
}

/**
* Filters entries in a map using the provided function.
*/
@ExpressionDescription(
usage = "_FUNC_(expr, func) - Filters entries in a map using the function.",
examples = """
Examples:
> SELECT _FUNC_(map(1, 0, 2, 2, 3, -1), (k, v) -> k > v);
[1 -> 0, 3 -> -1]
""",
since = "2.4.0")
case class MapFilter(
input: Expression,
function: Expression)
extends MapBasedUnaryHigherOrderFunction with CodegenFallback {

@transient val (keyType, valueType, valueContainsNull) = input.dataType match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should be a function in object MapBasedUnaryHigherOrderFunction, we can use it in other map based higher order function just like using ArrayBasedHigherOrderFunction.elementArgumentType.

case MapType(kType, vType, vContainsNull) => (kType, vType, vContainsNull)
case _ =>
val MapType(kType, vType, vContainsNull) = MapType.defaultConcreteType
(kType, vType, vContainsNull)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about extracting this to object MapBasedUnaryHigherOrderFunction like array based one? We'll need this in other map based ones.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I meant something like:

object MapBasedUnaryHigherOrderFunction {

  def keyValueArgumentType(dt: DataType): (DataType, DataType, Boolean) = {
    dt match {
      case MapType(kType, vType, vContainsNull) => (kType, vType, vContainsNull)
      case _ =>
        val MapType(kType, vType, vContainsNull) = MapType.defaultConcreteType
        (kType, vType, vContainsNull)
    }
  }
}

...

case class MapFilter( ... ) {
  ...
  @transient val (keyType, valueType, valueContainsNull) =
    MapBasedUnaryHigherOrderFunction.keyValueArgumentType(input.dataType)
  ...
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, something wrong with introducing object to have util methods?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about:

  1. rename ArrayBasedHigherOrderFunction object to HigherOrderFunction
  2. rename elementArgumentType method to arrayElementArgumentType
  3. move keyValueArgumentType to HigherOrderFunction object and rename to mapKeyValueArgumentType

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, sorry I haven read carefully your comment, now I see what you meant. Yes, I agree unifying them in a Helper object. I am updating accordingly. Thanks.


@transient lazy val (keyVar, valueVar) = {
val args = function.asInstanceOf[LambdaFunction].arguments
(args.head.asInstanceOf[NamedLambdaVariable], args.tail.head.asInstanceOf[NamedLambdaVariable])
}

override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapFilter = {
copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil))
}

override def nullable: Boolean = input.nullable

override def nullSafeEval(inputRow: InternalRow, value: Any): Any = {
val m = value.asInstanceOf[MapData]
val f = functionForEval
val retKeys = new mutable.ListBuffer[Any]
val retValues = new mutable.ListBuffer[Any]
m.foreach(keyType, valueType, (k, v) => {
keyVar.value.set(k)
valueVar.value.set(v)
if (f.eval(inputRow).asInstanceOf[Boolean]) {
retKeys += k
retValues += v
}
})
ArrayBasedMapData(retKeys.toArray, retValues.toArray)
}

override def dataType: DataType = input.dataType

override def expectingFunctionType: AbstractDataType = BooleanType

override def prettyName: String = "map_filter"
}

/**
* Filters the input array using the given lambda function.
*/
Expand All @@ -242,7 +320,7 @@ case class ArrayTransform(
case class ArrayFilter(
input: Expression,
function: Expression)
extends ArrayBasedHigherOrderFunction with CodegenFallback {
extends ArrayBasedUnaryHigherOrderFunction with CodegenFallback {

override def nullable: Boolean = input.nullable

Expand All @@ -257,23 +335,19 @@ case class ArrayFilter(

@transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function

override def eval(input: InternalRow): Any = {
val arr = this.input.eval(input).asInstanceOf[ArrayData]
if (arr == null) {
null
} else {
val f = functionForEval
val buffer = new mutable.ArrayBuffer[Any](arr.numElements)
var i = 0
while (i < arr.numElements) {
elementVar.value.set(arr.get(i, elementVar.dataType))
if (f.eval(input).asInstanceOf[Boolean]) {
buffer += elementVar.value.get
}
i += 1
override def nullSafeEval(inputRow: InternalRow, value: Any): Any = {
val arr = value.asInstanceOf[ArrayData]
val f = functionForEval
val buffer = new mutable.ArrayBuffer[Any](arr.numElements)
var i = 0
while (i < arr.numElements) {
elementVar.value.set(arr.get(i, elementVar.dataType))
if (f.eval(inputRow).asInstanceOf[Boolean]) {
buffer += elementVar.value.get
}
new GenericArrayData(buffer)
i += 1
}
new GenericArrayData(buffer)
}

override def prettyName: String = "filter"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,55 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
Seq("[1, 3, 5]", null, "[4, 6]"))
}

test("MapFilter") {
def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val mt = expr.dataType.asInstanceOf[MapType]
MapFilter(expr, createLambda(mt.keyType, false, mt.valueType, mt.valueContainsNull, f))
}
val mii0 = Literal.create(Map(1 -> 0, 2 -> 10, 3 -> -1),
MapType(IntegerType, IntegerType, valueContainsNull = false))
val mii1 = Literal.create(Map(1 -> null, 2 -> 10, 3 -> null),
MapType(IntegerType, IntegerType, valueContainsNull = true))
val miin = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false))

val kGreaterThanV: (Expression, Expression) => Expression = (k, v) => k > v

checkEvaluation(mapFilter(mii0, kGreaterThanV), Map(1 -> 0, 3 -> -1))
checkEvaluation(mapFilter(mii1, kGreaterThanV), Map())
checkEvaluation(mapFilter(miin, kGreaterThanV), null)

val valueIsNull: (Expression, Expression) => Expression = (_, v) => v.isNull

checkEvaluation(mapFilter(mii0, valueIsNull), Map())
checkEvaluation(mapFilter(mii1, valueIsNull), Map(1 -> null, 3 -> null))
checkEvaluation(mapFilter(miin, valueIsNull), null)

val msi0 = Literal.create(Map("abcdf" -> 5, "abc" -> 10, "" -> 0),
MapType(StringType, IntegerType, valueContainsNull = false))
val msi1 = Literal.create(Map("abcdf" -> 5, "abc" -> 10, "" -> null),
MapType(StringType, IntegerType, valueContainsNull = true))
val msin = Literal.create(null, MapType(StringType, IntegerType, valueContainsNull = false))

val isLengthOfKey: (Expression, Expression) => Expression = (k, v) => Length(k) === v

checkEvaluation(mapFilter(msi0, isLengthOfKey), Map("abcdf" -> 5, "" -> 0))
checkEvaluation(mapFilter(msi1, isLengthOfKey), Map("abcdf" -> 5))
checkEvaluation(mapFilter(msin, isLengthOfKey), null)

val mia0 = Literal.create(Map(1 -> Seq(0, 1, 2), 2 -> Seq(10), -3 -> Seq(-1, 0, -2, 3)),
MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = false))
val mia1 = Literal.create(Map(1 -> Seq(0, 1, 2), 2 -> null, -3 -> Seq(-1, 0, -2, 3)),
MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = true))
val mian = Literal.create(
null, MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = false))

val customFunc: (Expression, Expression) => Expression = (k, v) => Size(v) + k > 3

checkEvaluation(mapFilter(mia0, customFunc), Map(1 -> Seq(0, 1, 2)))
checkEvaluation(mapFilter(mia1, customFunc), Map(1 -> Seq(0, 1, 2)))
checkEvaluation(mapFilter(mian, customFunc), null)
}

test("ArrayFilter") {
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1800,6 +1800,52 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type"))
}

test("map_filter") {
val dfInts = Seq(
Map(1 -> 10, 2 -> 20, 3 -> 30),
Map(1 -> -1, 2 -> -2, 3 -> -3),
Map(1 -> 10, 2 -> 5, 3 -> -3)).toDF("m")

checkAnswer(dfInts.selectExpr(
"map_filter(m, (k, v) -> k * 10 = v)", "map_filter(m, (k, v) -> k = -v)"),
Seq(
Row(Map(1 -> 10, 2 -> 20, 3 -> 30), Map()),
Row(Map(), Map(1 -> -1, 2 -> -2, 3 -> -3)),
Row(Map(1 -> 10), Map(3 -> -3))))

val dfComplex = Seq(
Map(1 -> Seq(Some(1)), 2 -> Seq(Some(1), Some(2)), 3 -> Seq(Some(1), Some(2), Some(3))),
Map(1 -> null, 2 -> Seq(Some(-2), Some(-2)), 3 -> Seq[Option[Int]](None))).toDF("m")

checkAnswer(dfComplex.selectExpr(
"map_filter(m, (k, v) -> k = v[0])", "map_filter(m, (k, v) -> k = size(v))"),
Seq(
Row(Map(1 -> Seq(1)), Map(1 -> Seq(1), 2 -> Seq(1, 2), 3 -> Seq(1, 2, 3))),
Row(Map(), Map(2 -> Seq(-2, -2)))))

// Invalid use cases
val df = Seq(
(Map(1 -> "a"), 1),
(Map.empty[Int, String], 2),
(null, 3)
).toDF("s", "i")

val ex1 = intercept[AnalysisException] {
df.selectExpr("map_filter(s, (x, y, z) -> x + y + z)")
}
assert(ex1.getMessage.contains("The number of lambda function arguments '3' does not match"))

val ex2 = intercept[AnalysisException] {
df.selectExpr("map_filter(s, x -> x)")
}
assert(ex2.getMessage.contains("The number of lambda function arguments '1' does not match"))

val ex3 = intercept[AnalysisException] {
df.selectExpr("map_filter(i, (k, v) -> k > v)")
}
assert(ex3.getMessage.contains("data type mismatch: argument 1 requires map type"))
}

test("filter function - array for primitive type not containing null") {
val df = Seq(
Seq(1, 9, 8, 7),
Expand Down