Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.spark.sql.catalyst.analysis

import org.apache.iceberg.spark.source.SparkTable
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.expressions.Attribute
Expand All @@ -36,23 +37,7 @@ import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.LeftAnti
import org.apache.spark.sql.catalyst.plans.LeftOuter
import org.apache.spark.sql.catalyst.plans.RightOuter
import org.apache.spark.sql.catalyst.plans.logical.AppendData
import org.apache.spark.sql.catalyst.plans.logical.DeleteAction
import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.apache.spark.sql.catalyst.plans.logical.HintInfo
import org.apache.spark.sql.catalyst.plans.logical.InsertAction
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.plans.logical.JoinHint
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.MergeAction
import org.apache.spark.sql.catalyst.plans.logical.MergeIntoIcebergTable
import org.apache.spark.sql.catalyst.plans.logical.MergeRows
import org.apache.spark.sql.catalyst.plans.logical.NO_BROADCAST_HASH
import org.apache.spark.sql.catalyst.plans.logical.NoStatsUnaryNode
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.catalyst.plans.logical.ReplaceData
import org.apache.spark.sql.catalyst.plans.logical.UpdateAction
import org.apache.spark.sql.catalyst.plans.logical.WriteDelta
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, DeleteAction, Filter, HintInfo, InsertAction, Join, JoinHint, LogicalPlan, MergeAction, MergeIntoIcebergTable, MergeRows, NO_BROADCAST_HASH, NoStatsUnaryNode, Project, ReplaceData, UpdateAction, View, WriteDelta}
import org.apache.spark.sql.catalyst.util.RowDeltaUtils._
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.connector.expressions.NamedReference
Expand Down Expand Up @@ -82,7 +67,6 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case m @ MergeIntoIcebergTable(aliasedTable, source, cond, matchedActions, notMatchedActions, None)
if m.resolved && m.aligned && matchedActions.isEmpty && notMatchedActions.size == 1 =>

