diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionCodeGen.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionCodeGen.scala index 947291d10373b..cd5b201f91eb9 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionCodeGen.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionCodeGen.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.avro.AvroSerializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, LeafExpression, UnsafeArrayData, UnsafeMapData, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Cast, Expression, GenericInternalRow, LeafExpression, UnsafeArrayData, UnsafeMapData, UnsafeRow} import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.hudi.command.payload.ExpressionCodeGen.RECORD_NAME import org.apache.spark.sql.types.{DataType, Decimal} @@ -122,7 +122,8 @@ object ExpressionCodeGen extends Logging { classOf[IndexedRecord].getName, classOf[AvroSerializer].getName, classOf[GenericRecord].getName, - classOf[GenericInternalRow].getName + classOf[GenericInternalRow].getName, + classOf[Cast].getName ) evaluator.setImplementedInterfaces(Array(classOf[IExpressionEvaluator])) try { diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala index b77b5c3dbdbf8..8e6acd1be58c6 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala @@ -673,4 +673,44 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { ) } } + + test ("Test Merge into with String cast to Double") { + withTempDir { tmp => + val tableName = generateTableName + // Create a cow partitioned table. + spark.sql( + s""" + | create table $tableName ( + | id int, + | name string, + | price double, + | ts long, + | dt string + | ) using hudi + | tblproperties ( + | type = 'cow', + | primaryKey = 'id', + | preCombineField = 'ts' + | ) + | partitioned by(dt) + | location '${tmp.getCanonicalPath}' + """.stripMargin) + // Insert data + spark.sql(s"insert into $tableName select 1, 'a1', cast(10.0 as double), 999, '2021-03-21'") + spark.sql( + s""" + | merge into $tableName as t0 + | using ( + | select 'a1' as name, 1 as id, '10.1' as price, 1000 as ts, '2021-03-21' as dt + | ) as s0 + | on t0.id = s0.id + | when matched then update set t0.price = s0.price, t0.ts = s0.ts + | when not matched then insert * + """.stripMargin + ) + checkAnswer(s"select id,name,price,dt from $tableName")( + Seq(1, "a1", 10.1, "2021-03-21") + ) + } + } }