diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala index c178d1b84919e..a8fe758bf2e85 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.avro import org.apache.avro.LogicalTypes.{Date, Decimal, TimestampMicros, TimestampMillis} import org.apache.avro.Schema.Type._ import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder} +import org.apache.hudi.avro.AvroSchemaUtils.isNullable import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.types.Decimal.minBytesForPrecision import org.apache.spark.sql.types._ @@ -202,7 +203,12 @@ private[sql] object SchemaConverters { st.foreach { f => val fieldAvroType = toAvroType(f.dataType, f.nullable, f.name, childNameSpace) - fieldsAssembler.name(f.name).`type`(fieldAvroType).noDefault() + val fieldBuilder = fieldsAssembler.name(f.name).`type`(fieldAvroType) + if (isNullable(fieldAvroType)) { + fieldBuilder.withDefault(null) + } else { + fieldBuilder.noDefault() + } } fieldsAssembler.endRecord() } @@ -212,7 +218,7 @@ private[sql] object SchemaConverters { } if (nullable && catalystType != NullType && schema.getType != Schema.Type.UNION) { - Schema.createUnion(schema, nullSchema) + Schema.createUnion(nullSchema, schema) } else { schema } diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestUpdateTable.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestUpdateTable.scala index 8937e8595d389..cd0818d7efe4b 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestUpdateTable.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestUpdateTable.scala @@ -204,4 +204,52 @@ class TestUpdateTable extends HoodieSparkSqlTestBase { } }) } + + test("Test Add Column and Update Table") { + withTempDir { tmp => + val tableName = generateTableName + + spark.sql("SET hoodie.datasource.read.extract.partition.values.from.path=true") + + // create table + spark.sql( + s""" + |create table $tableName ( + | id int, + | name string, + | price double, + | ts long + |) using hudi + | location '${tmp.getCanonicalPath}/$tableName' + | tblproperties ( + | type = 'mor', + | primaryKey = 'id', + | preCombineField = 'ts' + | ) + """.stripMargin) + + // insert data to table + spark.sql(s"insert into $tableName select 1, 'a1', 10, 1000") + checkAnswer(s"select id, name, price, ts from $tableName")( + Seq(1, "a1", 10.0, 1000) + ) + + spark.sql(s"update $tableName set price = 22 where id = 1") + checkAnswer(s"select id, name, price, ts from $tableName")( + Seq(1, "a1", 22.0, 1000) + ) + + spark.sql(s"alter table $tableName add column new_col1 int") + + checkAnswer(s"select id, name, price, ts, new_col1 from $tableName")( + Seq(1, "a1", 22.0, 1000, null) + ) + + // update and check + spark.sql(s"update $tableName set price = price * 2 where id = 1") + checkAnswer(s"select id, name, price, ts, new_col1 from $tableName")( + Seq(1, "a1", 44.0, 1000, null) + ) + } + } }