Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions python/pyspark/sql/tests/test_pandas_cogrouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import unittest
import sys

from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf, PandasUDFType
from pyspark.sql.types import DoubleType, StructType, StructField
from pyspark.sql.functions import array, explode, col, lit, udf, sum, rand, pandas_udf, \
PandasUDFType
from pyspark.sql.types import DoubleType, StructType, StructField, Row
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
from pyspark.testing.utils import QuietTest
Expand Down Expand Up @@ -193,6 +194,41 @@ def test_wrong_args(self):
left.groupby('id').cogroup(right.groupby('id')) \
.applyInPandas(lambda: 1, StructType([StructField("d", DoubleType())]))

def test_case_insensitive_grouping_column(self):
# SPARK-31915: case-insensitive grouping column should work.
df1 = self.spark.createDataFrame([(1, 1)], ("column", "value"))

row = df1.groupby("ColUmn").cogroup(
df1.groupby("COLUMN")
).applyInPandas(lambda r, l: r + l, "column long, value long").first()
self.assertEquals(row.asDict(), Row(column=2, value=2).asDict())

df2 = self.spark.createDataFrame([(1, 1)], ("column", "value"))

row = df1.groupby("ColUmn").cogroup(
df2.groupby("COLUMN")
).applyInPandas(lambda r, l: r + l, "column long, value long").first()
self.assertEquals(row.asDict(), Row(column=2, value=2).asDict())

def test_nondeterministic_grouping_column(self):
def my_pandas_udf(key, left, right):
assert left.score.iloc[0] == right.score.iloc[0] == key
return left

df = self.spark.createDataFrame([[1], [3], [5]], ["column"])
df = df.select(col("column"), rand(seed=42).alias("score"))
df.groupby(rand(seed=42)).cogroup(
df.groupby(rand(seed=42))
).applyInPandas(
my_pandas_udf, schema="column integer, score float").count()

df2 = self.spark.createDataFrame([[1], [3], [5]], ["column"])
df2 = df2.select(col("column"), rand(seed=42).alias("score"))
df.groupby(rand(seed=42)).cogroup(
df2.groupby(rand(seed=42))
).applyInPandas(
my_pandas_udf, schema="column integer, score float").count()

@staticmethod
def _test_with_key(left, right, isLeft):

Expand Down
22 changes: 21 additions & 1 deletion python/pyspark/sql/tests/test_pandas_grouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from pyspark.sql import Row
from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf, PandasUDFType, \
window
window, rand
from pyspark.sql.types import *
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
Expand Down Expand Up @@ -587,6 +587,26 @@ def f(key, pdf):
# Check that all group and window_range values from udf matched expected
self.assertTrue(all([r[0] for r in result]))

def test_case_insensitive_grouping_column(self):
# SPARK-31915: case-insensitive grouping column should work.
def my_pandas_udf(pdf):
return pdf.assign(score=0.5)

df = self.spark.createDataFrame([[1, 1]], ["column", "score"])
row = df.groupby('COLUMN').applyInPandas(
my_pandas_udf, schema="column integer, score float").first()
self.assertEquals(row.asDict(), Row(column=1, score=0.5).asDict())

def test_nondeterministic_grouping_column(self):
def my_pandas_udf(key, pdf):
assert pdf.score.iloc[0] == key
return pdf

df = self.spark.createDataFrame([[1], [3], [5]], ["column"])
df = df.select(col("column"), rand(seed=42).alias("score"))
df.groupby(rand(seed=42)).applyInPandas(
my_pandas_udf, schema="column integer, score float").count()


