Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,16 @@
# If Arrow version requirement is not satisfied, skip related tests.
_pyarrow_requirement_message = _exception_message(e)

_test_not_compiled_message = None
try:
from pyspark.sql.utils import require_test_compiled
require_test_compiled()
except Exception as e:
_test_not_compiled_message = _exception_message(e)

_have_pandas = _pandas_requirement_message is None
_have_pyarrow = _pyarrow_requirement_message is None
_test_compiled = _test_not_compiled_message is None

from pyspark import SparkContext
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
Expand Down Expand Up @@ -3367,6 +3375,47 @@ def test_ignore_column_of_all_nulls(self):
finally:
shutil.rmtree(path)

# SPARK-24721
@unittest.skipIf(not _test_compiled, _test_not_compiled_message)
def test_datasource_with_udf(self):
from pyspark.sql.functions import udf, lit, col

path = tempfile.mkdtemp()
shutil.rmtree(path)

try:
self.spark.range(1).write.mode("overwrite").format('csv').save(path)
filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i')
datasource_df = self.spark.read \
.format("org.apache.spark.sql.sources.SimpleScanSource") \
.option('from', 0).option('to', 1).load().toDF('i')
datasource_v2_df = self.spark.read \
.format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \
.load().toDF('i', 'j')

c1 = udf(lambda x: x + 1, 'int')(lit(1))
c2 = udf(lambda x: x + 1, 'int')(col('i'))

f1 = udf(lambda x: False, 'boolean')(lit(1))
f2 = udf(lambda x: False, 'boolean')(col('i'))

for df in [filesource_df, datasource_df, datasource_v2_df]:
result = df.withColumn('c', c1)
expected = df.withColumn('c', lit(2))
self.assertEquals(expected.collect(), result.collect())

for df in [filesource_df, datasource_df, datasource_v2_df]:
result = df.withColumn('c', c2)
expected = df.withColumn('c', col('i') + 1)
self.assertEquals(expected.collect(), result.collect())

for df in [filesource_df, datasource_df, datasource_v2_df]:
for f in [f1, f2]:
result = df.filter(f)
self.assertEquals(0, result.count())
finally:
shutil.rmtree(path)

def test_repr_behaviors(self):
import re
pattern = re.compile(r'^ *\|', re.MULTILINE)
Expand Down Expand Up @@ -5269,6 +5318,51 @@ def f3(x):

self.assertEquals(expected.collect(), df1.collect())

# SPARK-24721
@unittest.skipIf(not _test_compiled, _test_not_compiled_message)
def test_datasource_with_udf(self):
# Same as SQLTests.test_datasource_with_udf, but with Pandas UDF
# This needs to a separate test because Arrow dependency is optional
import pandas as pd
import numpy as np
from pyspark.sql.functions import pandas_udf, lit, col

path = tempfile.mkdtemp()
shutil.rmtree(path)

try:
self.spark.range(1).write.mode("overwrite").format('csv').save(path)
filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i')
datasource_df = self.spark.read \
.format("org.apache.spark.sql.sources.SimpleScanSource") \
.option('from', 0).option('to', 1).load().toDF('i')
datasource_v2_df = self.spark.read \
.format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \
.load().toDF('i', 'j')

c1 = pandas_udf(lambda x: x + 1, 'int')(lit(1))
c2 = pandas_udf(lambda x: x + 1, 'int')(col('i'))

f1 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(lit(1))
f2 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(col('i'))

for df in [filesource_df, datasource_df, datasource_v2_df]:
result = df.withColumn('c', c1)
expected = df.withColumn('c', lit(2))
self.assertEquals(expected.collect(), result.collect())

for df in [filesource_df, datasource_df, datasource_v2_df]:
result = df.withColumn('c', c2)
expected = df.withColumn('c', col('i') + 1)
self.assertEquals(expected.collect(), result.collect())

for df in [filesource_df, datasource_df, datasource_v2_df]:
for f in [f1, f2]:
result = df.filter(f)
self.assertEquals(0, result.count())
finally:
shutil.rmtree(path)


@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
Expand Down
19 changes: 19 additions & 0 deletions python/pyspark/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,25 @@ def require_minimum_pyarrow_version():
"your version was %s." % (minimum_pyarrow_version, pyarrow.__version__))


def require_test_compiled():
""" Raise Exception if test classes are not compiled
"""
import os
import glob
try:
spark_home = os.environ['SPARK_HOME']
except KeyError:
raise RuntimeError('SPARK_HOME is not defined in environment')

test_class_path = os.path.join(
spark_home, 'sql', 'core', 'target', '*', 'test-classes')
paths = glob.glob(test_class_path)

if len(paths) == 0:
raise RuntimeError(
"%s doesn't exist. Spark sql test classes are not compiled." % test_class_path)


