diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/MetadataColumn.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/MetadataColumn.java index 6606748e6d6f9..c5b41d7dfbf4f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/MetadataColumn.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/MetadataColumn.java @@ -36,6 +36,45 @@ */ @Evolving public interface MetadataColumn { + /** + * Indicates whether a row-level operation should preserve the value of the metadata column + * for deleted rows. If set to true, the metadata value will be retained and passed back to + * the writer. If false, the metadata value will be replaced with {@code null}. + *
+ * This flag applies only to row-level operations working with deltas of rows. Group-based + * operations handle deletes by discarding matching records. + * + * @since 4.0.0 + */ + String PRESERVE_ON_DELETE = "__preserve_on_delete"; + boolean PRESERVE_ON_DELETE_DEFAULT = true; + + /** + * Indicates whether a row-level operation should preserve the value of the metadata column + * for updated rows. If set to true, the metadata value will be retained and passed back to + * the writer. If false, the metadata value will be replaced with {@code null}. + *
+ * This flag applies to both group-based and delta-based row-level operations. + * + * @since 4.0.0 + */ + String PRESERVE_ON_UPDATE = "__preserve_on_update"; + boolean PRESERVE_ON_UPDATE_DEFAULT = true; + + /** + * Indicates whether a row-level operation should preserve the value of the metadata column + * for reinserted rows generated by splitting updates into deletes and inserts. If true, + * the metadata value will be retained and passed back to the writer. If false, the metadata + * value will be replaced with {@code null}. + *
+ * This flag applies only to row-level operations working with deltas of rows. Group-based
+ * operations do not represent updates as deletes and inserts.
+ *
+ * @since 4.0.0
+ */
+ String PRESERVE_ON_REINSERT = "__preserve_on_reinsert";
+ boolean PRESERVE_ON_REINSERT_DEFAULT = false;
+
/**
* The name of this metadata column.
*
@@ -74,4 +113,13 @@ default String comment() {
default Transform transform() {
return null;
}
+
+ /**
+ * Returns the column metadata in JSON format.
+ *
+ * @since 4.0.0
+ */
+ default String metadataInJSON() {
+ return null;
+ }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java
index d6e94fe2ca8b0..a4ec1abc9dd7d 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java
@@ -64,6 +64,23 @@
*/
@Evolving
public interface DataWriter
+ * This method is used by group-based row-level operations to pass back metadata for records
+ * that are updated or copied. New records added during a MERGE operation are written using
+ * {@link #write(Object)} as there is no metadata associated with those records.
+ *
+ * If this method fails (by throwing an exception), {@link #abort()} will be called and this
+ * data writer is considered to have been failed.
+ *
+ * @throws IOException if failure happens during disk/network IO like writing files.
+ *
+ * @since 4.0.0
+ */
+ default void write(T metadata, T record) throws IOException {
+ write(record);
+ }
/**
* Writes one record.
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DeltaWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DeltaWriter.java
index 0cc6cb48801bf..a7ab0c162ddec 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DeltaWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DeltaWriter.java
@@ -48,6 +48,21 @@ public interface DeltaWriter
+ * This method handles the insert portion of updated rows split into deletes and inserts.
+ *
+ * @param metadata values for metadata columns
+ * @param row a row to reinsert
+ * @throws IOException if failure happens during disk/network IO like writing files
+ *
+ * @since 4.0.0
+ */
+ default void reinsert(T metadata, T row) throws IOException {
+ insert(row);
+ }
+
/**
* Inserts a new row.
*
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala
index 32163406ca6d2..9c63e091eaf51 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala
@@ -86,7 +86,9 @@ object RewriteDeleteFromTable extends RewriteRowLevelCommand {
// build a plan to replace read groups in the table
val writeRelation = relation.copy(table = operationTable)
- ReplaceData(writeRelation, cond, remainingRowsPlan, relation, Some(cond))
+ val query = addOperationColumn(WRITE_WITH_METADATA_OPERATION, remainingRowsPlan)
+ val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs)
+ ReplaceData(writeRelation, cond, query, relation, projections, Some(cond))
}
// build a rewrite plan for sources that support row deltas
@@ -106,7 +108,7 @@ object RewriteDeleteFromTable extends RewriteRowLevelCommand {
// construct a plan that only contains records to delete
val deletedRowsPlan = Filter(cond, readRelation)
val operationType = Alias(Literal(DELETE_OPERATION), OPERATION_COLUMN)()
- val requiredWriteAttrs = dedupAttrs(rowIdAttrs ++ metadataAttrs)
+ val requiredWriteAttrs = nullifyMetadataOnDelete(dedupAttrs(rowIdAttrs ++ metadataAttrs))
val project = Project(operationType +: requiredWriteAttrs, deletedRowsPlan)
// build a plan to write deletes to the table
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
index dacee70cf1286..7e2cf4f29807c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, DeleteAction, Filter, HintInfo, InsertAction, Join, JoinHint, LogicalPlan, MergeAction, MergeIntoTable, MergeRows, NO_BROADCAST_AND_REPLICATION, Project, ReplaceData, UpdateAction, WriteDelta}
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Discard, Instruction, Keep, ROW_ID, Split}
-import org.apache.spark.sql.catalyst.util.RowDeltaUtils.OPERATION_COLUMN
+import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{OPERATION_COLUMN, WRITE_OPERATION, WRITE_WITH_METADATA_OPERATION}
import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations
import org.apache.spark.sql.connector.write.{RowLevelOperationTable, SupportsDelta}
import org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE
@@ -180,7 +180,8 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
// build a plan to replace read groups in the table
val writeRelation = relation.copy(table = operationTable)
- ReplaceData(writeRelation, pushableCond, mergeRowsPlan, relation, groupFilterCond)
+ val projections = buildReplaceDataProjections(mergeRowsPlan, relation.output, metadataAttrs)
+ ReplaceData(writeRelation, pushableCond, mergeRowsPlan, relation, projections, groupFilterCond)
}
private def buildReplaceDataMergeRowsPlan(
@@ -197,7 +198,8 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
// that's why an extra unconditional instruction that would produce the original row is added
// as the last MATCHED and NOT MATCHED BY SOURCE instruction
// this logic is specific to data sources that replace groups of data
- val keepCarryoverRowsInstruction = Keep(TrueLiteral, targetTable.output)
+ val carryoverRowsOutput = Literal(WRITE_WITH_METADATA_OPERATION) +: targetTable.output
+ val keepCarryoverRowsInstruction = Keep(TrueLiteral, carryoverRowsOutput)
val matchedInstructions = matchedActions.map { action =>
toInstruction(action, metadataAttrs)
@@ -218,7 +220,8 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
notMatchedInstructions.flatMap(_.outputs) ++
notMatchedBySourceInstructions.flatMap(_.outputs)
- val attrs = targetTable.output
+ val operationTypeAttr = AttributeReference(OPERATION_COLUMN, IntegerType, nullable = false)()
+ val attrs = operationTypeAttr +: targetTable.output
MergeRows(
isSourceRowPresent = IsNotNull(rowFromSourceAttr),
@@ -430,15 +433,18 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
private def toInstruction(action: MergeAction, metadataAttrs: Seq[Attribute]): Instruction = {
action match {
case UpdateAction(cond, assignments) =>
- val output = assignments.map(_.value) ++ metadataAttrs
+ val rowValues = assignments.map(_.value)
+ val metadataValues = nullifyMetadataOnUpdate(metadataAttrs)
+ val output = Seq(Literal(WRITE_WITH_METADATA_OPERATION)) ++ rowValues ++ metadataValues
Keep(cond.getOrElse(TrueLiteral), output)
case DeleteAction(cond) =>
Discard(cond.getOrElse(TrueLiteral))
case InsertAction(cond, assignments) =>
+ val rowValues = assignments.map(_.value)
val metadataValues = metadataAttrs.map(attr => Literal(null, attr.dataType))
- val output = assignments.map(_.value) ++ metadataValues
+ val output = Seq(Literal(WRITE_OPERATION)) ++ rowValues ++ metadataValues
Keep(cond.getOrElse(TrueLiteral), output)
case other =>
@@ -460,7 +466,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
action match {
case UpdateAction(cond, assignments) if splitUpdates =>
val output = deltaDeleteOutput(rowAttrs, rowIdAttrs, metadataAttrs, originalRowIdValues)
- val otherOutput = deltaInsertOutput(assignments, metadataAttrs, originalRowIdValues)
+ val otherOutput = deltaReinsertOutput(assignments, metadataAttrs, originalRowIdValues)
Split(cond.getOrElse(TrueLiteral), output, otherOutput)
case UpdateAction(cond, assignments) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala
index cec44470e3a35..118ed4e99190c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala
@@ -19,24 +19,32 @@ package org.apache.spark.sql.catalyst.analysis
import scala.collection.mutable
+import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.ProjectingInternalRow
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, Expression, ExprId, Literal, V2ExpressionUtils}
-import org.apache.spark.sql.catalyst.plans.logical.{Assignment, LogicalPlan}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, Expression, ExprId, Literal, MetadataAttribute, NamedExpression, V2ExpressionUtils}
+import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Expand, LogicalPlan, MergeRows, Project}
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.catalyst.util.{ReplaceDataProjections, WriteDeltaProjections}
import org.apache.spark.sql.catalyst.util.RowDeltaUtils._
-import org.apache.spark.sql.catalyst.util.WriteDeltaProjections
import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.connector.write.{RowLevelOperation, RowLevelOperationInfoImpl, RowLevelOperationTable, SupportsDelta}
import org.apache.spark.sql.connector.write.RowLevelOperation.Command
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.ArrayImplicits._
trait RewriteRowLevelCommand extends Rule[LogicalPlan] {
+ private final val DELTA_OPERATIONS_WITH_ROW =
+ Set(UPDATE_OPERATION, REINSERT_OPERATION, INSERT_OPERATION)
+ private final val DELTA_OPERATIONS_WITH_METADATA =
+ Set(DELETE_OPERATION, UPDATE_OPERATION, REINSERT_OPERATION)
+ private final val DELTA_OPERATIONS_WITH_ROW_ID =
+ Set(DELETE_OPERATION, UPDATE_OPERATION)
+
protected def buildOperationTable(
table: SupportsRowLevelOperations,
command: Command,
@@ -103,7 +111,31 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] {
metadataAttrs: Seq[Attribute],
originalRowIdValues: Seq[Expression] = Seq.empty): Seq[Expression] = {
val rowValues = buildDeltaDeleteRowValues(rowAttrs, rowIdAttrs)
- Seq(Literal(DELETE_OPERATION)) ++ rowValues ++ metadataAttrs ++ originalRowIdValues
+ val metadataValues = nullifyMetadataOnDelete(metadataAttrs)
+ Seq(Literal(DELETE_OPERATION)) ++ rowValues ++ metadataValues ++ originalRowIdValues
+ }
+
+ protected def nullifyMetadataOnDelete(attrs: Seq[Attribute]): Seq[NamedExpression] = {
+ nullifyMetadata(attrs, MetadataAttribute.isPreservedOnDelete)
+ }
+
+ protected def nullifyMetadataOnUpdate(attrs: Seq[Attribute]): Seq[NamedExpression] = {
+ nullifyMetadata(attrs, MetadataAttribute.isPreservedOnUpdate)
+ }
+
+ private def nullifyMetadataOnReinsert(attrs: Seq[Attribute]): Seq[NamedExpression] = {
+ nullifyMetadata(attrs, MetadataAttribute.isPreservedOnReinsert)
+ }
+
+ private def nullifyMetadata(
+ attrs: Seq[Attribute],
+ shouldPreserve: Attribute => Boolean): Seq[NamedExpression] = {
+ attrs.map {
+ case MetadataAttribute(attr) if !shouldPreserve(attr) =>
+ Alias(Literal(null, attr.dataType), attr.name)(explicitMetadata = Some(attr.metadata))
+ case attr =>
+ attr
+ }
}
private def buildDeltaDeleteRowValues(
@@ -132,7 +164,42 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] {
metadataAttrs: Seq[Attribute],
originalRowIdValues: Seq[Expression]): Seq[Expression] = {
val rowValues = assignments.map(_.value)
- Seq(Literal(UPDATE_OPERATION)) ++ rowValues ++ metadataAttrs ++ originalRowIdValues
+ val metadataValues = nullifyMetadataOnUpdate(metadataAttrs)
+ Seq(Literal(UPDATE_OPERATION)) ++ rowValues ++ metadataValues ++ originalRowIdValues
+ }
+
+ protected def deltaReinsertOutput(
+ assignments: Seq[Assignment],
+ metadataAttrs: Seq[Attribute],
+ originalRowIdValues: Seq[Expression] = Seq.empty): Seq[Expression] = {
+ val rowValues = assignments.map(_.value)
+ val metadataValues = nullifyMetadataOnReinsert(metadataAttrs)
+ val extraNullValues = originalRowIdValues.map(e => Literal(null, e.dataType))
+ Seq(Literal(REINSERT_OPERATION)) ++ rowValues ++ metadataValues ++ extraNullValues
+ }
+
+ protected def addOperationColumn(operation: Int, plan: LogicalPlan): LogicalPlan = {
+ val operationType = Alias(Literal(operation, IntegerType), OPERATION_COLUMN)()
+ Project(operationType +: plan.output, plan)
+ }
+
+ protected def buildReplaceDataProjections(
+ plan: LogicalPlan,
+ rowAttrs: Seq[Attribute],
+ metadataAttrs: Seq[Attribute]): ReplaceDataProjections = {
+ val outputs = extractOutputs(plan)
+
+ val outputsWithRow = filterOutputs(outputs, Set(WRITE_WITH_METADATA_OPERATION, WRITE_OPERATION))
+ val rowProjection = newLazyProjection(plan, outputsWithRow, rowAttrs)
+
+ val metadataProjection = if (metadataAttrs.nonEmpty) {
+ val outputsWithMetadata = filterOutputs(outputs, Set(WRITE_WITH_METADATA_OPERATION))
+ Some(newLazyProjection(plan, outputsWithMetadata, metadataAttrs))
+ } else {
+ None
+ }
+
+ ReplaceDataProjections(rowProjection, metadataProjection)
}
protected def buildWriteDeltaProjections(
@@ -140,17 +207,21 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] {
rowAttrs: Seq[Attribute],
rowIdAttrs: Seq[Attribute],
metadataAttrs: Seq[Attribute]): WriteDeltaProjections = {
+ val outputs = extractOutputs(plan)
val rowProjection = if (rowAttrs.nonEmpty) {
- Some(newLazyProjection(plan, rowAttrs))
+ val outputsWithRow = filterOutputs(outputs, DELTA_OPERATIONS_WITH_ROW)
+ Some(newLazyProjection(plan, outputsWithRow, rowAttrs))
} else {
None
}
- val rowIdProjection = newLazyRowIdProjection(plan, rowIdAttrs)
+ val outputsWithRowId = filterOutputs(outputs, DELTA_OPERATIONS_WITH_ROW_ID)
+ val rowIdProjection = newLazyRowIdProjection(plan, outputsWithRowId, rowIdAttrs)
val metadataProjection = if (metadataAttrs.nonEmpty) {
- Some(newLazyProjection(plan, metadataAttrs))
+ val outputsWithMetadata = filterOutputs(outputs, DELTA_OPERATIONS_WITH_METADATA)
+ Some(newLazyProjection(plan, outputsWithMetadata, metadataAttrs))
} else {
None
}
@@ -158,26 +229,54 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] {
WriteDeltaProjections(rowProjection, rowIdProjection, metadataProjection)
}
+ private def extractOutputs(plan: LogicalPlan): Seq[Seq[Expression]] = {
+ plan match {
+ case p: Project => Seq(p.projectList)
+ case e: Expand => e.projections
+ case m: MergeRows => m.outputs
+ case _ => throw SparkException.internalError("Can't extract outputs from plan: " + plan)
+ }
+ }
+
+ private def filterOutputs(
+ outputs: Seq[Seq[Expression]],
+ operations: Set[Int]): Seq[Seq[Expression]] = {
+ outputs.filter {
+ case Literal(operation: Integer, _) +: _ => operations.contains(operation)
+ case Alias(Literal(operation: Integer, _), _) +: _ => operations.contains(operation)
+ case other => throw SparkException.internalError("Can't determine operation: " + other)
+ }
+ }
+
private def newLazyProjection(
plan: LogicalPlan,
+ outputs: Seq[Seq[Expression]],
attrs: Seq[Attribute]): ProjectingInternalRow = {
-
val colOrdinals = attrs.map(attr => findColOrdinal(plan, attr.name))
- val schema = DataTypeUtils.fromAttributes(attrs)
- ProjectingInternalRow(schema, colOrdinals)
+ createProjectingInternalRow(outputs, colOrdinals, attrs)
}
// if there are assignment to row ID attributes, original values are projected as special columns
// this method honors such special columns if present
private def newLazyRowIdProjection(
plan: LogicalPlan,
+ outputs: Seq[Seq[Expression]],
rowIdAttrs: Seq[Attribute]): ProjectingInternalRow = {
-
val colOrdinals = rowIdAttrs.map { attr =>
val originalValueIndex = findColOrdinal(plan, ORIGINAL_ROW_ID_VALUE_PREFIX + attr.name)
if (originalValueIndex != -1) originalValueIndex else findColOrdinal(plan, attr.name)
}
- val schema = DataTypeUtils.fromAttributes(rowIdAttrs)
+ createProjectingInternalRow(outputs, colOrdinals, rowIdAttrs)
+ }
+
+ private def createProjectingInternalRow(
+ outputs: Seq[Seq[Expression]],
+ colOrdinals: Seq[Int],
+ attrs: Seq[Attribute]): ProjectingInternalRow = {
+ val schema = StructType(attrs.zipWithIndex.map { case (attr, index) =>
+ val nullable = outputs.exists(output => output(colOrdinals(index)).nullable)
+ StructField(attr.name, attr.dataType, nullable, attr.metadata)
+ })
ProjectingInternalRow(schema, colOrdinals)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala
index 37644a33c7a54..b2955ca006878 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala
@@ -77,7 +77,9 @@ object RewriteUpdateTable extends RewriteRowLevelCommand {
// build a plan to replace read groups in the table
val writeRelation = relation.copy(table = operationTable)
- ReplaceData(writeRelation, cond, updatedAndRemainingRowsPlan, relation, Some(cond))
+ val query = addOperationColumn(WRITE_WITH_METADATA_OPERATION, updatedAndRemainingRowsPlan)
+ val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs)
+ ReplaceData(writeRelation, cond, query, relation, projections, Some(cond))
}
// build a rewrite plan for sources that support replacing groups of data (e.g. files, partitions)
@@ -109,7 +111,9 @@ object RewriteUpdateTable extends RewriteRowLevelCommand {
// build a plan to replace read groups in the table
val writeRelation = relation.copy(table = operationTable)
- ReplaceData(writeRelation, cond, updatedAndRemainingRowsPlan, relation, Some(cond))
+ val query = addOperationColumn(WRITE_WITH_METADATA_OPERATION, updatedAndRemainingRowsPlan)
+ val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs)
+ ReplaceData(writeRelation, cond, query, relation, projections, Some(cond))
}
// this method assumes the assignments have been already aligned before
@@ -118,7 +122,7 @@ object RewriteUpdateTable extends RewriteRowLevelCommand {
assignments: Seq[Assignment],
cond: Expression = TrueLiteral): LogicalPlan = {
- // the plan output may include immutable metadata columns at the end
+ // the plan output may include metadata columns at the end
// that's why the number of assignments may not match the number of plan output columns
val assignedValues = assignments.map(_.value)
val updatedValues = plan.output.zipWithIndex.map { case (attr, index) =>
@@ -128,7 +132,12 @@ object RewriteUpdateTable extends RewriteRowLevelCommand {
Alias(updatedValue, attr.name)()
} else {
assert(MetadataAttribute.isValid(attr.metadata))
- attr
+ if (MetadataAttribute.isPreservedOnUpdate(attr)) {
+ attr
+ } else {
+ val updatedValue = If(cond, Literal(null, attr.dataType), attr)
+ Alias(updatedValue, attr.name)(explicitMetadata = Some(attr.metadata))
+ }
}
}
@@ -181,7 +190,11 @@ object RewriteUpdateTable extends RewriteRowLevelCommand {
Alias(assignedExpr, attr.name)()
} else {
assert(MetadataAttribute.isValid(attr.metadata))
- attr
+ if (MetadataAttribute.isPreservedOnUpdate(attr)) {
+ attr
+ } else {
+ Alias(Literal(null, attr.dataType), attr.name)(explicitMetadata = Some(attr.metadata))
+ }
}
}
@@ -203,7 +216,7 @@ object RewriteUpdateTable extends RewriteRowLevelCommand {
MetadataAttribute.isValid(attr.metadata)
}
val deleteOutput = deltaDeleteOutput(rowAttrs, rowIdAttrs, metadataAttrs)
- val insertOutput = deltaInsertOutput(assignments, metadataAttrs)
+ val insertOutput = deltaReinsertOutput(assignments, metadataAttrs)
val outputs = Seq(deleteOutput, insertOutput)
val operationTypeAttr = AttributeReference(OPERATION_COLUMN, IntegerType, nullable = false)()
val attrs = operationTypeAttr +: matchedRowsPlan.output
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index f5f35050401ba..2af6a1ba84ec8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, METADATA_COL_ATTR_KEY}
+import org.apache.spark.sql.connector.catalog.MetadataColumn
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.BitSet
import org.apache.spark.util.collection.ImmutableBitSet
@@ -503,8 +504,41 @@ object MetadataAttribute {
.putString(METADATA_COL_ATTR_KEY, name)
.build()
+ def metadata(col: MetadataColumn): Metadata = {
+ val builder = new MetadataBuilder()
+ if (col.metadataInJSON != null) {
+ builder.withMetadata(Metadata.fromJson(col.metadataInJSON))
+ }
+ builder.putString(METADATA_COL_ATTR_KEY, col.name)
+ builder.build()
+ }
+
def isValid(metadata: Metadata): Boolean =
metadata.contains(METADATA_COL_ATTR_KEY)
+
+ def isPreservedOnDelete(attr: Attribute): Boolean = {
+ if (attr.metadata.contains(MetadataColumn.PRESERVE_ON_DELETE)) {
+ attr.metadata.getBoolean(MetadataColumn.PRESERVE_ON_DELETE)
+ } else {
+ MetadataColumn.PRESERVE_ON_DELETE_DEFAULT
+ }
+ }
+
+ def isPreservedOnUpdate(attr: Attribute): Boolean = {
+ if (attr.metadata.contains(MetadataColumn.PRESERVE_ON_UPDATE)) {
+ attr.metadata.getBoolean(MetadataColumn.PRESERVE_ON_UPDATE)
+ } else {
+ MetadataColumn.PRESERVE_ON_UPDATE_DEFAULT
+ }
+ }
+
+ def isPreservedOnReinsert(attr: Attribute): Boolean = {
+ if (attr.metadata.contains(MetadataColumn.PRESERVE_ON_REINSERT)) {
+ attr.metadata.getBoolean(MetadataColumn.PRESERVE_ON_REINSERT)
+ } else {
+ MetadataColumn.PRESERVE_ON_REINSERT_DEFAULT
+ }
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
index 2cda1142299ae..0358c45815944 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
@@ -57,7 +57,7 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
_.containsAnyPattern(NULL_LITERAL, TRUE_OR_FALSE_LITERAL, INSET), ruleId) {
case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond))
case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(replaceNullWithFalse(cond)))
- case rd @ ReplaceData(_, cond, _, _, groupFilterCond, _) =>
+ case rd @ ReplaceData(_, cond, _, _, _, groupFilterCond, _) =>
val newCond = replaceNullWithFalse(cond)
val newGroupFilterCond = groupFilterCond.map(replaceNullWithFalse)
rd.copy(condition = newCond, groupFilterCondition = newGroupFilterCond)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index d0a0fc307756c..54a4e75c90c95 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -437,7 +437,7 @@ object GroupBasedRowLevelOperation {
def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
case rd @ ReplaceData(DataSourceV2Relation(table, _, _, _, _),
- cond, query, _, groupFilterCond, _) =>
+ cond, query, _, _, groupFilterCond, _) =>
// group-based UPDATEs that are rewritten as UNION read the table twice
val allowMultipleReads = rd.operation.command == UPDATE
val readRelation = findReadRelation(table, query, allowMultipleReads)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala
index 9b1c8bc733a35..f7f515c29481d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala
@@ -49,6 +49,12 @@ case class MergeRows(
AttributeSet.fromAttributeSets(usedExprs.map(_.references)) -- producedAttributes
}
+ def instructions: Seq[Instruction] = {
+ matchedInstructions ++ notMatchedInstructions ++ notMatchedBySourceInstructions
+ }
+
+ def outputs: Seq[Seq[Expression]] = instructions.flatMap(_.outputs)
+
override def simpleString(maxFields: Int): String = {
s"MergeRows${truncatedString(output, "[", ", ", "]", maxFields)}"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index 58c62a90225aa..b361ccbe2439a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -23,10 +23,11 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils,
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.catalog.{FunctionResource, RoutineLanguage}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, MetadataAttribute, NamedExpression, UnaryExpression, Unevaluable, V2ExpressionUtils}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, MetadataAttribute, UnaryExpression, Unevaluable, V2ExpressionUtils}
import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema
import org.apache.spark.sql.catalyst.trees.BinaryLike
-import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, truncatedString, CharVarcharUtils, RowDeltaUtils, WriteDeltaProjections}
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, truncatedString, CharVarcharUtils, ReplaceDataProjections, RowDeltaUtils, WriteDeltaProjections}
import org.apache.spark.sql.catalyst.util.TypeUtils.{ordinalNumber, toSQLExpr}
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{IdentifierHelper, MultipartIdentifierHelper}
@@ -34,9 +35,10 @@ import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.write.{DeltaWrite, RowLevelOperation, RowLevelOperationTable, SupportsDelta, Write}
+import org.apache.spark.sql.connector.write.RowLevelOperation.Command.{DELETE, MERGE, UPDATE}
import org.apache.spark.sql.errors.DataTypeErrors.toSQLType
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
-import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, MapType, MetadataBuilder, StringType, StructField, StructType}
+import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, MapType, MetadataBuilder, StringType, StructType}
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.Utils
@@ -66,16 +68,19 @@ trait V2WriteCommand extends UnaryCommand with KeepAnalyzedQuery with CTEInChild
assert(table.resolved && query.resolved,
"`outputResolved` can only be called when `table` and `query` are both resolved.")
// If the table doesn't require schema match, we don't need to resolve the output columns.
- table.skipSchemaResolution || (query.output.size == table.output.size &&
- query.output.zip(table.output).forall {
- case (inAttr, outAttr) =>
- val inType = CharVarcharUtils.getRawType(inAttr.metadata).getOrElse(inAttr.dataType)
- val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType)
- // names and types must match, nullability must be compatible
- inAttr.name == outAttr.name &&
- DataType.equalsIgnoreCompatibleNullability(inType, outType) &&
- (outAttr.nullable || !inAttr.nullable)
- })
+ table.skipSchemaResolution || areCompatible(query.output, table.output)
+ }
+
+ protected def areCompatible(inAttrs: Seq[Attribute], outAttrs: Seq[Attribute]): Boolean = {
+ inAttrs.size == outAttrs.size && inAttrs.zip(outAttrs).forall {
+ case (inAttr, outAttr) =>
+ val inType = CharVarcharUtils.getRawType(inAttr.metadata).getOrElse(inAttr.dataType)
+ val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType)
+ // names and types must match, nullability must be compatible
+ inAttr.name == outAttr.name &&
+ DataType.equalsIgnoreCompatibleNullability(inType, outType) &&
+ (outAttr.nullable || !inAttr.nullable)
+ }
}
def withNewQuery(newQuery: LogicalPlan): V2WriteCommand
@@ -209,6 +214,17 @@ trait RowLevelWrite extends V2WriteCommand with SupportsSubquery {
def operation: RowLevelOperation
def condition: Expression
def originalTable: NamedRelation
+
+ protected def operationResolved: Boolean = {
+ val attr = query.output.head
+ attr.name == RowDeltaUtils.OPERATION_COLUMN && attr.dataType == IntegerType && !attr.nullable
+ }
+
+ protected def projectedMetadataAttrs: Seq[Attribute] = {
+ V2ExpressionUtils.resolveRefs[AttributeReference](
+ operation.requiredMetadataAttributes.toImmutableArraySeq,
+ originalTable)
+ }
}
/**
@@ -229,6 +245,7 @@ case class ReplaceData(
condition: Expression,
query: LogicalPlan,
originalTable: NamedRelation,
+ projections: ReplaceDataProjections,
groupFilterCondition: Option[Expression] = None,
write: Option[Write] = None) extends RowLevelWrite {
@@ -248,31 +265,33 @@ case class ReplaceData(
}
}
- // the incoming query may include metadata columns
- lazy val dataInput: Seq[Attribute] = {
- query.output.filter {
- case MetadataAttribute(_) => false
- case _ => true
- }
- }
-
override def outputResolved: Boolean = {
assert(table.resolved && query.resolved,
"`outputResolved` can only be called when `table` and `query` are both resolved.")
+ operationResolved && rowAttrsResolved && metadataAttrsResolved
+ }
- // take into account only incoming data columns and ignore metadata columns in the query
- // they will be discarded after the logical write is built in the optimizer
- // metadata columns may be needed to request a correct distribution or ordering
- // but are not passed back to the data source during writes
+ // validates row projection output is compatible with table attributes
+ private def rowAttrsResolved: Boolean = {
+ val inRowAttrs = DataTypeUtils.toAttributes(projections.rowProjection.schema)
+ table.skipSchemaResolution || areCompatible(inRowAttrs, table.output)
+ }
- table.skipSchemaResolution || (dataInput.size == table.output.size &&
- dataInput.zip(table.output).forall { case (inAttr, outAttr) =>
- val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType)
- // names and types must match, nullability must be compatible
- inAttr.name == outAttr.name &&
- DataType.equalsIgnoreCompatibleNullability(inAttr.dataType, outType) &&
- (outAttr.nullable || !inAttr.nullable)
- })
+ // validates metadata projection output is compatible with metadata attributes
+ private def metadataAttrsResolved: Boolean = {
+ val outMetadataAttrs = projectedMetadataAttrs.map {
+ case attr if isMetadataNullabilityPreserved(attr) => attr
+ case attr => attr.withNullability(true)
+ }
+ val inMetadataAttrs = projections.metadataProjection match {
+ case Some(projection) => DataTypeUtils.toAttributes(projection.schema)
+ case None => Nil
+ }
+ areCompatible(inMetadataAttrs, outMetadataAttrs)
+ }
+
+ private def isMetadataNullabilityPreserved(attr: Attribute): Boolean = {
+ operation.command == DELETE || MetadataAttribute.isPreservedOnUpdate(attr)
}
override def withNewQuery(newQuery: LogicalPlan): ReplaceData = copy(query = newQuery)
@@ -331,65 +350,52 @@ case class WriteDelta(
override def outputResolved: Boolean = {
assert(table.resolved && query.resolved,
"`outputResolved` can only be called when `table` and `query` are both resolved.")
-
operationResolved && rowAttrsResolved && rowIdAttrsResolved && metadataAttrsResolved
}
- private def operationResolved: Boolean = {
- val attr = query.output.head
- attr.name == RowDeltaUtils.OPERATION_COLUMN && attr.dataType == IntegerType && !attr.nullable
- }
-
// validates row projection output is compatible with table attributes
private def rowAttrsResolved: Boolean = {
- table.skipSchemaResolution || (projections.rowProjection match {
- case Some(projection) =>
- table.output.size == projection.schema.size &&
- projection.schema.zip(table.output).forall { case (field, outAttr) =>
- isCompatible(field, outAttr)
- }
- case None =>
- true
- })
+ val outRowAttrs = if (operation.command == DELETE) Nil else table.output
+ val inRowAttrs = projections.rowProjection match {
+ case Some(projection) => DataTypeUtils.toAttributes(projection.schema)
+ case None => Nil
+ }
+ table.skipSchemaResolution || areCompatible(inRowAttrs, outRowAttrs)
}
// validates row ID projection output is compatible with row ID attributes
private def rowIdAttrsResolved: Boolean = {
- val rowIdAttrs = V2ExpressionUtils.resolveRefs[AttributeReference](
+ val outRowIdAttrs = V2ExpressionUtils.resolveRefs[AttributeReference](
operation.rowId.toImmutableArraySeq,
originalTable)
-
- val projectionSchema = projections.rowIdProjection.schema
- rowIdAttrs.size == projectionSchema.size && projectionSchema.forall { field =>
- rowIdAttrs.exists(rowIdAttr => isCompatible(field, rowIdAttr))
- }
+ val inRowIdAttrs = DataTypeUtils.toAttributes(projections.rowIdProjection.schema)
+ areCompatible(inRowIdAttrs, outRowIdAttrs)
}
// validates metadata projection output is compatible with metadata attributes
private def metadataAttrsResolved: Boolean = {
- projections.metadataProjection match {
- case Some(projection) =>
- val metadataAttrs = V2ExpressionUtils.resolveRefs[AttributeReference](
- operation.requiredMetadataAttributes.toImmutableArraySeq,
- originalTable)
-
- val projectionSchema = projection.schema
- metadataAttrs.size == projectionSchema.size && projectionSchema.forall { field =>
- metadataAttrs.exists(metadataAttr => isCompatible(field, metadataAttr))
- }
- case None =>
- true
+ val outMetadataAttrs = projectedMetadataAttrs.map {
+ case attr if isMetadataNullabilityPreserved(attr) => attr
+ case attr => attr.withNullability(true)
}
+ val inMetadataAttrs = projections.metadataProjection match {
+ case Some(projection) => DataTypeUtils.toAttributes(projection.schema)
+ case None => Nil
+ }
+ areCompatible(inMetadataAttrs, outMetadataAttrs)
}
- // checks if a projection field is compatible with a table attribute
- private def isCompatible(inField: StructField, outAttr: NamedExpression): Boolean = {
- val inType = CharVarcharUtils.getRawType(inField.metadata).getOrElse(inField.dataType)
- val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType)
- // names and types must match, nullability must be compatible
- inField.name == outAttr.name &&
- DataType.equalsIgnoreCompatibleNullability(inType, outType) &&
- (outAttr.nullable || !inField.nullable)
+ private def isMetadataNullabilityPreserved(attr: Attribute): Boolean = {
+ operation.command match {
+ case DELETE =>
+ MetadataAttribute.isPreservedOnDelete(attr)
+ case UPDATE | MERGE if operation.representUpdateAsDeleteAndInsert =>
+ MetadataAttribute.isPreservedOnDelete(attr) && MetadataAttribute.isPreservedOnReinsert(attr)
+ case UPDATE =>
+ MetadataAttribute.isPreservedOnUpdate(attr)
+ case MERGE =>
+ MetadataAttribute.isPreservedOnDelete(attr) && MetadataAttribute.isPreservedOnUpdate(attr)
+ }
}
override def withNewQuery(newQuery: LogicalPlan): V2WriteCommand = copy(query = newQuery)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ReplaceDataProjections.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ReplaceDataProjections.scala
new file mode 100644
index 0000000000000..99744faf2c749
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ReplaceDataProjections.scala
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import org.apache.spark.sql.catalyst.ProjectingInternalRow
+
+case class ReplaceDataProjections(
+ rowProjection: ProjectingInternalRow,
+ metadataProjection: Option[ProjectingInternalRow])
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RowDeltaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RowDeltaUtils.scala
index db39c059c24e3..72baad069b180 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RowDeltaUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RowDeltaUtils.scala
@@ -25,5 +25,8 @@ object RowDeltaUtils {
final val DELETE_OPERATION: Int = 1
final val UPDATE_OPERATION: Int = 2
final val INSERT_OPERATION: Int = 3
+ final val REINSERT_OPERATION: Int = 4
+ final val WRITE_OPERATION: Int = 5
+ final val WRITE_WITH_METADATA_OPERATION: Int = 6
final val ORIGINAL_ROW_ID_VALUE_PREFIX: String = "__original_row_id_"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
index 134c4f1bb13a4..841f367896c40 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
@@ -25,6 +25,7 @@ import com.google.common.io.ByteStreams
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.connector.catalog.MetadataColumn
import org.apache.spark.sql.types.{MetadataBuilder, NumericType, StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{SparkErrorUtils, Utils}
@@ -193,7 +194,10 @@ package object util extends Logging {
QUALIFIED_ACCESS_ONLY,
FileSourceMetadataAttribute.FILE_SOURCE_METADATA_COL_ATTR_KEY,
FileSourceConstantMetadataStructField.FILE_SOURCE_CONSTANT_METADATA_COL_ATTR_KEY,
- FileSourceGeneratedMetadataStructField.FILE_SOURCE_GENERATED_METADATA_COL_ATTR_KEY
+ FileSourceGeneratedMetadataStructField.FILE_SOURCE_GENERATED_METADATA_COL_ATTR_KEY,
+ MetadataColumn.PRESERVE_ON_DELETE,
+ MetadataColumn.PRESERVE_ON_UPDATE,
+ MetadataColumn.PRESERVE_ON_REINSERT
)
def removeInternalMetadata(schema: StructType): StructType = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/LogicalWriteInfoImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/LogicalWriteInfoImpl.scala
index 8c0828d8a278b..1e4e1a5955f3c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/LogicalWriteInfoImpl.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/LogicalWriteInfoImpl.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.connector.write
import java.util.Optional
+import scala.jdk.OptionConverters._
+
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -29,3 +31,19 @@ private[sql] case class LogicalWriteInfoImpl(
override val rowIdSchema: Optional[StructType] = Optional.empty[StructType],
override val metadataSchema: Optional[StructType] = Optional.empty[StructType])
extends LogicalWriteInfo
+
+object LogicalWriteInfoImpl {
+ def apply(
+ queryId: String,
+ schema: StructType,
+ options: CaseInsensitiveStringMap,
+ rowIdSchema: Option[StructType],
+ metadataSchema: Option[StructType]): LogicalWriteInfoImpl = {
+ LogicalWriteInfoImpl(
+ queryId,
+ schema,
+ options,
+ rowIdSchema.toJava,
+ metadataSchema.toJava)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala
index 6b884713bd5c3..b24885270f52d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala
@@ -109,7 +109,7 @@ object DataSourceV2Implicits {
name = metaCol.name,
dataType = metaCol.dataType,
nullable = metaCol.isNullable,
- metadata = MetadataAttribute.metadata(metaCol.name))
+ metadata = MetadataAttribute.metadata(metaCol))
Option(metaCol.comment).map(field.withComment).getOrElse(field)
}
StructType(fields)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
index ab17b93ad6146..3ac8c3794b8ad 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
@@ -63,13 +63,29 @@ abstract class InMemoryBaseTable(
protected object PartitionKeyColumn extends MetadataColumn {
override def name: String = "_partition"
override def dataType: DataType = StringType
+ override def isNullable: Boolean = false
override def comment: String = "Partition key used to store the row"
+ override def metadataInJSON(): String = {
+ val metadata = new MetadataBuilder()
+ .putBoolean(MetadataColumn.PRESERVE_ON_UPDATE, value = true)
+ .putBoolean(MetadataColumn.PRESERVE_ON_REINSERT, value = true)
+ .build()
+ metadata.json
+ }
}
- private object IndexColumn extends MetadataColumn {
+ protected object IndexColumn extends MetadataColumn {
override def name: String = "index"
override def dataType: DataType = IntegerType
+ override def isNullable: Boolean = false
override def comment: String = "Metadata column used to conflict with a data column"
+ override def metadataInJSON(): String = {
+ val metadata = new MetadataBuilder()
+ .putBoolean(MetadataColumn.PRESERVE_ON_DELETE, value = false)
+ .putBoolean(MetadataColumn.PRESERVE_ON_UPDATE, value = false)
+ .build()
+ metadata.json
+ }
}
// purposely exposes a metadata column that conflicts with a data column in some tests
@@ -617,6 +633,7 @@ object InMemoryBaseTable {
class BufferedRows(val key: Seq[Any] = Seq.empty) extends WriterCommitMessage
with InputPartition with HasPartitionKey with HasPartitionStatistics with Serializable {
+ val log = new mutable.ArrayBuffer[InternalRow]()
val rows = new mutable.ArrayBuffer[InternalRow]()
val deletes = new mutable.ArrayBuffer[Int]()
@@ -734,9 +751,22 @@ private object BufferedRowsWriterFactory extends DataWriterFactory with Streamin
}
private class BufferWriter extends DataWriter[InternalRow] {
+
+ private final val WRITE = UTF8String.fromString(Write.toString)
+
protected val buffer = new BufferedRows
- override def write(row: InternalRow): Unit = buffer.rows.append(row.copy())
+ override def write(metadata: InternalRow, row: InternalRow): Unit = {
+ buffer.rows.append(row.copy())
+ val logEntry = new GenericInternalRow(Array[Any](WRITE, null, metadata.copy(), row.copy()))
+ buffer.log.append(logEntry)
+ }
+
+ override def write(row: InternalRow): Unit = {
+ buffer.rows.append(row.copy())
+ val logEntry = new GenericInternalRow(Array[Any](WRITE, null, null, row.copy()))
+ buffer.log.append(logEntry)
+ }
override def commit(): WriterCommitMessage = buffer
@@ -771,3 +801,10 @@ class InMemoryCustomDriverTaskMetric(value: Long) extends CustomTaskMetric {
override def name(): String = "number_of_rows_from_driver"
override def value(): Long = value
}
+
+sealed trait Operation
+case object Write extends Operation
+case object Delete extends Operation
+case object Update extends Operation
+case object Reinsert extends Operation
+case object Insert extends Operation
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala
index 3a684dc57c02f..98678289fa259 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.connector.catalog
import java.util
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
import org.apache.spark.sql.connector.expressions.{FieldReference, LogicalExpressions, NamedReference, SortDirection, SortOrder, Transform}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder}
@@ -27,6 +28,8 @@ import org.apache.spark.sql.connector.write.{BatchWrite, DeltaBatchWrite, DeltaW
import org.apache.spark.sql.connector.write.RowLevelOperation.Command
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.ArrayImplicits._
class InMemoryRowLevelOperationTable(
name: String,
@@ -36,11 +39,17 @@ class InMemoryRowLevelOperationTable(
extends InMemoryTable(name, schema, partitioning, properties) with SupportsRowLevelOperations {
private final val PARTITION_COLUMN_REF = FieldReference(PartitionKeyColumn.name)
+ private final val INDEX_COLUMN_REF = FieldReference(IndexColumn.name)
private final val SUPPORTS_DELTAS = "supports-deltas"
private final val SPLIT_UPDATES = "split-updates"
// used in row-level operation tests to verify replaced partitions
var replacedPartitions: Seq[Seq[Any]] = Seq.empty
+ // used in row-level operation tests to verify reported write schema
+ var lastWriteInfo: LogicalWriteInfo = _
+ // used in row-level operation tests to verify passed records
+ // (operation, id, metadata, row)
+ var lastWriteLog: Seq[InternalRow] = Seq.empty
override def newRowLevelOperationBuilder(
info: RowLevelOperationInfo): RowLevelOperationBuilder = {
@@ -55,7 +64,7 @@ class InMemoryRowLevelOperationTable(
var configuredScan: InMemoryBatchScan = _
override def requiredMetadataAttributes(): Array[NamedReference] = {
- Array(PARTITION_COLUMN_REF)
+ Array(PARTITION_COLUMN_REF, INDEX_COLUMN_REF)
}
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
@@ -68,25 +77,26 @@ class InMemoryRowLevelOperationTable(
}
}
- override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = new WriteBuilder {
-
- override def build(): Write = new Write with RequiresDistributionAndOrdering {
- override def requiredDistribution(): Distribution = {
- Distributions.clustered(Array(PARTITION_COLUMN_REF))
- }
+ override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
+ lastWriteInfo = info
+ new WriteBuilder {
+ override def build(): Write = new Write with RequiresDistributionAndOrdering {
+ override def requiredDistribution: Distribution = {
+ Distributions.clustered(Array(PARTITION_COLUMN_REF))
+ }
- override def requiredOrdering(): Array[SortOrder] = {
- Array[SortOrder](
- LogicalExpressions.sort(
- PARTITION_COLUMN_REF,
- SortDirection.ASCENDING,
- SortDirection.ASCENDING.defaultNullOrdering())
- )
- }
+ override def requiredOrdering: Array[SortOrder] = {
+ Array[SortOrder](
+ LogicalExpressions.sort(
+ PARTITION_COLUMN_REF,
+ SortDirection.ASCENDING,
+ SortDirection.ASCENDING.defaultNullOrdering()))
+ }
- override def toBatch: BatchWrite = PartitionBasedReplaceData(configuredScan)
+ override def toBatch: BatchWrite = PartitionBasedReplaceData(configuredScan)
- override def description(): String = "InMemoryWrite"
+ override def description: String = "InMemoryWrite"
+ }
}
}
@@ -102,6 +112,7 @@ class InMemoryRowLevelOperationTable(
dataMap --= readPartitions
replacedPartitions = readPartitions
withData(newData, schema)
+ lastWriteLog = newData.flatMap(buffer => buffer.log).toImmutableArraySeq
}
}
@@ -109,7 +120,7 @@ class InMemoryRowLevelOperationTable(
private final val PK_COLUMN_REF = FieldReference("pk")
override def requiredMetadataAttributes(): Array[NamedReference] = {
- Array(PARTITION_COLUMN_REF)
+ Array(PARTITION_COLUMN_REF, INDEX_COLUMN_REF)
}
override def rowId(): Array[NamedReference] = Array(PK_COLUMN_REF)
@@ -118,7 +129,8 @@ class InMemoryRowLevelOperationTable(
new InMemoryScanBuilder(schema, options)
}
- override def newWriteBuilder(info: LogicalWriteInfo): DeltaWriteBuilder =
+ override def newWriteBuilder(info: LogicalWriteInfo): DeltaWriteBuilder = {
+ lastWriteInfo = info
new DeltaWriteBuilder {
override def build(): DeltaWrite = new DeltaWrite with RequiresDistributionAndOrdering {
@@ -138,6 +150,7 @@ class InMemoryRowLevelOperationTable(
override def toBatch: DeltaBatchWrite = TestDeltaBatchWrite
}
}
+ }
override def representUpdateAsDeleteAndInsert(): Boolean = {
properties.getOrDefault(SPLIT_UPDATES, "false").toBoolean
@@ -150,8 +163,10 @@ class InMemoryRowLevelOperationTable(
}
override def commit(messages: Array[WriterCommitMessage]): Unit = {
- withDeletes(messages.map(_.asInstanceOf[BufferedRows]))
- withData(messages.map(_.asInstanceOf[BufferedRows]))
+ val newData = messages.map(_.asInstanceOf[BufferedRows])
+ withDeletes(newData)
+ withData(newData)
+ lastWriteLog = newData.flatMap(buffer => buffer.log).toIndexedSeq
}
override def abort(messages: Array[WriterCommitMessage]): Unit = {}
@@ -166,16 +181,41 @@ private object DeltaBufferedRowsWriterFactory extends DeltaWriterFactory {
private class DeltaBufferWriter extends BufferWriter with DeltaWriter[InternalRow] {
- override def delete(meta: InternalRow, id: InternalRow): Unit = buffer.deletes += id.getInt(0)
+ private final val DELETE = UTF8String.fromString(Delete.toString)
+ private final val UPDATE = UTF8String.fromString(Update.toString)
+ private final val REINSERT = UTF8String.fromString(Reinsert.toString)
+ private final val INSERT = UTF8String.fromString(Insert.toString)
+
+ override def delete(meta: InternalRow, id: InternalRow): Unit = {
+ val pk = id.getInt(0)
+ buffer.deletes += pk
+ val logEntry = new GenericInternalRow(Array[Any](DELETE, pk, meta.copy(), null))
+ buffer.log += logEntry
+ }
override def update(meta: InternalRow, id: InternalRow, row: InternalRow): Unit = {
- buffer.deletes += id.getInt(0)
- write(row)
+ val pk = id.getInt(0)
+ buffer.deletes += pk
+ buffer.rows.append(row.copy())
+ val logEntry = new GenericInternalRow(Array[Any](UPDATE, pk, meta.copy(), row.copy()))
+ buffer.log += logEntry
}
- override def insert(row: InternalRow): Unit = write(row)
+ override def reinsert(meta: InternalRow, row: InternalRow): Unit = {
+ buffer.rows.append(row.copy())
+ val logEntry = new GenericInternalRow(Array[Any](REINSERT, null, meta.copy(), row.copy()))
+ buffer.log += logEntry
+ }
+
+ override def insert(row: InternalRow): Unit = {
+ buffer.rows.append(row.copy())
+ val logEntry = new GenericInternalRow(Array[Any](INSERT, null, null, row.copy()))
+ buffer.log += logEntry
+ }
- override def write(row: InternalRow): Unit = super[BufferWriter].write(row)
+ override def write(row: InternalRow): Unit = {
+ throw new UnsupportedOperationException()
+ }
override def commit(): WriterCommitMessage = buffer
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index 499721fbae4e8..ce863791659bb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -323,9 +323,10 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
throw SparkException.internalError("Unexpected table relation: " + other)
}
- case ReplaceData(_: DataSourceV2Relation, _, query, r: DataSourceV2Relation, _, Some(write)) =>
+ case ReplaceData(_: DataSourceV2Relation, _, query, r: DataSourceV2Relation, projections, _,
+ Some(write)) =>
// use the original relation to refresh the cache
- ReplaceDataExec(planLater(query), refreshCache(r), write) :: Nil
+ ReplaceDataExec(planLater(query), refreshCache(r), projections, write) :: Nil
case WriteDelta(_: DataSourceV2Relation, _, query, r: DataSourceV2Relation, projections,
Some(write)) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromTable.scala
index 6eede88c55bd0..05fb37674a836 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromTable.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromTable.scala
@@ -74,7 +74,7 @@ object OptimizeMetadataOnlyDeleteFromTable extends Rule[LogicalPlan] with Predic
type ReturnType = (RowLevelWrite, RowLevelOperation.Command, Expression, LogicalPlan)
def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
- case rd @ ReplaceData(_, cond, _, originalTable, _, _) =>
+ case rd @ ReplaceData(_, cond, _, originalTable, _, _, _) =>
val command = rd.operation.command
Some(rd, command, cond, originalTable)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
index 17b2579ca873a..9d059416766a1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
@@ -17,15 +17,14 @@
package org.apache.spark.sql.execution.datasources.v2
-import java.util.{Optional, UUID}
+import java.util.UUID
import scala.jdk.CollectionConverters._
import org.apache.spark.sql.catalyst.expressions.PredicateHelper
-import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Project, ReplaceData, WriteDelta}
+import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceData, WriteDelta}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
-import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.WriteDeltaProjections
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table}
import org.apache.spark.sql.connector.expressions.filter.Predicate
@@ -95,7 +94,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
relationOpt, table, query, queryId, options, outputMode, Some(batchId)) =>
val writeOptions = mergeOptions(
options, relationOpt.map(r => r.options.asScala.toMap).getOrElse(Map.empty))
- val writeBuilder = newWriteBuilder(table, writeOptions, query.schema, queryId)
+ val writeBuilder = newWriteBuilder(table, writeOptions, query.schema, queryId = queryId)
val write = buildWriteForMicroBatch(table, writeBuilder, outputMode)
val microBatchWrite = new MicroBatchWrite(batchId, write.toStreaming)
val customMetrics = write.supportedCustomMetrics.toImmutableArraySeq
@@ -103,14 +102,14 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, funCatalogOpt)
WriteToDataSourceV2(relationOpt, microBatchWrite, newQuery, customMetrics)
- case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, _, None) =>
- val rowSchema = DataTypeUtils.fromAttributes(rd.dataInput)
+ case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, projections, _, None) =>
+ val rowSchema = projections.rowProjection.schema
+ val metadataSchema = projections.metadataProjection.map(_.schema)
val writeOptions = mergeOptions(Map.empty, r.options.asScala.toMap)
- val writeBuilder = newWriteBuilder(r.table, writeOptions, rowSchema)
+ val writeBuilder = newWriteBuilder(r.table, writeOptions, rowSchema, metadataSchema)
val write = writeBuilder.build()
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog)
- // project away any metadata columns that could be used for distribution and ordering
- rd.copy(write = Some(write), query = Project(rd.dataInput, newQuery))
+ rd.copy(write = Some(write), query = newQuery)
case wd @ WriteDelta(r: DataSourceV2Relation, _, query, _, projections, None) =>
val writeOptions = mergeOptions(Map.empty, r.options.asScala.toMap)
@@ -158,9 +157,15 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
table: Table,
writeOptions: Map[String, String],
rowSchema: StructType,
+ metadataSchema: Option[StructType] = None,
queryId: String = UUID.randomUUID().toString): WriteBuilder = {
- val info = LogicalWriteInfoImpl(queryId, rowSchema, writeOptions.asOptions)
+ val info = LogicalWriteInfoImpl(
+ queryId,
+ rowSchema,
+ writeOptions.asOptions,
+ rowIdSchema = None,
+ metadataSchema)
table.asWritable.newWriteBuilder(info)
}
@@ -171,15 +176,15 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
queryId: String = UUID.randomUUID().toString): DeltaWriteBuilder = {
val rowSchema = projections.rowProjection.map(_.schema).getOrElse(StructType(Nil))
- val rowIdSchema = projections.rowIdProjection.schema
+ val rowIdSchema = Some(projections.rowIdProjection.schema)
val metadataSchema = projections.metadataProjection.map(_.schema)
val info = LogicalWriteInfoImpl(
queryId,
rowSchema,
writeOptions.asOptions,
- Optional.of(rowIdSchema),
- Optional.ofNullable(metadataSchema.orNull))
+ rowIdSchema,
+ metadataSchema)
val writeBuilder = table.asWritable.newWriteBuilder(info)
assert(writeBuilder.isInstanceOf[DeltaWriteBuilder], s"$writeBuilder must be DeltaWriteBuilder")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
index 308b1bceca12a..016d6b5411acb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
@@ -22,12 +22,12 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.{SparkEnv, SparkException, TaskContext}
import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.{InternalRow, ProjectingInternalRow}
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, TableSpec, UnaryNode}
-import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils, WriteDeltaProjections}
-import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{DELETE_OPERATION, INSERT_OPERATION, UPDATE_OPERATION}
+import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils, ReplaceDataProjections, WriteDeltaProjections}
+import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{DELETE_OPERATION, INSERT_OPERATION, REINSERT_OPERATION, UPDATE_OPERATION, WRITE_OPERATION, WRITE_WITH_METADATA_OPERATION}
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableWritePrivilege}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.metric.CustomMetric
@@ -283,10 +283,20 @@ case class OverwritePartitionsDynamicExec(
case class ReplaceDataExec(
query: SparkPlan,
refreshCache: () => Unit,
+ projections: ReplaceDataProjections,
write: Write) extends V2ExistingTableWriteExec {
override val stringArgs: Iterator[Any] = Iterator(query, write)
+ override def writingTask: WritingSparkTask[_] = {
+ projections match {
+ case ReplaceDataProjections(dataProj, Some(metadataProj)) =>
+ DataAndMetadataWritingSparkTask(dataProj, metadataProj)
+ case _ =>
+ DataWritingSparkTask
+ }
+ }
+
override protected def withNewChildInternal(newChild: SparkPlan): ReplaceDataExec = {
copy(query = newChild)
}
@@ -542,6 +552,32 @@ trait WritingSparkTask[W <: DataWriter[InternalRow]] extends Logging with Serial
}
}
+case class DataAndMetadataWritingSparkTask(
+ dataProj: ProjectingInternalRow,
+ metadataProj: ProjectingInternalRow) extends WritingSparkTask[DataWriter[InternalRow]] {
+ override protected def write(
+ writer: DataWriter[InternalRow], iter: java.util.Iterator[InternalRow]): Unit = {
+ while (iter.hasNext) {
+ val row = iter.next()
+ val operation = row.getInt(0)
+
+ operation match {
+ case WRITE_WITH_METADATA_OPERATION =>
+ dataProj.project(row)
+ metadataProj.project(row)
+ writer.write(metadataProj, dataProj)
+
+ case WRITE_OPERATION =>
+ dataProj.project(row)
+ writer.write(dataProj)
+
+ case other =>
+ throw new SparkException(s"Unexpected operation ID: $other")
+ }
+ }
+ }
+}
+
object DataWritingSparkTask extends WritingSparkTask[DataWriter[InternalRow]] {
override protected def write(
writer: DataWriter[InternalRow], iter: java.util.Iterator[InternalRow]): Unit = {
@@ -571,6 +607,10 @@ case class DeltaWritingSparkTask(
rowIdProjection.project(row)
writer.update(null, rowIdProjection, rowProjection)
+ case REINSERT_OPERATION =>
+ rowProjection.project(row)
+ writer.reinsert(null, rowProjection)
+
case INSERT_OPERATION =>
rowProjection.project(row)
writer.insert(rowProjection)
@@ -607,6 +647,11 @@ case class DeltaWithMetadataWritingSparkTask(
metadataProjection.project(row)
writer.update(metadataProjection, rowIdProjection, rowProjection)
+ case REINSERT_OPERATION =>
+ rowProjection.project(row)
+ metadataProjection.project(row)
+ writer.reinsert(metadataProjection, rowProjection)
+
case INSERT_OPERATION =>
rowProjection.project(row)
writer.insert(rowProjection)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala
index be180eb89ce20..8ad713424cec5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala
@@ -18,15 +18,71 @@
package org.apache.spark.sql.connector
import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.types.StructType
class DeltaBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase {
+ import testImplicits._
+
override protected lazy val extraTableProps: java.util.Map[String, String] = {
val props = new java.util.HashMap[String, String]()
props.put("supports-deltas", "true")
props
}
+ test("delete handles metadata columns correctly") {
+ createAndInitTable("pk INT NOT NULL, id INT, dep STRING",
+ """{ "pk": 1, "id": 1, "dep": "hr" }
+ |{ "pk": 2, "id": 2, "dep": "software" }
+ |{ "pk": 3, "id": 3, "dep": "hr" }
+ |{ "pk": 4, "id": 4, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"DELETE FROM $tableNameAsString WHERE id IN (1, 100)")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, 2, "software") :: Row(3, 3, "hr") :: Row(4, 4, "hr") :: Nil)
+
+ checkLastWriteInfo(
+ expectedRowIdSchema = Some(StructType(Array(PK_FIELD))),
+ expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE))))
+
+ checkLastWriteLog(deleteWriteLogEntry(id = 1, metadata = Row("hr", null)))
+ }
+
+ test("delete with subquery handles metadata columns correctly") {
+ withTempView("updated_dep") {
+ createAndInitTable("pk INT NOT NULL, id INT, dep STRING",
+ """{ "pk": 1, "id": 1, "dep": "hr" }
+ |{ "pk": 2, "id": 2, "dep": "software" }
+ |{ "pk": 3, "id": 3, "dep": "hr" }
+ |{ "pk": 4, "id": 4, "dep": "hr" }
+ |""".stripMargin)
+
+ val updatedDepDF = Seq(Some("hr"), Some("it")).toDF()
+ updatedDepDF.createOrReplaceTempView("updated_dep")
+
+ sql(
+ s"""DELETE FROM $tableNameAsString
+ |WHERE
+ | id IN (1, 100)
+ | AND
+ | dep IN (SELECT * FROM updated_dep)
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, 2, "software") :: Row(3, 3, "hr") :: Row(4, 4, "hr") :: Nil)
+
+ checkLastWriteInfo(
+ expectedRowIdSchema = Some(StructType(Array(PK_FIELD))),
+ expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE))))
+
+ checkLastWriteLog(deleteWriteLogEntry(id = 1, metadata = Row("hr", null)))
+ }
+ }
+
test("delete with nondeterministic conditions") {
createAndInitTable("pk INT NOT NULL, id INT, dep STRING",
"""{ "pk": 1, "id": 1, "dep": "hr" }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableSuite.scala
index ff0f8b82bc9ce..32d602b154952 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableSuite.scala
@@ -17,11 +17,64 @@
package org.apache.spark.sql.connector
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.StructType
+
class DeltaBasedMergeIntoTableSuite extends DeltaBasedMergeIntoTableSuiteBase {
+ import testImplicits._
+
override protected lazy val extraTableProps: java.util.Map[String, String] = {
val props = new java.util.HashMap[String, String]()
props.put("supports-deltas", "true")
props
}
+
+ test("merge handles metadata columns correctly") {
+ withTempView("source") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |{ "pk": 4, "salary": 400, "dep": "hr" }
+ |{ "pk": 5, "salary": 500, "dep": "hr" }
+ |""".stripMargin)
+
+ val sourceDF = Seq(3, 4, 5, 6).toDF("pk")
+ sourceDF.createOrReplaceTempView("source")
+
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING source s
+ |ON t.pk = s.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET t.salary = t.salary + 1
+ |WHEN NOT MATCHED THEN
+ | INSERT (pk, salary, dep) VALUES (s.pk, 0, 'new')
+ |WHEN NOT MATCHED BY SOURCE AND pk = 1 THEN
+ | DELETE
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(2, 200, "software"), // unchanged
+ Row(3, 301, "hr"), // update
+ Row(4, 401, "hr"), // update
+ Row(5, 501, "hr"), // update
+ Row(6, 0, "new"))) // insert
+
+ checkLastWriteInfo(
+ expectedRowSchema = table.schema,
+ expectedRowIdSchema = Some(StructType(Array(PK_FIELD))),
+ expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE))))
+
+ checkLastWriteLog(
+ deleteWriteLogEntry(id = 1, metadata = Row("hr", null)),
+ updateWriteLogEntry(id = 3, metadata = Row("hr", null), data = Row(3, 301, "hr")),
+ updateWriteLogEntry(id = 4, metadata = Row("hr", null), data = Row(4, 401, "hr")),
+ updateWriteLogEntry(id = 5, metadata = Row("hr", null), data = Row(5, 501, "hr")),
+ insertWriteLogEntry(data = Row(6, 0, "new")))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite.scala
index 405ee99bb6dc5..e93d4165be332 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite.scala
@@ -17,13 +17,69 @@
package org.apache.spark.sql.connector
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.StructType
+
class DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite
extends DeltaBasedMergeIntoTableSuiteBase {
+ import testImplicits._
+
override protected lazy val extraTableProps: java.util.Map[String, String] = {
val props = new java.util.HashMap[String, String]()
props.put("supports-deltas", "true")
props.put("split-updates", "true")
props
}
+
+ test("merge handles metadata columns correctly") {
+ withTempView("source") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |{ "pk": 4, "salary": 400, "dep": "hr" }
+ |{ "pk": 5, "salary": 500, "dep": "hr" }
+ |""".stripMargin)
+
+ val sourceDF = Seq(3, 4, 5, 6).toDF("pk")
+ sourceDF.createOrReplaceTempView("source")
+
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING source s
+ |ON t.pk = s.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET t.salary = t.salary + 1
+ |WHEN NOT MATCHED THEN
+ | INSERT (pk, salary, dep) VALUES (s.pk, 0, 'new')
+ |WHEN NOT MATCHED BY SOURCE AND pk = 1 THEN
+ | DELETE
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(2, 200, "software"), // unchanged
+ Row(3, 301, "hr"), // update
+ Row(4, 401, "hr"), // update
+ Row(5, 501, "hr"), // update
+ Row(6, 0, "new"))) // insert
+
+ checkLastWriteInfo(
+ expectedRowSchema = table.schema,
+ expectedRowIdSchema = Some(StructType(Array(PK_FIELD))),
+ expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE))))
+
+ checkLastWriteLog(
+ deleteWriteLogEntry(id = 1, metadata = Row("hr", null)),
+ deleteWriteLogEntry(id = 3, metadata = Row("hr", null)),
+ reinsertWriteLogEntry(metadata = Row("hr", null), data = Row(3, 301, "hr")),
+ deleteWriteLogEntry(id = 4, metadata = Row("hr", null)),
+ reinsertWriteLogEntry(metadata = Row("hr", null), data = Row(4, 401, "hr")),
+ deleteWriteLogEntry(id = 5, metadata = Row("hr", null)),
+ reinsertWriteLogEntry(metadata = Row("hr", null), data = Row(5, 501, "hr")),
+ insertWriteLogEntry(data = Row(6, 0, "new")))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateAsDeleteAndInsertTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateAsDeleteAndInsertTableSuite.scala
index 363b90b45b87e..612a26e756abd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateAsDeleteAndInsertTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateAsDeleteAndInsertTableSuite.scala
@@ -17,12 +17,81 @@
package org.apache.spark.sql.connector
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.StructType
+
class DeltaBasedUpdateAsDeleteAndInsertTableSuite extends DeltaBasedUpdateTableSuiteBase {
+ import testImplicits._
+
override protected lazy val extraTableProps: java.util.Map[String, String] = {
val props = new java.util.HashMap[String, String]()
props.put("supports-deltas", "true")
props.put("split-updates", "true")
props
}
+
+ test("update handles metadata columns correctly") {
+ createAndInitTable("pk INT NOT NULL, id INT, dep STRING",
+ """{ "pk": 1, "id": 1, "dep": "hr" }
+ |{ "pk": 2, "id": 2, "dep": "software" }
+ |{ "pk": 3, "id": 3, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"UPDATE $tableNameAsString SET id = -1 WHERE id IN (1, 100)")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil)
+
+ checkLastWriteInfo(
+ expectedRowSchema = StructType(table.schema.map {
+ case attr if attr.name == "id" => attr.copy(nullable = false) // input is a constant
+ case attr => attr
+ }),
+ expectedRowIdSchema = Some(StructType(Array(PK_FIELD))),
+ expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE))))
+
+ checkLastWriteLog(
+ deleteWriteLogEntry(id = 1, metadata = Row("hr", null)),
+ reinsertWriteLogEntry(metadata = Row("hr", null), data = Row(1, -1, "hr")))
+ }
+
+ test("update with subquery handles metadata columns correctly") {
+ withTempView("updated_dep") {
+ createAndInitTable("pk INT NOT NULL, id INT, dep STRING",
+ """{ "pk": 1, "id": 1, "dep": "hr" }
+ |{ "pk": 2, "id": 2, "dep": "software" }
+ |{ "pk": 3, "id": 3, "dep": "hr" }
+ |""".stripMargin)
+
+ val updatedIdDF = Seq(Some("hr"), Some("it")).toDF()
+ updatedIdDF.createOrReplaceTempView("updated_dep")
+
+ sql(
+ s"""UPDATE $tableNameAsString
+ |SET id = -1
+ |WHERE
+ | id IN (1, 100)
+ | AND
+ | dep IN (SELECT * FROM updated_dep)
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil)
+
+ checkLastWriteInfo(
+ expectedRowSchema = StructType(table.schema.map {
+ case attr if attr.name == "id" => attr.copy(nullable = false) // input is a constant
+ case attr => attr
+ }),
+ expectedRowIdSchema = Some(StructType(Array(PK_FIELD))),
+ expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE))))
+
+ checkLastWriteLog(
+ deleteWriteLogEntry(id = 1, metadata = Row("hr", null)),
+ reinsertWriteLogEntry(metadata = Row("hr", null), data = Row(1, -1, "hr")))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuite.scala
index bf6dcfac04cb1..c9fd5d6e3ff0d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuite.scala
@@ -17,11 +17,78 @@
package org.apache.spark.sql.connector
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.StructType
+
class DeltaBasedUpdateTableSuite extends DeltaBasedUpdateTableSuiteBase {
+ import testImplicits._
+
override protected lazy val extraTableProps: java.util.Map[String, String] = {
val props = new java.util.HashMap[String, String]()
props.put("supports-deltas", "true")
props
}
+
+ test("update handles metadata columns correctly") {
+ createAndInitTable("pk INT NOT NULL, id INT, dep STRING",
+ """{ "pk": 1, "id": 1, "dep": "hr" }
+ |{ "pk": 2, "id": 2, "dep": "software" }
+ |{ "pk": 3, "id": 3, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"UPDATE $tableNameAsString SET id = -1 WHERE id IN (1, 100)")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil)
+
+ checkLastWriteInfo(
+ expectedRowSchema = StructType(table.schema.map {
+ case attr if attr.name == "id" => attr.copy(nullable = false) // input is a constant
+ case attr => attr
+ }),
+ expectedRowIdSchema = Some(StructType(Array(PK_FIELD))),
+ expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE))))
+
+ checkLastWriteLog(
+ updateWriteLogEntry(id = 1, metadata = Row("hr", null), data = Row(1, -1, "hr")))
+ }
+
+ test("update with subquery handles metadata columns correctly") {
+ withTempView("updated_dep") {
+ createAndInitTable("pk INT NOT NULL, id INT, dep STRING",
+ """{ "pk": 1, "id": 1, "dep": "hr" }
+ |{ "pk": 2, "id": 2, "dep": "software" }
+ |{ "pk": 3, "id": 3, "dep": "hr" }
+ |""".stripMargin)
+
+ val updatedIdDF = Seq(Some("hr"), Some("it")).toDF()
+ updatedIdDF.createOrReplaceTempView("updated_dep")
+
+ sql(
+ s"""UPDATE $tableNameAsString
+ |SET id = -1
+ |WHERE
+ | id IN (1, 20)
+ | AND
+ | dep IN (SELECT * FROM updated_dep)
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil)
+
+ checkLastWriteInfo(
+ expectedRowSchema = StructType(table.schema.map {
+ case attr if attr.name == "id" => attr.copy(nullable = false) // input is a constant
+ case attr => attr
+ }),
+ expectedRowIdSchema = Some(StructType(Array(PK_FIELD))),
+ expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE))))
+
+ checkLastWriteLog(
+ updateWriteLogEntry(id = 1, metadata = Row("hr", null), data = Row(1, -1, "hr")))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala
index 1be318f948fd9..4dd09a2f1c831 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala
@@ -19,11 +19,35 @@ package org.apache.spark.sql.connector
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase {
import testImplicits._
+ test("delete preserves metadata columns for carried-over records") {
+ createAndInitTable("pk INT NOT NULL, id INT, dep STRING",
+ """{ "pk": 1, "id": 1, "dep": "hr" }
+ |{ "pk": 2, "id": 2, "dep": "software" }
+ |{ "pk": 3, "id": 3, "dep": "hr" }
+ |{ "pk": 4, "id": 4, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"DELETE FROM $tableNameAsString WHERE id IN (1, 100)")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, 2, "software") :: Row(3, 3, "hr") :: Row(4, 4, "hr") :: Nil)
+
+ checkLastWriteInfo(
+ expectedRowSchema = table.schema,
+ expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD))))
+
+ checkLastWriteLog(
+ writeWithMetadataLogEntry(metadata = Row("hr", 1), data = Row(3, 3, "hr")),
+ writeWithMetadataLogEntry(metadata = Row("hr", 2), data = Row(4, 4, "hr")))
+ }
+
test("delete with nondeterministic conditions") {
createAndInitTable("pk INT NOT NULL, id INT, dep STRING",
"""{ "pk": 1, "id": 1, "dep": "hr" }
@@ -53,7 +77,7 @@ class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase {
executeAndCheckScans(
s"DELETE FROM $tableNameAsString WHERE salary IN (300, 400, 500)",
- primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING",
+ primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING, index INT",
groupFilterScanSchema = Some("salary INT, dep STRING"))
checkAnswer(
@@ -85,7 +109,7 @@ class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase {
| AND
| dep IN (SELECT * FROM deleted_dep)
|""".stripMargin,
- primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING",
+ primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING, index INT",
groupFilterScanSchema = Some("id INT, dep STRING"))
checkAnswer(
@@ -133,7 +157,7 @@ class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase {
executeAndCheckScans(
s"DELETE FROM $tableNameAsString WHERE id IN (SELECT * FROM deleted_id)",
- primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING",
+ primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING, index INT",
groupFilterScanSchema = Some("id INT, dep STRING"))
checkAnswer(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedMergeIntoTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedMergeIntoTableSuite.scala
index ebc34ae006e6e..63eba256d8f27 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedMergeIntoTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedMergeIntoTableSuite.scala
@@ -19,11 +19,62 @@ package org.apache.spark.sql.connector
import org.apache.spark.sql.Row
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
class GroupBasedMergeIntoTableSuite extends MergeIntoTableSuiteBase {
import testImplicits._
+ test("merge handles metadata columns correctly") {
+ withTempView("source") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |{ "pk": 4, "salary": 400, "dep": "hr" }
+ |{ "pk": 5, "salary": 500, "dep": "hr" }
+ |{ "pk": 7, "salary": 700, "dep": "hr" }
+ |""".stripMargin)
+
+ val sourceDF = Seq(3, 4, 5, 6, 7).toDF("pk")
+ sourceDF.createOrReplaceTempView("source")
+
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING source s
+ |ON t.pk = s.pk
+ |WHEN MATCHED AND t.pk != 7 THEN
+ | UPDATE SET t.salary = t.salary + 1
+ |WHEN NOT MATCHED THEN
+ | INSERT (pk, salary, dep) VALUES (s.pk, 0, 'new')
+ |WHEN NOT MATCHED BY SOURCE AND pk = 1 THEN
+ | DELETE
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(2, 200, "software"), // unchanged
+ Row(3, 301, "hr"), // update
+ Row(4, 401, "hr"), // update
+ Row(5, 501, "hr"), // update
+ Row(6, 0, "new"), // insert
+ Row(7, 700, "hr"))) // unchanged
+
+ checkLastWriteInfo(
+ expectedRowSchema = table.schema,
+ expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE))))
+
+ checkLastWriteLog(
+ writeWithMetadataLogEntry(metadata = Row("software", 0), data = Row(2, 200, "software")),
+ writeWithMetadataLogEntry(metadata = Row("hr", null), data = Row(3, 301, "hr")),
+ writeWithMetadataLogEntry(metadata = Row("hr", null), data = Row(4, 401, "hr")),
+ writeWithMetadataLogEntry(metadata = Row("hr", null), data = Row(5, 501, "hr")),
+ writeLogEntry(data = Row(6, 0, "new")),
+ writeWithMetadataLogEntry(metadata = Row("hr", 4), data = Row(7, 700, "hr")))
+ }
+ }
+
test("merge runtime filtering is disabled with NOT MATCHED BY SOURCE clauses") {
withTempView("source") {
createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
@@ -48,7 +99,7 @@ class GroupBasedMergeIntoTableSuite extends MergeIntoTableSuiteBase {
|WHEN NOT MATCHED BY SOURCE THEN
| DELETE
|""".stripMargin,
- primaryScanSchema = "pk INT, salary INT, dep STRING, _partition STRING",
+ primaryScanSchema = "pk INT, salary INT, dep STRING, _partition STRING, index INT",
groupFilterScanSchema = None)
checkAnswer(
@@ -109,7 +160,7 @@ class GroupBasedMergeIntoTableSuite extends MergeIntoTableSuiteBase {
|WHEN NOT MATCHED THEN
| INSERT (pk, salary, dep) VALUES (s.pk, 0, 'hr')
|""".stripMargin,
- primaryScanSchema = "pk INT, salary INT, dep STRING, _partition STRING",
+ primaryScanSchema = "pk INT, salary INT, dep STRING, _partition STRING, index INT",
groupFilterScanSchema = Some("pk INT, dep STRING"))
checkAnswer(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedUpdateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedUpdateTableSuite.scala
index 774ae97734d25..30545f5aa01aa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedUpdateTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedUpdateTableSuite.scala
@@ -22,11 +22,68 @@ import org.apache.spark.sql.catalyst.expressions.DynamicPruningExpression
import org.apache.spark.sql.execution.{InSubqueryExec, ReusedSubqueryExec}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
class GroupBasedUpdateTableSuite extends UpdateTableSuiteBase {
import testImplicits._
+ test("update handles metadata columns correctly") {
+ createAndInitTable("pk INT NOT NULL, id INT, dep STRING",
+ """{ "pk": 1, "id": 1, "dep": "hr" }
+ |{ "pk": 2, "id": 2, "dep": "software" }
+ |{ "pk": 3, "id": 3, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"UPDATE $tableNameAsString SET id = -1 WHERE id IN (1, 100)")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil)
+
+ checkLastWriteInfo(
+ expectedRowSchema = table.schema,
+ expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE))))
+
+ checkLastWriteLog(
+ writeWithMetadataLogEntry(metadata = Row("hr", null), data = Row(1, -1, "hr")),
+ writeWithMetadataLogEntry(metadata = Row("hr", 1), data = Row(3, 3, "hr")))
+ }
+
+ test("update with subquery handles metadata columns correctly") {
+ withTempView("updated_dep") {
+ createAndInitTable("pk INT NOT NULL, id INT, dep STRING",
+ """{ "pk": 1, "id": 1, "dep": "hr" }
+ |{ "pk": 2, "id": 2, "dep": "software" }
+ |{ "pk": 3, "id": 3, "dep": "hr" }
+ |""".stripMargin)
+
+ val updatedIdDF = Seq(Some("hr"), Some("it")).toDF()
+ updatedIdDF.createOrReplaceTempView("updated_dep")
+
+ sql(
+ s"""UPDATE $tableNameAsString
+ |SET id = -1
+ |WHERE
+ | id IN (1, 20)
+ | AND
+ | dep IN (SELECT * FROM updated_dep)
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil)
+
+ checkLastWriteInfo(
+ expectedRowSchema = table.schema,
+ expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE))))
+
+ checkLastWriteLog(
+ writeWithMetadataLogEntry(metadata = Row("hr", null), data = Row(1, -1, "hr")),
+ writeWithMetadataLogEntry(metadata = Row("hr", 1), data = Row(3, 3, "hr")))
+ }
+ }
+
test("update runtime group filtering") {
Seq(true, false).foreach { ddpEnabled =>
Seq(true, false).foreach { aqeEnabled =>
@@ -53,7 +110,7 @@ class GroupBasedUpdateTableSuite extends UpdateTableSuiteBase {
executeAndCheckScans(
s"UPDATE $tableNameAsString SET salary = -1 WHERE id IN (SELECT * FROM deleted_id)",
- primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING",
+ primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING, index INT",
groupFilterScanSchema = Some("id INT, dep STRING"))
checkAnswer(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala
index 68f996ba31367..580638230218b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala
@@ -21,10 +21,13 @@ import java.util.Collections
import org.scalatest.BeforeAndAfter
-import org.apache.spark.sql.{DataFrame, Encoders, QueryTest}
-import org.apache.spark.sql.catalyst.expressions.DynamicPruningExpression
+import org.apache.spark.sql.{DataFrame, Encoders, QueryTest, Row}
+import org.apache.spark.sql.QueryTest.sameRows
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, GenericRowWithSchema}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
-import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog}
+import org.apache.spark.sql.catalyst.util.METADATA_COL_ATTR_KEY
+import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Delete, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog, Insert, MetadataColumn, Operation, Reinsert, Update, Write}
import org.apache.spark.sql.connector.expressions.LogicalExpressions.{identity, reference}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.{InSubqueryExec, QueryExecution, SparkPlan}
@@ -32,7 +35,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, StringType, StructField, StructType}
import org.apache.spark.sql.util.QueryExecutionListener
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ArrayImplicits._
@@ -51,6 +54,30 @@ abstract class RowLevelOperationSuiteBase
spark.sessionState.conf.unsetConf("spark.sql.catalog.cat")
}
+ protected final val PK_FIELD = StructField("pk", IntegerType, nullable = false)
+ protected final val PARTITION_FIELD = StructField(
+ "_partition",
+ StringType,
+ nullable = false,
+ metadata = new MetadataBuilder()
+ .putString(METADATA_COL_ATTR_KEY, "_partition")
+ .putString("comment", "Partition key used to store the row")
+ .putBoolean(MetadataColumn.PRESERVE_ON_UPDATE, value = true)
+ .putBoolean(MetadataColumn.PRESERVE_ON_REINSERT, value = true)
+ .build())
+ protected final val PARTITION_FIELD_NULLABLE = PARTITION_FIELD.copy(nullable = true)
+ protected final val INDEX_FIELD = StructField(
+ "index",
+ IntegerType,
+ nullable = false,
+ metadata = new MetadataBuilder()
+ .putString(METADATA_COL_ATTR_KEY, "index")
+ .putString("comment", "Metadata column used to conflict with a data column")
+ .putBoolean(MetadataColumn.PRESERVE_ON_DELETE, value = false)
+ .putBoolean(MetadataColumn.PRESERVE_ON_UPDATE, value = false)
+ .build())
+ protected final val INDEX_FIELD_NULLABLE = INDEX_FIELD.copy(nullable = true)
+
protected val namespace: Array[String] = Array("ns1")
protected val ident: Identifier = Identifier.of(namespace, "test_table")
protected val tableNameAsString: String = "cat." + ident.toString
@@ -176,4 +203,77 @@ abstract class RowLevelOperationSuiteBase
}
assert(actualPartitions == expectedPartitions, "replaced partitions must match")
}
+
+ protected def checkLastWriteInfo(
+ expectedRowSchema: StructType = new StructType(),
+ expectedRowIdSchema: Option[StructType] = None,
+ expectedMetadataSchema: Option[StructType] = None): Unit = {
+ val info = table.lastWriteInfo
+ assert(info.schema == expectedRowSchema, "row schema must match")
+ val actualRowIdSchema = Option(info.rowIdSchema.orElse(null))
+ assert(actualRowIdSchema == expectedRowIdSchema, "row ID schema must match")
+ val actualMetadataSchema = Option(info.metadataSchema.orElse(null))
+ assert(actualMetadataSchema == expectedMetadataSchema, "metadata schema must match")
+ }
+
+ protected def checkLastWriteLog(expectedEntries: WriteLogEntry*): Unit = {
+ val entryType = new StructType()
+ .add(StructField("operation", StringType))
+ .add(StructField("id", IntegerType))
+ .add(StructField(
+ "metadata",
+ new StructType(Array(
+ StructField("_partition", StringType),
+ StructField("_index", IntegerType)))))
+ .add(StructField("data", table.schema))
+
+ val expectedEntriesAsRows = expectedEntries.map { entry =>
+ new GenericRowWithSchema(
+ values = Array(
+ entry.operation.toString,
+ entry.id.orNull,
+ entry.metadata.orNull,
+ entry.data.orNull),
+ schema = entryType)
+ }
+
+ val encoder = ExpressionEncoder(entryType)
+ val deserializer = encoder.resolveAndBind().createDeserializer()
+ val actualEntriesAsRows = table.lastWriteLog.map(deserializer)
+
+ sameRows(expectedEntriesAsRows, actualEntriesAsRows) match {
+ case Some(errMsg) => fail(s"Write log contains unexpected entries: $errMsg")
+ case None => // OK
+ }
+ }
+
+ protected def writeLogEntry(data: Row): WriteLogEntry = {
+ WriteLogEntry(operation = Write, data = Some(data))
+ }
+
+ protected def writeWithMetadataLogEntry(metadata: Row, data: Row): WriteLogEntry = {
+ WriteLogEntry(operation = Write, metadata = Some(metadata), data = Some(data))
+ }
+
+ protected def deleteWriteLogEntry(id: Int, metadata: Row): WriteLogEntry = {
+ WriteLogEntry(operation = Delete, id = Some(id), metadata = Some(metadata))
+ }
+
+ protected def updateWriteLogEntry(id: Int, metadata: Row, data: Row): WriteLogEntry = {
+ WriteLogEntry(operation = Update, id = Some(id), metadata = Some(metadata), data = Some(data))
+ }
+
+ protected def reinsertWriteLogEntry(metadata: Row, data: Row): WriteLogEntry = {
+ WriteLogEntry(operation = Reinsert, metadata = Some(metadata), data = Some(data))
+ }
+
+ protected def insertWriteLogEntry(data: Row): WriteLogEntry = {
+ WriteLogEntry(operation = Insert, data = Some(data))
+ }
+
+ case class WriteLogEntry(
+ operation: Operation,
+ id: Option[Int] = None,
+ metadata: Option[Row] = None,
+ data: Option[Row] = None)
}