Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ object FunctionRegistry {
expression[ArrayRemove]("array_remove"),
expression[ArrayDistinct]("array_distinct"),
expression[ArrayTransform]("transform"),
expression[MapFilter]("map_filter"),
CreateStruct.registryEntry,

// misc functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ package org.apache.spark.sql.catalyst.expressions

import java.util.concurrent.atomic.AtomicReference

import scala.collection.mutable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -123,7 +125,10 @@ trait HigherOrderFunction extends Expression {
}
}

trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes {
/**
* Trait for functions having as input one argument and one function.
*/
trait UnaryHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this trait but I'm not sure whether we can say "Unary"HigherOrderFunction for this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, how about defining nullSafeEval for input in this trait like UnaryExpression? (nullInputSafeEval?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I called it Unary as it gets one input and one function. Honestly I can't think of a better name without becoming very verbose. if you have a better suggestion I am happy to follow it. I will add the nullSafeEval, thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @hvanhovell for the naming?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use the term Unary a lot and this is different from the other uses. The name should convey a HigherOrderFunction that only uses a single (lambda) function right? The only thing I can come up with is SingleHigherOrderFunction. Simple would probably also be fine.


def input: Expression

Expand All @@ -135,9 +140,15 @@ trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInpu

def expectingFunctionType: AbstractDataType = AnyDataType

@transient lazy val functionForEval: Expression = functionsForEval.head
}

trait ArrayBasedUnaryHigherOrderFunction extends UnaryHigherOrderFunction {
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, expectingFunctionType)
}

@transient lazy val functionForEval: Expression = functionsForEval.head
trait MapBasedUnaryHigherOrderFunction extends UnaryHigherOrderFunction {
override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType)
}

/**
Expand All @@ -157,7 +168,7 @@ trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInpu
case class ArrayTransform(
input: Expression,
function: Expression)
extends ArrayBasedHigherOrderFunction with CodegenFallback {
extends ArrayBasedUnaryHigherOrderFunction with CodegenFallback {

override def nullable: Boolean = input.nullable

Expand Down Expand Up @@ -210,3 +221,66 @@ case class ArrayTransform(

override def prettyName: String = "transform"
}

/**
* Filters entries in a map using the provided function.
*/
@ExpressionDescription(
usage = "_FUNC_(expr, func) - Filters entries in a map using the function.",
examples = """
Examples:
> SELECT _FUNC_(map(1, 0, 2, 2, 3, -1), (k, v) -> k > v);
[1 -> 0, 3 -> -1]
""",
since = "2.4.0")
case class MapFilter(
input: Expression,
function: Expression)
extends MapBasedUnaryHigherOrderFunction with CodegenFallback {

@transient val (keyType, valueType, valueContainsNull) = input.dataType match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should be a function in object MapBasedUnaryHigherOrderFunction, we can use it in other map based higher order function just like using ArrayBasedHigherOrderFunction.elementArgumentType.

case MapType(kType, vType, vContainsNull) => (kType, vType, vContainsNull)
case _ =>
val MapType(kType, vType, vContainsNull) = MapType.defaultConcreteType
(kType, vType, vContainsNull)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about extracting this to object MapBasedUnaryHigherOrderFunction like array based one? We'll need this in other map based ones.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I meant something like:

object MapBasedUnaryHigherOrderFunction {

  def keyValueArgumentType(dt: DataType): (DataType, DataType, Boolean) = {
    dt match {
      case MapType(kType, vType, vContainsNull) => (kType, vType, vContainsNull)
      case _ =>
        val MapType(kType, vType, vContainsNull) = MapType.defaultConcreteType
        (kType, vType, vContainsNull)
    }
  }
}

...

case class MapFilter( ... ) {
  ...
  @transient val (keyType, valueType, valueContainsNull) =
    MapBasedUnaryHigherOrderFunction.keyValueArgumentType(input.dataType)
  ...
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, something wrong with introducing object to have util methods?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about:

  1. rename ArrayBasedHigherOrderFunction object to HigherOrderFunction
  2. rename elementArgumentType method to arrayElementArgumentType
  3. move keyValueArgumentType to HigherOrderFunction object and rename to mapKeyValueArgumentType

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, sorry I haven read carefully your comment, now I see what you meant. Yes, I agree unifying them in a Helper object. I am updating accordingly. Thanks.


@transient lazy val (keyVar, valueVar) = {
val args = function.asInstanceOf[LambdaFunction].arguments
(args.head.asInstanceOf[NamedLambdaVariable], args.tail.head.asInstanceOf[NamedLambdaVariable])
}

override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapFilter = {
function match {
case LambdaFunction(_, _, _) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this pattern matching necessary? If so, shouldn't ArrayFilter use it as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, I am removing it, thanks

copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil))
}
}

override def nullable: Boolean = input.nullable

override def eval(input: InternalRow): Any = {
val m = this.input.eval(input).asInstanceOf[MapData]
if (m == null) {
null
} else {
val retKeys = new mutable.ListBuffer[Any]
val retValues = new mutable.ListBuffer[Any]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just curious that ListBuffer is better than ArrayBuffer? If so, should we rewrite in ArrayFilter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better as here we are always appending (and then creating an array from it). Appending a value is always O(1) for ListBuffer, while in ArrayBuffer it is: O(1) if the length of the underlying allocated array is bigger than the number of elements in the list plus one, O(n) otherwise (since it has to create a new array and copy the old one). As the initial value for the length of the underlying array in ArrayBuffer is 16, this means that for output values with more than 16 elements ListBuffer saves at least one copy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I just checked that in ArrayFilter you initialized it with the number of incoming elements. So i think there is no difference in terms of performance, as using an upper value for the number of output elements we are sure no copy is performed.

m.foreach(keyType, valueType, (k, v) => {
keyVar.value.set(k)
valueVar.value.set(v)
if (functionForEval.eval(input).asInstanceOf[Boolean]) {
retKeys += k
retValues += v
}
})
ArrayBasedMapData(retKeys.toArray, retValues.toArray)
}
}