class ForeachBatchFunction(object):
"""
This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {

/** A sequence of rules that will be applied in order to the physical plan before execution. */
protected def preparations: Seq[Rule[SparkPlan]] = Seq(
python.ExtractPythonUDFs,
PlanSubqueries(sparkSession),
EnsureRequirements(sparkSession.sessionState.conf),
CollapseCodegenStages(sparkSession.sessionState.conf),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions
import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaPruning
import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate
import org.apache.spark.sql.execution.python.{ExtractPythonUDFFromAggregate, ExtractPythonUDFs}

class SparkOptimizer(
catalog: SessionCatalog,
Expand All @@ -31,7 +31,8 @@ class SparkOptimizer(

override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+
Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+
Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
Batch("Extract Python UDFs", Once,
Seq(ExtractPythonUDFFromAggregate, ExtractPythonUDFs): _*) :+
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks weird to add this rule in our optimizer batch. We need at least some comments to explain the reason in the code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but we already have ExtractPythonUDFFromAggregate here...

Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+
Batch("Parquet Schema Pruning", Once, ParquetSchemaPruning)) ++
postHocOptimizationBatches :+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class SparkPlanner(
override def strategies: Seq[Strategy] =
experimentalMethods.extraStrategies ++
extraPlanningStrategies ++ (
PythonEvals ::
DataSourceV2Strategy ::
FileSourceStrategy ::
DataSourceStrategy(conf) ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableS
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.execution.python._
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.MemoryPlanV2
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -517,6 +518,20 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}

/**
* Strategy to convert EvalPython logical operator to physical operator.
*/
object PythonEvals extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ArrowEvalPython(udfs, output, child) =>
ArrowEvalPythonExec(udfs, output, planLater(child)) :: Nil
case BatchEvalPython(udfs, output, child) =>
BatchEvalPythonExec(udfs, output, planLater(child)) :: Nil
case _ =>
Nil
}
}

object BasicOperators extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.arrow.ArrowUtils
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -57,7 +58,13 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int)
}

/**
* A physical plan that evaluates a [[PythonUDF]],
* A logical plan that evaluates a [[PythonUDF]].
*/
case class ArrowEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
extends UnaryNode

/**
* A physical plan that evaluates a [[PythonUDF]].
*/
case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
extends EvalPythonExec(udfs, output, child) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,16 @@ import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{StructField, StructType}

/**
* A logical plan that evaluates a [[PythonUDF]]
*/
case class BatchEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
extends UnaryNode

/**
* A physical plan that evaluates a [[PythonUDF]]
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}


/**
Expand Down Expand Up @@ -93,7 +92,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
* This has the limitation that the input to the Python UDF is not allowed include attributes from
* multiple child operators.
*/
object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper {

private type EvalType = Int
private type EvalTypeChecker = EvalType => Boolean
Expand Down Expand Up @@ -132,14 +131,14 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
expressions.flatMap(collectEvaluableUDFs)
}

def apply(plan: SparkPlan): SparkPlan = plan transformUp {
case plan: SparkPlan => extract(plan)
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case plan: LogicalPlan => extract(plan)
}

/**
* Extract all the PythonUDFs from the current operator and evaluate them before the operator.
*/
private def extract(plan: SparkPlan): SparkPlan = {
private def extract(plan: LogicalPlan): LogicalPlan = {
val udfs = collectEvaluableUDFsFromExpressions(plan.expressions)
// ignore the PythonUDF that come from second/third aggregate, which is not used
.filter(udf => udf.references.subsetOf(plan.inputSet))
Expand All @@ -151,7 +150,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
val prunedChildren = plan.children.map { child =>
val allNeededOutput = inputsForPlan.intersect(child.outputSet).toSeq
if (allNeededOutput.length != child.output.length) {
ProjectExec(allNeededOutput, child)
Project(allNeededOutput, child)
} else {
child
}
Expand Down Expand Up @@ -180,9 +179,9 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
_.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF
) match {
case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty =>
ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child)
ArrowEvalPython(vectorizedUdfs, child.output ++ resultAttrs, child)
case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty =>
BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child)
BatchEvalPython(plainUdfs, child.output ++ resultAttrs, child)
case _ =>
throw new AnalysisException(
"Expected either Scalar Pandas UDFs or Batched UDFs but got both")
Expand All @@ -209,7 +208,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
val newPlan = extract(rewritten)
if (newPlan.output != plan.output) {
// Trim away the new UDF value if it was only used for filtering or something.
ProjectExec(plan.output, newPlan)
Project(plan.output, newPlan)
} else {
newPlan
}
Expand All @@ -218,15 +217,15 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {

// Split the original FilterExec to two FilterExecs. Only push down the first few predicates
// that are all deterministic.
private def trySplitFilter(plan: SparkPlan): SparkPlan = {
private def trySplitFilter(plan: LogicalPlan): LogicalPlan = {
plan match {
case filter: FilterExec =>
case filter: Filter =>
val (candidates, nonDeterministic) =
splitConjunctivePredicates(filter.condition).partition(_.deterministic)
val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_))
if (pushDown.nonEmpty) {
val newChild = FilterExec(pushDown.reduceLeft(And), filter.child)
FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild)
val newChild = Filter(pushDown.reduceLeft(And), filter.child)
Filter((rest ++ nonDeterministic).reduceLeft(And), newChild)
} else {
filter
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import org.apache.spark.sql.types._

class DefaultSource extends SimpleScanSource

// This class is used by pyspark tests. If this class is modified/moved, make sure pyspark
// tests still pass.
class SimpleScanSource extends RelationProvider {
override def createRelation(
sqlContext: SQLContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,8 @@ class SimpleSinglePartitionSource extends DataSourceV2 with BatchReadSupportProv
}
}


// This class is used by pyspark tests. If this class is modified/moved, make sure pyspark
// tests still pass.
class SimpleDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider {

class ReadSupport extends SimpleReadSupport {
Expand Down