Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -181,38 +181,66 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Hoodie
var sourceDF = Dataset.ofRows(sparkSession, mergeInto.sourceTable)
targetKey2SourceExpression.foreach {
case (targetColumn, sourceExpression)
if !isEqualToTarget(targetColumn, sourceExpression) =>
if !containsPrimaryKeyFieldReference(targetColumn, sourceExpression) =>
sourceDF = sourceDF.withColumn(targetColumn, new Column(sourceExpression))
sourceDFOutput = sourceDFOutput :+ AttributeReference(targetColumn, sourceExpression.dataType)()
case _=>
}
target2SourcePreCombineFiled.foreach {
case (targetPreCombineField, sourceExpression)
if !isEqualToTarget(targetPreCombineField, sourceExpression) =>
if !containsPreCombineFieldReference(targetPreCombineField, sourceExpression) =>
sourceDF = sourceDF.withColumn(targetPreCombineField, new Column(sourceExpression))
sourceDFOutput = sourceDFOutput :+ AttributeReference(targetPreCombineField, sourceExpression.dataType)()
case _=>
}
sourceDF
}

private def isEqualToTarget(targetColumnName: String, sourceExpression: Expression): Boolean = {
val sourceColumnName = sourceDFOutput.map(_.name)
/**
* Check whether the source expression has the same column name with target column.
*
* Merge condition cases that return true:
* 1) merge into .. on h0.id = s0.id ..
* 2) merge into .. on h0.id = cast(s0.id as int) ..
* "id" is primaryKey field of h0.
*/
private def containsPrimaryKeyFieldReference(targetColumnName: String, sourceExpression: Expression): Boolean = {
val sourceColumnNames = sourceDFOutput.map(_.name)
val resolver = sparkSession.sessionState.conf.resolver

sourceExpression match {
case attr: AttributeReference if sourceColumnName.find(resolver(_, attr.name)).get.equals(targetColumnName) => true
case attr: AttributeReference if sourceColumnNames.find(resolver(_, attr.name)).get.equals(targetColumnName) => true
// SPARK-35857: the definition of Cast has been changed in Spark3.2.
// Match the class type instead of call the `unapply` method.
case cast: Cast =>
cast.child match {
case attr: AttributeReference if sourceColumnName.find(resolver(_, attr.name)).get.equals(targetColumnName) => true
case attr: AttributeReference if sourceColumnNames.find(resolver(_, attr.name)).get.equals(targetColumnName) => true
case _ => false
}
case _=> false
}
}

/**
* Check whether the source expression on preCombine field contains the same column name with target column.
*
* Merge expression cases that return true:
* 1) merge into .. on .. update set ts = s0.ts
* 2) merge into .. on .. update set ts = cast(s0.ts as int)
* 3) merge into .. on .. update set ts = s0.ts+1 (expressions like this whose sub node has the same column name with target)
* "ts" is preCombine field of h0.
*/
private def containsPreCombineFieldReference(targetColumnName: String, sourceExpression: Expression): Boolean = {
val sourceColumnNames = sourceDFOutput.map(_.name)
val resolver = sparkSession.sessionState.conf.resolver

// sub node of the expression may have same column name with target column name
sourceExpression.find {
case attr: AttributeReference => sourceColumnNames.find(resolver(_, attr.name)).get.equals(targetColumnName)
case _ => false
}.isDefined
}

/**
* Compare a [[Attribute]] to another, return true if they have the same column name(by resolver) and exprId
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,187 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase {
}
}

test("Test MergeInto with preCombine field expression") {
withTempDir { tmp =>
Seq("cow", "mor").foreach { tableType =>
val tableName1 = generateTableName
spark.sql(
s"""
| create table $tableName1 (
| id int,
| name string,
| price double,
| v long,
| dt string
| ) using hudi
| tblproperties (
| type = '$tableType',
| primaryKey = 'id',
| preCombineField = 'v'
| )
| partitioned by(dt)
| location '${tmp.getCanonicalPath}/$tableName1'
""".stripMargin)
// Insert data
spark.sql(s"""insert into $tableName1 values(1, 'a1', 10, 1000, '2021-03-21')""")

// Update data with a value expression on preCombine field
// 1) set source column name to be same as target column
spark.sql(
s"""
| merge into $tableName1 as t0
| using (
| select 1 as id, 'a1' as name, 11 as price, 999 as v, '2021-03-21' as dt
| ) as s0
| on t0.id = s0.id
| when matched then update set id=s0.id, name=s0.name, price=s0.price*2, v=s0.v+2, dt=s0.dt
""".stripMargin
)
// Update success as new value 1001 is bigger than original value 1000
checkAnswer(s"select id,name,price,dt,v from $tableName1")(
Seq(1, "a1", 22, "2021-03-21", 1001)
)

// 2) set source column name to be different with target column
spark.sql(
s"""
| merge into $tableName1 as t0
| using (
| select 1 as s_id, 'a1' as s_name, 12 as s_price, 1000 as s_v, '2021-03-21' as dt
| ) as s0
| on t0.id = s0.s_id
| when matched then update set id=s0.s_id, name=s0.s_name, price=s0.s_price*2, v=s0.s_v+2, dt=s0.dt
""".stripMargin
)
// Update success as new value 1002 is bigger than original value 1001
checkAnswer(s"select id,name,price,dt,v from $tableName1")(
Seq(1, "a1", 24, "2021-03-21", 1002)
)
}
}
}

test("Test MergeInto with primaryKey expression") {
withTempDir { tmp =>
val tableName1 = generateTableName
spark.sql(
s"""
| create table $tableName1 (
| id int,
| name string,
| price double,
| v long,
| dt string
| ) using hudi
| tblproperties (
| type = 'cow',
| primaryKey = 'id',
| preCombineField = 'v'
| )
| partitioned by(dt)
| location '${tmp.getCanonicalPath}/$tableName1'
""".stripMargin)
// Insert data
spark.sql(s"""insert into $tableName1 values(3, 'a3', 30, 3000, '2021-03-21')""")
spark.sql(s"""insert into $tableName1 values(2, 'a2', 20, 2000, '2021-03-21')""")
spark.sql(s"""insert into $tableName1 values(1, 'a1', 10, 1000, '2021-03-21')""")

// Delete data with a condition expression on primaryKey field
// 1) set source column name to be same as target column
spark.sql(
s"""
| merge into $tableName1 t0
| using (
| select 1 as id, 'a1' as name, 15 as price, 1001 as v, '2021-03-21' as dt
| ) s0
| on t0.id = s0.id + 1
| when matched then delete
""".stripMargin
)
checkAnswer(s"select id,name,price,v,dt from $tableName1 order by id")(
Seq(1, "a1", 10, 1000, "2021-03-21"),
Seq(3, "a3", 30, 3000, "2021-03-21")
)

// 2) set source column name to be different with target column
spark.sql(
s"""
| merge into $tableName1 t0
| using (
| select 2 as s_id, 'a1' as s_name, 15 as s_price, 1001 as s_v, '2021-03-21' as dt
| ) s0
| on t0.id = s0.s_id + 1
| when matched then delete
""".stripMargin
)
checkAnswer(s"select id,name,price,v,dt from $tableName1 order by id")(
Seq(1, "a1", 10, 1000, "2021-03-21")
)
}
}

test("Test MergeInto with combination of delete update insert") {
withTempDir { tmp =>
val sourceTable = generateTableName
val targetTable = generateTableName
// Create source table
spark.sql(
s"""
| create table $sourceTable (
| id int,
| name string,
| price double,
| ts long,
| dt string
| ) using parquet
| location '${tmp.getCanonicalPath}/$sourceTable'
""".stripMargin)
spark.sql(s"insert into $sourceTable values(8, 's8', 80, 2008, '2021-03-21')")
spark.sql(s"insert into $sourceTable values(9, 's9', 90, 2009, '2021-03-21')")
spark.sql(s"insert into $sourceTable values(10, 's10', 100, 2010, '2021-03-21')")
spark.sql(s"insert into $sourceTable values(11, 's11', 110, 2011, '2021-03-21')")
spark.sql(s"insert into $sourceTable values(12, 's12', 120, 2012, '2021-03-21')")
// Create target table
spark.sql(
s"""
|create table $targetTable (
| id int,
| name string,
| price double,
| ts long,
| dt string
|) using hudi
| tblproperties (
| primaryKey ='id',
| preCombineField = 'ts'
| )
| partitioned by(dt)
| location '${tmp.getCanonicalPath}/$targetTable'
""".stripMargin)
spark.sql(s"insert into $targetTable values(7, 'a7', 70, 1007, '2021-03-21')")
spark.sql(s"insert into $targetTable values(8, 'a8', 80, 1008, '2021-03-21')")
spark.sql(s"insert into $targetTable values(9, 'a9', 90, 1009, '2021-03-21')")
spark.sql(s"insert into $targetTable values(10, 'a10', 100, 1010, '2021-03-21')")

spark.sql(
s"""
| merge into $targetTable as t0
| using $sourceTable as s0
| on t0.id = s0.id
| when matched and id = 10 then delete
| when matched and id < 10 then update set name='sxx', price=s0.price*2, ts=s0.ts+10000, dt=s0.dt
| when not matched and id > 10 then insert *
""".stripMargin)
checkAnswer(s"select id,name,price,ts,dt from $targetTable order by id")(
Seq(7, "a7", 70, 1007, "2021-03-21"),
Seq(8, "sxx", 160, 12008, "2021-03-21"),
Seq(9, "sxx", 180, 12009, "2021-03-21"),
Seq(11, "s11", 110, 2011, "2021-03-21"),
Seq(12, "s12", 120, 2012, "2021-03-21")
)
}
}

test("Merge Hudi to Hudi") {
withTempDir { tmp =>
Seq("cow", "mor").foreach { tableType =>
Expand Down