diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala index 09e0314ff5ca0..87cbb8a7f0306 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala @@ -125,6 +125,7 @@ case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[Logi case mergeInto @ MergeIntoTable(target, source, mergeCondition, matchedActions, notMatchedActions) if isHoodieTable(target, sparkSession) && target.resolved => + val resolver = sparkSession.sessionState.conf.resolver val resolvedSource = analyzer.execute(source) def isInsertOrUpdateStar(assignments: Seq[Assignment]): Boolean = { if (assignments.isEmpty) { @@ -161,23 +162,21 @@ case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[Logi val resolvedCondition = condition.map(resolveExpressionFrom(resolvedSource)(_)) val resolvedAssignments = if (isInsertOrUpdateStar(assignments)) { // assignments is empty means insert * or update set * - val resolvedSourceOutputWithoutMetaFields = resolvedSource.output.filter(attr => !HoodieSqlUtils.isMetaField(attr.name)) - val targetOutputWithoutMetaFields = target.output.filter(attr => !HoodieSqlUtils.isMetaField(attr.name)) - val resolvedSourceColumnNamesWithoutMetaFields = resolvedSourceOutputWithoutMetaFields.map(_.name) - val targetColumnNamesWithoutMetaFields = targetOutputWithoutMetaFields.map(_.name) + val resolvedSourceOutput = resolvedSource.output.filter(attr => !HoodieSqlUtils.isMetaField(attr.name)) + val targetOutput = target.output.filter(attr => !HoodieSqlUtils.isMetaField(attr.name)) + val resolvedSourceColumnNames = resolvedSourceOutput.map(_.name) - if(targetColumnNamesWithoutMetaFields.toSet.subsetOf(resolvedSourceColumnNamesWithoutMetaFields.toSet)){ + if(targetOutput.filter(attr => resolvedSourceColumnNames.exists(resolver(_, attr.name))).equals(targetOutput)){ //If sourceTable's columns contains all targetTable's columns, //We fill assign all the source fields to the target fields by column name matching. - val sourceColNameAttrMap = resolvedSourceOutputWithoutMetaFields.map(attr => (attr.name, attr)).toMap - targetOutputWithoutMetaFields.map(targetAttr => { - val sourceAttr = sourceColNameAttrMap(targetAttr.name) + targetOutput.map(targetAttr => { + val sourceAttr = resolvedSourceOutput.find(f => resolver(f.name, targetAttr.name)).get Assignment(targetAttr, sourceAttr) }) } else { // We fill assign all the source fields to the target fields by order. - targetOutputWithoutMetaFields - .zip(resolvedSourceOutputWithoutMetaFields) + targetOutput + .zip(resolvedSourceOutput) .map { case (targetAttr, sourceAttr) => Assignment(targetAttr, sourceAttr) } } } else { @@ -214,8 +213,9 @@ case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[Logi }.toMap // Validate if there are incorrect target attributes. + val targetColumnNames = removeMetaFields(target.output).map(_.name) val unKnowTargets = target2Values.keys - .filterNot(removeMetaFields(target.output).map(_.name).contains(_)) + .filterNot(name => targetColumnNames.exists(resolver(_, name))) if (unKnowTargets.nonEmpty) { throw new AnalysisException(s"Cannot find target attributes: ${unKnowTargets.mkString(",")}.") } @@ -224,19 +224,20 @@ case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[Logi // e.g. If the update action missing 'id' attribute, we fill a "id = target.id" to the update action. val newAssignments = removeMetaFields(target.output) .map(attr => { + val valueOption = target2Values.find(f => resolver(f._1, attr.name)) // TODO support partial update for MOR. - if (!target2Values.contains(attr.name) && targetTableType == MOR_TABLE_TYPE_OPT_VAL) { + if (valueOption.isEmpty && targetTableType == MOR_TABLE_TYPE_OPT_VAL) { throw new AnalysisException(s"Missing specify the value for target field: '${attr.name}' in merge into update action" + s" for MOR table. Currently we cannot support partial update for MOR," + s" please complete all the target fields just like '...update set id = s0.id, name = s0.name ....'") } if (preCombineField.isDefined && preCombineField.get.equalsIgnoreCase(attr.name) - && !target2Values.contains(attr.name)) { + && valueOption.isEmpty) { throw new AnalysisException(s"Missing specify value for the preCombineField:" + s" ${preCombineField.get} in merge-into update action. You should add" + s" '... update set ${preCombineField.get} = xx....' to the when-matched clause.") } - Assignment(attr, target2Values.getOrElse(attr.name, attr)) + Assignment(attr, if (valueOption.isEmpty) attr else valueOption.get._2) }) UpdateAction(resolvedCondition, newAssignments) case DeleteAction(condition) => diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala index b22c60792b749..8ee47804951f3 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala @@ -26,6 +26,7 @@ import org.apache.hudi.hive.ddl.HiveSyncMode import org.apache.hudi.{AvroConversionUtils, DataSourceWriteOptions, HoodieSparkSqlWriter, HoodieWriterUtils, SparkAdapterSupport} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, BoundReference, Cast, EqualTo, Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command.RunnableCommand @@ -90,6 +91,7 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Runnab * TODO Currently Non-equivalent conditions are not supported. */ private lazy val targetKey2SourceExpression: Map[String, Expression] = { + val resolver = sparkSession.sessionState.conf.resolver val conditions = splitByAnd(mergeInto.mergeCondition) val allEqs = conditions.forall(p => p.isInstanceOf[EqualTo]) if (!allEqs) { @@ -101,11 +103,11 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Runnab val target2Source = conditions.map(_.asInstanceOf[EqualTo]) .map { case EqualTo(left: AttributeReference, right) - if targetAttrs.indexOf(left) >= 0 => // left is the target field - left.name -> right + if targetAttrs.exists(f => attributeEqual(f, left, resolver)) => // left is the target field + targetAttrs.find(f => resolver(f.name, left.name)).get.name -> right case EqualTo(left, right: AttributeReference) - if targetAttrs.indexOf(right) >= 0 => // right is the target field - right.name -> left + if targetAttrs.exists(f => attributeEqual(f, right, resolver)) => // right is the target field + targetAttrs.find(f => resolver(f.name, right.name)).get.name -> left case eq => throw new AnalysisException(s"Invalidate Merge-On condition: ${eq.sql}." + "The validate condition should be 'targetColumn = sourceColumnExpression', e.g." + @@ -196,15 +198,24 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Runnab } private def isEqualToTarget(targetColumnName: String, sourceExpression: Expression): Boolean = { - val sourceColNameMap = sourceDFOutput.map(attr => (attr.name.toLowerCase, attr.name)).toMap + val sourceColumnName = sourceDFOutput.map(_.name) + val resolver = sparkSession.sessionState.conf.resolver sourceExpression match { - case attr: AttributeReference if sourceColNameMap(attr.name.toLowerCase).equals(targetColumnName) => true - case Cast(attr: AttributeReference, _, _) if sourceColNameMap(attr.name.toLowerCase).equals(targetColumnName) => true + case attr: AttributeReference if sourceColumnName.find(resolver(_, attr.name)).get.equals(targetColumnName) => true + case Cast(attr: AttributeReference, _, _) if sourceColumnName.find(resolver(_, attr.name)).get.equals(targetColumnName) => true case _=> false } } + /** + * Compare a [[Attribute]] to another, return true if they have the same column name(by resolver) and exprId + */ + private def attributeEqual( + attr: Attribute, other: Attribute, resolver: Resolver): Boolean = { + resolver(attr.name, other.name) && attr.exprId == other.exprId + } + /** * Execute the update and delete action. All the matched and not-matched actions will * execute in one upsert write operation. We pushed down the matched condition and assignment @@ -361,9 +372,9 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Runnab mergeInto.targetTable.output .filterNot(attr => isMetaField(attr.name)) .map(attr => { - val assignment = attr2Assignment.getOrElse(attr, - throw new IllegalArgumentException(s"Cannot find related assignment for field: ${attr.name}")) - castIfNeeded(assignment, attr.dataType, sparkSession.sqlContext.conf) + val assignment = attr2Assignment.find(f => attributeEqual(f._1, attr, sparkSession.sessionState.conf.resolver)) + .getOrElse(throw new IllegalArgumentException(s"Cannot find related assignment for field: ${attr.name}")) + castIfNeeded(assignment._2, attr.dataType, sparkSession.sqlContext.conf) }) } diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala index 30a2448f0a5e4..185d1fe15e687 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala @@ -444,4 +444,115 @@ class TestMergeIntoTable2 extends TestHoodieSqlBase { } } + test("Test ignoring case") { + withTempDir { tmp => + val tableName = generateTableName + // Create table + spark.sql( + s""" + |create table $tableName ( + | ID int, + | name string, + | price double, + | TS long, + | DT string + |) using hudi + | location '${tmp.getCanonicalPath}/$tableName' + | options ( + | primaryKey ='ID', + | preCombineField = 'TS' + | ) + """.stripMargin) + + // First merge with a extra input field 'flag' (insert a new record) + spark.sql( + s""" + | merge into $tableName + | using ( + | select 1 as id, 'a1' as name, 10 as PRICE, 1000 as ts, '2021-05-05' as dt, '1' as flag + | ) s0 + | on s0.id = $tableName.id + | when matched and flag = '1' then update set + | id = s0.id, name = s0.name, PRICE = s0.price, ts = s0.ts, dt = s0.dt + | when not matched and flag = '1' then insert * + """.stripMargin) + checkAnswer(s"select id, name, price, ts, dt from $tableName")( + Seq(1, "a1", 10.0, 1000, "2021-05-05") + ) + + // Second merge (update the record) + spark.sql( + s""" + | merge into $tableName + | using ( + | select 1 as id, 'a1' as name, 20 as PRICE, '2021-05-05' as dt, 1001 as ts + | ) s0 + | on s0.id = $tableName.id + | when matched then update set + | id = s0.id, name = s0.name, PRICE = s0.price, ts = s0.ts, dt = s0.dt + | when not matched then insert * + """.stripMargin) + checkAnswer(s"select id, name, price, ts, dt from $tableName")( + Seq(1, "a1", 20.0, 1001, "2021-05-05") + ) + + // Test ignoring case when column name matches + spark.sql( + s""" + | merge into $tableName as t0 + | using ( + | select 1 as id, 'a1' as name, 1111 as ts, '2021-05-05' as dt, 111 as PRICE union all + | select 2 as id, 'a2' as name, 1112 as ts, '2021-05-05' as dt, 112 as PRICE + | ) as s0 + | on t0.id = s0.id + | when matched then update set * + | when not matched then insert * + |""".stripMargin) + checkAnswer(s"select id, name, price, ts, dt from $tableName")( + Seq(1, "a1", 111.0, 1111, "2021-05-05"), + Seq(2, "a2", 112.0, 1112, "2021-05-05") + ) + } + } + + test("Test ignoring case for MOR table") { + withTempDir { tmp => + val tableName = generateTableName + // Create a mor partitioned table. + spark.sql( + s""" + | create table $tableName ( + | ID int, + | NAME string, + | price double, + | TS long, + | dt string + | ) using hudi + | options ( + | type = 'mor', + | primaryKey = 'ID', + | preCombineField = 'TS' + | ) + | partitioned by(dt) + | location '${tmp.getCanonicalPath}/$tableName' + """.stripMargin) + + // Test ignoring case when column name matches + spark.sql( + s""" + | merge into $tableName as t0 + | using ( + | select 1 as id, 'a1' as NAME, 1111 as ts, '2021-05-05' as DT, 111 as price + | ) as s0 + | on t0.id = s0.id + | when matched then update set * + | when not matched then insert * + """.stripMargin + ) + checkAnswer(s"select id, name, price, ts, dt from $tableName")( + Seq(1, "a1", 111.0, 1111, "2021-05-05") + ) + } + } + }