diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionPayload.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionPayload.scala index a9ae7c6d2dc45..8ea9e71a07ddd 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionPayload.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionPayload.scala @@ -21,7 +21,9 @@ import java.util.{Base64, Properties} import java.util.concurrent.Callable import scala.collection.JavaConverters._ import com.google.common.cache.CacheBuilder -import org.apache.avro.Schema +import org.apache.avro.Conversions.DecimalConversion +import org.apache.avro.Schema.Type +import org.apache.avro.{LogicalTypes, Schema} import org.apache.avro.generic.{GenericData, GenericRecord, IndexedRecord} import org.apache.avro.util.Utf8 import org.apache.hudi.DataSourceWriteOptions._ @@ -33,9 +35,9 @@ import org.apache.hudi.config.HoodieWriteConfig import org.apache.hudi.io.HoodieWriteHandle import org.apache.hudi.sql.IExpressionEvaluator import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical.Assignment import org.apache.spark.sql.hudi.SerDeUtils import org.apache.spark.sql.hudi.command.payload.ExpressionPayload.getEvaluator +import org.apache.spark.sql.types.Decimal import org.apache.spark.unsafe.types.UTF8String import scala.collection.mutable.ArrayBuffer @@ -84,7 +86,8 @@ class ExpressionPayload(record: GenericRecord, var resultRecordOpt: HOption[IndexedRecord] = null // Get the Evaluator for each condition and update assignments. - val updateConditionAndAssignments = getEvaluator(updateConditionAndAssignmentsText.toString) + initWriteSchemaIfNeed(properties) + val updateConditionAndAssignments = getEvaluator(updateConditionAndAssignmentsText.toString, writeSchema) for ((conditionEvaluator, assignmentEvaluator) <- updateConditionAndAssignments if resultRecordOpt == null) { val conditionVal = evaluate(conditionEvaluator, joinSqlRecord).head.asInstanceOf[Boolean] @@ -92,7 +95,6 @@ class ExpressionPayload(record: GenericRecord, // to compute final record to update. We will return the first matched record. if (conditionVal) { val results = evaluate(assignmentEvaluator, joinSqlRecord) - initWriteSchemaIfNeed(properties) val resultRecord = convertToRecord(results, writeSchema) if (needUpdatingPersistedRecord(targetRecord, resultRecord, properties)) { @@ -108,7 +110,7 @@ class ExpressionPayload(record: GenericRecord, // Process delete val deleteConditionText = properties.get(ExpressionPayload.PAYLOAD_DELETE_CONDITION) if (deleteConditionText != null) { - val deleteCondition = getEvaluator(deleteConditionText.toString).head._1 + val deleteCondition = getEvaluator(deleteConditionText.toString, writeSchema).head._1 val deleteConditionVal = evaluate(deleteCondition, joinSqlRecord).head.asInstanceOf[Boolean] if (deleteConditionVal) { resultRecordOpt = HOption.empty() @@ -134,8 +136,9 @@ class ExpressionPayload(record: GenericRecord, // Process insert val sqlTypedRecord = new SqlTypedRecord(incomingRecord) // Get the evaluator for each condition and insert assignment. + initWriteSchemaIfNeed(properties) val insertConditionAndAssignments = - ExpressionPayload.getEvaluator(insertConditionAndAssignmentsText.toString) + ExpressionPayload.getEvaluator(insertConditionAndAssignmentsText.toString, writeSchema) var resultRecordOpt: HOption[IndexedRecord] = null for ((conditionEvaluator, assignmentEvaluator) <- insertConditionAndAssignments if resultRecordOpt == null) { @@ -144,7 +147,6 @@ class ExpressionPayload(record: GenericRecord, // result record. We will return the first matched record. if (conditionVal) { val results = evaluate(assignmentEvaluator, sqlTypedRecord) - initWriteSchemaIfNeed(properties) resultRecordOpt = HOption.of(convertToRecord(results, writeSchema)) } } @@ -153,7 +155,7 @@ class ExpressionPayload(record: GenericRecord, if (resultRecordOpt == null && isMORTable(properties)) { val deleteConditionText = properties.get(ExpressionPayload.PAYLOAD_DELETE_CONDITION) if (deleteConditionText != null) { - val deleteCondition = getEvaluator(deleteConditionText.toString).head._1 + val deleteCondition = getEvaluator(deleteConditionText.toString, writeSchema).head._1 val deleteConditionVal = evaluate(deleteCondition, sqlTypedRecord).head.asInstanceOf[Boolean] if (deleteConditionVal) { resultRecordOpt = HOption.empty() @@ -269,19 +271,19 @@ object ExpressionPayload { * @return */ def getEvaluator( - serializedConditionAssignments: String): Map[IExpressionEvaluator, IExpressionEvaluator] = { + serializedConditionAssignments: String, writeSchema: Schema): Map[IExpressionEvaluator, IExpressionEvaluator] = { cache.get(serializedConditionAssignments, new Callable[Map[IExpressionEvaluator, IExpressionEvaluator]] { override def call(): Map[IExpressionEvaluator, IExpressionEvaluator] = { val serializedBytes = Base64.getDecoder.decode(serializedConditionAssignments) val conditionAssignments = SerDeUtils.toObject(serializedBytes) - .asInstanceOf[Map[Expression, Seq[Assignment]]] + .asInstanceOf[Map[Expression, Seq[Expression]]] // Do the CodeGen for condition expression and assignment expression conditionAssignments.map { case (condition, assignments) => val conditionEvaluator = ExpressionCodeGen.doCodeGen(Seq(condition)) - val assignmentEvaluator = StringConvertEvaluator(ExpressionCodeGen.doCodeGen(assignments)) + val assignmentEvaluator = AvroTypeConvertEvaluator(ExpressionCodeGen.doCodeGen(assignments), writeSchema) conditionEvaluator -> assignmentEvaluator } } @@ -289,17 +291,29 @@ object ExpressionPayload { } /** - * As the "baseEvaluator" return "UTF8String" for the string type which cannot be process by - * the Avro, The StringConvertEvaluator will convert the "UTF8String" to "Utf8". + * A IExpressionEvaluator wrapped the base evaluator which convert the result of the base evaluator + * to the avro typed-value. */ - case class StringConvertEvaluator(baseEvaluator: IExpressionEvaluator) extends IExpressionEvaluator { + case class AvroTypeConvertEvaluator(baseEvaluator: IExpressionEvaluator, writeSchema: Schema) extends IExpressionEvaluator { + private lazy val decimalConversions = new DecimalConversion() + /** - * Convert the UTF8String to Utf8 + * Convert to the avro typed-value. + * e.g. convert UTF8String -> Utf8, Dicimal -> GenericFixed. */ override def eval(record: IndexedRecord): Array[AnyRef] = { - baseEvaluator.eval(record).map { - case s: UTF8String => new Utf8(s.toString) - case o => o + baseEvaluator.eval(record).zipWithIndex.map { + case (s: UTF8String, _) => new Utf8(s.toString) + case (d: Decimal, i) => + val schema = writeSchema.getFields.get(i).schema() + val fixedSchema = if (schema.getType == Type.UNION) { + schema.getTypes.asScala.filter(s => s.getType != Type.NULL).head + } else { + schema + } + decimalConversions.toFixed(d.toJavaBigDecimal, fixedSchema + , LogicalTypes.decimal(d.precision, d.scale)) + case (o, _) => o } } diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/SqlTypedRecord.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/SqlTypedRecord.scala index d906f0c502cd8..6895ca8117840 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/SqlTypedRecord.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/SqlTypedRecord.scala @@ -27,7 +27,6 @@ import org.apache.avro.generic.{GenericFixed, IndexedRecord} import org.apache.avro.util.Utf8 import org.apache.avro.{LogicalTypes, Schema} import org.apache.spark.sql.avro.{IncompatibleSchemaException, SchemaConverters} -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestHoodieSqlBase.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestHoodieSqlBase.scala index c8291db83af84..67b52754cb833 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestHoodieSqlBase.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestHoodieSqlBase.scala @@ -86,4 +86,11 @@ class TestHoodieSqlBase extends FunSuite with BeforeAndAfterAll { assertResult(errorMsg)(e.getMessage) } } + + protected def removeQuotes(value: Any): Any = { + value match { + case s: String => s.stripPrefix("'").stripSuffix("'") + case _=> value + } + } } diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala index 945ccf5382915..8ae82d68638c7 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala @@ -277,11 +277,4 @@ class TestInsertTable extends TestHoodieSqlBase { " count: 3,columns: (1,a1,10)" ) } - - private def removeQuotes(value: Any): Any = { - value match { - case s: String => s.stripPrefix("'").stripSuffix("'") - case _=> value - } - } } diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable.scala index 1c20a82fd0cce..403d93e1109fc 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.hudi import org.apache.hudi.{DataSourceReadOptions, HoodieDataSourceHelpers} import org.apache.hudi.common.fs.FSUtils - class TestMergeIntoTable extends TestHoodieSqlBase { test("Test MergeInto Basic") { @@ -687,4 +686,50 @@ class TestMergeIntoTable extends TestHoodieSqlBase { } } } + + test("Test MereInto With All Kinds Of DataType") { + withTempDir { tmp => + val dataAndTypes = Seq( + ("string", "'a1'"), + ("int", "10"), + ("bigint", "10"), + ("double", "10.0"), + ("float", "10.0"), + ("decimal(5,2)", "10.11"), + ("decimal(5,0)", "10"), + ("timestamp", "'2021-05-20 00:00:00'"), + ("date", "'2021-05-20'") + ) + dataAndTypes.foreach { case (dataType, dataValue) => + val tableName = generateTableName + spark.sql( + s""" + |create table $tableName ( + | id int, + | name string, + | value $dataType, + | ts long + |) using hudi + | location '${tmp.getCanonicalPath}/$tableName' + | options ( + | primaryKey ='id', + | preCombineField = 'ts' + | ) + """.stripMargin) + + spark.sql( + s""" + |merge into $tableName h0 + |using ( + | select 1 as id, 'a1' as name, cast($dataValue as $dataType) as value, 1000 as ts + | ) s0 + | on h0.id = s0.id + | when not matched then insert * + |""".stripMargin) + checkAnswer(s"select id, name, cast(value as string), ts from $tableName")( + Seq(1, "a1", removeQuotes(dataValue), 1000) + ) + } + } + } }