Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[Logi
case mergeInto @ MergeIntoTable(target, source, mergeCondition, matchedActions, notMatchedActions)
if isHoodieTable(target, sparkSession) && target.resolved =>

val resolver = sparkSession.sessionState.conf.resolver
val resolvedSource = analyzer.execute(source)
def isInsertOrUpdateStar(assignments: Seq[Assignment]): Boolean = {
if (assignments.isEmpty) {
Expand Down Expand Up @@ -161,23 +162,21 @@ case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[Logi
val resolvedCondition = condition.map(resolveExpressionFrom(resolvedSource)(_))
val resolvedAssignments = if (isInsertOrUpdateStar(assignments)) {
// assignments is empty means insert * or update set *
val resolvedSourceOutputWithoutMetaFields = resolvedSource.output.filter(attr => !HoodieSqlUtils.isMetaField(attr.name))
val targetOutputWithoutMetaFields = target.output.filter(attr => !HoodieSqlUtils.isMetaField(attr.name))
val resolvedSourceColumnNamesWithoutMetaFields = resolvedSourceOutputWithoutMetaFields.map(_.name)
val targetColumnNamesWithoutMetaFields = targetOutputWithoutMetaFields.map(_.name)
val resolvedSourceOutput = resolvedSource.output.filter(attr => !HoodieSqlUtils.isMetaField(attr.name))
val targetOutput = target.output.filter(attr => !HoodieSqlUtils.isMetaField(attr.name))
val resolvedSourceColumnNames = resolvedSourceOutput.map(_.name)

if(targetColumnNamesWithoutMetaFields.toSet.subsetOf(resolvedSourceColumnNamesWithoutMetaFields.toSet)){
if(targetOutput.filter(attr => resolvedSourceColumnNames.exists(resolver(_, attr.name))).equals(targetOutput)){
//If sourceTable's columns contains all targetTable's columns,
//We fill assign all the source fields to the target fields by column name matching.
val sourceColNameAttrMap = resolvedSourceOutputWithoutMetaFields.map(attr => (attr.name, attr)).toMap
targetOutputWithoutMetaFields.map(targetAttr => {
val sourceAttr = sourceColNameAttrMap(targetAttr.name)
targetOutput.map(targetAttr => {
val sourceAttr = resolvedSourceOutput.find(f => resolver(f.name, targetAttr.name)).get
Assignment(targetAttr, sourceAttr)
})
} else {
// We fill assign all the source fields to the target fields by order.
targetOutputWithoutMetaFields
.zip(resolvedSourceOutputWithoutMetaFields)
targetOutput
.zip(resolvedSourceOutput)
.map { case (targetAttr, sourceAttr) => Assignment(targetAttr, sourceAttr) }
}
} else {
Expand Down Expand Up @@ -214,8 +213,9 @@ case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[Logi
}.toMap

// Validate if there are incorrect target attributes.
val targetColumnNames = removeMetaFields(target.output).map(_.name)
val unKnowTargets = target2Values.keys
.filterNot(removeMetaFields(target.output).map(_.name).contains(_))
.filterNot(name => targetColumnNames.exists(resolver(_, name)))
if (unKnowTargets.nonEmpty) {
throw new AnalysisException(s"Cannot find target attributes: ${unKnowTargets.mkString(",")}.")
}
Expand All @@ -224,19 +224,20 @@ case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[Logi
// e.g. If the update action missing 'id' attribute, we fill a "id = target.id" to the update action.
val newAssignments = removeMetaFields(target.output)
.map(attr => {
val valueOption = target2Values.find(f => resolver(f._1, attr.name))
// TODO support partial update for MOR.
if (!target2Values.contains(attr.name) && targetTableType == MOR_TABLE_TYPE_OPT_VAL) {
if (valueOption.isEmpty && targetTableType == MOR_TABLE_TYPE_OPT_VAL) {
throw new AnalysisException(s"Missing specify the value for target field: '${attr.name}' in merge into update action" +
s" for MOR table. Currently we cannot support partial update for MOR," +
s" please complete all the target fields just like '...update set id = s0.id, name = s0.name ....'")
}
if (preCombineField.isDefined && preCombineField.get.equalsIgnoreCase(attr.name)
&& !target2Values.contains(attr.name)) {
&& valueOption.isEmpty) {
throw new AnalysisException(s"Missing specify value for the preCombineField:" +
s" ${preCombineField.get} in merge-into update action. You should add" +
s" '... update set ${preCombineField.get} = xx....' to the when-matched clause.")
}
Assignment(attr, target2Values.getOrElse(attr.name, attr))
Assignment(attr, if (valueOption.isEmpty) attr else valueOption.get._2)
})
UpdateAction(resolvedCondition, newAssignments)
case DeleteAction(condition) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.hudi.hive.ddl.HiveSyncMode
import org.apache.hudi.{AvroConversionUtils, DataSourceWriteOptions, HoodieSparkSqlWriter, HoodieWriterUtils, SparkAdapterSupport}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, BoundReference, Cast, EqualTo, Expression, Literal}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.command.RunnableCommand
Expand Down Expand Up @@ -90,6 +91,7 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Runnab
* TODO Currently Non-equivalent conditions are not supported.
*/
private lazy val targetKey2SourceExpression: Map[String, Expression] = {
val resolver = sparkSession.sessionState.conf.resolver
val conditions = splitByAnd(mergeInto.mergeCondition)
val allEqs = conditions.forall(p => p.isInstanceOf[EqualTo])
if (!allEqs) {
Expand All @@ -101,11 +103,11 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Runnab
val target2Source = conditions.map(_.asInstanceOf[EqualTo])
.map {
case EqualTo(left: AttributeReference, right)
if targetAttrs.indexOf(left) >= 0 => // left is the target field
left.name -> right
if targetAttrs.exists(f => attributeEqual(f, left, resolver)) => // left is the target field
targetAttrs.find(f => resolver(f.name, left.name)).get.name -> right
case EqualTo(left, right: AttributeReference)
if targetAttrs.indexOf(right) >= 0 => // right is the target field
right.name -> left
if targetAttrs.exists(f => attributeEqual(f, right, resolver)) => // right is the target field
targetAttrs.find(f => resolver(f.name, right.name)).get.name -> left
case eq =>
throw new AnalysisException(s"Invalidate Merge-On condition: ${eq.sql}." +
"The validate condition should be 'targetColumn = sourceColumnExpression', e.g." +
Expand Down Expand Up @@ -196,15 +198,24 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Runnab
}

private def isEqualToTarget(targetColumnName: String, sourceExpression: Expression): Boolean = {
val sourceColNameMap = sourceDFOutput.map(attr => (attr.name.toLowerCase, attr.name)).toMap
val sourceColumnName = sourceDFOutput.map(_.name)
val resolver = sparkSession.sessionState.conf.resolver

sourceExpression match {
case attr: AttributeReference if sourceColNameMap(attr.name.toLowerCase).equals(targetColumnName) => true
case Cast(attr: AttributeReference, _, _) if sourceColNameMap(attr.name.toLowerCase).equals(targetColumnName) => true
case attr: AttributeReference if sourceColumnName.find(resolver(_, attr.name)).get.equals(targetColumnName) => true
case Cast(attr: AttributeReference, _, _) if sourceColumnName.find(resolver(_, attr.name)).get.equals(targetColumnName) => true
case _=> false
}
}

/**
* Compare a [[Attribute]] to another, return true if they have the same column name(by resolver) and exprId
*/
private def attributeEqual(
attr: Attribute, other: Attribute, resolver: Resolver): Boolean = {
resolver(attr.name, other.name) && attr.exprId == other.exprId
}

/**
* Execute the update and delete action. All the matched and not-matched actions will
* execute in one upsert write operation. We pushed down the matched condition and assignment
Expand Down Expand Up @@ -361,9 +372,9 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Runnab
mergeInto.targetTable.output
.filterNot(attr => isMetaField(attr.name))
.map(attr => {
val assignment = attr2Assignment.getOrElse(attr,
throw new IllegalArgumentException(s"Cannot find related assignment for field: ${attr.name}"))
castIfNeeded(assignment, attr.dataType, sparkSession.sqlContext.conf)
val assignment = attr2Assignment.find(f => attributeEqual(f._1, attr, sparkSession.sessionState.conf.resolver))
.getOrElse(throw new IllegalArgumentException(s"Cannot find related assignment for field: ${attr.name}"))
castIfNeeded(assignment._2, attr.dataType, sparkSession.sqlContext.conf)
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,4 +444,115 @@ class TestMergeIntoTable2 extends TestHoodieSqlBase {
}
}

test("Test ignoring case") {
withTempDir { tmp =>
val tableName = generateTableName
// Create table
spark.sql(
s"""
|create table $tableName (
| ID int,
| name string,
| price double,
| TS long,
| DT string
|) using hudi
| location '${tmp.getCanonicalPath}/$tableName'
| options (
| primaryKey ='ID',
| preCombineField = 'TS'
| )
""".stripMargin)

// First merge with a extra input field 'flag' (insert a new record)
spark.sql(
s"""
| merge into $tableName
| using (
| select 1 as id, 'a1' as name, 10 as PRICE, 1000 as ts, '2021-05-05' as dt, '1' as flag
| ) s0
| on s0.id = $tableName.id
| when matched and flag = '1' then update set
| id = s0.id, name = s0.name, PRICE = s0.price, ts = s0.ts, dt = s0.dt
| when not matched and flag = '1' then insert *
""".stripMargin)
checkAnswer(s"select id, name, price, ts, dt from $tableName")(
Seq(1, "a1", 10.0, 1000, "2021-05-05")
)

// Second merge (update the record)
spark.sql(
s"""
| merge into $tableName
| using (
| select 1 as id, 'a1' as name, 20 as PRICE, '2021-05-05' as dt, 1001 as ts
| ) s0
| on s0.id = $tableName.id
| when matched then update set
| id = s0.id, name = s0.name, PRICE = s0.price, ts = s0.ts, dt = s0.dt
| when not matched then insert *
""".stripMargin)
checkAnswer(s"select id, name, price, ts, dt from $tableName")(
Seq(1, "a1", 20.0, 1001, "2021-05-05")
)

// Test ignoring case when column name matches
spark.sql(
s"""
| merge into $tableName as t0
| using (
| select 1 as id, 'a1' as name, 1111 as ts, '2021-05-05' as dt, 111 as PRICE union all
| select 2 as id, 'a2' as name, 1112 as ts, '2021-05-05' as dt, 112 as PRICE
| ) as s0
| on t0.id = s0.id
| when matched then update set *
| when not matched then insert *
|""".stripMargin)
checkAnswer(s"select id, name, price, ts, dt from $tableName")(
Seq(1, "a1", 111.0, 1111, "2021-05-05"),
Seq(2, "a2", 112.0, 1112, "2021-05-05")
)
}
}

test("Test ignoring case for MOR table") {
withTempDir { tmp =>
val tableName = generateTableName
// Create a mor partitioned table.
spark.sql(
s"""
| create table $tableName (
| ID int,
| NAME string,
| price double,
| TS long,
| dt string
| ) using hudi
| options (
| type = 'mor',
| primaryKey = 'ID',
| preCombineField = 'TS'
| )
| partitioned by(dt)
| location '${tmp.getCanonicalPath}/$tableName'
""".stripMargin)

// Test ignoring case when column name matches
spark.sql(
s"""
| merge into $tableName as t0
| using (
| select 1 as id, 'a1' as NAME, 1111 as ts, '2021-05-05' as DT, 111 as price
| ) as s0
| on t0.id = s0.id
| when matched then update set *
| when not matched then insert *
""".stripMargin
)
checkAnswer(s"select id, name, price, ts, dt from $tableName")(
Seq(1, "a1", 111.0, 1111, "2021-05-05")
)
}
}

}