diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieInternalRowUtils.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieInternalRowUtils.scala index 4d2ee33d154e9..3ea801177fbd9 100644 --- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieInternalRowUtils.scala +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieInternalRowUtils.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.HoodieUnsafeRowUtils.{NestedFieldPath, composeNested import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.types.Decimal.ROUND_HALF_EVEN +import org.apache.spark.sql.types.Decimal.ROUND_HALF_UP import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -286,7 +286,7 @@ object HoodieInternalRowUtils { (fieldUpdater, ordinal, value) => val scale = newDecimal.scale // TODO this has to be revisited to avoid loss of precision (for fps) - fieldUpdater.setDecimal(ordinal, Decimal.fromDecimal(BigDecimal(value.toString).setScale(scale, ROUND_HALF_EVEN))) + fieldUpdater.setDecimal(ordinal, Decimal.fromDecimal(BigDecimal(value.toString).setScale(scale, ROUND_HALF_UP))) case _: DecimalType => (fieldUpdater, ordinal, value) => diff --git a/hudi-common/src/main/java/org/apache/hudi/avro/HoodieAvroUtils.java b/hudi-common/src/main/java/org/apache/hudi/avro/HoodieAvroUtils.java index b66d22b046a04..98cd4e290b595 100644 --- a/hudi-common/src/main/java/org/apache/hudi/avro/HoodieAvroUtils.java +++ b/hudi-common/src/main/java/org/apache/hudi/avro/HoodieAvroUtils.java @@ -62,6 +62,7 @@ import java.io.IOException; import java.math.BigDecimal; import java.math.BigInteger; +import java.math.RoundingMode; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.sql.Date; @@ -1017,15 +1018,8 @@ private static Object rewritePrimaryTypeWithDiffSchemaType(Object oldValue, Sche || oldSchema.getType() == Schema.Type.LONG || oldSchema.getType() == Schema.Type.FLOAT) { LogicalTypes.Decimal decimal = (LogicalTypes.Decimal) newSchema.getLogicalType(); - BigDecimal bigDecimal = null; - if (oldSchema.getType() == Schema.Type.STRING) { - bigDecimal = new java.math.BigDecimal(oldValue.toString()) - .setScale(decimal.getScale()); - } else { - // Due to Java, there will be precision problems in direct conversion, we should use string instead of use double - bigDecimal = new java.math.BigDecimal(oldValue.toString()) - .setScale(decimal.getScale()); - } + // due to Java, there will be precision problems in direct conversion, we should use string instead of use double + BigDecimal bigDecimal = new java.math.BigDecimal(oldValue.toString()).setScale(decimal.getScale(), RoundingMode.HALF_UP); return DECIMAL_CONVERSION.toFixed(bigDecimal, newSchema, newSchema.getLogicalType()); } } @@ -1077,7 +1071,7 @@ private static Schema getActualSchemaFromUnion(Schema schema, Object data) { } else if (schema.getTypes().size() == 1) { actualSchema = schema.getTypes().get(0); } else { - // deal complex union. this should not happened in hoodie, + // deal complex union. this should not happen in hoodie, // since flink/spark do not write this type. int i = GenericData.get().resolveUnion(schema, data); actualSchema = schema.getTypes().get(i); @@ -1110,10 +1104,10 @@ public static HoodieRecord createHoodieRecordFromAvro( /** * Given avro records, rewrites them with new schema. * - * @param oldRecords oldRecords to be rewrite - * @param newSchema newSchema used to rewrite oldRecord + * @param oldRecords oldRecords to be rewritten + * @param newSchema newSchema used to rewrite oldRecord * @param renameCols a map store all rename cols, (k, v)-> (colNameFromNewSchema, colNameFromOldSchema) - * @return an iterator of rewrote {@link GenericRecord} + * @return a iterator of rewritten GenericRecords */ public static Iterator rewriteRecordWithNewSchema(Iterator oldRecords, Schema newSchema, Map renameCols, boolean validate) { if (oldRecords == null || newSchema == null) { diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestSpark3DDL.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestSpark3DDL.scala index b41f7dc67da71..b1f9347d7e8d9 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestSpark3DDL.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestSpark3DDL.scala @@ -25,6 +25,7 @@ import org.apache.hudi.common.model.HoodieRecord import org.apache.hudi.common.model.HoodieRecord.HoodieRecordType import org.apache.hudi.common.testutils.{HoodieTestDataGenerator, RawTripTestPayload} import org.apache.hudi.config.HoodieWriteConfig +import org.apache.hudi.testutils.DataSourceTestUtils import org.apache.hudi.{DataSourceWriteOptions, HoodieSparkRecordMerger, HoodieSparkUtils} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.functions.{arrays_zip, col, expr, lit} @@ -812,4 +813,171 @@ class TestSpark3DDL extends HoodieSparkSqlTestBase { } } } + + test("Test FLOAT to DECIMAL schema evolution (lost in scale)") { + Seq("cow", "mor").foreach { tableType => + withTempDir { tmp => + // Using INMEMORY index for mor table so that log files will be created instead of parquet + val tableName = generateTableName + if (HoodieSparkUtils.gteqSpark3_1) { + spark.sql( + s""" + |create table $tableName ( + | id int, + | name string, + | price float, + | ts long + |) using hudi + | location '${tmp.getCanonicalPath}' + | tblproperties ( + | primaryKey = 'id', + | type = '$tableType', + | preCombineField = 'ts' + | ${if (tableType.equals("mor")) ", hoodie.index.type = 'INMEMORY'" else ""} + | ) + """.stripMargin) + + spark.sql(s"insert into $tableName values (1, 'a1', 10.024, 1000)") + + assertResult(tableType.equals("mor"))(DataSourceTestUtils.isLogFileOnly(tmp.getCanonicalPath)) + + spark.sql("set hoodie.schema.on.read.enable=true") + spark.sql(s"alter table $tableName alter column price type decimal(4, 2)") + + // Not checking answer as this is an unsafe casting operation, just need to make sure that error is not thrown + spark.sql(s"select id, name, cast(price as string), ts from $tableName") + } + } + } + } + + test("Test DOUBLE to DECIMAL schema evolution (lost in scale)") { + Seq("cow", "mor").foreach { tableType => + withTempDir { tmp => + // Using INMEMORY index for mor table so that log files will be created instead of parquet + val tableName = generateTableName + if (HoodieSparkUtils.gteqSpark3_1) { + spark.sql( + s""" + |create table $tableName ( + | id int, + | name string, + | price double, + | ts long + |) using hudi + | location '${tmp.getCanonicalPath}' + | tblproperties ( + | primaryKey = 'id', + | type = '$tableType', + | preCombineField = 'ts' + | ${if (tableType.equals("mor")) ", hoodie.index.type = 'INMEMORY'" else ""} + | ) + """.stripMargin) + + spark.sql(s"insert into $tableName values " + + // testing the rounding behaviour to ensure that HALF_UP is used for positive values + "(1, 'a1', 10.024, 1000)," + + "(2, 'a2', 10.025, 1000)," + + "(3, 'a3', 10.026, 1000)," + + // testing the rounding behaviour to ensure that HALF_UP is used for negative values + "(4, 'a4', -10.024, 1000)," + + "(5, 'a5', -10.025, 1000)," + + "(6, 'a6', -10.026, 1000)," + + // testing the GENERAL rounding behaviour (HALF_UP and HALF_EVEN will retain the same result) + "(7, 'a7', 10.034, 1000)," + + "(8, 'a8', 10.035, 1000)," + + "(9, 'a9', 10.036, 1000)," + + // testing the GENERAL rounding behaviour (HALF_UP and HALF_EVEN will retain the same result) + "(10, 'a10', -10.034, 1000)," + + "(11, 'a11', -10.035, 1000)," + + "(12, 'a12', -10.036, 1000)") + + assertResult(tableType.equals("mor"))(DataSourceTestUtils.isLogFileOnly(tmp.getCanonicalPath)) + + spark.sql("set hoodie.schema.on.read.enable=true") + spark.sql(s"alter table $tableName alter column price type decimal(4, 2)") + + checkAnswer(s"select id, name, cast(price as string), ts from $tableName order by id")( + Seq(1, "a1", "10.02", 1000), + Seq(2, "a2", "10.03", 1000), + Seq(3, "a3", "10.03", 1000), + Seq(4, "a4", "-10.02", 1000), + Seq(5, "a5", "-10.03", 1000), + Seq(6, "a6", "-10.03", 1000), + Seq(7, "a7", "10.03", 1000), + Seq(8, "a8", "10.04", 1000), + Seq(9, "a9", "10.04", 1000), + Seq(10, "a10", "-10.03", 1000), + Seq(11, "a11", "-10.04", 1000), + Seq(12, "a12", "-10.04", 1000) + ) + } + } + } + } + + test("Test STRING to DECIMAL schema evolution (lost in scale)") { + Seq("cow", "mor").foreach { tableType => + withTempDir { tmp => + // Using INMEMORY index for mor table so that log files will be created instead of parquet + val tableName = generateTableName + if (HoodieSparkUtils.gteqSpark3_1) { + spark.sql( + s""" + |create table $tableName ( + | id int, + | name string, + | price string, + | ts long + |) using hudi + | location '${tmp.getCanonicalPath}' + | tblproperties ( + | primaryKey = 'id', + | type = '$tableType', + | preCombineField = 'ts' + | ${if (tableType.equals("mor")) ", hoodie.index.type = 'INMEMORY'" else ""} + | ) + """.stripMargin) + + spark.sql(s"insert into $tableName values " + + // testing the rounding behaviour to ensure that HALF_UP is used for positive values + "(1, 'a1', '10.024', 1000)," + + "(2, 'a2', '10.025', 1000)," + + "(3, 'a3', '10.026', 1000)," + + // testing the rounding behaviour to ensure that HALF_UP is used for negative values + "(4, 'a4', '-10.024', 1000)," + + "(5, 'a5', '-10.025', 1000)," + + "(6, 'a6', '-10.026', 1000)," + + // testing the GENERAL rounding behaviour (HALF_UP and HALF_EVEN will retain the same result) + "(7, 'a7', '10.034', 1000)," + + "(8, 'a8', '10.035', 1000)," + + "(9, 'a9', '10.036', 1000)," + + // testing the GENERAL rounding behaviour (HALF_UP and HALF_EVEN will retain the same result) + "(10, 'a10', '-10.034', 1000)," + + "(11, 'a11', '-10.035', 1000)," + + "(12, 'a12', '-10.036', 1000)") + + assertResult(tableType.equals("mor"))(DataSourceTestUtils.isLogFileOnly(tmp.getCanonicalPath)) + + spark.sql("set hoodie.schema.on.read.enable=true") + spark.sql(s"alter table $tableName alter column price type decimal(4, 2)") + + checkAnswer(s"select id, name, cast(price as string), ts from $tableName order by id")( + Seq(1, "a1", "10.02", 1000), + Seq(2, "a2", "10.03", 1000), + Seq(3, "a3", "10.03", 1000), + Seq(4, "a4", "-10.02", 1000), + Seq(5, "a5", "-10.03", 1000), + Seq(6, "a6", "-10.03", 1000), + Seq(7, "a7", "10.03", 1000), + Seq(8, "a8", "10.04", 1000), + Seq(9, "a9", "10.04", 1000), + Seq(10, "a10", "-10.03", 1000), + Seq(11, "a11", "-10.04", 1000), + Seq(12, "a12", "-10.04", 1000) + ) + } + } + } + } }