diff --git a/core/src/main/scala/org/apache/spark/sql/delta/PreprocessTableMerge.scala b/core/src/main/scala/org/apache/spark/sql/delta/PreprocessTableMerge.scala index 3b247dbb51b..280e24c88f9 100644 --- a/core/src/main/scala/org/apache/spark/sql/delta/PreprocessTableMerge.scala +++ b/core/src/main/scala/org/apache/spark/sql/delta/PreprocessTableMerge.scala @@ -61,8 +61,8 @@ case class PreprocessTableMerge(override val conf: SQLConf) (matched ++ notMatched).filter(_.condition.nonEmpty).foreach { clause => checkCondition(clause.condition.get, clause.clauseType.toUpperCase(Locale.ROOT)) } - - val shouldAutoMigrate = conf.getConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE) && migrateSchema + val canMergeSchema = conf.getConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE) + val shouldAutoMigrate = canMergeSchema && migrateSchema val finalSchema = if (shouldAutoMigrate) { // The implicit conversions flag allows any type to be merged from source to target if Spark // SQL considers the source type implicitly castable to the target. Normally, mergeSchemas @@ -208,6 +208,6 @@ case class PreprocessTableMerge(override val conf: SQLConf) MergeIntoCommand( source, target, tahoeFileIndex, condition, - processedMatched, processedNotMatched, Some(finalSchema)) + processedMatched, processedNotMatched, Some(finalSchema), canMergeSchema) } } diff --git a/core/src/main/scala/org/apache/spark/sql/delta/commands/MergeIntoCommand.scala b/core/src/main/scala/org/apache/spark/sql/delta/commands/MergeIntoCommand.scala index 3c2e03e9247..f112aece629 100644 --- a/core/src/main/scala/org/apache/spark/sql/delta/commands/MergeIntoCommand.scala +++ b/core/src/main/scala/org/apache/spark/sql/delta/commands/MergeIntoCommand.scala @@ -197,6 +197,7 @@ object MergeStats { * @param matchedClauses All info related to matched clauses. * @param notMatchedClauses All info related to not matched clause. * @param migratedSchema The final schema of the target - may be changed by schema evolution. + * @param autoSchemaMergeEnabled Auto schema merge enabled config used in PreprocessTableMerge. */ case class MergeIntoCommand( @transient source: LogicalPlan, @@ -205,13 +206,14 @@ case class MergeIntoCommand( condition: Expression, matchedClauses: Seq[DeltaMergeIntoMatchedClause], notMatchedClauses: Seq[DeltaMergeIntoInsertClause], - migratedSchema: Option[StructType]) extends RunnableCommand + migratedSchema: Option[StructType], + autoSchemaMergeEnabled: Boolean) extends RunnableCommand with DeltaCommand with PredicateHelper with AnalysisHelper with ImplicitMetadataOperation { import SQLMetrics._ import MergeIntoCommand._ - override val canMergeSchema: Boolean = conf.getConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE) + override val canMergeSchema: Boolean = autoSchemaMergeEnabled override val canOverwriteSchema: Boolean = false @transient private lazy val sc: SparkContext = SparkContext.getOrCreate() @@ -265,13 +267,14 @@ case class MergeIntoCommand( isOverwriteMode = false, rearrangeOnly = false) } + val targetOutputCols = getTargetOutputCols(deltaTxn, spark) val deltaActions = { if (isSingleInsertOnly && spark.conf.get(DeltaSQLConf.MERGE_INSERT_ONLY_ENABLED)) { - writeInsertsOnlyWhenNoMatchedClauses(spark, deltaTxn) + writeInsertsOnlyWhenNoMatchedClauses(spark, deltaTxn, targetOutputCols) } else { - val filesToRewrite = findTouchedFiles(spark, deltaTxn) + val filesToRewrite = findTouchedFiles(spark, deltaTxn, targetOutputCols) val newWrittenFiles = withStatusCode("DELTA", "Writing merged data") { - writeAllChanges(spark, deltaTxn, filesToRewrite) + writeAllChanges(spark, deltaTxn, filesToRewrite, targetOutputCols) } filesToRewrite.map(_.remove) ++ newWrittenFiles } @@ -309,9 +312,9 @@ case class MergeIntoCommand( */ private def findTouchedFiles( spark: SparkSession, - deltaTxn: OptimisticTransaction + deltaTxn: OptimisticTransaction, + targetOutputCols: Seq[NamedExpression] ): Seq[AddFile] = recordMergeOperation(sqlMetricName = "scanTimeMs") { - // Accumulator to collect all the distinct touched files val touchedFilesAccum = new SetAccumulator[String]() spark.sparkContext.register(touchedFilesAccum, TOUCHED_FILES_ACCUM_NAME) @@ -334,7 +337,9 @@ case class MergeIntoCommand( // - the target file name the row is from to later identify the files touched by matched rows val joinToFindTouchedFiles = { val sourceDF = Dataset.ofRows(spark, source) - val targetDF = Dataset.ofRows(spark, buildTargetPlanWithFiles(deltaTxn, dataSkippedFiles)) + val targetDF = Dataset + .ofRows(spark, + buildTargetPlanWithFiles(deltaTxn, spark, dataSkippedFiles, targetOutputCols)) .withColumn(ROW_ID_COL, monotonically_increasing_id()) .withColumn(FILE_NAME_COL, input_file_name()) sourceDF.join(targetDF, new Column(condition), "inner") @@ -396,14 +401,15 @@ case class MergeIntoCommand( */ private def writeInsertsOnlyWhenNoMatchedClauses( spark: SparkSession, - deltaTxn: OptimisticTransaction + deltaTxn: OptimisticTransaction, + targetOutputCols: Seq[NamedExpression] ): Seq[FileAction] = recordMergeOperation(sqlMetricName = "rewriteTimeMs") { // UDFs to update metrics val incrSourceRowCountExpr = makeMetricUpdateUDF("numSourceRows") val incrInsertedCountExpr = makeMetricUpdateUDF("numTargetRowsInserted") - val outputColNames = getTargetOutputCols(deltaTxn).map(_.name) + val outputColNames = targetOutputCols.map(_.name) // we use head here since we know there is only a single notMatchedClause val outputExprs = notMatchedClauses.head.resolvedActions.map(_.expr) :+ incrInsertedCountExpr val outputCols = outputExprs.zip(outputColNames).map { case (expr, name) => @@ -423,7 +429,7 @@ case class MergeIntoCommand( // target DataFrame val targetDF = Dataset.ofRows( - spark, buildTargetPlanWithFiles(deltaTxn, dataSkippedFiles)) + spark, buildTargetPlanWithFiles(deltaTxn, spark, dataSkippedFiles, targetOutputCols)) val insertDf = sourceDF.join(targetDF, new Column(condition), "leftanti") .select(outputCols: _*) @@ -456,13 +462,13 @@ case class MergeIntoCommand( private def writeAllChanges( spark: SparkSession, deltaTxn: OptimisticTransaction, - filesToRewrite: Seq[AddFile] + filesToRewrite: Seq[AddFile], + targetOutputCols: Seq[NamedExpression] ): Seq[FileAction] = recordMergeOperation(sqlMetricName = "rewriteTimeMs") { - val targetOutputCols = getTargetOutputCols(deltaTxn) // Generate a new logical plan that has same output attributes exprIds as the target plan. // This allows us to apply the existing resolved update/insert expressions. - val newTarget = buildTargetPlanWithFiles(deltaTxn, filesToRewrite) + val newTarget = buildTargetPlanWithFiles(deltaTxn, spark, filesToRewrite, targetOutputCols) val joinType = if (isMatchedOnly && spark.conf.get(DeltaSQLConf.MERGE_MATCHED_ONLY_ENABLED)) { "rightOuter" @@ -568,8 +574,9 @@ case class MergeIntoCommand( */ private def buildTargetPlanWithFiles( deltaTxn: OptimisticTransaction, - files: Seq[AddFile]): LogicalPlan = { - val targetOutputCols = getTargetOutputCols(deltaTxn) + spark: SparkSession, + files: Seq[AddFile], + targetOutputCols: Seq[NamedExpression]): LogicalPlan = { val plan = { // We have to do surgery to use the attributes from `targetOutputCols` to scan the table. // In cases of schema evolution, they may not be the same type as the original attributes. @@ -590,12 +597,12 @@ case class MergeIntoCommand( // create an alias val aliases = plan.output.map { case newAttrib: AttributeReference => - val existingTargetAttrib = getTargetOutputCols(deltaTxn).find { col => - conf.resolver(col.name, newAttrib.name) + val existingTargetAttrib = targetOutputCols.find { col => + spark.sessionState.conf.resolver(col.name, newAttrib.name) }.getOrElse { throw new AnalysisException( s"Could not find ${newAttrib.name} among the existing target output " + - s"${getTargetOutputCols(deltaTxn)}") + s"$targetOutputCols") }.asInstanceOf[AttributeReference] if (existingTargetAttrib.exprId == newAttrib.exprId) { @@ -619,9 +626,11 @@ case class MergeIntoCommand( private def seqToString(exprs: Seq[Expression]): String = exprs.map(_.sql).mkString("\n\t") - private def getTargetOutputCols(txn: OptimisticTransaction): Seq[NamedExpression] = { + private def getTargetOutputCols( + txn: OptimisticTransaction, + spark: SparkSession): Seq[NamedExpression] = { txn.metadata.schema.map { col => - target.output.find(attr => conf.resolver(attr.name, col.name)).map { a => + target.output.find(attr => spark.sessionState.conf.resolver(attr.name, col.name)).map { a => AttributeReference(col.name, col.dataType, col.nullable)(a.exprId) }.getOrElse( Alias(Literal(null), col.name)()) @@ -712,7 +721,7 @@ object MergeIntoCommand { val outputProj = UnsafeProjection.create(outputRowEncoder.schema) def shouldDeleteRow(row: InternalRow): Boolean = - row.getBoolean(outputRowEncoder.schema.fields.size) + row.getBoolean(row.numFields - 2) def processRow(inputRow: InternalRow): InternalRow = { if (targetRowHasNoMatchPred.eval(inputRow)) {