EliminateSubqueryAliases(aliasedTable) match {
case r: DataSourceV2Relation =>
// NOT MATCHED conditions may only refer to columns in source so they can be pushed down
Expand Down Expand Up @@ -112,7 +96,6 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand {

case m @ MergeIntoIcebergTable(aliasedTable, source, cond, matchedActions, notMatchedActions, None)
if m.resolved && m.aligned && matchedActions.isEmpty =>

EliminateSubqueryAliases(aliasedTable) match {
case r: DataSourceV2Relation =>
// when there are no MATCHED actions, use a left anti join to remove any matching rows
Expand Down Expand Up @@ -144,25 +127,44 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand {
throw new AnalysisException(s"$p is not an Iceberg table")
}

case m @ MergeIntoIcebergTable(aliasedTable, source, cond, matchedActions, notMatchedActions, None)
case m @ MergeIntoIcebergTable(aliasedTable,source, cond, matchedActions, notMatchedActions, None)
if m.resolved && m.aligned =>

EliminateSubqueryAliases(aliasedTable) match {
case r @ DataSourceV2Relation(tbl: SupportsRowLevelOperations, _, _, _, _) =>
val operation = buildRowLevelOperation(tbl, MERGE)
val table = RowLevelOperationTable(tbl, operation)
val rewritePlan = operation match {
case _: SupportsDelta =>
buildWriteDeltaPlan(r, table, source, cond, matchedActions, notMatchedActions)
case _ =>
buildReplaceDataPlan(r, table, source, cond, matchedActions, notMatchedActions)
rewriteIcebergRelation(m, r, tbl)
case p: View =>
val relations = p.children.collect { case r: DataSourceV2Relation if r.table.isInstanceOf[SparkTable] =>
r
}
val icebergTableView = relations.nonEmpty && relations.size == 1
if (icebergTableView) {
val newM = rewriteIcebergRelation(
m,
relations.head,
relations.head.table.asInstanceOf[SupportsRowLevelOperations])
newM
} else {
throw new AnalysisException(s"$p is not an Iceberg table")
}

m.copy(rewritePlan = Some(rewritePlan))

case p =>
throw new AnalysisException(s"$p is not an Iceberg table")
}

}

private def rewriteIcebergRelation(
m: MergeIntoIcebergTable,
r: DataSourceV2Relation,
tbl: SupportsRowLevelOperations): MergeIntoIcebergTable = {
val operation = buildRowLevelOperation(tbl, MERGE)
val table = RowLevelOperationTable(tbl, operation)
val rewritePlan = operation match {
case _: SupportsDelta =>
buildWriteDeltaPlan(r, table, m.sourceTable, m.mergeCondition, m.matchedActions, m.notMatchedActions)
case _ =>
buildReplaceDataPlan(r, table, m.sourceTable, m.mergeCondition, m.matchedActions, m.notMatchedActions)
}
m.copy(rewritePlan = Some(rewritePlan))
}

// build a rewrite plan for sources that support replacing groups of data (e.g. files, partitions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
package org.apache.spark.sql.catalyst.parser.extensions

import java.util.Locale

import org.antlr.v4.runtime._
import org.antlr.v4.runtime.atn.PredictionMode
import org.antlr.v4.runtime.misc.Interval
import org.antlr.v4.runtime.misc.ParseCancellationException
import org.antlr.v4.runtime.tree.TerminalNodeImpl
import org.apache.iceberg.common.DynConstructors
import org.apache.iceberg.spark.Spark3Util
import org.apache.iceberg.spark.source.SparkTable
import org.apache.iceberg.spark.source.{SparkBatchQueryScan, SparkFilesScan, SparkTable}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.FunctionIdentifier
Expand Down Expand Up @@ -57,6 +58,10 @@ import org.apache.spark.sql.types.StructType
import scala.jdk.CollectionConverters._
import scala.util.Try

import org.apache.iceberg.hadoop.HadoopTables
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.metric.SQLMetrics.cachedSQLAccumIdentifier.x

class IcebergSparkSqlExtensionsParser(delegate: ParserInterface) extends ParserInterface {

import IcebergSparkSqlExtensionsParser._
Expand Down Expand Up @@ -162,8 +167,12 @@ class IcebergSparkSqlExtensionsParser(delegate: ParserInterface) extends ParserI
case tableCatalog: TableCatalog =>
Try(tableCatalog.loadTable(catalogAndIdentifier.identifier))
.map(isIcebergTable)
.getOrElse(false)

.getOrElse(SparkSession.active.table(s"${multipartIdent.mkString(".")}").queryExecution
.executedPlan.collect {
case BatchScanExec(_, scan, _) if scan.isInstanceOf[SparkBatchQueryScan] =>
val ht = new HadoopTables(SparkSession.active.sparkContext.hadoopConfiguration)
ht.exists(scan.asInstanceOf[SparkBatchQueryScan].tableScan().table().location())
}.contains(true))
case _ =>
false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ object RewrittenRowLevelCommand {
case _ =>
None
}

case _ =>
None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ case class ReplaceData(
// they will be discarded after the logical write is built in the optimizer
// metadata columns may be needed to request a correct distribution or ordering
// but are not passed back to the data source during writes

table.skipSchemaResolution || (dataInput.size == table.output.size &&
dataInput.zip(table.output).forall { case (inAttr, outAttr) =>
val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,12 @@ case class MergeRowsExec(
}

rowIterator
.map(processFunc)
.map(row => {
println(s"input: $row")
val o = processFunc(row)
println(s"output: $o")
o
})
.filter(row => row != null)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ object RowLevelCommandScanRelationPushDown extends Rule[LogicalPlan] with Predic
case r: DataSourceV2Relation if r.table eq table =>
DataSourceV2ScanRelation(r, scan, toOutputAttrs(scan.readSchema(), r))
}

command.withNewRewritePlan(newRewritePlan)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.iceberg.spark.extensions;

import java.io.File;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
Expand All @@ -29,33 +30,32 @@
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.iceberg.AssertHelpers;
import org.apache.iceberg.DataFile;
import org.apache.iceberg.RowLevelOperationMode;
import org.apache.iceberg.Snapshot;
import org.apache.iceberg.SnapshotSummary;
import org.apache.iceberg.Table;

import org.apache.commons.io.FileUtils;
import org.apache.iceberg.*;
import org.apache.iceberg.exceptions.ValidationException;
import org.apache.iceberg.hadoop.HadoopTables;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet;
import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors;
import org.apache.iceberg.spark.SparkSQLProperties;
import org.apache.iceberg.spark.SparkSchemaUtil;
import org.apache.spark.SparkException;
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.catalyst.plans.logical.ReplaceData;
import org.apache.spark.sql.execution.SparkPlan;
import org.apache.spark.sql.internal.SQLConf;
import org.apache.spark.sql.types.StructType;
import org.hamcrest.CoreMatchers;
import org.junit.After;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.*;

import static org.apache.iceberg.DataOperations.OVERWRITE;
import static org.apache.iceberg.RowLevelOperationMode.COPY_ON_WRITE;
Expand All @@ -69,6 +69,9 @@
import static org.apache.iceberg.TableProperties.UPDATE_MODE_DEFAULT;
import static org.apache.spark.sql.functions.lit;

import scala.collection.Iterator;
import scala.collection.JavaConverters.*;

public abstract class TestUpdate extends SparkRowLevelOperationsTestBase {

public TestUpdate(String catalogName, String implementation, Map<String, String> config,
Expand Down Expand Up @@ -137,6 +140,30 @@ public void testUpdateWithAlias() {
sql("SELECT * FROM %s", tableName));
}

@Test
public void testHadoopTables() throws Exception {
List<Integer> ids = Lists.newArrayListWithCapacity(2);
for (int id = 1; id <= 2; id++) {
ids.add(id);
}
Dataset<Row> df = spark.createDataset(ids, Encoders.INT())
.withColumnRenamed("value", "id");
HadoopTables ht = new HadoopTables(spark.sparkContext().hadoopConfiguration());
Schema tableSchema = SparkSchemaUtil.convert(df.schema());
File dir = java.nio.file.Files.createTempDirectory("TestUpdate").toFile();
FileUtils.forceDeleteOnExit(dir);
String path = dir.getAbsolutePath();
ht.create(tableSchema, path);
df.write().format("iceberg").mode("overwrite").save(path);
Dataset<Row> tableDF = spark.read().format("iceberg").load(path);
tableDF.createOrReplaceTempView("target");
df.createOrReplaceTempView("source");
spark.sql("select * from source").show();
sql("MERGE INTO target using source on target.id = source.id " +
"WHEN MATCHED THEN UPDATE SET target.id = source.id + 1");
spark.read().format("iceberg").load(path).show();
}

@Test
public void testUpdateAlignsAssignments() {
createAndInitTable("id INT, c1 INT, c2 INT");
Expand Down
Binary file added spark/v3.2/spark-extensions/table/._SUCCESS.crc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class SparkBatchQueryScan extends SparkScan implements SupportsRuntimeFiltering {
public class SparkBatchQueryScan extends SparkScan implements SupportsRuntimeFiltering {

private static final Logger LOG = LoggerFactory.getLogger(SparkBatchQueryScan.class);

Expand Down Expand Up @@ -121,6 +121,10 @@ private List<FileScanTask> files() {
return files;
}

public TableScan tableScan() {
return scan;
}

@Override
protected List<CombinedScanTask> tasks() {
if (tasks == null) {
Expand Down