diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala index 8ab0dc7072664..5c8c7cf420d65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala @@ -41,8 +41,18 @@ object V2ScanPartitioningAndOrdering extends Rule[LogicalPlan] with SQLConfHelpe private def partitioning(plan: LogicalPlan) = plan.transformDown { case d @ DataSourceV2ScanRelation(relation, scan: SupportsReportPartitioning, _, None, _) => val catalystPartitioning = scan.outputPartitioning() match { - case kgp: KeyGroupedPartitioning => sequenceToOption(kgp.keys().map( - V2ExpressionUtils.toCatalystOpt(_, relation, relation.funCatalog))) + case kgp: KeyGroupedPartitioning => + val partitioning = sequenceToOption( + kgp.keys().map(V2ExpressionUtils.toCatalystOpt(_, relation, relation.funCatalog))) + if (partitioning.isEmpty) { + None + } else { + if (partitioning.get.forall(p => p.references.subsetOf(d.outputSet))) { + partitioning + } else { + None + } + } case _: UnknownPartitioning => None case p => throw new IllegalArgumentException("Unsupported data source V2 partitioning " + "type: " + p.getClass.getSimpleName) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala index 9b90ee43657f5..8454b9f85ecdd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala @@ -216,4 +216,20 @@ class MetadataColumnSuite extends DatasourceV2SQLBase { .withColumn("right_all", struct($"right.*")) checkAnswer(dfQuery, Row(1, "a", "b", Row(1, "a"), Row(1, "b"))) } + + test("SPARK-40429: Only set KeyGroupedPartitioning when the referenced column is in the output") { + withTable(tbl) { + sql(s"CREATE TABLE $tbl (id bigint, data string) PARTITIONED BY (id)") + sql(s"INSERT INTO $tbl VALUES (1, 'a'), (2, 'b'), (3, 'c')") + checkAnswer( + spark.table(tbl).select("index", "_partition"), + Seq(Row(0, "3"), Row(0, "2"), Row(0, "1")) + ) + + checkAnswer( + spark.table(tbl).select("id", "index", "_partition"), + Seq(Row(3, 0, "3"), Row(2, 0, "2"), Row(1, 0, "1")) + ) + } + } }