diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala index 52b27d53e1a7..4e89b9a1c243 100644 --- a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala +++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala @@ -20,13 +20,17 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.PredicateHelper import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.planning.RewrittenRowLevelCommand import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType object RowLevelCommandScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { @@ -39,16 +43,12 @@ object RowLevelCommandScanRelationPushDown extends Rule[LogicalPlan] with Predic val table = relation.table.asRowLevelOperationTable val scanBuilder = table.newScanBuilder(relation.options) - val filters = command.condition.toSeq - val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, relation.output) - val (_, normalizedFiltersWithoutSubquery) = - normalizedFilters.partition(SubqueryExpression.hasSubquery) - - val (pushedFilters, remainingFilters) = PushDownUtils.pushFilters( - scanBuilder, normalizedFiltersWithoutSubquery) + val (pushedFilters, remainingFilters) = command.condition match { + case Some(cond) => pushFilters(cond, scanBuilder, relation.output) + case None => (Nil, Nil) + } - val (scan, output) = PushDownUtils.pruneColumns( - scanBuilder, relation, relation.output, Seq.empty) + val (scan, output) = PushDownUtils.pruneColumns(scanBuilder, relation, relation.output, Nil) logInfo( s""" @@ -68,6 +68,20 @@ object RowLevelCommandScanRelationPushDown extends Rule[LogicalPlan] with Predic command.withNewRewritePlan(newRewritePlan) } + private def pushFilters( + cond: Expression, + scanBuilder: ScanBuilder, + tableAttrs: Seq[AttributeReference]): (Seq[Filter], Seq[Expression]) = { + + val tableAttrSet = AttributeSet(tableAttrs) + val filters = splitConjunctivePredicates(cond).filter(_.references.subsetOf(tableAttrSet)) + val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, tableAttrs) + val (_, normalizedFiltersWithoutSubquery) = + normalizedFilters.partition(SubqueryExpression.hasSubquery) + + PushDownUtils.pushFilters(scanBuilder, normalizedFiltersWithoutSubquery) + } + private def toOutputAttrs( schema: StructType, relation: DataSourceV2Relation): Seq[AttributeReference] = { diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java index 23ba7addf99b..6537e31c6757 100644 --- a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java +++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java @@ -32,6 +32,9 @@ import java.util.concurrent.atomic.AtomicInteger; import org.apache.iceberg.AssertHelpers; import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotSummary; +import org.apache.iceberg.Table; import org.apache.iceberg.exceptions.ValidationException; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; @@ -75,6 +78,52 @@ public void removeTables() { sql("DROP TABLE IF EXISTS source"); } + @Test + public void testMergeWithStaticPredicatePushDown() { + createAndInitTable("id BIGINT, dep STRING"); + + sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); + + append(tableName, + "{ \"id\": 1, \"dep\": \"software\" }\n" + + "{ \"id\": 11, \"dep\": \"software\" }\n" + + "{ \"id\": 1, \"dep\": \"hr\" }"); + + Table table = validationCatalog.loadTable(tableIdent); + + Snapshot snapshot = table.currentSnapshot(); + String dataFilesCount = snapshot.summary().get(SnapshotSummary.TOTAL_DATA_FILES_PROP); + Assert.assertEquals("Must have 2 files before MERGE", "2", dataFilesCount); + + createOrReplaceView("source", + "{ \"id\": 1, \"dep\": \"finance\" }\n" + + "{ \"id\": 2, \"dep\": \"hardware\" }"); + + // disable dynamic pruning and rely only on static predicate pushdown + withSQLConf(ImmutableMap.of(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED().key(), "false"), () -> { + sql("MERGE INTO %s t USING source " + + "ON t.id == source.id AND t.dep IN ('software') AND source.id < 10 " + + "WHEN MATCHED AND source.id = 1 THEN " + + " UPDATE SET dep = source.dep " + + "WHEN NOT MATCHED THEN " + + " INSERT (dep, id) VALUES (source.dep, source.id)", tableName); + }); + + table.refresh(); + + Snapshot mergeSnapshot = table.currentSnapshot(); + String deletedDataFilesCount = mergeSnapshot.summary().get(SnapshotSummary.DELETED_FILES_PROP); + Assert.assertEquals("Must overwrite only 1 file", "1", deletedDataFilesCount); + + ImmutableList expectedRows = ImmutableList.of( + row(1L, "finance"), // updated + row(1L, "hr"), // kept + row(2L, "hardware"), // new + row(11L, "software") // kept + ); + assertEquals("Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + } + @Test public void testMergeIntoEmptyTargetInsertAllNonMatchingRows() { createAndInitTable("id INT, dep STRING");