Skip to content

Commit 8637205

Browse files
Ngone51HyukjinKwon
authored andcommitted
[SPARK-34319][SQL] Resolve duplicate attributes for FlatMapCoGroupsInPandas/MapInPandas
### What changes were proposed in this pull request? Resolve duplicate attributes for `FlatMapCoGroupsInPandas`. ### Why are the changes needed? When performing self-join on top of `FlatMapCoGroupsInPandas`, analysis can fail because of conflicting attributes. For example, ```scala df = spark.createDataFrame([(1, 1)], ("column", "value")) row = df.groupby("ColUmn").cogroup( df.groupby("COLUMN") ).applyInPandas(lambda r, l: r + l, "column long, value long") row.join(row).show() ``` error: ```scala ... Conflicting attributes: column#163321L,value#163322L ;; ’Join Inner :- FlatMapCoGroupsInPandas [ColUmn#163312L], [COLUMN#163312L], <lambda>(column#163312L, value#163313L, column#163312L, value#163313L), [column#163321L, value#163322L] : :- Project [ColUmn#163312L, column#163312L, value#163313L] : : +- LogicalRDD [column#163312L, value#163313L], false : +- Project [COLUMN#163312L, column#163312L, value#163313L] : +- LogicalRDD [column#163312L, value#163313L], false +- FlatMapCoGroupsInPandas [ColUmn#163312L], [COLUMN#163312L], <lambda>(column#163312L, value#163313L, column#163312L, value#163313L), [column#163321L, value#163322L] :- Project [ColUmn#163312L, column#163312L, value#163313L] : +- LogicalRDD [column#163312L, value#163313L], false +- Project [COLUMN#163312L, column#163312L, value#163313L] +- LogicalRDD [column#163312L, value#163313L], false ... ``` ### Does this PR introduce _any_ user-facing change? yes, the query like the above example won't fail. ### How was this patch tested? Adde unit tests. Closes #31429 from Ngone51/fix-conflcting-attrs-of-FlatMapCoGroupsInPandas. Lead-authored-by: yi.wu <[email protected]> Co-authored-by: wuyi <[email protected]> Signed-off-by: HyukjinKwon <[email protected]> (cherry picked from commit e9362c2) Signed-off-by: HyukjinKwon <[email protected]>
1 parent a8e8ff6 commit 8637205

File tree

4 files changed

+70
-0
lines changed

4 files changed

+70
-0
lines changed

python/pyspark/sql/tests/test_pandas_cogrouped_map.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,18 @@ def test_case_insensitive_grouping_column(self):
209209
).applyInPandas(lambda r, l: r + l, "column long, value long").first()
210210
self.assertEquals(row.asDict(), Row(column=2, value=2).asDict())
211211

212+
def test_self_join(self):
213+
# SPARK-34319: self-join with FlatMapCoGroupsInPandas
214+
df = self.spark.createDataFrame([(1, 1)], ("column", "value"))
215+
216+
row = df.groupby("ColUmn").cogroup(
217+
df.groupby("COLUMN")
218+
).applyInPandas(lambda r, l: r + l, "column long, value long")
219+
220+
row = row.join(row).first()
221+
222+
self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())
223+
212224
@staticmethod
213225
def _test_with_key(left, right, isLeft):
214226

python/pyspark/sql/tests/test_pandas_map.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,14 @@ def func(iterator):
117117
expected = df.collect()
118118
self.assertEquals(actual, expected)
119119

120+
def test_self_join(self):
121+
# SPARK-34319: self-join with MapInPandas
122+
df1 = self.spark.range(10)
123+
df2 = df1.mapInPandas(lambda iter: iter, 'id long')
124+
actual = df2.join(df2).collect()
125+
expected = df1.join(df1).collect()
126+
self.assertEqual(sorted(actual), sorted(expected))
127+
120128

121129
if __name__ == "__main__":
122130
from pyspark.sql.tests.test_pandas_map import *

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,14 @@ class Analyzer(
11981198
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
11991199
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
12001200

1201+
case oldVersion @ FlatMapCoGroupsInPandas(_, _, _, output, _, _)
1202+
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
1203+
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
1204+
1205+
case oldVersion @ MapInPandas(_, output, _)
1206+
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
1207+
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
1208+
12011209
case oldVersion: Generate
12021210
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
12031211
val newOutput = oldVersion.generatorOutput.map(_.newInstance())

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,48 @@ class AnalysisSuite extends AnalysisTest with Matchers {
610610
Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join))
611611
}
612612

613+
test("SPARK-34319: analysis fails on self-join with FlatMapCoGroupsInPandas") {
614+
val pythonUdf = PythonUDF("pyUDF", null,
615+
StructType(Seq(StructField("a", LongType))),
616+
Seq.empty,
617+
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
618+
true)
619+
val output = pythonUdf.dataType.asInstanceOf[StructType].toAttributes
620+
val project1 = Project(Seq(UnresolvedAttribute("a")), testRelation)
621+
val project2 = Project(Seq(UnresolvedAttribute("a")), testRelation2)
622+
val flatMapGroupsInPandas = FlatMapCoGroupsInPandas(
623+
Seq(UnresolvedAttribute("a")),
624+
Seq(UnresolvedAttribute("a")),
625+
pythonUdf,
626+
output,
627+
project1,
628+
project2)
629+
val left = SubqueryAlias("temp0", flatMapGroupsInPandas)
630+
val right = SubqueryAlias("temp1", flatMapGroupsInPandas)
631+
val join = Join(left, right, Inner, None, JoinHint.NONE)
632+
assertAnalysisSuccess(
633+
Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join))
634+
}
635+
636+
test("SPARK-34319: analysis fails on self-join with MapInPandas") {
637+
val pythonUdf = PythonUDF("pyUDF", null,
638+
StructType(Seq(StructField("a", LongType))),
639+
Seq.empty,
640+
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
641+
true)
642+
val output = pythonUdf.dataType.asInstanceOf[StructType].toAttributes
643+
val project = Project(Seq(UnresolvedAttribute("a")), testRelation)
644+
val mapInPandas = MapInPandas(
645+
pythonUdf,
646+
output,
647+
project)
648+
val left = SubqueryAlias("temp0", mapInPandas)
649+
val right = SubqueryAlias("temp1", mapInPandas)
650+
val join = Join(left, right, Inner, None, JoinHint.NONE)
651+
assertAnalysisSuccess(
652+
Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join))
653+
}
654+
613655
test("SPARK-24488 Generator with multiple aliases") {
614656
assertAnalysisSuccess(
615657
listRelation.select(Explode($"list").as("first_alias").as("second_alias")))

0 commit comments

Comments
 (0)