Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,15 @@ def test_broadcast_in_udf(self):
[res] = self.spark.sql("SELECT MYUDF('')").collect()
self.assertEqual("", res[0])

def test_udf_with_filter_function(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
from pyspark.sql.functions import udf, col
from pyspark.sql.types import BooleanType

my_filter = udf(lambda a: a < 2, BooleanType())
sel = df.select(col("key"), col("value")).filter((my_filter(col("key"))) & (df.value < "2"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this test fail before this PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. This case works well.

self.assertEqual(sel.collect(), [Row(key=1, value='1')])

def test_udf_with_aggregate_function(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
from pyspark.sql.functions import udf, col, sum
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{FilterExec, SparkPlan}


/**
Expand Down Expand Up @@ -111,7 +111,15 @@ object ExtractPythonUDFs extends Rule[SparkPlan] {
}

def apply(plan: SparkPlan): SparkPlan = plan transformUp {
case plan: SparkPlan => extract(plan)
case plan: SparkPlan =>
val newPlan = extract(plan)
Copy link
Member Author

@gatorsmile gatorsmile Dec 7, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extract is a recursive function. That is why I did not move the following logics into extract for performance reasons.

if (newPlan != plan) {
// Found and build BatchEvalPythonExec, and then push FilterExec
// through BatchEvalPythonExec
PushPredicateThroughBatchEvalPython.apply(newPlan)
} else {
plan
}
}

/**
Expand Down Expand Up @@ -166,3 +174,40 @@ object ExtractPythonUDFs extends Rule[SparkPlan] {
}
}
}

// This rule is to push deterministic predicates through BatchEvalPythonExec
object PushPredicateThroughBatchEvalPython extends Rule[SparkPlan] with PredicateHelper {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of codes are from the optimizer rule PushDownPredicate. Not sure whether we should combine them. You know, this rule is for SparkPlan.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having a predicate-pushdown rule for SparkPlan sounds bad, can we try to do this in extract()? for example

val splittedFilter = trySplitFilter(plan)
val newChildren = splittedFilter.children.map { child =>
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! The new commit does it.

def apply(plan: SparkPlan): SparkPlan = plan transform {
case filter @ FilterExec(_, child: BatchEvalPythonExec)
if child.expressions.forall(_.deterministic) =>
pushDownPredicate(filter, child.child) { predicate =>
child.withNewChildren(Seq(FilterExec(predicate, child.child)))
}
}

private def pushDownPredicate(
filter: FilterExec,
grandchild: SparkPlan)(insertFilter: Expression => SparkPlan): SparkPlan = {
// Only push down the predicates that is deterministic and all the referenced attributes
// come from grandchild.
val (candidates, containingNonDeterministic) =
splitConjunctivePredicates(filter.condition).span(_.deterministic)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. Indentation?


val (pushDown, rest) = candidates.partition { cond =>
cond.references.subsetOf(grandchild.outputSet)
}

val stayUp = rest ++ containingNonDeterministic

if (pushDown.nonEmpty) {
val newChild = insertFilter(pushDown.reduceLeft(And))
if (stayUp.nonEmpty) {
FilterExec(stayUp.reduceLeft(And), newChild)
} else {
newChild
}
} else {
filter
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.python

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.api.python.PythonFunction
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, In}
import org.apache.spark.sql.execution.{FilterExec, SparkPlanTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.BooleanType

class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext {
import testImplicits.newProductEncoder
import testImplicits.localSeqToDatasetHolder
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indentation?


override def beforeAll(): Unit = {
super.beforeAll()
spark.udf.registerPython("dummyPythonUDF", new MyDummyPythonUDF)
}

override def afterAll(): Unit = {
spark.sessionState.functionRegistry.dropFunction("dummyPythonUDF")
super.afterAll()
}

test("Python UDF: push down deterministic FilterExec predicates") {
val df = Seq(("Hello", 4)).toDF("a", "b")
.where("dummyPythonUDF(b) and dummyPythonUDF(a) and a in (3, 4)")
val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
case f @ FilterExec(And(_: AttributeReference, _: AttributeReference), _) => f
case b: BatchEvalPythonExec => b
case f @ FilterExec(_: In, _) => f
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The physical plan has a few hidden nodes that are not shown in Explain output. Thus, I did not compare the result with the expected tree structure.

}
assert(qualifiedPlanNodes.size == 3)
}

test("Nested Python UDF: push down deterministic FilterExec predicates") {
val df = Seq(("Hello", 4)).toDF("a", "b")
.where("dummyPythonUDF(a, dummyPythonUDF(a, b)) and a in (3, 4)")
val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
case f @ FilterExec(_: AttributeReference, _) => f
case b: BatchEvalPythonExec => b
case f @ FilterExec(_: In, _) => f
}
assert(qualifiedPlanNodes.size == 4)
}

test("Python UDF: no push down on non-deterministic FilterExec predicates") {
val df = Seq(("Hello", 4)).toDF("a", "b")
.where("dummyPythonUDF(a) and rand() > 3")
val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
case f: FilterExec => f
case b: BatchEvalPythonExec => b
}
assert(qualifiedPlanNodes.size == 2)
}

test("Python UDF refers to the attributes from more than one child") {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case is not directly related to this PR. In the future, we need to add more unit test cases in Scala side for verifying BatchEvalPythonExec for improving the test case coverage.

val df = Seq(("Hello", 4)).toDF("a", "b")
val df2 = Seq(("Hello", 4)).toDF("c", "d")
val joinDF = df.join(df2).where("dummyPythonUDF(a, c) == dummyPythonUDF(d, c)")

val e = intercept[RuntimeException] {
joinDF.queryExecution.executedPlan
}.getMessage
assert(Seq("Invalid PythonUDF dummyUDF", "requires attributes from more than one child")
.forall(e.contains))
}
}

// This Python UDF is dummy and just for testing. Unable to execute.
class DummyUDF extends PythonFunction(
command = Array[Byte](),
envVars = Map("" -> "").asJava,
pythonIncludes = ArrayBuffer("").asJava,
pythonExec = "",
pythonVer = "",
broadcastVars = null,
accumulator = null)

class MyDummyPythonUDF
extends UserDefinedPythonFunction(name = "dummyUDF", func = new DummyUDF, dataType = BooleanType)