diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/MetadataColumn.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/MetadataColumn.java index 6606748e6d6f9..c5b41d7dfbf4f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/MetadataColumn.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/MetadataColumn.java @@ -36,6 +36,45 @@ */ @Evolving public interface MetadataColumn { + /** + * Indicates whether a row-level operation should preserve the value of the metadata column + * for deleted rows. If set to true, the metadata value will be retained and passed back to + * the writer. If false, the metadata value will be replaced with {@code null}. + *

+ * This flag applies only to row-level operations working with deltas of rows. Group-based + * operations handle deletes by discarding matching records. + * + * @since 4.0.0 + */ + String PRESERVE_ON_DELETE = "__preserve_on_delete"; + boolean PRESERVE_ON_DELETE_DEFAULT = true; + + /** + * Indicates whether a row-level operation should preserve the value of the metadata column + * for updated rows. If set to true, the metadata value will be retained and passed back to + * the writer. If false, the metadata value will be replaced with {@code null}. + *

+ * This flag applies to both group-based and delta-based row-level operations. + * + * @since 4.0.0 + */ + String PRESERVE_ON_UPDATE = "__preserve_on_update"; + boolean PRESERVE_ON_UPDATE_DEFAULT = true; + + /** + * Indicates whether a row-level operation should preserve the value of the metadata column + * for reinserted rows generated by splitting updates into deletes and inserts. If true, + * the metadata value will be retained and passed back to the writer. If false, the metadata + * value will be replaced with {@code null}. + *

+ * This flag applies only to row-level operations working with deltas of rows. Group-based + * operations do not represent updates as deletes and inserts. + * + * @since 4.0.0 + */ + String PRESERVE_ON_REINSERT = "__preserve_on_reinsert"; + boolean PRESERVE_ON_REINSERT_DEFAULT = false; + /** * The name of this metadata column. * @@ -74,4 +113,13 @@ default String comment() { default Transform transform() { return null; } + + /** + * Returns the column metadata in JSON format. + * + * @since 4.0.0 + */ + default String metadataInJSON() { + return null; + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java index d6e94fe2ca8b0..a4ec1abc9dd7d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java @@ -64,6 +64,23 @@ */ @Evolving public interface DataWriter extends Closeable { + /** + * Writes one record with metadata. + *

+ * This method is used by group-based row-level operations to pass back metadata for records + * that are updated or copied. New records added during a MERGE operation are written using + * {@link #write(Object)} as there is no metadata associated with those records. + *

+ * If this method fails (by throwing an exception), {@link #abort()} will be called and this + * data writer is considered to have been failed. + * + * @throws IOException if failure happens during disk/network IO like writing files. + * + * @since 4.0.0 + */ + default void write(T metadata, T record) throws IOException { + write(record); + } /** * Writes one record. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DeltaWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DeltaWriter.java index 0cc6cb48801bf..a7ab0c162ddec 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DeltaWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DeltaWriter.java @@ -48,6 +48,21 @@ public interface DeltaWriter extends DataWriter { */ void update(T metadata, T id, T row) throws IOException; + /** + * Reinserts a row with metadata. + *

+ * This method handles the insert portion of updated rows split into deletes and inserts. + * + * @param metadata values for metadata columns + * @param row a row to reinsert + * @throws IOException if failure happens during disk/network IO like writing files + * + * @since 4.0.0 + */ + default void reinsert(T metadata, T row) throws IOException { + insert(row); + } + /** * Inserts a new row. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala index 32163406ca6d2..9c63e091eaf51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala @@ -86,7 +86,9 @@ object RewriteDeleteFromTable extends RewriteRowLevelCommand { // build a plan to replace read groups in the table val writeRelation = relation.copy(table = operationTable) - ReplaceData(writeRelation, cond, remainingRowsPlan, relation, Some(cond)) + val query = addOperationColumn(WRITE_WITH_METADATA_OPERATION, remainingRowsPlan) + val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs) + ReplaceData(writeRelation, cond, query, relation, projections, Some(cond)) } // build a rewrite plan for sources that support row deltas @@ -106,7 +108,7 @@ object RewriteDeleteFromTable extends RewriteRowLevelCommand { // construct a plan that only contains records to delete val deletedRowsPlan = Filter(cond, readRelation) val operationType = Alias(Literal(DELETE_OPERATION), OPERATION_COLUMN)() - val requiredWriteAttrs = dedupAttrs(rowIdAttrs ++ metadataAttrs) + val requiredWriteAttrs = nullifyMetadataOnDelete(dedupAttrs(rowIdAttrs ++ metadataAttrs)) val project = Project(operationType +: requiredWriteAttrs, deletedRowsPlan) // build a plan to write deletes to the table diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala index dacee70cf1286..7e2cf4f29807c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, DeleteAction, Filter, HintInfo, InsertAction, Join, JoinHint, LogicalPlan, MergeAction, MergeIntoTable, MergeRows, NO_BROADCAST_AND_REPLICATION, Project, ReplaceData, UpdateAction, WriteDelta} import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Discard, Instruction, Keep, ROW_ID, Split} -import org.apache.spark.sql.catalyst.util.RowDeltaUtils.OPERATION_COLUMN +import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{OPERATION_COLUMN, WRITE_OPERATION, WRITE_WITH_METADATA_OPERATION} import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations import org.apache.spark.sql.connector.write.{RowLevelOperationTable, SupportsDelta} import org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE @@ -180,7 +180,8 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper // build a plan to replace read groups in the table val writeRelation = relation.copy(table = operationTable) - ReplaceData(writeRelation, pushableCond, mergeRowsPlan, relation, groupFilterCond) + val projections = buildReplaceDataProjections(mergeRowsPlan, relation.output, metadataAttrs) + ReplaceData(writeRelation, pushableCond, mergeRowsPlan, relation, projections, groupFilterCond) } private def buildReplaceDataMergeRowsPlan( @@ -197,7 +198,8 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper // that's why an extra unconditional instruction that would produce the original row is added // as the last MATCHED and NOT MATCHED BY SOURCE instruction // this logic is specific to data sources that replace groups of data - val keepCarryoverRowsInstruction = Keep(TrueLiteral, targetTable.output) + val carryoverRowsOutput = Literal(WRITE_WITH_METADATA_OPERATION) +: targetTable.output + val keepCarryoverRowsInstruction = Keep(TrueLiteral, carryoverRowsOutput) val matchedInstructions = matchedActions.map { action => toInstruction(action, metadataAttrs) @@ -218,7 +220,8 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper notMatchedInstructions.flatMap(_.outputs) ++ notMatchedBySourceInstructions.flatMap(_.outputs) - val attrs = targetTable.output + val operationTypeAttr = AttributeReference(OPERATION_COLUMN, IntegerType, nullable = false)() + val attrs = operationTypeAttr +: targetTable.output MergeRows( isSourceRowPresent = IsNotNull(rowFromSourceAttr), @@ -430,15 +433,18 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper private def toInstruction(action: MergeAction, metadataAttrs: Seq[Attribute]): Instruction = { action match { case UpdateAction(cond, assignments) => - val output = assignments.map(_.value) ++ metadataAttrs + val rowValues = assignments.map(_.value) + val metadataValues = nullifyMetadataOnUpdate(metadataAttrs) + val output = Seq(Literal(WRITE_WITH_METADATA_OPERATION)) ++ rowValues ++ metadataValues Keep(cond.getOrElse(TrueLiteral), output) case DeleteAction(cond) => Discard(cond.getOrElse(TrueLiteral)) case InsertAction(cond, assignments) => + val rowValues = assignments.map(_.value) val metadataValues = metadataAttrs.map(attr => Literal(null, attr.dataType)) - val output = assignments.map(_.value) ++ metadataValues + val output = Seq(Literal(WRITE_OPERATION)) ++ rowValues ++ metadataValues Keep(cond.getOrElse(TrueLiteral), output) case other => @@ -460,7 +466,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper action match { case UpdateAction(cond, assignments) if splitUpdates => val output = deltaDeleteOutput(rowAttrs, rowIdAttrs, metadataAttrs, originalRowIdValues) - val otherOutput = deltaInsertOutput(assignments, metadataAttrs, originalRowIdValues) + val otherOutput = deltaReinsertOutput(assignments, metadataAttrs, originalRowIdValues) Split(cond.getOrElse(TrueLiteral), output, otherOutput) case UpdateAction(cond, assignments) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala index cec44470e3a35..118ed4e99190c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala @@ -19,24 +19,32 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ProjectingInternalRow -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, Expression, ExprId, Literal, V2ExpressionUtils} -import org.apache.spark.sql.catalyst.plans.logical.{Assignment, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, Expression, ExprId, Literal, MetadataAttribute, NamedExpression, V2ExpressionUtils} +import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Expand, LogicalPlan, MergeRows, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.catalyst.util.{ReplaceDataProjections, WriteDeltaProjections} import org.apache.spark.sql.catalyst.util.RowDeltaUtils._ -import org.apache.spark.sql.catalyst.util.WriteDeltaProjections import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.connector.write.{RowLevelOperation, RowLevelOperationInfoImpl, RowLevelOperationTable, SupportsDelta} import org.apache.spark.sql.connector.write.RowLevelOperation.Command import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.ArrayImplicits._ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { + private final val DELTA_OPERATIONS_WITH_ROW = + Set(UPDATE_OPERATION, REINSERT_OPERATION, INSERT_OPERATION) + private final val DELTA_OPERATIONS_WITH_METADATA = + Set(DELETE_OPERATION, UPDATE_OPERATION, REINSERT_OPERATION) + private final val DELTA_OPERATIONS_WITH_ROW_ID = + Set(DELETE_OPERATION, UPDATE_OPERATION) + protected def buildOperationTable( table: SupportsRowLevelOperations, command: Command, @@ -103,7 +111,31 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { metadataAttrs: Seq[Attribute], originalRowIdValues: Seq[Expression] = Seq.empty): Seq[Expression] = { val rowValues = buildDeltaDeleteRowValues(rowAttrs, rowIdAttrs) - Seq(Literal(DELETE_OPERATION)) ++ rowValues ++ metadataAttrs ++ originalRowIdValues + val metadataValues = nullifyMetadataOnDelete(metadataAttrs) + Seq(Literal(DELETE_OPERATION)) ++ rowValues ++ metadataValues ++ originalRowIdValues + } + + protected def nullifyMetadataOnDelete(attrs: Seq[Attribute]): Seq[NamedExpression] = { + nullifyMetadata(attrs, MetadataAttribute.isPreservedOnDelete) + } + + protected def nullifyMetadataOnUpdate(attrs: Seq[Attribute]): Seq[NamedExpression] = { + nullifyMetadata(attrs, MetadataAttribute.isPreservedOnUpdate) + } + + private def nullifyMetadataOnReinsert(attrs: Seq[Attribute]): Seq[NamedExpression] = { + nullifyMetadata(attrs, MetadataAttribute.isPreservedOnReinsert) + } + + private def nullifyMetadata( + attrs: Seq[Attribute], + shouldPreserve: Attribute => Boolean): Seq[NamedExpression] = { + attrs.map { + case MetadataAttribute(attr) if !shouldPreserve(attr) => + Alias(Literal(null, attr.dataType), attr.name)(explicitMetadata = Some(attr.metadata)) + case attr => + attr + } } private def buildDeltaDeleteRowValues( @@ -132,7 +164,42 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { metadataAttrs: Seq[Attribute], originalRowIdValues: Seq[Expression]): Seq[Expression] = { val rowValues = assignments.map(_.value) - Seq(Literal(UPDATE_OPERATION)) ++ rowValues ++ metadataAttrs ++ originalRowIdValues + val metadataValues = nullifyMetadataOnUpdate(metadataAttrs) + Seq(Literal(UPDATE_OPERATION)) ++ rowValues ++ metadataValues ++ originalRowIdValues + } + + protected def deltaReinsertOutput( + assignments: Seq[Assignment], + metadataAttrs: Seq[Attribute], + originalRowIdValues: Seq[Expression] = Seq.empty): Seq[Expression] = { + val rowValues = assignments.map(_.value) + val metadataValues = nullifyMetadataOnReinsert(metadataAttrs) + val extraNullValues = originalRowIdValues.map(e => Literal(null, e.dataType)) + Seq(Literal(REINSERT_OPERATION)) ++ rowValues ++ metadataValues ++ extraNullValues + } + + protected def addOperationColumn(operation: Int, plan: LogicalPlan): LogicalPlan = { + val operationType = Alias(Literal(operation, IntegerType), OPERATION_COLUMN)() + Project(operationType +: plan.output, plan) + } + + protected def buildReplaceDataProjections( + plan: LogicalPlan, + rowAttrs: Seq[Attribute], + metadataAttrs: Seq[Attribute]): ReplaceDataProjections = { + val outputs = extractOutputs(plan) + + val outputsWithRow = filterOutputs(outputs, Set(WRITE_WITH_METADATA_OPERATION, WRITE_OPERATION)) + val rowProjection = newLazyProjection(plan, outputsWithRow, rowAttrs) + + val metadataProjection = if (metadataAttrs.nonEmpty) { + val outputsWithMetadata = filterOutputs(outputs, Set(WRITE_WITH_METADATA_OPERATION)) + Some(newLazyProjection(plan, outputsWithMetadata, metadataAttrs)) + } else { + None + } + + ReplaceDataProjections(rowProjection, metadataProjection) } protected def buildWriteDeltaProjections( @@ -140,17 +207,21 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { rowAttrs: Seq[Attribute], rowIdAttrs: Seq[Attribute], metadataAttrs: Seq[Attribute]): WriteDeltaProjections = { + val outputs = extractOutputs(plan) val rowProjection = if (rowAttrs.nonEmpty) { - Some(newLazyProjection(plan, rowAttrs)) + val outputsWithRow = filterOutputs(outputs, DELTA_OPERATIONS_WITH_ROW) + Some(newLazyProjection(plan, outputsWithRow, rowAttrs)) } else { None } - val rowIdProjection = newLazyRowIdProjection(plan, rowIdAttrs) + val outputsWithRowId = filterOutputs(outputs, DELTA_OPERATIONS_WITH_ROW_ID) + val rowIdProjection = newLazyRowIdProjection(plan, outputsWithRowId, rowIdAttrs) val metadataProjection = if (metadataAttrs.nonEmpty) { - Some(newLazyProjection(plan, metadataAttrs)) + val outputsWithMetadata = filterOutputs(outputs, DELTA_OPERATIONS_WITH_METADATA) + Some(newLazyProjection(plan, outputsWithMetadata, metadataAttrs)) } else { None } @@ -158,26 +229,54 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { WriteDeltaProjections(rowProjection, rowIdProjection, metadataProjection) } + private def extractOutputs(plan: LogicalPlan): Seq[Seq[Expression]] = { + plan match { + case p: Project => Seq(p.projectList) + case e: Expand => e.projections + case m: MergeRows => m.outputs + case _ => throw SparkException.internalError("Can't extract outputs from plan: " + plan) + } + } + + private def filterOutputs( + outputs: Seq[Seq[Expression]], + operations: Set[Int]): Seq[Seq[Expression]] = { + outputs.filter { + case Literal(operation: Integer, _) +: _ => operations.contains(operation) + case Alias(Literal(operation: Integer, _), _) +: _ => operations.contains(operation) + case other => throw SparkException.internalError("Can't determine operation: " + other) + } + } + private def newLazyProjection( plan: LogicalPlan, + outputs: Seq[Seq[Expression]], attrs: Seq[Attribute]): ProjectingInternalRow = { - val colOrdinals = attrs.map(attr => findColOrdinal(plan, attr.name)) - val schema = DataTypeUtils.fromAttributes(attrs) - ProjectingInternalRow(schema, colOrdinals) + createProjectingInternalRow(outputs, colOrdinals, attrs) } // if there are assignment to row ID attributes, original values are projected as special columns // this method honors such special columns if present private def newLazyRowIdProjection( plan: LogicalPlan, + outputs: Seq[Seq[Expression]], rowIdAttrs: Seq[Attribute]): ProjectingInternalRow = { - val colOrdinals = rowIdAttrs.map { attr => val originalValueIndex = findColOrdinal(plan, ORIGINAL_ROW_ID_VALUE_PREFIX + attr.name) if (originalValueIndex != -1) originalValueIndex else findColOrdinal(plan, attr.name) } - val schema = DataTypeUtils.fromAttributes(rowIdAttrs) + createProjectingInternalRow(outputs, colOrdinals, rowIdAttrs) + } + + private def createProjectingInternalRow( + outputs: Seq[Seq[Expression]], + colOrdinals: Seq[Int], + attrs: Seq[Attribute]): ProjectingInternalRow = { + val schema = StructType(attrs.zipWithIndex.map { case (attr, index) => + val nullable = outputs.exists(output => output(colOrdinals(index)).nullable) + StructField(attr.name, attr.dataType, nullable, attr.metadata) + }) ProjectingInternalRow(schema, colOrdinals) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala index 37644a33c7a54..b2955ca006878 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala @@ -77,7 +77,9 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { // build a plan to replace read groups in the table val writeRelation = relation.copy(table = operationTable) - ReplaceData(writeRelation, cond, updatedAndRemainingRowsPlan, relation, Some(cond)) + val query = addOperationColumn(WRITE_WITH_METADATA_OPERATION, updatedAndRemainingRowsPlan) + val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs) + ReplaceData(writeRelation, cond, query, relation, projections, Some(cond)) } // build a rewrite plan for sources that support replacing groups of data (e.g. files, partitions) @@ -109,7 +111,9 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { // build a plan to replace read groups in the table val writeRelation = relation.copy(table = operationTable) - ReplaceData(writeRelation, cond, updatedAndRemainingRowsPlan, relation, Some(cond)) + val query = addOperationColumn(WRITE_WITH_METADATA_OPERATION, updatedAndRemainingRowsPlan) + val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs) + ReplaceData(writeRelation, cond, query, relation, projections, Some(cond)) } // this method assumes the assignments have been already aligned before @@ -118,7 +122,7 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { assignments: Seq[Assignment], cond: Expression = TrueLiteral): LogicalPlan = { - // the plan output may include immutable metadata columns at the end + // the plan output may include metadata columns at the end // that's why the number of assignments may not match the number of plan output columns val assignedValues = assignments.map(_.value) val updatedValues = plan.output.zipWithIndex.map { case (attr, index) => @@ -128,7 +132,12 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { Alias(updatedValue, attr.name)() } else { assert(MetadataAttribute.isValid(attr.metadata)) - attr + if (MetadataAttribute.isPreservedOnUpdate(attr)) { + attr + } else { + val updatedValue = If(cond, Literal(null, attr.dataType), attr) + Alias(updatedValue, attr.name)(explicitMetadata = Some(attr.metadata)) + } } } @@ -181,7 +190,11 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { Alias(assignedExpr, attr.name)() } else { assert(MetadataAttribute.isValid(attr.metadata)) - attr + if (MetadataAttribute.isPreservedOnUpdate(attr)) { + attr + } else { + Alias(Literal(null, attr.dataType), attr.name)(explicitMetadata = Some(attr.metadata)) + } } } @@ -203,7 +216,7 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { MetadataAttribute.isValid(attr.metadata) } val deleteOutput = deltaDeleteOutput(rowAttrs, rowIdAttrs, metadataAttrs) - val insertOutput = deltaInsertOutput(assignments, metadataAttrs) + val insertOutput = deltaReinsertOutput(assignments, metadataAttrs) val outputs = Seq(deleteOutput, insertOutput) val operationTypeAttr = AttributeReference(OPERATION_COLUMN, IntegerType, nullable = false)() val attrs = operationTypeAttr +: matchedRowsPlan.output diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index f5f35050401ba..2af6a1ba84ec8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, METADATA_COL_ATTR_KEY} +import org.apache.spark.sql.connector.catalog.MetadataColumn import org.apache.spark.sql.types._ import org.apache.spark.util.collection.BitSet import org.apache.spark.util.collection.ImmutableBitSet @@ -503,8 +504,41 @@ object MetadataAttribute { .putString(METADATA_COL_ATTR_KEY, name) .build() + def metadata(col: MetadataColumn): Metadata = { + val builder = new MetadataBuilder() + if (col.metadataInJSON != null) { + builder.withMetadata(Metadata.fromJson(col.metadataInJSON)) + } + builder.putString(METADATA_COL_ATTR_KEY, col.name) + builder.build() + } + def isValid(metadata: Metadata): Boolean = metadata.contains(METADATA_COL_ATTR_KEY) + + def isPreservedOnDelete(attr: Attribute): Boolean = { + if (attr.metadata.contains(MetadataColumn.PRESERVE_ON_DELETE)) { + attr.metadata.getBoolean(MetadataColumn.PRESERVE_ON_DELETE) + } else { + MetadataColumn.PRESERVE_ON_DELETE_DEFAULT + } + } + + def isPreservedOnUpdate(attr: Attribute): Boolean = { + if (attr.metadata.contains(MetadataColumn.PRESERVE_ON_UPDATE)) { + attr.metadata.getBoolean(MetadataColumn.PRESERVE_ON_UPDATE) + } else { + MetadataColumn.PRESERVE_ON_UPDATE_DEFAULT + } + } + + def isPreservedOnReinsert(attr: Attribute): Boolean = { + if (attr.metadata.contains(MetadataColumn.PRESERVE_ON_REINSERT)) { + attr.metadata.getBoolean(MetadataColumn.PRESERVE_ON_REINSERT) + } else { + MetadataColumn.PRESERVE_ON_REINSERT_DEFAULT + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala index 2cda1142299ae..0358c45815944 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -57,7 +57,7 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { _.containsAnyPattern(NULL_LITERAL, TRUE_OR_FALSE_LITERAL, INSET), ruleId) { case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond)) case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(replaceNullWithFalse(cond))) - case rd @ ReplaceData(_, cond, _, _, groupFilterCond, _) => + case rd @ ReplaceData(_, cond, _, _, _, groupFilterCond, _) => val newCond = replaceNullWithFalse(cond) val newGroupFilterCond = groupFilterCond.map(replaceNullWithFalse) rd.copy(condition = newCond, groupFilterCondition = newGroupFilterCond) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index d0a0fc307756c..54a4e75c90c95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -437,7 +437,7 @@ object GroupBasedRowLevelOperation { def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { case rd @ ReplaceData(DataSourceV2Relation(table, _, _, _, _), - cond, query, _, groupFilterCond, _) => + cond, query, _, _, groupFilterCond, _) => // group-based UPDATEs that are rewritten as UNION read the table twice val allowMultipleReads = rd.operation.command == UPDATE val readRelation = findReadRelation(table, query, allowMultipleReads) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala index 9b1c8bc733a35..f7f515c29481d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala @@ -49,6 +49,12 @@ case class MergeRows( AttributeSet.fromAttributeSets(usedExprs.map(_.references)) -- producedAttributes } + def instructions: Seq[Instruction] = { + matchedInstructions ++ notMatchedInstructions ++ notMatchedBySourceInstructions + } + + def outputs: Seq[Seq[Expression]] = instructions.flatMap(_.outputs) + override def simpleString(maxFields: Int): String = { s"MergeRows${truncatedString(output, "[", ", ", "]", maxFields)}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 58c62a90225aa..b361ccbe2439a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -23,10 +23,11 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.catalog.{FunctionResource, RoutineLanguage} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, MetadataAttribute, NamedExpression, UnaryExpression, Unevaluable, V2ExpressionUtils} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, MetadataAttribute, UnaryExpression, Unevaluable, V2ExpressionUtils} import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema import org.apache.spark.sql.catalyst.trees.BinaryLike -import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, truncatedString, CharVarcharUtils, RowDeltaUtils, WriteDeltaProjections} +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, truncatedString, CharVarcharUtils, ReplaceDataProjections, RowDeltaUtils, WriteDeltaProjections} import org.apache.spark.sql.catalyst.util.TypeUtils.{ordinalNumber, toSQLExpr} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{IdentifierHelper, MultipartIdentifierHelper} @@ -34,9 +35,10 @@ import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.write.{DeltaWrite, RowLevelOperation, RowLevelOperationTable, SupportsDelta, Write} +import org.apache.spark.sql.connector.write.RowLevelOperation.Command.{DELETE, MERGE, UPDATE} import org.apache.spark.sql.errors.DataTypeErrors.toSQLType import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, MapType, MetadataBuilder, StringType, StructField, StructType} +import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, MapType, MetadataBuilder, StringType, StructType} import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils @@ -66,16 +68,19 @@ trait V2WriteCommand extends UnaryCommand with KeepAnalyzedQuery with CTEInChild assert(table.resolved && query.resolved, "`outputResolved` can only be called when `table` and `query` are both resolved.") // If the table doesn't require schema match, we don't need to resolve the output columns. - table.skipSchemaResolution || (query.output.size == table.output.size && - query.output.zip(table.output).forall { - case (inAttr, outAttr) => - val inType = CharVarcharUtils.getRawType(inAttr.metadata).getOrElse(inAttr.dataType) - val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType) - // names and types must match, nullability must be compatible - inAttr.name == outAttr.name && - DataType.equalsIgnoreCompatibleNullability(inType, outType) && - (outAttr.nullable || !inAttr.nullable) - }) + table.skipSchemaResolution || areCompatible(query.output, table.output) + } + + protected def areCompatible(inAttrs: Seq[Attribute], outAttrs: Seq[Attribute]): Boolean = { + inAttrs.size == outAttrs.size && inAttrs.zip(outAttrs).forall { + case (inAttr, outAttr) => + val inType = CharVarcharUtils.getRawType(inAttr.metadata).getOrElse(inAttr.dataType) + val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType) + // names and types must match, nullability must be compatible + inAttr.name == outAttr.name && + DataType.equalsIgnoreCompatibleNullability(inType, outType) && + (outAttr.nullable || !inAttr.nullable) + } } def withNewQuery(newQuery: LogicalPlan): V2WriteCommand @@ -209,6 +214,17 @@ trait RowLevelWrite extends V2WriteCommand with SupportsSubquery { def operation: RowLevelOperation def condition: Expression def originalTable: NamedRelation + + protected def operationResolved: Boolean = { + val attr = query.output.head + attr.name == RowDeltaUtils.OPERATION_COLUMN && attr.dataType == IntegerType && !attr.nullable + } + + protected def projectedMetadataAttrs: Seq[Attribute] = { + V2ExpressionUtils.resolveRefs[AttributeReference]( + operation.requiredMetadataAttributes.toImmutableArraySeq, + originalTable) + } } /** @@ -229,6 +245,7 @@ case class ReplaceData( condition: Expression, query: LogicalPlan, originalTable: NamedRelation, + projections: ReplaceDataProjections, groupFilterCondition: Option[Expression] = None, write: Option[Write] = None) extends RowLevelWrite { @@ -248,31 +265,33 @@ case class ReplaceData( } } - // the incoming query may include metadata columns - lazy val dataInput: Seq[Attribute] = { - query.output.filter { - case MetadataAttribute(_) => false - case _ => true - } - } - override def outputResolved: Boolean = { assert(table.resolved && query.resolved, "`outputResolved` can only be called when `table` and `query` are both resolved.") + operationResolved && rowAttrsResolved && metadataAttrsResolved + } - // take into account only incoming data columns and ignore metadata columns in the query - // they will be discarded after the logical write is built in the optimizer - // metadata columns may be needed to request a correct distribution or ordering - // but are not passed back to the data source during writes + // validates row projection output is compatible with table attributes + private def rowAttrsResolved: Boolean = { + val inRowAttrs = DataTypeUtils.toAttributes(projections.rowProjection.schema) + table.skipSchemaResolution || areCompatible(inRowAttrs, table.output) + } - table.skipSchemaResolution || (dataInput.size == table.output.size && - dataInput.zip(table.output).forall { case (inAttr, outAttr) => - val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType) - // names and types must match, nullability must be compatible - inAttr.name == outAttr.name && - DataType.equalsIgnoreCompatibleNullability(inAttr.dataType, outType) && - (outAttr.nullable || !inAttr.nullable) - }) + // validates metadata projection output is compatible with metadata attributes + private def metadataAttrsResolved: Boolean = { + val outMetadataAttrs = projectedMetadataAttrs.map { + case attr if isMetadataNullabilityPreserved(attr) => attr + case attr => attr.withNullability(true) + } + val inMetadataAttrs = projections.metadataProjection match { + case Some(projection) => DataTypeUtils.toAttributes(projection.schema) + case None => Nil + } + areCompatible(inMetadataAttrs, outMetadataAttrs) + } + + private def isMetadataNullabilityPreserved(attr: Attribute): Boolean = { + operation.command == DELETE || MetadataAttribute.isPreservedOnUpdate(attr) } override def withNewQuery(newQuery: LogicalPlan): ReplaceData = copy(query = newQuery) @@ -331,65 +350,52 @@ case class WriteDelta( override def outputResolved: Boolean = { assert(table.resolved && query.resolved, "`outputResolved` can only be called when `table` and `query` are both resolved.") - operationResolved && rowAttrsResolved && rowIdAttrsResolved && metadataAttrsResolved } - private def operationResolved: Boolean = { - val attr = query.output.head - attr.name == RowDeltaUtils.OPERATION_COLUMN && attr.dataType == IntegerType && !attr.nullable - } - // validates row projection output is compatible with table attributes private def rowAttrsResolved: Boolean = { - table.skipSchemaResolution || (projections.rowProjection match { - case Some(projection) => - table.output.size == projection.schema.size && - projection.schema.zip(table.output).forall { case (field, outAttr) => - isCompatible(field, outAttr) - } - case None => - true - }) + val outRowAttrs = if (operation.command == DELETE) Nil else table.output + val inRowAttrs = projections.rowProjection match { + case Some(projection) => DataTypeUtils.toAttributes(projection.schema) + case None => Nil + } + table.skipSchemaResolution || areCompatible(inRowAttrs, outRowAttrs) } // validates row ID projection output is compatible with row ID attributes private def rowIdAttrsResolved: Boolean = { - val rowIdAttrs = V2ExpressionUtils.resolveRefs[AttributeReference]( + val outRowIdAttrs = V2ExpressionUtils.resolveRefs[AttributeReference]( operation.rowId.toImmutableArraySeq, originalTable) - - val projectionSchema = projections.rowIdProjection.schema - rowIdAttrs.size == projectionSchema.size && projectionSchema.forall { field => - rowIdAttrs.exists(rowIdAttr => isCompatible(field, rowIdAttr)) - } + val inRowIdAttrs = DataTypeUtils.toAttributes(projections.rowIdProjection.schema) + areCompatible(inRowIdAttrs, outRowIdAttrs) } // validates metadata projection output is compatible with metadata attributes private def metadataAttrsResolved: Boolean = { - projections.metadataProjection match { - case Some(projection) => - val metadataAttrs = V2ExpressionUtils.resolveRefs[AttributeReference]( - operation.requiredMetadataAttributes.toImmutableArraySeq, - originalTable) - - val projectionSchema = projection.schema - metadataAttrs.size == projectionSchema.size && projectionSchema.forall { field => - metadataAttrs.exists(metadataAttr => isCompatible(field, metadataAttr)) - } - case None => - true + val outMetadataAttrs = projectedMetadataAttrs.map { + case attr if isMetadataNullabilityPreserved(attr) => attr + case attr => attr.withNullability(true) } + val inMetadataAttrs = projections.metadataProjection match { + case Some(projection) => DataTypeUtils.toAttributes(projection.schema) + case None => Nil + } + areCompatible(inMetadataAttrs, outMetadataAttrs) } - // checks if a projection field is compatible with a table attribute - private def isCompatible(inField: StructField, outAttr: NamedExpression): Boolean = { - val inType = CharVarcharUtils.getRawType(inField.metadata).getOrElse(inField.dataType) - val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType) - // names and types must match, nullability must be compatible - inField.name == outAttr.name && - DataType.equalsIgnoreCompatibleNullability(inType, outType) && - (outAttr.nullable || !inField.nullable) + private def isMetadataNullabilityPreserved(attr: Attribute): Boolean = { + operation.command match { + case DELETE => + MetadataAttribute.isPreservedOnDelete(attr) + case UPDATE | MERGE if operation.representUpdateAsDeleteAndInsert => + MetadataAttribute.isPreservedOnDelete(attr) && MetadataAttribute.isPreservedOnReinsert(attr) + case UPDATE => + MetadataAttribute.isPreservedOnUpdate(attr) + case MERGE => + MetadataAttribute.isPreservedOnDelete(attr) && MetadataAttribute.isPreservedOnUpdate(attr) + } } override def withNewQuery(newQuery: LogicalPlan): V2WriteCommand = copy(query = newQuery) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ReplaceDataProjections.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ReplaceDataProjections.scala new file mode 100644 index 0000000000000..99744faf2c749 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ReplaceDataProjections.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.ProjectingInternalRow + +case class ReplaceDataProjections( + rowProjection: ProjectingInternalRow, + metadataProjection: Option[ProjectingInternalRow]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RowDeltaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RowDeltaUtils.scala index db39c059c24e3..72baad069b180 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RowDeltaUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RowDeltaUtils.scala @@ -25,5 +25,8 @@ object RowDeltaUtils { final val DELETE_OPERATION: Int = 1 final val UPDATE_OPERATION: Int = 2 final val INSERT_OPERATION: Int = 3 + final val REINSERT_OPERATION: Int = 4 + final val WRITE_OPERATION: Int = 5 + final val WRITE_WITH_METADATA_OPERATION: Int = 6 final val ORIGINAL_ROW_ID_VALUE_PREFIX: String = "__original_row_id_" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 134c4f1bb13a4..841f367896c40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -25,6 +25,7 @@ import com.google.common.io.ByteStreams import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.connector.catalog.MetadataColumn import org.apache.spark.sql.types.{MetadataBuilder, NumericType, StringType, StructType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{SparkErrorUtils, Utils} @@ -193,7 +194,10 @@ package object util extends Logging { QUALIFIED_ACCESS_ONLY, FileSourceMetadataAttribute.FILE_SOURCE_METADATA_COL_ATTR_KEY, FileSourceConstantMetadataStructField.FILE_SOURCE_CONSTANT_METADATA_COL_ATTR_KEY, - FileSourceGeneratedMetadataStructField.FILE_SOURCE_GENERATED_METADATA_COL_ATTR_KEY + FileSourceGeneratedMetadataStructField.FILE_SOURCE_GENERATED_METADATA_COL_ATTR_KEY, + MetadataColumn.PRESERVE_ON_DELETE, + MetadataColumn.PRESERVE_ON_UPDATE, + MetadataColumn.PRESERVE_ON_REINSERT ) def removeInternalMetadata(schema: StructType): StructType = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/LogicalWriteInfoImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/LogicalWriteInfoImpl.scala index 8c0828d8a278b..1e4e1a5955f3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/LogicalWriteInfoImpl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/LogicalWriteInfoImpl.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.connector.write import java.util.Optional +import scala.jdk.OptionConverters._ + import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -29,3 +31,19 @@ private[sql] case class LogicalWriteInfoImpl( override val rowIdSchema: Optional[StructType] = Optional.empty[StructType], override val metadataSchema: Optional[StructType] = Optional.empty[StructType]) extends LogicalWriteInfo + +object LogicalWriteInfoImpl { + def apply( + queryId: String, + schema: StructType, + options: CaseInsensitiveStringMap, + rowIdSchema: Option[StructType], + metadataSchema: Option[StructType]): LogicalWriteInfoImpl = { + LogicalWriteInfoImpl( + queryId, + schema, + options, + rowIdSchema.toJava, + metadataSchema.toJava) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala index 6b884713bd5c3..b24885270f52d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala @@ -109,7 +109,7 @@ object DataSourceV2Implicits { name = metaCol.name, dataType = metaCol.dataType, nullable = metaCol.isNullable, - metadata = MetadataAttribute.metadata(metaCol.name)) + metadata = MetadataAttribute.metadata(metaCol)) Option(metaCol.comment).map(field.withComment).getOrElse(field) } StructType(fields) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index ab17b93ad6146..3ac8c3794b8ad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -63,13 +63,29 @@ abstract class InMemoryBaseTable( protected object PartitionKeyColumn extends MetadataColumn { override def name: String = "_partition" override def dataType: DataType = StringType + override def isNullable: Boolean = false override def comment: String = "Partition key used to store the row" + override def metadataInJSON(): String = { + val metadata = new MetadataBuilder() + .putBoolean(MetadataColumn.PRESERVE_ON_UPDATE, value = true) + .putBoolean(MetadataColumn.PRESERVE_ON_REINSERT, value = true) + .build() + metadata.json + } } - private object IndexColumn extends MetadataColumn { + protected object IndexColumn extends MetadataColumn { override def name: String = "index" override def dataType: DataType = IntegerType + override def isNullable: Boolean = false override def comment: String = "Metadata column used to conflict with a data column" + override def metadataInJSON(): String = { + val metadata = new MetadataBuilder() + .putBoolean(MetadataColumn.PRESERVE_ON_DELETE, value = false) + .putBoolean(MetadataColumn.PRESERVE_ON_UPDATE, value = false) + .build() + metadata.json + } } // purposely exposes a metadata column that conflicts with a data column in some tests @@ -617,6 +633,7 @@ object InMemoryBaseTable { class BufferedRows(val key: Seq[Any] = Seq.empty) extends WriterCommitMessage with InputPartition with HasPartitionKey with HasPartitionStatistics with Serializable { + val log = new mutable.ArrayBuffer[InternalRow]() val rows = new mutable.ArrayBuffer[InternalRow]() val deletes = new mutable.ArrayBuffer[Int]() @@ -734,9 +751,22 @@ private object BufferedRowsWriterFactory extends DataWriterFactory with Streamin } private class BufferWriter extends DataWriter[InternalRow] { + + private final val WRITE = UTF8String.fromString(Write.toString) + protected val buffer = new BufferedRows - override def write(row: InternalRow): Unit = buffer.rows.append(row.copy()) + override def write(metadata: InternalRow, row: InternalRow): Unit = { + buffer.rows.append(row.copy()) + val logEntry = new GenericInternalRow(Array[Any](WRITE, null, metadata.copy(), row.copy())) + buffer.log.append(logEntry) + } + + override def write(row: InternalRow): Unit = { + buffer.rows.append(row.copy()) + val logEntry = new GenericInternalRow(Array[Any](WRITE, null, null, row.copy())) + buffer.log.append(logEntry) + } override def commit(): WriterCommitMessage = buffer @@ -771,3 +801,10 @@ class InMemoryCustomDriverTaskMetric(value: Long) extends CustomTaskMetric { override def name(): String = "number_of_rows_from_driver" override def value(): Long = value } + +sealed trait Operation +case object Write extends Operation +case object Delete extends Operation +case object Update extends Operation +case object Reinsert extends Operation +case object Insert extends Operation diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala index 3a684dc57c02f..98678289fa259 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.connector.catalog import java.util import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} import org.apache.spark.sql.connector.expressions.{FieldReference, LogicalExpressions, NamedReference, SortDirection, SortOrder, Transform} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder} @@ -27,6 +28,8 @@ import org.apache.spark.sql.connector.write.{BatchWrite, DeltaBatchWrite, DeltaW import org.apache.spark.sql.connector.write.RowLevelOperation.Command import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ArrayImplicits._ class InMemoryRowLevelOperationTable( name: String, @@ -36,11 +39,17 @@ class InMemoryRowLevelOperationTable( extends InMemoryTable(name, schema, partitioning, properties) with SupportsRowLevelOperations { private final val PARTITION_COLUMN_REF = FieldReference(PartitionKeyColumn.name) + private final val INDEX_COLUMN_REF = FieldReference(IndexColumn.name) private final val SUPPORTS_DELTAS = "supports-deltas" private final val SPLIT_UPDATES = "split-updates" // used in row-level operation tests to verify replaced partitions var replacedPartitions: Seq[Seq[Any]] = Seq.empty + // used in row-level operation tests to verify reported write schema + var lastWriteInfo: LogicalWriteInfo = _ + // used in row-level operation tests to verify passed records + // (operation, id, metadata, row) + var lastWriteLog: Seq[InternalRow] = Seq.empty override def newRowLevelOperationBuilder( info: RowLevelOperationInfo): RowLevelOperationBuilder = { @@ -55,7 +64,7 @@ class InMemoryRowLevelOperationTable( var configuredScan: InMemoryBatchScan = _ override def requiredMetadataAttributes(): Array[NamedReference] = { - Array(PARTITION_COLUMN_REF) + Array(PARTITION_COLUMN_REF, INDEX_COLUMN_REF) } override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { @@ -68,25 +77,26 @@ class InMemoryRowLevelOperationTable( } } - override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = new WriteBuilder { - - override def build(): Write = new Write with RequiresDistributionAndOrdering { - override def requiredDistribution(): Distribution = { - Distributions.clustered(Array(PARTITION_COLUMN_REF)) - } + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + lastWriteInfo = info + new WriteBuilder { + override def build(): Write = new Write with RequiresDistributionAndOrdering { + override def requiredDistribution: Distribution = { + Distributions.clustered(Array(PARTITION_COLUMN_REF)) + } - override def requiredOrdering(): Array[SortOrder] = { - Array[SortOrder]( - LogicalExpressions.sort( - PARTITION_COLUMN_REF, - SortDirection.ASCENDING, - SortDirection.ASCENDING.defaultNullOrdering()) - ) - } + override def requiredOrdering: Array[SortOrder] = { + Array[SortOrder]( + LogicalExpressions.sort( + PARTITION_COLUMN_REF, + SortDirection.ASCENDING, + SortDirection.ASCENDING.defaultNullOrdering())) + } - override def toBatch: BatchWrite = PartitionBasedReplaceData(configuredScan) + override def toBatch: BatchWrite = PartitionBasedReplaceData(configuredScan) - override def description(): String = "InMemoryWrite" + override def description: String = "InMemoryWrite" + } } } @@ -102,6 +112,7 @@ class InMemoryRowLevelOperationTable( dataMap --= readPartitions replacedPartitions = readPartitions withData(newData, schema) + lastWriteLog = newData.flatMap(buffer => buffer.log).toImmutableArraySeq } } @@ -109,7 +120,7 @@ class InMemoryRowLevelOperationTable( private final val PK_COLUMN_REF = FieldReference("pk") override def requiredMetadataAttributes(): Array[NamedReference] = { - Array(PARTITION_COLUMN_REF) + Array(PARTITION_COLUMN_REF, INDEX_COLUMN_REF) } override def rowId(): Array[NamedReference] = Array(PK_COLUMN_REF) @@ -118,7 +129,8 @@ class InMemoryRowLevelOperationTable( new InMemoryScanBuilder(schema, options) } - override def newWriteBuilder(info: LogicalWriteInfo): DeltaWriteBuilder = + override def newWriteBuilder(info: LogicalWriteInfo): DeltaWriteBuilder = { + lastWriteInfo = info new DeltaWriteBuilder { override def build(): DeltaWrite = new DeltaWrite with RequiresDistributionAndOrdering { @@ -138,6 +150,7 @@ class InMemoryRowLevelOperationTable( override def toBatch: DeltaBatchWrite = TestDeltaBatchWrite } } + } override def representUpdateAsDeleteAndInsert(): Boolean = { properties.getOrDefault(SPLIT_UPDATES, "false").toBoolean @@ -150,8 +163,10 @@ class InMemoryRowLevelOperationTable( } override def commit(messages: Array[WriterCommitMessage]): Unit = { - withDeletes(messages.map(_.asInstanceOf[BufferedRows])) - withData(messages.map(_.asInstanceOf[BufferedRows])) + val newData = messages.map(_.asInstanceOf[BufferedRows]) + withDeletes(newData) + withData(newData) + lastWriteLog = newData.flatMap(buffer => buffer.log).toIndexedSeq } override def abort(messages: Array[WriterCommitMessage]): Unit = {} @@ -166,16 +181,41 @@ private object DeltaBufferedRowsWriterFactory extends DeltaWriterFactory { private class DeltaBufferWriter extends BufferWriter with DeltaWriter[InternalRow] { - override def delete(meta: InternalRow, id: InternalRow): Unit = buffer.deletes += id.getInt(0) + private final val DELETE = UTF8String.fromString(Delete.toString) + private final val UPDATE = UTF8String.fromString(Update.toString) + private final val REINSERT = UTF8String.fromString(Reinsert.toString) + private final val INSERT = UTF8String.fromString(Insert.toString) + + override def delete(meta: InternalRow, id: InternalRow): Unit = { + val pk = id.getInt(0) + buffer.deletes += pk + val logEntry = new GenericInternalRow(Array[Any](DELETE, pk, meta.copy(), null)) + buffer.log += logEntry + } override def update(meta: InternalRow, id: InternalRow, row: InternalRow): Unit = { - buffer.deletes += id.getInt(0) - write(row) + val pk = id.getInt(0) + buffer.deletes += pk + buffer.rows.append(row.copy()) + val logEntry = new GenericInternalRow(Array[Any](UPDATE, pk, meta.copy(), row.copy())) + buffer.log += logEntry } - override def insert(row: InternalRow): Unit = write(row) + override def reinsert(meta: InternalRow, row: InternalRow): Unit = { + buffer.rows.append(row.copy()) + val logEntry = new GenericInternalRow(Array[Any](REINSERT, null, meta.copy(), row.copy())) + buffer.log += logEntry + } + + override def insert(row: InternalRow): Unit = { + buffer.rows.append(row.copy()) + val logEntry = new GenericInternalRow(Array[Any](INSERT, null, null, row.copy())) + buffer.log += logEntry + } - override def write(row: InternalRow): Unit = super[BufferWriter].write(row) + override def write(row: InternalRow): Unit = { + throw new UnsupportedOperationException() + } override def commit(): WriterCommitMessage = buffer } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 499721fbae4e8..ce863791659bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -323,9 +323,10 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat throw SparkException.internalError("Unexpected table relation: " + other) } - case ReplaceData(_: DataSourceV2Relation, _, query, r: DataSourceV2Relation, _, Some(write)) => + case ReplaceData(_: DataSourceV2Relation, _, query, r: DataSourceV2Relation, projections, _, + Some(write)) => // use the original relation to refresh the cache - ReplaceDataExec(planLater(query), refreshCache(r), write) :: Nil + ReplaceDataExec(planLater(query), refreshCache(r), projections, write) :: Nil case WriteDelta(_: DataSourceV2Relation, _, query, r: DataSourceV2Relation, projections, Some(write)) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromTable.scala index 6eede88c55bd0..05fb37674a836 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromTable.scala @@ -74,7 +74,7 @@ object OptimizeMetadataOnlyDeleteFromTable extends Rule[LogicalPlan] with Predic type ReturnType = (RowLevelWrite, RowLevelOperation.Command, Expression, LogicalPlan) def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { - case rd @ ReplaceData(_, cond, _, originalTable, _, _) => + case rd @ ReplaceData(_, cond, _, originalTable, _, _, _) => val command = rd.operation.command Some(rd, command, cond, originalTable) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala index 17b2579ca873a..9d059416766a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.execution.datasources.v2 -import java.util.{Optional, UUID} +import java.util.UUID import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.expressions.PredicateHelper -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Project, ReplaceData, WriteDelta} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceData, WriteDelta} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.WriteDeltaProjections import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table} import org.apache.spark.sql.connector.expressions.filter.Predicate @@ -95,7 +94,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { relationOpt, table, query, queryId, options, outputMode, Some(batchId)) => val writeOptions = mergeOptions( options, relationOpt.map(r => r.options.asScala.toMap).getOrElse(Map.empty)) - val writeBuilder = newWriteBuilder(table, writeOptions, query.schema, queryId) + val writeBuilder = newWriteBuilder(table, writeOptions, query.schema, queryId = queryId) val write = buildWriteForMicroBatch(table, writeBuilder, outputMode) val microBatchWrite = new MicroBatchWrite(batchId, write.toStreaming) val customMetrics = write.supportedCustomMetrics.toImmutableArraySeq @@ -103,14 +102,14 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, funCatalogOpt) WriteToDataSourceV2(relationOpt, microBatchWrite, newQuery, customMetrics) - case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, _, None) => - val rowSchema = DataTypeUtils.fromAttributes(rd.dataInput) + case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, projections, _, None) => + val rowSchema = projections.rowProjection.schema + val metadataSchema = projections.metadataProjection.map(_.schema) val writeOptions = mergeOptions(Map.empty, r.options.asScala.toMap) - val writeBuilder = newWriteBuilder(r.table, writeOptions, rowSchema) + val writeBuilder = newWriteBuilder(r.table, writeOptions, rowSchema, metadataSchema) val write = writeBuilder.build() val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog) - // project away any metadata columns that could be used for distribution and ordering - rd.copy(write = Some(write), query = Project(rd.dataInput, newQuery)) + rd.copy(write = Some(write), query = newQuery) case wd @ WriteDelta(r: DataSourceV2Relation, _, query, _, projections, None) => val writeOptions = mergeOptions(Map.empty, r.options.asScala.toMap) @@ -158,9 +157,15 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { table: Table, writeOptions: Map[String, String], rowSchema: StructType, + metadataSchema: Option[StructType] = None, queryId: String = UUID.randomUUID().toString): WriteBuilder = { - val info = LogicalWriteInfoImpl(queryId, rowSchema, writeOptions.asOptions) + val info = LogicalWriteInfoImpl( + queryId, + rowSchema, + writeOptions.asOptions, + rowIdSchema = None, + metadataSchema) table.asWritable.newWriteBuilder(info) } @@ -171,15 +176,15 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { queryId: String = UUID.randomUUID().toString): DeltaWriteBuilder = { val rowSchema = projections.rowProjection.map(_.schema).getOrElse(StructType(Nil)) - val rowIdSchema = projections.rowIdProjection.schema + val rowIdSchema = Some(projections.rowIdProjection.schema) val metadataSchema = projections.metadataProjection.map(_.schema) val info = LogicalWriteInfoImpl( queryId, rowSchema, writeOptions.asOptions, - Optional.of(rowIdSchema), - Optional.ofNullable(metadataSchema.orNull)) + rowIdSchema, + metadataSchema) val writeBuilder = table.asWritable.newWriteBuilder(info) assert(writeBuilder.isInstanceOf[DeltaWriteBuilder], s"$writeBuilder must be DeltaWriteBuilder") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 308b1bceca12a..016d6b5411acb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -22,12 +22,12 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{InternalRow, ProjectingInternalRow} import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, TableSpec, UnaryNode} -import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils, WriteDeltaProjections} -import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{DELETE_OPERATION, INSERT_OPERATION, UPDATE_OPERATION} +import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils, ReplaceDataProjections, WriteDeltaProjections} +import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{DELETE_OPERATION, INSERT_OPERATION, REINSERT_OPERATION, UPDATE_OPERATION, WRITE_OPERATION, WRITE_WITH_METADATA_OPERATION} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableWritePrivilege} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.CustomMetric @@ -283,10 +283,20 @@ case class OverwritePartitionsDynamicExec( case class ReplaceDataExec( query: SparkPlan, refreshCache: () => Unit, + projections: ReplaceDataProjections, write: Write) extends V2ExistingTableWriteExec { override val stringArgs: Iterator[Any] = Iterator(query, write) + override def writingTask: WritingSparkTask[_] = { + projections match { + case ReplaceDataProjections(dataProj, Some(metadataProj)) => + DataAndMetadataWritingSparkTask(dataProj, metadataProj) + case _ => + DataWritingSparkTask + } + } + override protected def withNewChildInternal(newChild: SparkPlan): ReplaceDataExec = { copy(query = newChild) } @@ -542,6 +552,32 @@ trait WritingSparkTask[W <: DataWriter[InternalRow]] extends Logging with Serial } } +case class DataAndMetadataWritingSparkTask( + dataProj: ProjectingInternalRow, + metadataProj: ProjectingInternalRow) extends WritingSparkTask[DataWriter[InternalRow]] { + override protected def write( + writer: DataWriter[InternalRow], iter: java.util.Iterator[InternalRow]): Unit = { + while (iter.hasNext) { + val row = iter.next() + val operation = row.getInt(0) + + operation match { + case WRITE_WITH_METADATA_OPERATION => + dataProj.project(row) + metadataProj.project(row) + writer.write(metadataProj, dataProj) + + case WRITE_OPERATION => + dataProj.project(row) + writer.write(dataProj) + + case other => + throw new SparkException(s"Unexpected operation ID: $other") + } + } + } +} + object DataWritingSparkTask extends WritingSparkTask[DataWriter[InternalRow]] { override protected def write( writer: DataWriter[InternalRow], iter: java.util.Iterator[InternalRow]): Unit = { @@ -571,6 +607,10 @@ case class DeltaWritingSparkTask( rowIdProjection.project(row) writer.update(null, rowIdProjection, rowProjection) + case REINSERT_OPERATION => + rowProjection.project(row) + writer.reinsert(null, rowProjection) + case INSERT_OPERATION => rowProjection.project(row) writer.insert(rowProjection) @@ -607,6 +647,11 @@ case class DeltaWithMetadataWritingSparkTask( metadataProjection.project(row) writer.update(metadataProjection, rowIdProjection, rowProjection) + case REINSERT_OPERATION => + rowProjection.project(row) + metadataProjection.project(row) + writer.reinsert(metadataProjection, rowProjection) + case INSERT_OPERATION => rowProjection.project(row) writer.insert(rowProjection) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala index be180eb89ce20..8ad713424cec5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala @@ -18,15 +18,71 @@ package org.apache.spark.sql.connector import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.types.StructType class DeltaBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { + import testImplicits._ + override protected lazy val extraTableProps: java.util.Map[String, String] = { val props = new java.util.HashMap[String, String]() props.put("supports-deltas", "true") props } + test("delete handles metadata columns correctly") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |{ "pk": 3, "id": 3, "dep": "hr" } + |{ "pk": 4, "id": 4, "dep": "hr" } + |""".stripMargin) + + sql(s"DELETE FROM $tableNameAsString WHERE id IN (1, 100)") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(2, 2, "software") :: Row(3, 3, "hr") :: Row(4, 4, "hr") :: Nil) + + checkLastWriteInfo( + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + + checkLastWriteLog(deleteWriteLogEntry(id = 1, metadata = Row("hr", null))) + } + + test("delete with subquery handles metadata columns correctly") { + withTempView("updated_dep") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |{ "pk": 3, "id": 3, "dep": "hr" } + |{ "pk": 4, "id": 4, "dep": "hr" } + |""".stripMargin) + + val updatedDepDF = Seq(Some("hr"), Some("it")).toDF() + updatedDepDF.createOrReplaceTempView("updated_dep") + + sql( + s"""DELETE FROM $tableNameAsString + |WHERE + | id IN (1, 100) + | AND + | dep IN (SELECT * FROM updated_dep) + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(2, 2, "software") :: Row(3, 3, "hr") :: Row(4, 4, "hr") :: Nil) + + checkLastWriteInfo( + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + + checkLastWriteLog(deleteWriteLogEntry(id = 1, metadata = Row("hr", null))) + } + } + test("delete with nondeterministic conditions") { createAndInitTable("pk INT NOT NULL, id INT, dep STRING", """{ "pk": 1, "id": 1, "dep": "hr" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableSuite.scala index ff0f8b82bc9ce..32d602b154952 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableSuite.scala @@ -17,11 +17,64 @@ package org.apache.spark.sql.connector +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.StructType + class DeltaBasedMergeIntoTableSuite extends DeltaBasedMergeIntoTableSuiteBase { + import testImplicits._ + override protected lazy val extraTableProps: java.util.Map[String, String] = { val props = new java.util.HashMap[String, String]() props.put("supports-deltas", "true") props } + + test("merge handles metadata columns correctly") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |{ "pk": 4, "salary": 400, "dep": "hr" } + |{ "pk": 5, "salary": 500, "dep": "hr" } + |""".stripMargin) + + val sourceDF = Seq(3, 4, 5, 6).toDF("pk") + sourceDF.createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET t.salary = t.salary + 1 + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, 0, 'new') + |WHEN NOT MATCHED BY SOURCE AND pk = 1 THEN + | DELETE + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(3, 301, "hr"), // update + Row(4, 401, "hr"), // update + Row(5, 501, "hr"), // update + Row(6, 0, "new"))) // insert + + checkLastWriteInfo( + expectedRowSchema = table.schema, + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + + checkLastWriteLog( + deleteWriteLogEntry(id = 1, metadata = Row("hr", null)), + updateWriteLogEntry(id = 3, metadata = Row("hr", null), data = Row(3, 301, "hr")), + updateWriteLogEntry(id = 4, metadata = Row("hr", null), data = Row(4, 401, "hr")), + updateWriteLogEntry(id = 5, metadata = Row("hr", null), data = Row(5, 501, "hr")), + insertWriteLogEntry(data = Row(6, 0, "new"))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite.scala index 405ee99bb6dc5..e93d4165be332 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite.scala @@ -17,13 +17,69 @@ package org.apache.spark.sql.connector +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.StructType + class DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite extends DeltaBasedMergeIntoTableSuiteBase { + import testImplicits._ + override protected lazy val extraTableProps: java.util.Map[String, String] = { val props = new java.util.HashMap[String, String]() props.put("supports-deltas", "true") props.put("split-updates", "true") props } + + test("merge handles metadata columns correctly") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |{ "pk": 4, "salary": 400, "dep": "hr" } + |{ "pk": 5, "salary": 500, "dep": "hr" } + |""".stripMargin) + + val sourceDF = Seq(3, 4, 5, 6).toDF("pk") + sourceDF.createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET t.salary = t.salary + 1 + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, 0, 'new') + |WHEN NOT MATCHED BY SOURCE AND pk = 1 THEN + | DELETE + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(3, 301, "hr"), // update + Row(4, 401, "hr"), // update + Row(5, 501, "hr"), // update + Row(6, 0, "new"))) // insert + + checkLastWriteInfo( + expectedRowSchema = table.schema, + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + + checkLastWriteLog( + deleteWriteLogEntry(id = 1, metadata = Row("hr", null)), + deleteWriteLogEntry(id = 3, metadata = Row("hr", null)), + reinsertWriteLogEntry(metadata = Row("hr", null), data = Row(3, 301, "hr")), + deleteWriteLogEntry(id = 4, metadata = Row("hr", null)), + reinsertWriteLogEntry(metadata = Row("hr", null), data = Row(4, 401, "hr")), + deleteWriteLogEntry(id = 5, metadata = Row("hr", null)), + reinsertWriteLogEntry(metadata = Row("hr", null), data = Row(5, 501, "hr")), + insertWriteLogEntry(data = Row(6, 0, "new"))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateAsDeleteAndInsertTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateAsDeleteAndInsertTableSuite.scala index 363b90b45b87e..612a26e756abd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateAsDeleteAndInsertTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateAsDeleteAndInsertTableSuite.scala @@ -17,12 +17,81 @@ package org.apache.spark.sql.connector +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.StructType + class DeltaBasedUpdateAsDeleteAndInsertTableSuite extends DeltaBasedUpdateTableSuiteBase { + import testImplicits._ + override protected lazy val extraTableProps: java.util.Map[String, String] = { val props = new java.util.HashMap[String, String]() props.put("supports-deltas", "true") props.put("split-updates", "true") props } + + test("update handles metadata columns correctly") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |{ "pk": 3, "id": 3, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1 WHERE id IN (1, 100)") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil) + + checkLastWriteInfo( + expectedRowSchema = StructType(table.schema.map { + case attr if attr.name == "id" => attr.copy(nullable = false) // input is a constant + case attr => attr + }), + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + + checkLastWriteLog( + deleteWriteLogEntry(id = 1, metadata = Row("hr", null)), + reinsertWriteLogEntry(metadata = Row("hr", null), data = Row(1, -1, "hr"))) + } + + test("update with subquery handles metadata columns correctly") { + withTempView("updated_dep") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |{ "pk": 3, "id": 3, "dep": "hr" } + |""".stripMargin) + + val updatedIdDF = Seq(Some("hr"), Some("it")).toDF() + updatedIdDF.createOrReplaceTempView("updated_dep") + + sql( + s"""UPDATE $tableNameAsString + |SET id = -1 + |WHERE + | id IN (1, 100) + | AND + | dep IN (SELECT * FROM updated_dep) + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil) + + checkLastWriteInfo( + expectedRowSchema = StructType(table.schema.map { + case attr if attr.name == "id" => attr.copy(nullable = false) // input is a constant + case attr => attr + }), + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + + checkLastWriteLog( + deleteWriteLogEntry(id = 1, metadata = Row("hr", null)), + reinsertWriteLogEntry(metadata = Row("hr", null), data = Row(1, -1, "hr"))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuite.scala index bf6dcfac04cb1..c9fd5d6e3ff0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuite.scala @@ -17,11 +17,78 @@ package org.apache.spark.sql.connector +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.StructType + class DeltaBasedUpdateTableSuite extends DeltaBasedUpdateTableSuiteBase { + import testImplicits._ + override protected lazy val extraTableProps: java.util.Map[String, String] = { val props = new java.util.HashMap[String, String]() props.put("supports-deltas", "true") props } + + test("update handles metadata columns correctly") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |{ "pk": 3, "id": 3, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1 WHERE id IN (1, 100)") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil) + + checkLastWriteInfo( + expectedRowSchema = StructType(table.schema.map { + case attr if attr.name == "id" => attr.copy(nullable = false) // input is a constant + case attr => attr + }), + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + + checkLastWriteLog( + updateWriteLogEntry(id = 1, metadata = Row("hr", null), data = Row(1, -1, "hr"))) + } + + test("update with subquery handles metadata columns correctly") { + withTempView("updated_dep") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |{ "pk": 3, "id": 3, "dep": "hr" } + |""".stripMargin) + + val updatedIdDF = Seq(Some("hr"), Some("it")).toDF() + updatedIdDF.createOrReplaceTempView("updated_dep") + + sql( + s"""UPDATE $tableNameAsString + |SET id = -1 + |WHERE + | id IN (1, 20) + | AND + | dep IN (SELECT * FROM updated_dep) + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil) + + checkLastWriteInfo( + expectedRowSchema = StructType(table.schema.map { + case attr if attr.name == "id" => attr.copy(nullable = false) // input is a constant + case attr => attr + }), + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + + checkLastWriteLog( + updateWriteLogEntry(id = 1, metadata = Row("hr", null), data = Row(1, -1, "hr"))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala index 1be318f948fd9..4dd09a2f1c831 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala @@ -19,11 +19,35 @@ package org.apache.spark.sql.connector import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { import testImplicits._ + test("delete preserves metadata columns for carried-over records") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |{ "pk": 3, "id": 3, "dep": "hr" } + |{ "pk": 4, "id": 4, "dep": "hr" } + |""".stripMargin) + + sql(s"DELETE FROM $tableNameAsString WHERE id IN (1, 100)") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(2, 2, "software") :: Row(3, 3, "hr") :: Row(4, 4, "hr") :: Nil) + + checkLastWriteInfo( + expectedRowSchema = table.schema, + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD)))) + + checkLastWriteLog( + writeWithMetadataLogEntry(metadata = Row("hr", 1), data = Row(3, 3, "hr")), + writeWithMetadataLogEntry(metadata = Row("hr", 2), data = Row(4, 4, "hr"))) + } + test("delete with nondeterministic conditions") { createAndInitTable("pk INT NOT NULL, id INT, dep STRING", """{ "pk": 1, "id": 1, "dep": "hr" } @@ -53,7 +77,7 @@ class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { executeAndCheckScans( s"DELETE FROM $tableNameAsString WHERE salary IN (300, 400, 500)", - primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING", + primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING, index INT", groupFilterScanSchema = Some("salary INT, dep STRING")) checkAnswer( @@ -85,7 +109,7 @@ class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { | AND | dep IN (SELECT * FROM deleted_dep) |""".stripMargin, - primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING", + primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING, index INT", groupFilterScanSchema = Some("id INT, dep STRING")) checkAnswer( @@ -133,7 +157,7 @@ class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { executeAndCheckScans( s"DELETE FROM $tableNameAsString WHERE id IN (SELECT * FROM deleted_id)", - primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING", + primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING, index INT", groupFilterScanSchema = Some("id INT, dep STRING")) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedMergeIntoTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedMergeIntoTableSuite.scala index ebc34ae006e6e..63eba256d8f27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedMergeIntoTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedMergeIntoTableSuite.scala @@ -19,11 +19,62 @@ package org.apache.spark.sql.connector import org.apache.spark.sql.Row import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType class GroupBasedMergeIntoTableSuite extends MergeIntoTableSuiteBase { import testImplicits._ + test("merge handles metadata columns correctly") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |{ "pk": 4, "salary": 400, "dep": "hr" } + |{ "pk": 5, "salary": 500, "dep": "hr" } + |{ "pk": 7, "salary": 700, "dep": "hr" } + |""".stripMargin) + + val sourceDF = Seq(3, 4, 5, 6, 7).toDF("pk") + sourceDF.createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED AND t.pk != 7 THEN + | UPDATE SET t.salary = t.salary + 1 + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, 0, 'new') + |WHEN NOT MATCHED BY SOURCE AND pk = 1 THEN + | DELETE + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(3, 301, "hr"), // update + Row(4, 401, "hr"), // update + Row(5, 501, "hr"), // update + Row(6, 0, "new"), // insert + Row(7, 700, "hr"))) // unchanged + + checkLastWriteInfo( + expectedRowSchema = table.schema, + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + + checkLastWriteLog( + writeWithMetadataLogEntry(metadata = Row("software", 0), data = Row(2, 200, "software")), + writeWithMetadataLogEntry(metadata = Row("hr", null), data = Row(3, 301, "hr")), + writeWithMetadataLogEntry(metadata = Row("hr", null), data = Row(4, 401, "hr")), + writeWithMetadataLogEntry(metadata = Row("hr", null), data = Row(5, 501, "hr")), + writeLogEntry(data = Row(6, 0, "new")), + writeWithMetadataLogEntry(metadata = Row("hr", 4), data = Row(7, 700, "hr"))) + } + } + test("merge runtime filtering is disabled with NOT MATCHED BY SOURCE clauses") { withTempView("source") { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", @@ -48,7 +99,7 @@ class GroupBasedMergeIntoTableSuite extends MergeIntoTableSuiteBase { |WHEN NOT MATCHED BY SOURCE THEN | DELETE |""".stripMargin, - primaryScanSchema = "pk INT, salary INT, dep STRING, _partition STRING", + primaryScanSchema = "pk INT, salary INT, dep STRING, _partition STRING, index INT", groupFilterScanSchema = None) checkAnswer( @@ -109,7 +160,7 @@ class GroupBasedMergeIntoTableSuite extends MergeIntoTableSuiteBase { |WHEN NOT MATCHED THEN | INSERT (pk, salary, dep) VALUES (s.pk, 0, 'hr') |""".stripMargin, - primaryScanSchema = "pk INT, salary INT, dep STRING, _partition STRING", + primaryScanSchema = "pk INT, salary INT, dep STRING, _partition STRING, index INT", groupFilterScanSchema = Some("pk INT, dep STRING")) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedUpdateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedUpdateTableSuite.scala index 774ae97734d25..30545f5aa01aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedUpdateTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedUpdateTableSuite.scala @@ -22,11 +22,68 @@ import org.apache.spark.sql.catalyst.expressions.DynamicPruningExpression import org.apache.spark.sql.execution.{InSubqueryExec, ReusedSubqueryExec} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType class GroupBasedUpdateTableSuite extends UpdateTableSuiteBase { import testImplicits._ + test("update handles metadata columns correctly") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |{ "pk": 3, "id": 3, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1 WHERE id IN (1, 100)") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil) + + checkLastWriteInfo( + expectedRowSchema = table.schema, + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + + checkLastWriteLog( + writeWithMetadataLogEntry(metadata = Row("hr", null), data = Row(1, -1, "hr")), + writeWithMetadataLogEntry(metadata = Row("hr", 1), data = Row(3, 3, "hr"))) + } + + test("update with subquery handles metadata columns correctly") { + withTempView("updated_dep") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |{ "pk": 3, "id": 3, "dep": "hr" } + |""".stripMargin) + + val updatedIdDF = Seq(Some("hr"), Some("it")).toDF() + updatedIdDF.createOrReplaceTempView("updated_dep") + + sql( + s"""UPDATE $tableNameAsString + |SET id = -1 + |WHERE + | id IN (1, 20) + | AND + | dep IN (SELECT * FROM updated_dep) + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil) + + checkLastWriteInfo( + expectedRowSchema = table.schema, + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + + checkLastWriteLog( + writeWithMetadataLogEntry(metadata = Row("hr", null), data = Row(1, -1, "hr")), + writeWithMetadataLogEntry(metadata = Row("hr", 1), data = Row(3, 3, "hr"))) + } + } + test("update runtime group filtering") { Seq(true, false).foreach { ddpEnabled => Seq(true, false).foreach { aqeEnabled => @@ -53,7 +110,7 @@ class GroupBasedUpdateTableSuite extends UpdateTableSuiteBase { executeAndCheckScans( s"UPDATE $tableNameAsString SET salary = -1 WHERE id IN (SELECT * FROM deleted_id)", - primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING", + primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING, index INT", groupFilterScanSchema = Some("id INT, dep STRING")) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala index 68f996ba31367..580638230218b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala @@ -21,10 +21,13 @@ import java.util.Collections import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.{DataFrame, Encoders, QueryTest} -import org.apache.spark.sql.catalyst.expressions.DynamicPruningExpression +import org.apache.spark.sql.{DataFrame, Encoders, QueryTest, Row} +import org.apache.spark.sql.QueryTest.sameRows +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, GenericRowWithSchema} import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog} +import org.apache.spark.sql.catalyst.util.METADATA_COL_ATTR_KEY +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Delete, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog, Insert, MetadataColumn, Operation, Reinsert, Update, Write} import org.apache.spark.sql.connector.expressions.LogicalExpressions.{identity, reference} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.{InSubqueryExec, QueryExecution, SparkPlan} @@ -32,7 +35,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, StringType, StructField, StructType} import org.apache.spark.sql.util.QueryExecutionListener import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ @@ -51,6 +54,30 @@ abstract class RowLevelOperationSuiteBase spark.sessionState.conf.unsetConf("spark.sql.catalog.cat") } + protected final val PK_FIELD = StructField("pk", IntegerType, nullable = false) + protected final val PARTITION_FIELD = StructField( + "_partition", + StringType, + nullable = false, + metadata = new MetadataBuilder() + .putString(METADATA_COL_ATTR_KEY, "_partition") + .putString("comment", "Partition key used to store the row") + .putBoolean(MetadataColumn.PRESERVE_ON_UPDATE, value = true) + .putBoolean(MetadataColumn.PRESERVE_ON_REINSERT, value = true) + .build()) + protected final val PARTITION_FIELD_NULLABLE = PARTITION_FIELD.copy(nullable = true) + protected final val INDEX_FIELD = StructField( + "index", + IntegerType, + nullable = false, + metadata = new MetadataBuilder() + .putString(METADATA_COL_ATTR_KEY, "index") + .putString("comment", "Metadata column used to conflict with a data column") + .putBoolean(MetadataColumn.PRESERVE_ON_DELETE, value = false) + .putBoolean(MetadataColumn.PRESERVE_ON_UPDATE, value = false) + .build()) + protected final val INDEX_FIELD_NULLABLE = INDEX_FIELD.copy(nullable = true) + protected val namespace: Array[String] = Array("ns1") protected val ident: Identifier = Identifier.of(namespace, "test_table") protected val tableNameAsString: String = "cat." + ident.toString @@ -176,4 +203,77 @@ abstract class RowLevelOperationSuiteBase } assert(actualPartitions == expectedPartitions, "replaced partitions must match") } + + protected def checkLastWriteInfo( + expectedRowSchema: StructType = new StructType(), + expectedRowIdSchema: Option[StructType] = None, + expectedMetadataSchema: Option[StructType] = None): Unit = { + val info = table.lastWriteInfo + assert(info.schema == expectedRowSchema, "row schema must match") + val actualRowIdSchema = Option(info.rowIdSchema.orElse(null)) + assert(actualRowIdSchema == expectedRowIdSchema, "row ID schema must match") + val actualMetadataSchema = Option(info.metadataSchema.orElse(null)) + assert(actualMetadataSchema == expectedMetadataSchema, "metadata schema must match") + } + + protected def checkLastWriteLog(expectedEntries: WriteLogEntry*): Unit = { + val entryType = new StructType() + .add(StructField("operation", StringType)) + .add(StructField("id", IntegerType)) + .add(StructField( + "metadata", + new StructType(Array( + StructField("_partition", StringType), + StructField("_index", IntegerType))))) + .add(StructField("data", table.schema)) + + val expectedEntriesAsRows = expectedEntries.map { entry => + new GenericRowWithSchema( + values = Array( + entry.operation.toString, + entry.id.orNull, + entry.metadata.orNull, + entry.data.orNull), + schema = entryType) + } + + val encoder = ExpressionEncoder(entryType) + val deserializer = encoder.resolveAndBind().createDeserializer() + val actualEntriesAsRows = table.lastWriteLog.map(deserializer) + + sameRows(expectedEntriesAsRows, actualEntriesAsRows) match { + case Some(errMsg) => fail(s"Write log contains unexpected entries: $errMsg") + case None => // OK + } + } + + protected def writeLogEntry(data: Row): WriteLogEntry = { + WriteLogEntry(operation = Write, data = Some(data)) + } + + protected def writeWithMetadataLogEntry(metadata: Row, data: Row): WriteLogEntry = { + WriteLogEntry(operation = Write, metadata = Some(metadata), data = Some(data)) + } + + protected def deleteWriteLogEntry(id: Int, metadata: Row): WriteLogEntry = { + WriteLogEntry(operation = Delete, id = Some(id), metadata = Some(metadata)) + } + + protected def updateWriteLogEntry(id: Int, metadata: Row, data: Row): WriteLogEntry = { + WriteLogEntry(operation = Update, id = Some(id), metadata = Some(metadata), data = Some(data)) + } + + protected def reinsertWriteLogEntry(metadata: Row, data: Row): WriteLogEntry = { + WriteLogEntry(operation = Reinsert, metadata = Some(metadata), data = Some(data)) + } + + protected def insertWriteLogEntry(data: Row): WriteLogEntry = { + WriteLogEntry(operation = Insert, data = Some(data)) + } + + case class WriteLogEntry( + operation: Operation, + id: Option[Int] = None, + metadata: Option[Row] = None, + data: Option[Row] = None) }