Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@
package org.apache.spark.sql.execution.datasources.v2

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.AttributeSet
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.PredicateHelper
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.planning.RewrittenRowLevelCommand
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType

object RowLevelCommandScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
Expand All @@ -39,16 +43,12 @@ object RowLevelCommandScanRelationPushDown extends Rule[LogicalPlan] with Predic
val table = relation.table.asRowLevelOperationTable
val scanBuilder = table.newScanBuilder(relation.options)

val filters = command.condition.toSeq
val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, relation.output)
val (_, normalizedFiltersWithoutSubquery) =
normalizedFilters.partition(SubqueryExpression.hasSubquery)

val (pushedFilters, remainingFilters) = PushDownUtils.pushFilters(
scanBuilder, normalizedFiltersWithoutSubquery)
val (pushedFilters, remainingFilters) = command.condition match {
case Some(cond) => pushFilters(cond, scanBuilder, relation.output)
case None => (Nil, Nil)
}

val (scan, output) = PushDownUtils.pruneColumns(
scanBuilder, relation, relation.output, Seq.empty)
val (scan, output) = PushDownUtils.pruneColumns(scanBuilder, relation, relation.output, Nil)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, is it necessary change , Seq.empty => Nil?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it so that it can fit on one line just like the new filter pushdown logic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @aokolnychyi @szehon-ho any reason for not passing pushedFilters here instead of Nil?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually nvm, this is only to prune columns.


logInfo(
s"""
Expand All @@ -68,6 +68,20 @@ object RowLevelCommandScanRelationPushDown extends Rule[LogicalPlan] with Predic
command.withNewRewritePlan(newRewritePlan)
}

private def pushFilters(
cond: Expression,
scanBuilder: ScanBuilder,
tableAttrs: Seq[AttributeReference]): (Seq[Filter], Seq[Expression]) = {

val tableAttrSet = AttributeSet(tableAttrs)
val filters = splitConjunctivePredicates(cond).filter(_.references.subsetOf(tableAttrSet))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this what was preventing pushdown before? We weren't filtering out expressions that referenced columns outside of the table?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we did not split the condition before into parts and did not remove filters that referenced both tables.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@szehon-ho ˆˆ

This comment provides a little bit more info to answer your question above.
We treated t.id = s.id and t.dep IN ('hr') as a single predicate that couldn't be converted as it referenced both tables. Instead, we now split it into parts and convert whatever we can (i.e. t.dep IN ('hr') in this case).

val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, tableAttrs)
val (_, normalizedFiltersWithoutSubquery) =
normalizedFilters.partition(SubqueryExpression.hasSubquery)

PushDownUtils.pushFilters(scanBuilder, normalizedFiltersWithoutSubquery)
}

private def toOutputAttrs(
schema: StructType,
relation: DataSourceV2Relation): Seq[AttributeReference] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.iceberg.AssertHelpers;
import org.apache.iceberg.DistributionMode;
import org.apache.iceberg.Snapshot;
import org.apache.iceberg.SnapshotSummary;
import org.apache.iceberg.Table;
import org.apache.iceberg.exceptions.ValidationException;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -75,6 +78,52 @@ public void removeTables() {
sql("DROP TABLE IF EXISTS source");
}

@Test
public void testMergeWithStaticPredicatePushDown() {
createAndInitTable("id BIGINT, dep STRING");

sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName);

append(tableName,
"{ \"id\": 1, \"dep\": \"software\" }\n" +
"{ \"id\": 11, \"dep\": \"software\" }\n" +
"{ \"id\": 1, \"dep\": \"hr\" }");

Table table = validationCatalog.loadTable(tableIdent);

Snapshot snapshot = table.currentSnapshot();
String dataFilesCount = snapshot.summary().get(SnapshotSummary.TOTAL_DATA_FILES_PROP);
Assert.assertEquals("Must have 2 files before MERGE", "2", dataFilesCount);

createOrReplaceView("source",
"{ \"id\": 1, \"dep\": \"finance\" }\n" +
"{ \"id\": 2, \"dep\": \"hardware\" }");

// disable dynamic pruning and rely only on static predicate pushdown
withSQLConf(ImmutableMap.of(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED().key(), "false"), () -> {
sql("MERGE INTO %s t USING source " +
"ON t.id == source.id AND t.dep IN ('software') AND source.id < 10 " +
"WHEN MATCHED AND source.id = 1 THEN " +
" UPDATE SET dep = source.dep " +
"WHEN NOT MATCHED THEN " +
" INSERT (dep, id) VALUES (source.dep, source.id)", tableName);
});

table.refresh();

Snapshot mergeSnapshot = table.currentSnapshot();
String deletedDataFilesCount = mergeSnapshot.summary().get(SnapshotSummary.DELETED_FILES_PROP);
Assert.assertEquals("Must overwrite only 1 file", "1", deletedDataFilesCount);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other tests use the listener to check the expressions that were pushed down directly. Should we do that in this test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I missed that. Could you point me to an example?


ImmutableList<Object[]> expectedRows = ImmutableList.of(
row(1L, "finance"), // updated
row(1L, "hr"), // kept
row(2L, "hardware"), // new
row(11L, "software") // kept
);
assertEquals("Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id, dep", tableName));
}

@Test
public void testMergeIntoEmptyTargetInsertAllNonMatchingRows() {
createAndInitTable("id INT, dep STRING");
Expand Down