Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ object PythonUDF {
}

/**
* A serialized version of a Python lambda function.
* A serialized version of a Python lambda function. This is a special expression, which needs a
* dedicated physical operator to execute it, and thus can't be pushed down to data sources.
*/
case class PythonUDF(
name: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
case _: Repartition => true
case _: ScriptTransformation => true
case _: Sort => true
case _: BatchEvalPython => true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my benefit, would you mind explain what does canPushThrough define? Are these nodes that a projection and/or filter can be pushed through?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This defines the nodes that we can push filters through.

case _: ArrowEvalPython => true
case _ => false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF}

/**
* FlatMap groups using an udf: pandas.Dataframe -> pandas.DataFrame.
Expand All @@ -38,3 +38,30 @@ case class FlatMapGroupsInPandas(
*/
override val producedAttributes = AttributeSet(output)
}

trait BaseEvalPython extends UnaryNode {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is producedAttributes missing from this? Previously, BatchEvalPython and ArrowEvalPython have it defined.

Copy link
Contributor Author

@cloud-fan cloud-fan May 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a problem I want to address later. I think producedAttributes makes no sense. It's only used to define missingInput, but we can overwrite reference to do the same thing.

More specifically, if reference is wrongly implemented, column pruning will be broken. If producedAttributes is not implemented, nothing serious will happen.


def udfs: Seq[PythonUDF]

def resultAttrs: Seq[Attribute]

override def output: Seq[Attribute] = child.output ++ resultAttrs
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to work with the ColumnPruning rule, the python-eval node should be able to dynamically update its output if the child's output updated.


override def references: AttributeSet = AttributeSet(udfs.flatMap(_.references))
Copy link
Contributor Author

@cloud-fan cloud-fan May 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to work with the ColumnPruning and PushDownPredicate rule, we must correctly implement the references method. resultAttrs are definitely not references.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If references only cover references in udfs, will some output attributes from child that aren't referred by udfs be pruned from BaseEvalPython?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, and this is "column pruning".

}

/**
* A logical plan that evaluates a [[PythonUDF]]
*/
case class BatchEvalPython(
udfs: Seq[PythonUDF],
resultAttrs: Seq[Attribute],
child: LogicalPlan) extends BaseEvalPython

/**
* A logical plan that evaluates a [[PythonUDF]] with Apache Arrow.
*/
case class ArrowEvalPython(
udfs: Seq[PythonUDF],
resultAttrs: Seq[Attribute],
child: LogicalPlan) extends BaseEvalPython
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution

import org.apache.spark.sql.ExperimentalMethods
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.optimizer.{ColumnPruning, Optimizer, PushDownPredicate, RemoveNoopOperators}
import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions
import org.apache.spark.sql.execution.datasources.SchemaPruning
import org.apache.spark.sql.execution.python.{ExtractPythonUDFFromAggregate, ExtractPythonUDFs}
Expand All @@ -32,14 +32,21 @@ class SparkOptimizer(
override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+
Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+
Batch("Extract Python UDFs", Once,
Seq(ExtractPythonUDFFromAggregate, ExtractPythonUDFs): _*) :+
ExtractPythonUDFFromAggregate,
ExtractPythonUDFs,
// The eval-python node may be between Project/Filter and the scan node, which breaks
// column pruning and filter push-down. Here we rerun the related optimizer rules.
ColumnPruning,
PushDownPredicate,
RemoveNoopOperators) :+
Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+
Batch("Schema Pruning", Once, SchemaPruning)) ++
postHocOptimizationBatches :+
Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)

override def nonExcludableRules: Seq[String] =
super.nonExcludableRules :+ ExtractPythonUDFFromAggregate.ruleName
override def nonExcludableRules: Seq[String] = super.nonExcludableRules :+
ExtractPythonUDFFromAggregate.ruleName :+
ExtractPythonUDFs.ruleName
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is ExtractPythonUDFs newly added to nonExcludableRules? Is it also for the fix? Or just it should be there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be there. We can do it in another PR, but since I'm touching this file, I just fixed it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Just out of curiosity.


/**
* Optimization batches that are executed before the regular optimization batches (also before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.arrow.ArrowUtils
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -57,21 +56,11 @@ private[spark] class BatchIterator[T](iter: Iterator[T], batchSize: Int)
}
}

/**
* A logical plan that evaluates a [[PythonUDF]].
*/
case class ArrowEvalPython(
udfs: Seq[PythonUDF],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
}

