Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,15 @@ def test_broadcast_in_udf(self):
[res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
self.assertEqual("", res[0])

def test_udf_with_aggregate_function(self):
df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
from pyspark.sql.functions import udf, col
from pyspark.sql.types import BooleanType

my_filter = udf(lambda a: a == 1, BooleanType())
sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
self.assertEqual(sel.collect(), [Row(key=1)])

def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.sqlCtx.read.json(rdd)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {

/** A sequence of rules that will be applied in order to the physical plan before execution. */
protected def preparations: Seq[Rule[SparkPlan]] = Seq(
python.ExtractPythonUDFs,
PlanSubqueries(sqlContext),
EnsureRequirements(sqlContext.conf),
CollapseCodegenStages(sqlContext.conf),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.RepartitionByExpression(expressions, child, nPartitions) =>
exchange.ShuffleExchange(HashPartitioning(
expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
case e @ python.EvaluatePython(udfs, child, _) =>
python.BatchPythonEvaluation(udfs, e.output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil
case BroadcastHint(child) => planLater(child) :: Nil
case _ => Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,30 +35,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, Generic
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/**
* Evaluates a list of [[PythonUDF]], appending the result to the end of the input tuple.
*/
case class EvaluatePython(
udfs: Seq[PythonUDF],
child: LogicalPlan,
resultAttribute: Seq[AttributeReference])
extends logical.UnaryNode {

def output: Seq[Attribute] = child.output ++ resultAttribute

// References should not include the produced attribute.
override def references: AttributeSet = AttributeSet(udfs.flatMap(_.references))
}


object EvaluatePython {
def apply(udfs: Seq[PythonUDF], child: LogicalPlan): EvaluatePython = {
val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
AttributeReference(s"pythonUDF$i", u.dataType)()
}
new EvaluatePython(udfs, child, resultAttrs)
}

def takeAndServe(df: DataFrame, n: Int): Int = {
registerPicklers()
df.withNewExecutionId {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ package org.apache.spark.sql.execution.python

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.SparkPlan

/**
* Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
Expand All @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
* This has the limitation that the input to the Python UDF is not allowed include attributes from
* multiple child operators.
*/
private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {

private def hasPythonUDF(e: Expression): Boolean = {
e.find(_.isInstanceOf[PythonUDF]).isDefined
Expand All @@ -54,49 +54,61 @@ private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
case e => e.children.flatMap(collectEvaluatableUDF)
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// Skip EvaluatePython nodes.
case plan: EvaluatePython => plan
def apply(plan: SparkPlan): SparkPlan = plan transformUp {
case plan: SparkPlan => extract(plan)
}

case plan: LogicalPlan if plan.resolved =>
// Extract any PythonUDFs from the current operator.
val udfs = plan.expressions.flatMap(collectEvaluatableUDF).filter(_.resolved)
if (udfs.isEmpty) {
// If there aren't any, we are done.
plan
} else {
val attributeMap = mutable.HashMap[PythonUDF, Expression]()
// Rewrite the child that has the input required for the UDF
val newChildren = plan.children.map { child =>
// Pick the UDF we are going to evaluate
val validUdfs = udfs.filter { case udf =>
// Check to make sure that the UDF can be evaluated with only the input of this child.
udf.references.subsetOf(child.outputSet)
}
if (validUdfs.nonEmpty) {
val evaluation = EvaluatePython(validUdfs, child)
attributeMap ++= validUdfs.zip(evaluation.resultAttribute)
evaluation
} else {
child
}
/**
* Extract all the PythonUDFs from the current operator.
*/
def extract(plan: SparkPlan): SparkPlan = {
val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
if (udfs.isEmpty) {
// If there aren't any, we are done.
plan
} else {
val attributeMap = mutable.HashMap[PythonUDF, Expression]()
// Rewrite the child that has the input required for the UDF
val newChildren = plan.children.map { child =>
// Pick the UDF we are going to evaluate
val validUdfs = udfs.filter { case udf =>
// Check to make sure that the UDF can be evaluated with only the input of this child.
udf.references.subsetOf(child.outputSet)
}
// Other cases are disallowed as they are ambiguous or would require a cartesian
// product.
udfs.filterNot(attributeMap.contains).foreach { udf =>
if (udf.references.subsetOf(plan.inputSet)) {
sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
} else {
sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.")
if (validUdfs.nonEmpty) {
val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
AttributeReference(s"pythonUDF$i", u.dataType)()
}
val evaluation = BatchPythonEvaluation(validUdfs, child.output ++ resultAttrs, child)
attributeMap ++= validUdfs.zip(resultAttrs)
evaluation
} else {
child
}
}
// Other cases are disallowed as they are ambiguous or would require a cartesian
// product.
udfs.filterNot(attributeMap.contains).foreach { udf =>
if (udf.references.subsetOf(plan.inputSet)) {
sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
} else {
sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.")
}
}

val rewritten = plan.transformExpressions {
case p: PythonUDF if attributeMap.contains(p) =>
attributeMap(p)
}.withNewChildren(newChildren)

// extract remaining python UDFs recursively
val newPlan = extract(rewritten)
if (newPlan.output != plan.output) {
// Trim away the new UDF value if it was only used for filtering or something.
logical.Project(
plan.output,
plan.transformExpressions {
case p: PythonUDF if attributeMap.contains(p) => attributeMap(p)
}.withNewChildren(newChildren))
execution.Project(plan.output, newPlan)
} else {
newPlan
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql.execution.python

import org.apache.spark.api.python.PythonFunction
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable}
import org.apache.spark.sql.types.DataType

Expand All @@ -30,7 +29,7 @@ case class PythonUDF(
func: PythonFunction,
dataType: DataType,
children: Seq[Expression])
extends Expression with Unevaluable with NonSQLExpression with Logging {
extends Expression with Unevaluable with NonSQLExpression {

override def toString: String = s"$name(${children.mkString(", ")})"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ private[sql] class SessionState(ctx: SQLContext) {
lazy val analyzer: Analyzer = {
new Analyzer(catalog, functionRegistry, conf) {
override val extendedResolutionRules =
python.ExtractPythonUDFs ::
PreInsertCastAndRename ::
DataSourceAnalysis ::
(if (conf.runSQLOnFile) new ResolveDataSource(ctx) :: Nil else Nil)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx)
catalog.OrcConversions ::
catalog.CreateTables ::
catalog.PreInsertionCasts ::
python.ExtractPythonUDFs ::
PreInsertCastAndRename ::
DataSourceAnalysis ::
(if (conf.runSQLOnFile) new ResolveDataSource(ctx) :: Nil else Nil)
Expand Down