diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala index 886ed5002584..285feeb0d14d 100644 --- a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala +++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.iceberg.spark.source.SparkTable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.Alias import org.apache.spark.sql.catalyst.expressions.Attribute @@ -36,23 +37,7 @@ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.LeftAnti import org.apache.spark.sql.catalyst.plans.LeftOuter import org.apache.spark.sql.catalyst.plans.RightOuter -import org.apache.spark.sql.catalyst.plans.logical.AppendData -import org.apache.spark.sql.catalyst.plans.logical.DeleteAction -import org.apache.spark.sql.catalyst.plans.logical.Filter -import org.apache.spark.sql.catalyst.plans.logical.HintInfo -import org.apache.spark.sql.catalyst.plans.logical.InsertAction -import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.catalyst.plans.logical.JoinHint -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.plans.logical.MergeAction -import org.apache.spark.sql.catalyst.plans.logical.MergeIntoIcebergTable -import org.apache.spark.sql.catalyst.plans.logical.MergeRows -import org.apache.spark.sql.catalyst.plans.logical.NO_BROADCAST_HASH -import org.apache.spark.sql.catalyst.plans.logical.NoStatsUnaryNode -import org.apache.spark.sql.catalyst.plans.logical.Project -import org.apache.spark.sql.catalyst.plans.logical.ReplaceData -import org.apache.spark.sql.catalyst.plans.logical.UpdateAction -import org.apache.spark.sql.catalyst.plans.logical.WriteDelta +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, DeleteAction, Filter, HintInfo, InsertAction, Join, JoinHint, LogicalPlan, MergeAction, MergeIntoIcebergTable, MergeRows, NO_BROADCAST_HASH, NoStatsUnaryNode, Project, ReplaceData, UpdateAction, View, WriteDelta} import org.apache.spark.sql.catalyst.util.RowDeltaUtils._ import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.connector.expressions.NamedReference @@ -82,7 +67,6 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case m @ MergeIntoIcebergTable(aliasedTable, source, cond, matchedActions, notMatchedActions, None) if m.resolved && m.aligned && matchedActions.isEmpty && notMatchedActions.size == 1 => - EliminateSubqueryAliases(aliasedTable) match { case r: DataSourceV2Relation => // NOT MATCHED conditions may only refer to columns in source so they can be pushed down @@ -112,7 +96,6 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand { case m @ MergeIntoIcebergTable(aliasedTable, source, cond, matchedActions, notMatchedActions, None) if m.resolved && m.aligned && matchedActions.isEmpty => - EliminateSubqueryAliases(aliasedTable) match { case r: DataSourceV2Relation => // when there are no MATCHED actions, use a left anti join to remove any matching rows @@ -144,25 +127,44 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand { throw new AnalysisException(s"$p is not an Iceberg table") } - case m @ MergeIntoIcebergTable(aliasedTable, source, cond, matchedActions, notMatchedActions, None) + case m @ MergeIntoIcebergTable(aliasedTable,source, cond, matchedActions, notMatchedActions, None) if m.resolved && m.aligned => - EliminateSubqueryAliases(aliasedTable) match { case r @ DataSourceV2Relation(tbl: SupportsRowLevelOperations, _, _, _, _) => - val operation = buildRowLevelOperation(tbl, MERGE) - val table = RowLevelOperationTable(tbl, operation) - val rewritePlan = operation match { - case _: SupportsDelta => - buildWriteDeltaPlan(r, table, source, cond, matchedActions, notMatchedActions) - case _ => - buildReplaceDataPlan(r, table, source, cond, matchedActions, notMatchedActions) + rewriteIcebergRelation(m, r, tbl) + case p: View => + val relations = p.children.collect { case r: DataSourceV2Relation if r.table.isInstanceOf[SparkTable] => + r + } + val icebergTableView = relations.nonEmpty && relations.size == 1 + if (icebergTableView) { + val newM = rewriteIcebergRelation( + m, + relations.head, + relations.head.table.asInstanceOf[SupportsRowLevelOperations]) + newM + } else { + throw new AnalysisException(s"$p is not an Iceberg table") } - - m.copy(rewritePlan = Some(rewritePlan)) - case p => throw new AnalysisException(s"$p is not an Iceberg table") } + + } + + private def rewriteIcebergRelation( + m: MergeIntoIcebergTable, + r: DataSourceV2Relation, + tbl: SupportsRowLevelOperations): MergeIntoIcebergTable = { + val operation = buildRowLevelOperation(tbl, MERGE) + val table = RowLevelOperationTable(tbl, operation) + val rewritePlan = operation match { + case _: SupportsDelta => + buildWriteDeltaPlan(r, table, m.sourceTable, m.mergeCondition, m.matchedActions, m.notMatchedActions) + case _ => + buildReplaceDataPlan(r, table, m.sourceTable, m.mergeCondition, m.matchedActions, m.notMatchedActions) + } + m.copy(rewritePlan = Some(rewritePlan)) } // build a rewrite plan for sources that support replacing groups of data (e.g. files, partitions) diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala index 0339d8bff833..6c9e1f579916 100644 --- a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala +++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.parser.extensions import java.util.Locale + import org.antlr.v4.runtime._ import org.antlr.v4.runtime.atn.PredictionMode import org.antlr.v4.runtime.misc.Interval @@ -27,7 +28,7 @@ import org.antlr.v4.runtime.misc.ParseCancellationException import org.antlr.v4.runtime.tree.TerminalNodeImpl import org.apache.iceberg.common.DynConstructors import org.apache.iceberg.spark.Spark3Util -import org.apache.iceberg.spark.source.SparkTable +import org.apache.iceberg.spark.source.{SparkBatchQueryScan, SparkFilesScan, SparkTable} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.FunctionIdentifier @@ -57,6 +58,10 @@ import org.apache.spark.sql.types.StructType import scala.jdk.CollectionConverters._ import scala.util.Try +import org.apache.iceberg.hadoop.HadoopTables +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.execution.metric.SQLMetrics.cachedSQLAccumIdentifier.x + class IcebergSparkSqlExtensionsParser(delegate: ParserInterface) extends ParserInterface { import IcebergSparkSqlExtensionsParser._ @@ -162,8 +167,12 @@ class IcebergSparkSqlExtensionsParser(delegate: ParserInterface) extends ParserI case tableCatalog: TableCatalog => Try(tableCatalog.loadTable(catalogAndIdentifier.identifier)) .map(isIcebergTable) - .getOrElse(false) - + .getOrElse(SparkSession.active.table(s"${multipartIdent.mkString(".")}").queryExecution + .executedPlan.collect { + case BatchScanExec(_, scan, _) if scan.isInstanceOf[SparkBatchQueryScan] => + val ht = new HadoopTables(SparkSession.active.sparkContext.hadoopConfiguration) + ht.exists(scan.asInstanceOf[SparkBatchQueryScan].tableScan().table().location()) + }.contains(true)) case _ => false } diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/planning/RewrittenRowLevelCommand.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/planning/RewrittenRowLevelCommand.scala index 4cc7a7bf2f96..90a38877be1c 100644 --- a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/planning/RewrittenRowLevelCommand.scala +++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/planning/RewrittenRowLevelCommand.scala @@ -66,7 +66,6 @@ object RewrittenRowLevelCommand { case _ => None } - case _ => None } diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplaceData.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplaceData.scala index 3bf726ffb719..7a52acc3a5bd 100644 --- a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplaceData.scala +++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplaceData.scala @@ -52,7 +52,6 @@ case class ReplaceData( // they will be discarded after the logical write is built in the optimizer // metadata columns may be needed to request a correct distribution or ordering // but are not passed back to the data source during writes - table.skipSchemaResolution || (dataInput.size == table.output.size && dataInput.zip(table.output).forall { case (inAttr, outAttr) => val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType) diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala index 4fbf8a523a54..0a796546450d 100644 --- a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala +++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala @@ -184,7 +184,12 @@ case class MergeRowsExec( } rowIterator - .map(processFunc) + .map(row => { + println(s"input: $row") + val o = processFunc(row) + println(s"output: $o") + o + }) .filter(row => row != null) } } diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala index 4e89b9a1c243..3f6424d9ca03 100644 --- a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala +++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala @@ -64,7 +64,6 @@ object RowLevelCommandScanRelationPushDown extends Rule[LogicalPlan] with Predic case r: DataSourceV2Relation if r.table eq table => DataSourceV2ScanRelation(r, scan, toOutputAttrs(scan.readSchema(), r)) } - command.withNewRewritePlan(newRewritePlan) } diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java index edc3944e69fe..15017ce1f8cf 100644 --- a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java +++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java @@ -19,6 +19,7 @@ package org.apache.iceberg.spark.extensions; +import java.io.File; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -29,13 +30,11 @@ import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; -import org.apache.iceberg.AssertHelpers; -import org.apache.iceberg.DataFile; -import org.apache.iceberg.RowLevelOperationMode; -import org.apache.iceberg.Snapshot; -import org.apache.iceberg.SnapshotSummary; -import org.apache.iceberg.Table; + +import org.apache.commons.io.FileUtils; +import org.apache.iceberg.*; import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.hadoop.HadoopTables; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; @@ -43,19 +42,20 @@ import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.spark.SparkSchemaUtil; import org.apache.spark.SparkException; import org.apache.spark.sql.AnalysisException; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.ReplaceData; +import org.apache.spark.sql.execution.SparkPlan; import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.StructType; import org.hamcrest.CoreMatchers; -import org.junit.After; -import org.junit.Assert; -import org.junit.Assume; -import org.junit.BeforeClass; -import org.junit.Test; +import org.junit.*; import static org.apache.iceberg.DataOperations.OVERWRITE; import static org.apache.iceberg.RowLevelOperationMode.COPY_ON_WRITE; @@ -69,6 +69,9 @@ import static org.apache.iceberg.TableProperties.UPDATE_MODE_DEFAULT; import static org.apache.spark.sql.functions.lit; +import scala.collection.Iterator; +import scala.collection.JavaConverters.*; + public abstract class TestUpdate extends SparkRowLevelOperationsTestBase { public TestUpdate(String catalogName, String implementation, Map config, @@ -137,6 +140,30 @@ public void testUpdateWithAlias() { sql("SELECT * FROM %s", tableName)); } + @Test + public void testHadoopTables() throws Exception { + List ids = Lists.newArrayListWithCapacity(2); + for (int id = 1; id <= 2; id++) { + ids.add(id); + } + Dataset df = spark.createDataset(ids, Encoders.INT()) + .withColumnRenamed("value", "id"); + HadoopTables ht = new HadoopTables(spark.sparkContext().hadoopConfiguration()); + Schema tableSchema = SparkSchemaUtil.convert(df.schema()); + File dir = java.nio.file.Files.createTempDirectory("TestUpdate").toFile(); + FileUtils.forceDeleteOnExit(dir); + String path = dir.getAbsolutePath(); + ht.create(tableSchema, path); + df.write().format("iceberg").mode("overwrite").save(path); + Dataset tableDF = spark.read().format("iceberg").load(path); + tableDF.createOrReplaceTempView("target"); + df.createOrReplaceTempView("source"); + spark.sql("select * from source").show(); + sql("MERGE INTO target using source on target.id = source.id " + + "WHEN MATCHED THEN UPDATE SET target.id = source.id + 1"); + spark.read().format("iceberg").load(path).show(); + } + @Test public void testUpdateAlignsAssignments() { createAndInitTable("id INT, c1 INT, c2 INT"); diff --git a/spark/v3.2/spark-extensions/table/._SUCCESS.crc b/spark/v3.2/spark-extensions/table/._SUCCESS.crc new file mode 100644 index 000000000000..3b7b044936a8 Binary files /dev/null and b/spark/v3.2/spark-extensions/table/._SUCCESS.crc differ diff --git a/spark/v3.2/spark-extensions/table/.part-00000-cfeeb10f-708b-47a3-9bc2-cbcbc66c325b-c000.snappy.parquet.crc b/spark/v3.2/spark-extensions/table/.part-00000-cfeeb10f-708b-47a3-9bc2-cbcbc66c325b-c000.snappy.parquet.crc new file mode 100644 index 000000000000..72989152a2bb Binary files /dev/null and b/spark/v3.2/spark-extensions/table/.part-00000-cfeeb10f-708b-47a3-9bc2-cbcbc66c325b-c000.snappy.parquet.crc differ diff --git a/spark/v3.2/spark-extensions/table/.part-00001-cfeeb10f-708b-47a3-9bc2-cbcbc66c325b-c000.snappy.parquet.crc b/spark/v3.2/spark-extensions/table/.part-00001-cfeeb10f-708b-47a3-9bc2-cbcbc66c325b-c000.snappy.parquet.crc new file mode 100644 index 000000000000..85ecbb30bb6a Binary files /dev/null and b/spark/v3.2/spark-extensions/table/.part-00001-cfeeb10f-708b-47a3-9bc2-cbcbc66c325b-c000.snappy.parquet.crc differ diff --git a/spark/v3.2/spark-extensions/table/_SUCCESS b/spark/v3.2/spark-extensions/table/_SUCCESS new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/spark/v3.2/spark-extensions/table/part-00000-cfeeb10f-708b-47a3-9bc2-cbcbc66c325b-c000.snappy.parquet b/spark/v3.2/spark-extensions/table/part-00000-cfeeb10f-708b-47a3-9bc2-cbcbc66c325b-c000.snappy.parquet new file mode 100644 index 000000000000..47c9bf5a7de1 Binary files /dev/null and b/spark/v3.2/spark-extensions/table/part-00000-cfeeb10f-708b-47a3-9bc2-cbcbc66c325b-c000.snappy.parquet differ diff --git a/spark/v3.2/spark-extensions/table/part-00001-cfeeb10f-708b-47a3-9bc2-cbcbc66c325b-c000.snappy.parquet b/spark/v3.2/spark-extensions/table/part-00001-cfeeb10f-708b-47a3-9bc2-cbcbc66c325b-c000.snappy.parquet new file mode 100644 index 000000000000..c884140ac8d6 Binary files /dev/null and b/spark/v3.2/spark-extensions/table/part-00001-cfeeb10f-708b-47a3-9bc2-cbcbc66c325b-c000.snappy.parquet differ diff --git a/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java b/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java index 651a411ebd7b..8f729e973856 100644 --- a/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java +++ b/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java @@ -59,7 +59,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -class SparkBatchQueryScan extends SparkScan implements SupportsRuntimeFiltering { +public class SparkBatchQueryScan extends SparkScan implements SupportsRuntimeFiltering { private static final Logger LOG = LoggerFactory.getLogger(SparkBatchQueryScan.class); @@ -121,6 +121,10 @@ private List files() { return files; } + public TableScan tableScan() { + return scan; + } + @Override protected List tasks() { if (tasks == null) {