if __name__ == "__main__":
from pyspark.sql.tests.test_pandas_grouped_map import *
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1374,12 +1374,10 @@ class Analyzer(
// To resolve duplicate expression IDs for Join and Intersect
case j @ Join(left, right, _, _, _) if !j.duplicateResolved =>
j.copy(right = dedupRight(left, right))
case f @ FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, _, _, left, right) =>
val leftRes = leftAttributes
.map(x => resolveExpressionBottomUp(x, left).asInstanceOf[Attribute])
val rightRes = rightAttributes
.map(x => resolveExpressionBottomUp(x, right).asInstanceOf[Attribute])
f.copy(leftAttributes = leftRes, rightAttributes = rightRes)
case f @ FlatMapCoGroupsInPandas(leftExprs, rightExprs, _, _, left, right) =>
f.copy(
leftExprs = leftExprs.map(x => resolveExpressionBottomUp(x, left)),
rightExprs = rightExprs.map(x => resolveExpressionBottomUp(x, right)))
// intersect/except will be rewritten to join at the begininng of optimizer. Here we need to
// deduplicate the right side plan, so that we won't produce an invalid self-join later.
case i @ Intersect(left, right, _) if !i.duplicateResolved =>
Expand Down Expand Up @@ -2755,11 +2753,23 @@ class Analyzer(
case f: Filter => f

case a: Aggregate if a.groupingExpressions.exists(!_.deterministic) =>
val nondeterToAttr = getNondeterToAttr(a.groupingExpressions)
val newChild = Project(a.child.output ++ nondeterToAttr.values, a.child)
a.transformExpressions { case e =>
nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e)
}.copy(child = newChild)
projectGroupingExprs(a, a.groupingExpressions)

case f: FlatMapGroupsInPandas if f.groupingExprs.exists(!_.deterministic) =>
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, basically this is for the case when grouping expressions are non-deterministic:

