Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,45 @@
*/
@Evolving
public interface MetadataColumn {
/**
* Indicates whether a row-level operation should preserve the value of the metadata column
* for deleted rows. If set to true, the metadata value will be retained and passed back to
* the writer. If false, the metadata value will be replaced with {@code null}.
* <p>
* This flag applies only to row-level operations working with deltas of rows. Group-based
* operations handle deletes by discarding matching records.
*
* @since 4.0.0
*/
String PRESERVE_ON_DELETE = "__preserve_on_delete";
boolean PRESERVE_ON_DELETE_DEFAULT = true;

/**
* Indicates whether a row-level operation should preserve the value of the metadata column
* for updated rows. If set to true, the metadata value will be retained and passed back to
* the writer. If false, the metadata value will be replaced with {@code null}.
* <p>
* This flag applies to both group-based and delta-based row-level operations.
*
* @since 4.0.0
*/
String PRESERVE_ON_UPDATE = "__preserve_on_update";
boolean PRESERVE_ON_UPDATE_DEFAULT = true;

/**
* Indicates whether a row-level operation should preserve the value of the metadata column
* for reinserted rows generated by splitting updates into deletes and inserts. If true,
* the metadata value will be retained and passed back to the writer. If false, the metadata
* value will be replaced with {@code null}.
* <p>
* This flag applies only to row-level operations working with deltas of rows. Group-based
* operations do not represent updates as deletes and inserts.
*
* @since 4.0.0
*/
String PRESERVE_ON_REINSERT = "__preserve_on_reinsert";
boolean PRESERVE_ON_REINSERT_DEFAULT = false;

/**
* The name of this metadata column.
*
Expand Down Expand Up @@ -74,4 +113,13 @@ default String comment() {
default Transform transform() {
return null;
}

/**
* Returns the column metadata in JSON format.
*
* @since 4.0.0
*/
default String metadataInJSON() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follows exactly what we have in Column.

return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,23 @@
*/
@Evolving
public interface DataWriter<T> extends Closeable {
/**
* Writes one record with metadata.
* <p>
* This method is used by group-based row-level operations to pass back metadata for records
* that are updated or copied. New records added during a MERGE operation are written using
* {@link #write(Object)} as there is no metadata associated with those records.
* <p>
* If this method fails (by throwing an exception), {@link #abort()} will be called and this
* data writer is considered to have been failed.
*
* @throws IOException if failure happens during disk/network IO like writing files.
*
* @since 4.0.0
*/
default void write(T metadata, T record) throws IOException {
write(record);
}

/**
* Writes one record.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,21 @@ public interface DeltaWriter<T> extends DataWriter<T> {
*/
void update(T metadata, T id, T row) throws IOException;

/**
* Reinserts a row with metadata.
* <p>
* This method handles the insert portion of updated rows split into deletes and inserts.
*
* @param metadata values for metadata columns
* @param row a row to reinsert
* @throws IOException if failure happens during disk/network IO like writing files
*
* @since 4.0.0
*/
default void reinsert(T metadata, T row) throws IOException {
insert(row);
}

/**
* Inserts a new row.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ object RewriteDeleteFromTable extends RewriteRowLevelCommand {

// build a plan to replace read groups in the table
val writeRelation = relation.copy(table = operationTable)
ReplaceData(writeRelation, cond, remainingRowsPlan, relation, Some(cond))
val query = addOperationColumn(WRITE_WITH_METADATA_OPERATION, remainingRowsPlan)
val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs)
ReplaceData(writeRelation, cond, query, relation, projections, Some(cond))
}

// build a rewrite plan for sources that support row deltas
Expand All @@ -106,7 +108,7 @@ object RewriteDeleteFromTable extends RewriteRowLevelCommand {
// construct a plan that only contains records to delete
val deletedRowsPlan = Filter(cond, readRelation)
val operationType = Alias(Literal(DELETE_OPERATION), OPERATION_COLUMN)()
val requiredWriteAttrs = dedupAttrs(rowIdAttrs ++ metadataAttrs)
val requiredWriteAttrs = nullifyMetadataOnDelete(dedupAttrs(rowIdAttrs ++ metadataAttrs))
val project = Project(operationType +: requiredWriteAttrs, deletedRowsPlan)

// build a plan to write deletes to the table
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, DeleteAction, Filter, HintInfo, InsertAction, Join, JoinHint, LogicalPlan, MergeAction, MergeIntoTable, MergeRows, NO_BROADCAST_AND_REPLICATION, Project, ReplaceData, UpdateAction, WriteDelta}
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Discard, Instruction, Keep, ROW_ID, Split}
import org.apache.spark.sql.catalyst.util.RowDeltaUtils.OPERATION_COLUMN
import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{OPERATION_COLUMN, WRITE_OPERATION, WRITE_WITH_METADATA_OPERATION}
import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations
import org.apache.spark.sql.connector.write.{RowLevelOperationTable, SupportsDelta}
import org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE
Expand Down Expand Up @@ -180,7 +180,8 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper

// build a plan to replace read groups in the table
val writeRelation = relation.copy(table = operationTable)
ReplaceData(writeRelation, pushableCond, mergeRowsPlan, relation, groupFilterCond)
val projections = buildReplaceDataProjections(mergeRowsPlan, relation.output, metadataAttrs)
ReplaceData(writeRelation, pushableCond, mergeRowsPlan, relation, projections, groupFilterCond)
}

private def buildReplaceDataMergeRowsPlan(
Expand All @@ -197,7 +198,8 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
// that's why an extra unconditional instruction that would produce the original row is added
// as the last MATCHED and NOT MATCHED BY SOURCE instruction
// this logic is specific to data sources that replace groups of data
val keepCarryoverRowsInstruction = Keep(TrueLiteral, targetTable.output)
val carryoverRowsOutput = Literal(WRITE_WITH_METADATA_OPERATION) +: targetTable.output
val keepCarryoverRowsInstruction = Keep(TrueLiteral, carryoverRowsOutput)

val matchedInstructions = matchedActions.map { action =>
toInstruction(action, metadataAttrs)
Expand All @@ -218,7 +220,8 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
notMatchedInstructions.flatMap(_.outputs) ++
notMatchedBySourceInstructions.flatMap(_.outputs)

val attrs = targetTable.output
val operationTypeAttr = AttributeReference(OPERATION_COLUMN, IntegerType, nullable = false)()
val attrs = operationTypeAttr +: targetTable.output

MergeRows(
isSourceRowPresent = IsNotNull(rowFromSourceAttr),
Expand Down Expand Up @@ -430,15 +433,18 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
private def toInstruction(action: MergeAction, metadataAttrs: Seq[Attribute]): Instruction = {
action match {
case UpdateAction(cond, assignments) =>
val output = assignments.map(_.value) ++ metadataAttrs
val rowValues = assignments.map(_.value)
val metadataValues = nullifyMetadataOnUpdate(metadataAttrs)
val output = Seq(Literal(WRITE_WITH_METADATA_OPERATION)) ++ rowValues ++ metadataValues
Keep(cond.getOrElse(TrueLiteral), output)

case DeleteAction(cond) =>
Discard(cond.getOrElse(TrueLiteral))

case InsertAction(cond, assignments) =>
val rowValues = assignments.map(_.value)
val metadataValues = metadataAttrs.map(attr => Literal(null, attr.dataType))
val output = assignments.map(_.value) ++ metadataValues
val output = Seq(Literal(WRITE_OPERATION)) ++ rowValues ++ metadataValues
Keep(cond.getOrElse(TrueLiteral), output)

case other =>
Expand All @@ -460,7 +466,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
action match {
case UpdateAction(cond, assignments) if splitUpdates =>
val output = deltaDeleteOutput(rowAttrs, rowIdAttrs, metadataAttrs, originalRowIdValues)
val otherOutput = deltaInsertOutput(assignments, metadataAttrs, originalRowIdValues)
val otherOutput = deltaReinsertOutput(assignments, metadataAttrs, originalRowIdValues)
Split(cond.getOrElse(TrueLiteral), output, otherOutput)

case UpdateAction(cond, assignments) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,32 @@ package org.apache.spark.sql.catalyst.analysis

import scala.collection.mutable

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.ProjectingInternalRow
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, Expression, ExprId, Literal, V2ExpressionUtils}
import org.apache.spark.sql.catalyst.plans.logical.{Assignment, LogicalPlan}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, Expression, ExprId, Literal, MetadataAttribute, NamedExpression, V2ExpressionUtils}
import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Expand, LogicalPlan, MergeRows, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{ReplaceDataProjections, WriteDeltaProjections}
import org.apache.spark.sql.catalyst.util.RowDeltaUtils._
import org.apache.spark.sql.catalyst.util.WriteDeltaProjections
import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.connector.write.{RowLevelOperation, RowLevelOperationInfoImpl, RowLevelOperationTable, SupportsDelta}
import org.apache.spark.sql.connector.write.RowLevelOperation.Command
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.ArrayImplicits._

trait RewriteRowLevelCommand extends Rule[LogicalPlan] {

private final val DELTA_OPERATIONS_WITH_ROW =
Set(UPDATE_OPERATION, REINSERT_OPERATION, INSERT_OPERATION)
private final val DELTA_OPERATIONS_WITH_METADATA =
Set(DELETE_OPERATION, UPDATE_OPERATION, REINSERT_OPERATION)
private final val DELTA_OPERATIONS_WITH_ROW_ID =
Set(DELETE_OPERATION, UPDATE_OPERATION)

protected def buildOperationTable(
table: SupportsRowLevelOperations,
command: Command,
Expand Down Expand Up @@ -103,7 +111,31 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] {
metadataAttrs: Seq[Attribute],
originalRowIdValues: Seq[Expression] = Seq.empty): Seq[Expression] = {
val rowValues = buildDeltaDeleteRowValues(rowAttrs, rowIdAttrs)
Seq(Literal(DELETE_OPERATION)) ++ rowValues ++ metadataAttrs ++ originalRowIdValues
val metadataValues = nullifyMetadataOnDelete(metadataAttrs)
Seq(Literal(DELETE_OPERATION)) ++ rowValues ++ metadataValues ++ originalRowIdValues
}

protected def nullifyMetadataOnDelete(attrs: Seq[Attribute]): Seq[NamedExpression] = {
nullifyMetadata(attrs, MetadataAttribute.isPreservedOnDelete)
}

protected def nullifyMetadataOnUpdate(attrs: Seq[Attribute]): Seq[NamedExpression] = {
nullifyMetadata(attrs, MetadataAttribute.isPreservedOnUpdate)
}

private def nullifyMetadataOnReinsert(attrs: Seq[Attribute]): Seq[NamedExpression] = {
nullifyMetadata(attrs, MetadataAttribute.isPreservedOnReinsert)
}

private def nullifyMetadata(
attrs: Seq[Attribute],
shouldPreserve: Attribute => Boolean): Seq[NamedExpression] = {
attrs.map {
case MetadataAttribute(attr) if !shouldPreserve(attr) =>
Alias(Literal(null, attr.dataType), attr.name)(explicitMetadata = Some(attr.metadata))
case attr =>
attr
}
}

private def buildDeltaDeleteRowValues(
Expand Down Expand Up @@ -132,52 +164,119 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] {
metadataAttrs: Seq[Attribute],
originalRowIdValues: Seq[Expression]): Seq[Expression] = {
val rowValues = assignments.map(_.value)
Seq(Literal(UPDATE_OPERATION)) ++ rowValues ++ metadataAttrs ++ originalRowIdValues
val metadataValues = nullifyMetadataOnUpdate(metadataAttrs)
Seq(Literal(UPDATE_OPERATION)) ++ rowValues ++ metadataValues ++ originalRowIdValues
}

protected def deltaReinsertOutput(
assignments: Seq[Assignment],
metadataAttrs: Seq[Attribute],
originalRowIdValues: Seq[Expression] = Seq.empty): Seq[Expression] = {
val rowValues = assignments.map(_.value)
val metadataValues = nullifyMetadataOnReinsert(metadataAttrs)
val extraNullValues = originalRowIdValues.map(e => Literal(null, e.dataType))
Seq(Literal(REINSERT_OPERATION)) ++ rowValues ++ metadataValues ++ extraNullValues
}

protected def addOperationColumn(operation: Int, plan: LogicalPlan): LogicalPlan = {
val operationType = Alias(Literal(operation, IntegerType), OPERATION_COLUMN)()
Project(operationType +: plan.output, plan)
}

protected def buildReplaceDataProjections(
plan: LogicalPlan,
rowAttrs: Seq[Attribute],
metadataAttrs: Seq[Attribute]): ReplaceDataProjections = {
val outputs = extractOutputs(plan)

val outputsWithRow = filterOutputs(outputs, Set(WRITE_WITH_METADATA_OPERATION, WRITE_OPERATION))
val rowProjection = newLazyProjection(plan, outputsWithRow, rowAttrs)

val metadataProjection = if (metadataAttrs.nonEmpty) {
val outputsWithMetadata = filterOutputs(outputs, Set(WRITE_WITH_METADATA_OPERATION))
Some(newLazyProjection(plan, outputsWithMetadata, metadataAttrs))
} else {
None
}

ReplaceDataProjections(rowProjection, metadataProjection)
}

protected def buildWriteDeltaProjections(
plan: LogicalPlan,
rowAttrs: Seq[Attribute],
rowIdAttrs: Seq[Attribute],
metadataAttrs: Seq[Attribute]): WriteDeltaProjections = {
val outputs = extractOutputs(plan)

val rowProjection = if (rowAttrs.nonEmpty) {
Some(newLazyProjection(plan, rowAttrs))
val outputsWithRow = filterOutputs(outputs, DELTA_OPERATIONS_WITH_ROW)
Some(newLazyProjection(plan, outputsWithRow, rowAttrs))
} else {
None
}

val rowIdProjection = newLazyRowIdProjection(plan, rowIdAttrs)
val outputsWithRowId = filterOutputs(outputs, DELTA_OPERATIONS_WITH_ROW_ID)
val rowIdProjection = newLazyRowIdProjection(plan, outputsWithRowId, rowIdAttrs)

val metadataProjection = if (metadataAttrs.nonEmpty) {
Some(newLazyProjection(plan, metadataAttrs))
val outputsWithMetadata = filterOutputs(outputs, DELTA_OPERATIONS_WITH_METADATA)
Some(newLazyProjection(plan, outputsWithMetadata, metadataAttrs))
} else {
None
}

WriteDeltaProjections(rowProjection, rowIdProjection, metadataProjection)
}

private def extractOutputs(plan: LogicalPlan): Seq[Seq[Expression]] = {
plan match {
case p: Project => Seq(p.projectList)
case e: Expand => e.projections
case m: MergeRows => m.outputs
case _ => throw SparkException.internalError("Can't extract outputs from plan: " + plan)
}
}

private def filterOutputs(
outputs: Seq[Seq[Expression]],
operations: Set[Int]): Seq[Seq[Expression]] = {
outputs.filter {
case Literal(operation: Integer, _) +: _ => operations.contains(operation)
case Alias(Literal(operation: Integer, _), _) +: _ => operations.contains(operation)
case other => throw SparkException.internalError("Can't determine operation: " + other)
}
}

private def newLazyProjection(
plan: LogicalPlan,
outputs: Seq[Seq[Expression]],
attrs: Seq[Attribute]): ProjectingInternalRow = {

val colOrdinals = attrs.map(attr => findColOrdinal(plan, attr.name))
val schema = DataTypeUtils.fromAttributes(attrs)
ProjectingInternalRow(schema, colOrdinals)
createProjectingInternalRow(outputs, colOrdinals, attrs)
}

// if there are assignment to row ID attributes, original values are projected as special columns
// this method honors such special columns if present
private def newLazyRowIdProjection(
plan: LogicalPlan,
outputs: Seq[Seq[Expression]],
rowIdAttrs: Seq[Attribute]): ProjectingInternalRow = {

val colOrdinals = rowIdAttrs.map { attr =>
val originalValueIndex = findColOrdinal(plan, ORIGINAL_ROW_ID_VALUE_PREFIX + attr.name)
if (originalValueIndex != -1) originalValueIndex else findColOrdinal(plan, attr.name)
}
val schema = DataTypeUtils.fromAttributes(rowIdAttrs)
createProjectingInternalRow(outputs, colOrdinals, rowIdAttrs)
}

private def createProjectingInternalRow(
outputs: Seq[Seq[Expression]],
colOrdinals: Seq[Int],
attrs: Seq[Attribute]): ProjectingInternalRow = {
val schema = StructType(attrs.zipWithIndex.map { case (attr, index) =>
val nullable = outputs.exists(output => output(colOrdinals(index)).nullable)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aokolnychyi Do we only need this for metadata columns? For regular columns, shall we use attr.nullable instead?

StructField(attr.name, attr.dataType, nullable, attr.metadata)
})
ProjectingInternalRow(schema, colOrdinals)
}

Expand Down
Loading
Loading