override def dataType: DataType = input.dataType

override def expectingFunctionType: AbstractDataType = BooleanType

override def prettyName: String = "map_filter"
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,53 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(transform(aai, array => Cast(transform(array, plusIndex), StringType)),
Seq("[1, 3, 5]", null, "[4, 6]"))
}

test("MapFilter") {
def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val mt = expr.dataType.asInstanceOf[MapType]
MapFilter(expr, createLambda(mt.keyType, false, mt.valueType, mt.valueContainsNull, f))
}
val mii0 = Literal.create(Map(1 -> 0, 2 -> 10, 3 -> -1),
MapType(IntegerType, IntegerType, valueContainsNull = false))
val mii1 = Literal.create(Map(1 -> null, 2 -> 10, 3 -> null),
MapType(IntegerType, IntegerType, valueContainsNull = true))
val miin = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false))

val kGreaterThanV: (Expression, Expression) => Expression = (k, v) => k > v

checkEvaluation(mapFilter(mii0, kGreaterThanV), Map(1 -> 0, 3 -> -1))
checkEvaluation(mapFilter(mii1, kGreaterThanV), Map())
checkEvaluation(mapFilter(miin, kGreaterThanV), null)

val valueNull: (Expression, Expression) => Expression = (_, v) => v.isNull
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: valueIsNull?


checkEvaluation(mapFilter(mii0, valueNull), Map())
checkEvaluation(mapFilter(mii1, valueNull), Map(1 -> null, 3 -> null))
checkEvaluation(mapFilter(miin, valueNull), null)

val msi0 = Literal.create(Map("abcdf" -> 5, "abc" -> 10, "" -> 0),
MapType(StringType, IntegerType, valueContainsNull = false))
val msi1 = Literal.create(Map("abcdf" -> 5, "abc" -> 10, "" -> null),
MapType(StringType, IntegerType, valueContainsNull = true))
val msin = Literal.create(null, MapType(StringType, IntegerType, valueContainsNull = false))

val isLengthOfKey: (Expression, Expression) => Expression = (k, v) => Length(k) === v

checkEvaluation(mapFilter(msi0, isLengthOfKey), Map("abcdf" -> 5, "" -> 0))
checkEvaluation(mapFilter(msi1, isLengthOfKey), Map("abcdf" -> 5))
checkEvaluation(mapFilter(msin, isLengthOfKey), null)

val mia0 = Literal.create(Map(1 -> Seq(0, 1, 2), 2 -> Seq(10), -3 -> Seq(-1, 0, -2, 3)),
MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = false))
val mia1 = Literal.create(Map(1 -> Seq(0, 1, 2), 2 -> null, -3 -> Seq(-1, 0, -2, 3)),
MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = true))
val mian = Literal.create(
null, MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = false))

val customFunc: (Expression, Expression) => Expression = (k, v) => Size(v) + k > 3

checkEvaluation(mapFilter(mia0, customFunc), Map(1 -> Seq(0, 1, 2)))
checkEvaluation(mapFilter(mia1, customFunc), Map(1 -> Seq(0, 1, 2)))
checkEvaluation(mapFilter(mian, customFunc), null)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1800,6 +1800,52 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type"))
}

test("map_filter") {
val dfInts = Seq(
Map(1 -> 10, 2 -> 20, 3 -> 30),
Map(1 -> -1, 2 -> -2, 3 -> -3),
Map(1 -> 10, 2 -> 5, 3 -> -3)).toDF("m")

checkAnswer(dfInts.selectExpr(
"map_filter(m, (k, v) -> k * 10 = v)", "map_filter(m, (k, v) -> k = -v)"),
Seq(
Row(Map(1 -> 10, 2 -> 20, 3 -> 30), Map()),
Row(Map(), Map(1 -> -1, 2 -> -2, 3 -> -3)),
Row(Map(1 -> 10), Map(3 -> -3))))

val dfComplex = Seq(
Map(1 -> Seq(Some(1)), 2 -> Seq(Some(1), Some(2)), 3 -> Seq(Some(1), Some(2), Some(3))),
Map(1 -> null, 2 -> Seq(Some(-2), Some(-2)), 3 -> Seq[Option[Int]](None))).toDF("m")

checkAnswer(dfComplex.selectExpr(
"map_filter(m, (k, v) -> k = v[0])", "map_filter(m, (k, v) -> k = size(v))"),
Seq(
Row(Map(1 -> Seq(1)), Map(1 -> Seq(1), 2 -> Seq(1, 2), 3 -> Seq(1, 2, 3))),
Row(Map(), Map(2 -> Seq(-2, -2)))))

// Invalid use cases
val df = Seq(
(Map(1 -> "a"), 1),
(Map.empty[Int, String], 2),
(null, 3)
).toDF("s", "i")

val ex1 = intercept[AnalysisException] {
df.selectExpr("map_filter(s, (x, y, z) -> x + y + z)")
}
assert(ex1.getMessage.contains("The number of lambda function arguments '3' does not match"))

val ex2 = intercept[AnalysisException] {
df.selectExpr("map_filter(s, x -> x)")
}
assert(ex2.getMessage.contains("The number of lambda function arguments '1' does not match"))

val ex3 = intercept[AnalysisException] {
df.selectExpr("map_filter(i, (k, v) -> k > v)")
}
assert(ex3.getMessage.contains("data type mismatch: argument 1 requires map type"))
}

private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
Expand Down