From 42fd8ebeeef8811996b62e3666e3b0e7ab674e90 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 22 Dec 2023 10:49:15 +0800 Subject: [PATCH 1/4] connect_unpivot_plan_id --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 7 +++++-- .../apache/spark/sql/catalyst/analysis/TypeCoercion.scala | 4 +++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 94f6d3346265..28855f8d248b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -912,7 +912,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // TypeCoercionBase.UnpivotCoercion determines valueType // and casts values once values are set and resolved - case Unpivot(Some(ids), Some(values), aliases, variableColumnName, valueColumnNames, child) => + case up @ Unpivot(Some(ids), Some(values), aliases, + variableColumnName, valueColumnNames, child) => def toString(values: Seq[NamedExpression]): String = values.map(v => v.name).mkString("_") @@ -938,7 +939,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor val output = (ids.map(_.toAttribute) :+ variableAttr) ++ valueAttrs // expand the unpivot expressions - Expand(exprs, output, child) + val expand = Expand(exprs, output, child) + expand.copyTagsFrom(up) + expand } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index c5e98683c749..36940fbc04cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -223,7 +223,9 @@ abstract class TypeCoercionBase { } ) - up.copy(values = Some(values)) + val newUnpivot = up.copy(values = Some(values)) + newUnpivot.copyTagsFrom(up) + newUnpivot } } From 427e8f0e1d3786d6c663f213bdef71c47db7dbb1 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 22 Dec 2023 11:22:36 +0800 Subject: [PATCH 2/4] init --- python/pyspark/sql/tests/test_dataframe.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 692cf77d9afb..17658d5378a7 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -963,6 +963,16 @@ def test_unpivot_negative(self): ): df.unpivot("id", ["int", "str"], "var", "val").collect() + def test_unpivot_groupby(self): + df = self.spark.createDataFrame( + [(1, 11, 1.1), (2, 12, 1.2)], + ["id", "int", "double"], + ) + self.assertEqual( + df.unpivot("id", ["int", "double"], "var", "val").groupBy("id").count().count(), + 2, + ) + def test_observe(self): # SPARK-36263: tests the DataFrame.observe(Observation, *Column) method from pyspark.sql import Observation From ce6108a66d91732b40f6a5b1b4a93ad9fc8c9603 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 22 Dec 2023 11:31:54 +0800 Subject: [PATCH 3/4] add test --- python/pyspark/sql/tests/test_dataframe.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 17658d5378a7..c5c7dae4aeec 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -963,13 +963,13 @@ def test_unpivot_negative(self): ): df.unpivot("id", ["int", "str"], "var", "val").collect() - def test_unpivot_groupby(self): + def test_melt_groupby(self): df = self.spark.createDataFrame( - [(1, 11, 1.1), (2, 12, 1.2)], - ["id", "int", "double"], + [(1, 2, 3, 4, 5, 6)], + ["f1", "f2", "label", "pred", "model_version", "ts"], ) self.assertEqual( - df.unpivot("id", ["int", "double"], "var", "val").groupBy("id").count().count(), + df.melt("model_version", ["label", "f2"], "f1", "f2").groupby("f1").count().count(), 2, ) From ae4a2f36311de18466310b2f8b80b148cc51196c Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 22 Dec 2023 20:06:55 +0800 Subject: [PATCH 4/4] more safe --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 28855f8d248b..84b4e3211c53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -897,13 +897,18 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ids.forall(_.isInstanceOf[AttributeReference]) => val idAttrs = AttributeSet(up.ids.get) val values = up.child.output.filterNot(idAttrs.contains) - up.copy(values = Some(values.map(Seq(_)))) + val newUnpivot = up.copy(values = Some(values.map(Seq(_)))) + newUnpivot.copyTagsFrom(up) + newUnpivot + case up @ Unpivot(None, Some(values), _, _, _, _) if up.childrenResolved && values.forall(_.forall(_.resolved)) && values.forall(_.forall(_.isInstanceOf[AttributeReference])) => val valueAttrs = AttributeSet(up.values.get.flatten) val ids = up.child.output.filterNot(valueAttrs.contains) - up.copy(ids = Some(ids)) + val newUnpivot = up.copy(ids = Some(ids)) + newUnpivot.copyTagsFrom(up) + newUnpivot case up: Unpivot if !up.childrenResolved || !up.ids.exists(_.forall(_.resolved)) || !up.values.exists(_.nonEmpty) || !up.values.exists(_.forall(_.forall(_.resolved))) ||