From 3a829d5acda87b56e6b006b4aa62819f129883fc Mon Sep 17 00:00:00 2001 From: shenjiayu17 Date: Wed, 8 Jun 2022 11:58:58 +0800 Subject: [PATCH] [HUDI-4219] Merge Into when update expression "col=s.col+2" on precombine cause exception --- .../command/MergeIntoHoodieTableCommand.scala | 40 +++- .../spark/sql/hudi/TestMergeIntoTable.scala | 181 ++++++++++++++++++ 2 files changed, 215 insertions(+), 6 deletions(-) diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala index 636599ce0cf4..94fe7ea8bcb2 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala @@ -181,14 +181,14 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Hoodie var sourceDF = Dataset.ofRows(sparkSession, mergeInto.sourceTable) targetKey2SourceExpression.foreach { case (targetColumn, sourceExpression) - if !isEqualToTarget(targetColumn, sourceExpression) => + if !containsPrimaryKeyFieldReference(targetColumn, sourceExpression) => sourceDF = sourceDF.withColumn(targetColumn, new Column(sourceExpression)) sourceDFOutput = sourceDFOutput :+ AttributeReference(targetColumn, sourceExpression.dataType)() case _=> } target2SourcePreCombineFiled.foreach { case (targetPreCombineField, sourceExpression) - if !isEqualToTarget(targetPreCombineField, sourceExpression) => + if !containsPreCombineFieldReference(targetPreCombineField, sourceExpression) => sourceDF = sourceDF.withColumn(targetPreCombineField, new Column(sourceExpression)) sourceDFOutput = sourceDFOutput :+ AttributeReference(targetPreCombineField, sourceExpression.dataType)() case _=> @@ -196,23 +196,51 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Hoodie sourceDF } - private def isEqualToTarget(targetColumnName: String, sourceExpression: Expression): Boolean = { - val sourceColumnName = sourceDFOutput.map(_.name) + /** + * Check whether the source expression has the same column name with target column. + * + * Merge condition cases that return true: + * 1) merge into .. on h0.id = s0.id .. + * 2) merge into .. on h0.id = cast(s0.id as int) .. + * "id" is primaryKey field of h0. + */ + private def containsPrimaryKeyFieldReference(targetColumnName: String, sourceExpression: Expression): Boolean = { + val sourceColumnNames = sourceDFOutput.map(_.name) val resolver = sparkSession.sessionState.conf.resolver sourceExpression match { - case attr: AttributeReference if sourceColumnName.find(resolver(_, attr.name)).get.equals(targetColumnName) => true + case attr: AttributeReference if sourceColumnNames.find(resolver(_, attr.name)).get.equals(targetColumnName) => true // SPARK-35857: the definition of Cast has been changed in Spark3.2. // Match the class type instead of call the `unapply` method. case cast: Cast => cast.child match { - case attr: AttributeReference if sourceColumnName.find(resolver(_, attr.name)).get.equals(targetColumnName) => true + case attr: AttributeReference if sourceColumnNames.find(resolver(_, attr.name)).get.equals(targetColumnName) => true case _ => false } case _=> false } } + /** + * Check whether the source expression on preCombine field contains the same column name with target column. + * + * Merge expression cases that return true: + * 1) merge into .. on .. update set ts = s0.ts + * 2) merge into .. on .. update set ts = cast(s0.ts as int) + * 3) merge into .. on .. update set ts = s0.ts+1 (expressions like this whose sub node has the same column name with target) + * "ts" is preCombine field of h0. + */ + private def containsPreCombineFieldReference(targetColumnName: String, sourceExpression: Expression): Boolean = { + val sourceColumnNames = sourceDFOutput.map(_.name) + val resolver = sparkSession.sessionState.conf.resolver + + // sub node of the expression may have same column name with target column name + sourceExpression.find { + case attr: AttributeReference => sourceColumnNames.find(resolver(_, attr.name)).get.equals(targetColumnName) + case _ => false + }.isDefined + } + /** * Compare a [[Attribute]] to another, return true if they have the same column name(by resolver) and exprId */ diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable.scala index 992a442f4fda..ac11f83d5311 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable.scala @@ -427,6 +427,187 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase { } } + test("Test MergeInto with preCombine field expression") { + withTempDir { tmp => + Seq("cow", "mor").foreach { tableType => + val tableName1 = generateTableName + spark.sql( + s""" + | create table $tableName1 ( + | id int, + | name string, + | price double, + | v long, + | dt string + | ) using hudi + | tblproperties ( + | type = '$tableType', + | primaryKey = 'id', + | preCombineField = 'v' + | ) + | partitioned by(dt) + | location '${tmp.getCanonicalPath}/$tableName1' + """.stripMargin) + // Insert data + spark.sql(s"""insert into $tableName1 values(1, 'a1', 10, 1000, '2021-03-21')""") + + // Update data with a value expression on preCombine field + // 1) set source column name to be same as target column + spark.sql( + s""" + | merge into $tableName1 as t0 + | using ( + | select 1 as id, 'a1' as name, 11 as price, 999 as v, '2021-03-21' as dt + | ) as s0 + | on t0.id = s0.id + | when matched then update set id=s0.id, name=s0.name, price=s0.price*2, v=s0.v+2, dt=s0.dt + """.stripMargin + ) + // Update success as new value 1001 is bigger than original value 1000 + checkAnswer(s"select id,name,price,dt,v from $tableName1")( + Seq(1, "a1", 22, "2021-03-21", 1001) + ) + + // 2) set source column name to be different with target column + spark.sql( + s""" + | merge into $tableName1 as t0 + | using ( + | select 1 as s_id, 'a1' as s_name, 12 as s_price, 1000 as s_v, '2021-03-21' as dt + | ) as s0 + | on t0.id = s0.s_id + | when matched then update set id=s0.s_id, name=s0.s_name, price=s0.s_price*2, v=s0.s_v+2, dt=s0.dt + """.stripMargin + ) + // Update success as new value 1002 is bigger than original value 1001 + checkAnswer(s"select id,name,price,dt,v from $tableName1")( + Seq(1, "a1", 24, "2021-03-21", 1002) + ) + } + } + } + + test("Test MergeInto with primaryKey expression") { + withTempDir { tmp => + val tableName1 = generateTableName + spark.sql( + s""" + | create table $tableName1 ( + | id int, + | name string, + | price double, + | v long, + | dt string + | ) using hudi + | tblproperties ( + | type = 'cow', + | primaryKey = 'id', + | preCombineField = 'v' + | ) + | partitioned by(dt) + | location '${tmp.getCanonicalPath}/$tableName1' + """.stripMargin) + // Insert data + spark.sql(s"""insert into $tableName1 values(3, 'a3', 30, 3000, '2021-03-21')""") + spark.sql(s"""insert into $tableName1 values(2, 'a2', 20, 2000, '2021-03-21')""") + spark.sql(s"""insert into $tableName1 values(1, 'a1', 10, 1000, '2021-03-21')""") + + // Delete data with a condition expression on primaryKey field + // 1) set source column name to be same as target column + spark.sql( + s""" + | merge into $tableName1 t0 + | using ( + | select 1 as id, 'a1' as name, 15 as price, 1001 as v, '2021-03-21' as dt + | ) s0 + | on t0.id = s0.id + 1 + | when matched then delete + """.stripMargin + ) + checkAnswer(s"select id,name,price,v,dt from $tableName1 order by id")( + Seq(1, "a1", 10, 1000, "2021-03-21"), + Seq(3, "a3", 30, 3000, "2021-03-21") + ) + + // 2) set source column name to be different with target column + spark.sql( + s""" + | merge into $tableName1 t0 + | using ( + | select 2 as s_id, 'a1' as s_name, 15 as s_price, 1001 as s_v, '2021-03-21' as dt + | ) s0 + | on t0.id = s0.s_id + 1 + | when matched then delete + """.stripMargin + ) + checkAnswer(s"select id,name,price,v,dt from $tableName1 order by id")( + Seq(1, "a1", 10, 1000, "2021-03-21") + ) + } + } + + test("Test MergeInto with combination of delete update insert") { + withTempDir { tmp => + val sourceTable = generateTableName + val targetTable = generateTableName + // Create source table + spark.sql( + s""" + | create table $sourceTable ( + | id int, + | name string, + | price double, + | ts long, + | dt string + | ) using parquet + | location '${tmp.getCanonicalPath}/$sourceTable' + """.stripMargin) + spark.sql(s"insert into $sourceTable values(8, 's8', 80, 2008, '2021-03-21')") + spark.sql(s"insert into $sourceTable values(9, 's9', 90, 2009, '2021-03-21')") + spark.sql(s"insert into $sourceTable values(10, 's10', 100, 2010, '2021-03-21')") + spark.sql(s"insert into $sourceTable values(11, 's11', 110, 2011, '2021-03-21')") + spark.sql(s"insert into $sourceTable values(12, 's12', 120, 2012, '2021-03-21')") + // Create target table + spark.sql( + s""" + |create table $targetTable ( + | id int, + | name string, + | price double, + | ts long, + | dt string + |) using hudi + | tblproperties ( + | primaryKey ='id', + | preCombineField = 'ts' + | ) + | partitioned by(dt) + | location '${tmp.getCanonicalPath}/$targetTable' + """.stripMargin) + spark.sql(s"insert into $targetTable values(7, 'a7', 70, 1007, '2021-03-21')") + spark.sql(s"insert into $targetTable values(8, 'a8', 80, 1008, '2021-03-21')") + spark.sql(s"insert into $targetTable values(9, 'a9', 90, 1009, '2021-03-21')") + spark.sql(s"insert into $targetTable values(10, 'a10', 100, 1010, '2021-03-21')") + + spark.sql( + s""" + | merge into $targetTable as t0 + | using $sourceTable as s0 + | on t0.id = s0.id + | when matched and id = 10 then delete + | when matched and id < 10 then update set name='sxx', price=s0.price*2, ts=s0.ts+10000, dt=s0.dt + | when not matched and id > 10 then insert * + """.stripMargin) + checkAnswer(s"select id,name,price,ts,dt from $targetTable order by id")( + Seq(7, "a7", 70, 1007, "2021-03-21"), + Seq(8, "sxx", 160, 12008, "2021-03-21"), + Seq(9, "sxx", 180, 12009, "2021-03-21"), + Seq(11, "s11", 110, 2011, "2021-03-21"), + Seq(12, "s12", 120, 2012, "2021-03-21") + ) + } + } + test("Merge Hudi to Hudi") { withTempDir { tmp => Seq("cow", "mor").foreach { tableType =>