/**
* A physical plan that evaluates a [[PythonUDF]].
*/
case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
extends EvalPythonExec(udfs, output, child) {
case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan)
extends EvalPythonExec(udfs, resultAttrs, child) {

private val batchSize = conf.arrowMaxRecordsPerBatch
private val sessionLocalTimeZone = conf.sessionLocalTimeZone
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,14 @@ import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{StructField, StructType}

/**
* A logical plan that evaluates a [[PythonUDF]]
*/
case class BatchEvalPython(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for moving out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've to move out because I need to access them in PushdownPredicate, which is in catalyst module.

udfs: Seq[PythonUDF],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
}

/**
* A physical plan that evaluates a [[PythonUDF]]
*/
case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
extends EvalPythonExec(udfs, output, child) {
case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan)
extends EvalPythonExec(udfs, resultAttrs, child) {

protected override def evaluate(
funcs: Seq[ChainedPythonFunctions],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ import org.apache.spark.util.Utils
* there should be always some rows buffered in the socket or Python process, so the pulling from
* RowQueue ALWAYS happened after pushing into it.
*/
abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
abstract class EvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan)
extends UnaryExecNode {

override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
override def output: Seq[Attribute] = child.output ++ resultAttrs

override def references: AttributeSet = AttributeSet(udfs.flatMap(_.references))

private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
udf.children match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,21 +158,9 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper {
// If there aren't any, we are done.
plan
} else {
val inputsForPlan = plan.references ++ plan.outputSet
val prunedChildren = plan.children.map { child =>
val allNeededOutput = inputsForPlan.intersect(child.outputSet).toSeq
if (allNeededOutput.length != child.output.length) {
Project(allNeededOutput, child)
} else {
child
}
}
val planWithNewChildren = plan.withNewChildren(prunedChildren)

val attributeMap = mutable.HashMap[PythonUDF, Expression]()
val splitFilter = trySplitFilter(planWithNewChildren)
// Rewrite the child that has the input required for the UDF
val newChildren = splitFilter.children.map { child =>
val newChildren = plan.children.map { child =>
// Pick the UDF we are going to evaluate
val validUdfs = udfs.filter { udf =>
// Check to make sure that the UDF can be evaluated with only the input of this child.
Expand All @@ -191,9 +179,9 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper {
_.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF
) match {
case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty =>
ArrowEvalPython(vectorizedUdfs, child.output ++ resultAttrs, child)
ArrowEvalPython(vectorizedUdfs, resultAttrs, child)
case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty =>
BatchEvalPython(plainUdfs, child.output ++ resultAttrs, child)
BatchEvalPython(plainUdfs, resultAttrs, child)
case _ =>
throw new AnalysisException(
"Expected either Scalar Pandas UDFs or Batched UDFs but got both")
Expand All @@ -211,7 +199,7 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper {
sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
}

val rewritten = splitFilter.withNewChildren(newChildren).transformExpressions {
val rewritten = plan.withNewChildren(newChildren).transformExpressions {
case p: PythonUDF if attributeMap.contains(p) =>
attributeMap(p)
}
Expand All @@ -226,22 +214,4 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper {
}
}
}

// Split the original FilterExec to two FilterExecs. Only push down the first few predicates
// that are all deterministic.
private def trySplitFilter(plan: LogicalPlan): LogicalPlan = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain a little why this is no longer needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quote from the PR description

There are some hacks in the ExtractPythonUDFs rule, to duplicate the column pruning and filter pushdown logic. However, it has some bugs as demonstrated in the new test case(only column pruning is broken). This PR removes the hacks and re-apply the column pruning and filter pushdown rules explicitly.

plan match {
case filter: Filter =>
val (candidates, nonDeterministic) =
splitConjunctivePredicates(filter.condition).partition(_.deterministic)
val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_))
if (pushDown.nonEmpty) {
val newChild = Filter(pushDown.reduceLeft(And), filter.child)
Filter((rest ++ nonDeterministic).reduceLeft(And), newChild)
} else {
filter
}
case o => o
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

package org.apache.spark.sql.execution.python

import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.test.SharedSQLContext

class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSQLContext {
import testImplicits.newProductEncoder
import testImplicits.localSeqToDatasetHolder
import testImplicits._

val batchedPythonUDF = new MyDummyPythonUDF
val scalarPandasUDF = new MyDummyScalarPandasUDF
Expand Down Expand Up @@ -88,5 +87,40 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSQLContext {
assert(pythonEvalNodes.size == 2)
assert(arrowEvalNodes.size == 2)
}

test("Python UDF should not break column pruning/filter pushdown") {
withTempPath { f =>
spark.range(10).select($"id".as("a"), $"id".as("b"))
.write.parquet(f.getCanonicalPath)
val df = spark.read.parquet(f.getCanonicalPath)

withClue("column pruning") {
val query = df.filter(batchedPythonUDF($"a")).select($"a")

val pythonEvalNodes = collectBatchExec(query.queryExecution.executedPlan)
assert(pythonEvalNodes.length == 1)

val scanNodes = query.queryExecution.executedPlan.collect {
case scan: FileSourceScanExec => scan
}
assert(scanNodes.length == 1)
assert(scanNodes.head.output.map(_.name) == Seq("a"))
}

withClue("filter pushdown") {
val query = df.filter($"a" > 1 && batchedPythonUDF($"a"))
val pythonEvalNodes = collectBatchExec(query.queryExecution.executedPlan)
assert(pythonEvalNodes.length == 1)

val scanNodes = query.queryExecution.executedPlan.collect {
case scan: FileSourceScanExec => scan
}
assert(scanNodes.length == 1)
// 'a is not null and 'a > 1
assert(scanNodes.head.dataFilters.length == 2)
assert(scanNodes.head.dataFilters.flatMap(_.references.map(_.name)).distinct == Seq("a"))
}
}
}
}