== Physical Plan ==
FlatMapGroupsInPandas [_nondeterministic#14], my_pandas_udf(column#4L, score#6), [column#12, score#13]
+- *(2) Sort [_nondeterministic#14 ASC NULLS FIRST], false, 0
   +- Exchange hashpartitioning(_nondeterministic#14, 200), true, [id=#19]
      +- *(1) Project [column#4L, score#6, rand(42) AS _nondeterministic#14]  # <--- here to evaluate non-deterministic expression only once.
...

projectGroupingExprs(f, f.groupingExprs)

case a: FlatMapCoGroupsInPandas if (a.leftExprs ++ a.rightExprs).exists(!_.deterministic) =>
val leftNondeterToAttr = getNondeterToAttr(a.leftExprs)
val leftNewChild = Project(a.left.output ++ leftNondeterToAttr.values, a.left)
val rightNondeterToAttr = getNondeterToAttr(a.rightExprs)
val rightNewChild = Project(a.right.output ++ rightNondeterToAttr.values, a.right)
a.copy(
leftExprs = a.leftExprs.map(
e => leftNondeterToAttr.get(e).map(_.toAttribute).getOrElse(e)),
rightExprs = a.rightExprs.map(
e => rightNondeterToAttr.get(e).map(_.toAttribute).getOrElse(e)),
left = leftNewChild,
right = rightNewChild)

// Don't touch collect metrics. Top-level metrics are not supported (check analysis will fail)
// and we want to retain them inside the aggregate functions.
Expand All @@ -2777,6 +2787,14 @@ class Analyzer(
Project(p.output, newPlan.withNewChildren(newChild :: Nil))
}

private def projectGroupingExprs(p: LogicalPlan, exprs: Seq[Expression]): LogicalPlan = {
val nondeterToAttr = getNondeterToAttr(exprs)
val newChild = Project(p.children.head.output ++ nondeterToAttr.values, p.children.head)
p.transformExpressions { case e =>
nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e)
}.withNewChildren(newChild :: Nil)
}

private def getNondeterToAttr(exprs: Seq[Expression]): Map[Expression, NamedExpression] = {
exprs.filterNot(_.deterministic).flatMap { expr =>
val leafNondeterministic = expr.collect { case n: Nondeterministic => n }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,8 @@ trait CheckAnalysis extends PredicateHelper {

case o if o.expressions.exists(!_.deterministic) &&
!o.isInstanceOf[Project] && !o.isInstanceOf[Filter] &&
!o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] =>
!o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] &&
!o.isInstanceOf[FlatMapGroupsInPandas] && !o.isInstanceOf[FlatMapCoGroupsInPandas] =>
// The rule above is used to check Aggregate operator.
failAnalysis(
s"""nondeterministic expressions are only allowed in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,18 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.util.quoteIdentifier
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, toPrettySQL}
import org.apache.spark.sql.types._

object NamedExpression {
private val curId = new java.util.concurrent.atomic.AtomicLong()
private[expressions] val jvmId = UUID.randomUUID()
def newExprId: ExprId = ExprId(curId.getAndIncrement(), jvmId)
def unapply(expr: NamedExpression): Option[(String, DataType)] = Some((expr.name, expr.dataType))
def fromExpression(expr: Expression): NamedExpression = expr match {
case ne: NamedExpression => ne
case _: Expression => Alias(expr, toPrettySQL(expr))()
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will send another PR to use this in other places, for example,

val namedDistinctExpressions = distinctExpressions.map {
case ne: NamedExpression => ne
case other => Alias(other, other.toString)()
}

val aliasedExprs = aggregateExprs.map {
case ne: NamedExpression => ne
case e => Alias(e, e.toString)()
}

private[this] def alias(expr: Expression): NamedExpression = expr match {
case u: UnresolvedAttribute => UnresolvedAlias(u)
case expr: NamedExpression => expr
case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] =>
UnresolvedAlias(a, Some(Column.generateAlias))
case expr: Expression => Alias(expr, toPrettySQL(expr))()
}

and possibly at:

case expr: Expression => Alias(expr, toPrettySQL(expr))()

I can don't add this util here for now too if anyone is not sure on this.

Copy link
Contributor

@TJX2014 TJX2014 Jun 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find org.apache.spark.sql.Dataset#groupBy(cols: Column*) is invoked through py4j instead of groupBy(col1: String, cols: String*), is it possible to change param sent in python side only to invoke groupBy(col1: String, cols: String*), which may also be helpful to this jira :-)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let me take a look separate with a separate JIRA.

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expre
* This is used by DataFrame.groupby().apply().
*/
case class FlatMapGroupsInPandas(
groupingAttributes: Seq[Attribute],
groupingExprs: Seq[Expression],
functionExpr: Expression,
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
Expand Down Expand Up @@ -56,8 +56,8 @@ case class MapInPandas(
* This is used by DataFrame.groupby().cogroup().apply().
*/
case class FlatMapCoGroupsInPandas(
leftAttributes: Seq[Attribute],
rightAttributes: Seq[Attribute],
leftExprs: Seq[Expression],
rightExprs: Seq[Expression],
functionExpr: Expression,
output: Seq[Attribute],
left: LogicalPlan,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -541,17 +541,8 @@ class RelationalGroupedDataset protected[sql](
"Must pass a grouped map udf")
require(expr.dataType.isInstanceOf[StructType],
s"The returnType of the udf must be a ${StructType.simpleString}")

val groupingNamedExpressions = groupingExprs.map {
case ne: NamedExpression => ne
case other => Alias(other, other.toString)()
}
val groupingAttributes = groupingNamedExpressions.map(_.toAttribute)
val child = df.logicalPlan
val project = Project(groupingNamedExpressions ++ child.output, child)
val output = expr.dataType.asInstanceOf[StructType].toAttributes
val plan = FlatMapGroupsInPandas(groupingAttributes, expr, output, project)

val plan = FlatMapGroupsInPandas(groupingExprs, expr, output, df.logicalPlan)
Dataset.ofRows(df.sparkSession, plan)
}

Expand All @@ -572,28 +563,9 @@ class RelationalGroupedDataset protected[sql](
"Must pass a cogrouped map udf")
require(expr.dataType.isInstanceOf[StructType],
s"The returnType of the udf must be a ${StructType.simpleString}")

val leftGroupingNamedExpressions = groupingExprs.map {
case ne: NamedExpression => ne
case other => Alias(other, other.toString)()
}

val rightGroupingNamedExpressions = r.groupingExprs.map {
case ne: NamedExpression => ne
case other => Alias(other, other.toString)()
}

val leftAttributes = leftGroupingNamedExpressions.map(_.toAttribute)
val rightAttributes = rightGroupingNamedExpressions.map(_.toAttribute)

val leftChild = df.logicalPlan
val rightChild = r.df.logicalPlan

val left = Project(leftGroupingNamedExpressions ++ leftChild.output, leftChild)
val right = Project(rightGroupingNamedExpressions ++ rightChild.output, rightChild)

val output = expr.dataType.asInstanceOf[StructType].toAttributes
val plan = FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, expr, output, left, right)
val plan = FlatMapCoGroupsInPandas(
groupingExprs, r.groupingExprs, expr, output, df.logicalPlan, r.df.logicalPlan)
Dataset.ofRows(df.sparkSession, plan)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -608,10 +608,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.MapPartitionsInRWithArrowExec(
f, p, b, is, ot, planLater(child)) :: Nil
case logical.FlatMapGroupsInPandas(grouping, func, output, child) =>
execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil
case logical.FlatMapCoGroupsInPandas(leftGroup, rightGroup, func, output, left, right) =>
execution.python.FlatMapGroupsInPandasExec(
grouping.map(NamedExpression.fromExpression), func, output, planLater(child)) :: Nil
case logical.FlatMapCoGroupsInPandas(leftExprs, rightExprs, func, output, left, right) =>
execution.python.FlatMapCoGroupsInPandasExec(
leftGroup, rightGroup, func, output, planLater(left), planLater(right)) :: Nil
leftExprs.map(NamedExpression.fromExpression),
rightExprs.map(NamedExpression.fromExpression),
func, output, planLater(left), planLater(right)) :: Nil
case logical.MapInPandas(func, output, child) =>
execution.python.MapInPandasExec(func, output, planLater(child)) :: Nil
case logical.MapElements(f, _, _, objAttr, child) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ import org.apache.spark.sql.util.ArrowUtils
* is left as future work.
*/
case class FlatMapCoGroupsInPandasExec(
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
leftGroupingExprs: Seq[NamedExpression],
rightGroupingExprs: Seq[NamedExpression],
func: Expression,
output: Seq[Attribute],
left: SparkPlan,
Expand All @@ -60,42 +60,49 @@ case class FlatMapCoGroupsInPandasExec(
private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
private val pandasFunction = func.asInstanceOf[PythonUDF].func
private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
private val inputAttrs =
func.asInstanceOf[PythonUDF].children.map(_.asInstanceOf[NamedExpression])
private val leftAttrs = left.output.filter(e => inputAttrs.exists(_.semanticEquals(e)))
private val rightAttrs = right.output.filter(e => inputAttrs.exists(_.semanticEquals(e)))

override def producedAttributes: AttributeSet = AttributeSet(output)

override def outputPartitioning: Partitioning = left.outputPartitioning

override def requiredChildDistribution: Seq[Distribution] = {
val leftDist = if (leftGroup.isEmpty) AllTuples else ClusteredDistribution(leftGroup)
val rightDist = if (rightGroup.isEmpty) AllTuples else ClusteredDistribution(rightGroup)
val leftDist =
if (leftGroupingExprs.isEmpty) AllTuples else ClusteredDistribution(leftGroupingExprs)
val rightDist =
if (rightGroupingExprs.isEmpty) AllTuples else ClusteredDistribution(rightGroupingExprs)
leftDist :: rightDist :: Nil
}

override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
leftGroup
.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil
leftGroupingExprs
.map(SortOrder(_, Ascending)) :: rightGroupingExprs.map(SortOrder(_, Ascending)) :: Nil
}

override protected def doExecute(): RDD[InternalRow] = {

val (leftDedup, leftArgOffsets) = resolveArgOffsets(left, leftGroup)
val (rightDedup, rightArgOffsets) = resolveArgOffsets(right, rightGroup)
val (leftDedup, leftArgOffsets) = resolveArgOffsets(leftAttrs, leftGroupingExprs)
val (rightDedup, rightArgOffsets) = resolveArgOffsets(rightAttrs, rightGroupingExprs)

// Map cogrouped rows to ArrowPythonRunner results, Only execute if partition is not empty
left.execute().zipPartitions(right.execute()) { (leftData, rightData) =>
if (leftData.isEmpty && rightData.isEmpty) Iterator.empty else {

val leftGrouped = groupAndProject(leftData, leftGroup, left.output, leftDedup)
val rightGrouped = groupAndProject(rightData, rightGroup, right.output, rightDedup)
val data = new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup)
val leftGrouped = groupAndProject(leftData, leftGroupingExprs, left.output, leftDedup)
val rightGrouped = groupAndProject(rightData, rightGroupingExprs, right.output, rightDedup)
val data = new CoGroupedIterator(
leftGrouped, rightGrouped, leftGroupingExprs.map(_.toAttribute))
.map { case (_, l, r) => (l, r) }

val runner = new CoGroupedArrowPythonRunner(
chainedFunc,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
Array(leftArgOffsets ++ rightArgOffsets),
StructType.fromAttributes(leftDedup),
StructType.fromAttributes(rightDedup),
StructType.fromAttributes(leftDedup.map(_.toAttribute)),
StructType.fromAttributes(rightDedup.map(_.toAttribute)),
sessionLocalTimeZone,
pythonRunnerConf)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import org.apache.spark.sql.util.ArrowUtils
* is left as future work.
*/
case class FlatMapGroupsInPandasExec(
groupingAttributes: Seq[Attribute],
groupingExprs: Seq[NamedExpression],
func: Expression,
output: Seq[Attribute],
child: SparkPlan)
Expand All @@ -56,38 +56,39 @@ case class FlatMapGroupsInPandasExec(
private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
private val pandasFunction = func.asInstanceOf[PythonUDF].func
private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
private val inputAttrs =
func.asInstanceOf[PythonUDF].children.map(_.asInstanceOf[NamedExpression])

override def producedAttributes: AttributeSet = AttributeSet(output)

override def outputPartitioning: Partitioning = child.outputPartitioning

override def requiredChildDistribution: Seq[Distribution] = {
if (groupingAttributes.isEmpty) {
if (groupingExprs.isEmpty) {
AllTuples :: Nil
} else {
ClusteredDistribution(groupingAttributes) :: Nil
ClusteredDistribution(groupingExprs) :: Nil
}
}

override def requiredChildOrdering: Seq[Seq[SortOrder]] =
Seq(groupingAttributes.map(SortOrder(_, Ascending)))
Seq(groupingExprs.map(SortOrder(_, Ascending)))

override protected def doExecute(): RDD[InternalRow] = {
val inputRDD = child.execute()

val (dedupAttributes, argOffsets) = resolveArgOffsets(child, groupingAttributes)

val (dedupExprs, argOffsets) = resolveArgOffsets(inputAttrs, groupingExprs)
// Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty
inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else {

val data = groupAndProject(iter, groupingAttributes, child.output, dedupAttributes)
val data = groupAndProject(iter, groupingExprs, child.output, dedupExprs)
.map { case (_, x) => x }

val runner = new ArrowPythonRunner(
chainedFunc,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
Array(argOffsets),
StructType.fromAttributes(dedupAttributes),
StructType.fromAttributes(dedupExprs.map(_.toAttribute)),
sessionLocalTimeZone,
pythonRunnerConf)

Expand Down
Loading