diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 692cf77d9afb..c5c7dae4aeec 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -963,6 +963,16 @@ def test_unpivot_negative(self): ): df.unpivot("id", ["int", "str"], "var", "val").collect() + def test_melt_groupby(self): + df = self.spark.createDataFrame( + [(1, 2, 3, 4, 5, 6)], + ["f1", "f2", "label", "pred", "model_version", "ts"], + ) + self.assertEqual( + df.melt("model_version", ["label", "f2"], "f1", "f2").groupby("f1").count().count(), + 2, + ) + def test_observe(self): # SPARK-36263: tests the DataFrame.observe(Observation, *Column) method from pyspark.sql import Observation diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 94f6d3346265..84b4e3211c53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -897,13 +897,18 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ids.forall(_.isInstanceOf[AttributeReference]) => val idAttrs = AttributeSet(up.ids.get) val values = up.child.output.filterNot(idAttrs.contains) - up.copy(values = Some(values.map(Seq(_)))) + val newUnpivot = up.copy(values = Some(values.map(Seq(_)))) + newUnpivot.copyTagsFrom(up) + newUnpivot + case up @ Unpivot(None, Some(values), _, _, _, _) if up.childrenResolved && values.forall(_.forall(_.resolved)) && values.forall(_.forall(_.isInstanceOf[AttributeReference])) => val valueAttrs = AttributeSet(up.values.get.flatten) val ids = up.child.output.filterNot(valueAttrs.contains) - up.copy(ids = Some(ids)) + val newUnpivot = up.copy(ids = Some(ids)) + newUnpivot.copyTagsFrom(up) + newUnpivot case up: Unpivot if !up.childrenResolved || !up.ids.exists(_.forall(_.resolved)) || !up.values.exists(_.nonEmpty) || !up.values.exists(_.forall(_.forall(_.resolved))) || @@ -912,7 +917,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // TypeCoercionBase.UnpivotCoercion determines valueType // and casts values once values are set and resolved - case Unpivot(Some(ids), Some(values), aliases, variableColumnName, valueColumnNames, child) => + case up @ Unpivot(Some(ids), Some(values), aliases, + variableColumnName, valueColumnNames, child) => def toString(values: Seq[NamedExpression]): String = values.map(v => v.name).mkString("_") @@ -938,7 +944,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor val output = (ids.map(_.toAttribute) :+ variableAttr) ++ valueAttrs // expand the unpivot expressions - Expand(exprs, output, child) + val expand = Expand(exprs, output, child) + expand.copyTagsFrom(up) + expand } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index c5e98683c749..36940fbc04cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -223,7 +223,9 @@ abstract class TypeCoercionBase { } ) - up.copy(values = Some(values)) + val newUnpivot = up.copy(values = Some(values)) + newUnpivot.copyTagsFrom(up) + newUnpivot } }