diff --git a/hudi-spark/src/main/scala/org/apache/hudi/HoodieSparkSqlWriter.scala b/hudi-spark/src/main/scala/org/apache/hudi/HoodieSparkSqlWriter.scala index fe0f3f5e47316..372ec127d9cf7 100644 --- a/hudi-spark/src/main/scala/org/apache/hudi/HoodieSparkSqlWriter.scala +++ b/hudi-spark/src/main/scala/org/apache/hudi/HoodieSparkSqlWriter.scala @@ -30,8 +30,8 @@ import org.apache.hudi.avro.HoodieAvroUtils import org.apache.hudi.client.HoodieWriteResult import org.apache.hudi.client.SparkRDDWriteClient import org.apache.hudi.common.config.TypedProperties -import org.apache.hudi.common.model.{HoodieRecordPayload, HoodieTableType, WriteOperationType} -import org.apache.hudi.common.table.{HoodieTableConfig, HoodieTableMetaClient} +import org.apache.hudi.common.model.{HoodieRecordPayload, HoodieTableType, OverwriteNonDefaultsWithLatestAvroPayload, WriteOperationType} +import org.apache.hudi.common.table.{HoodieTableConfig, HoodieTableMetaClient, TableSchemaResolver} import org.apache.hudi.common.table.timeline.HoodieActiveTimeline import org.apache.hudi.common.util.ReflectionUtils import org.apache.hudi.config.HoodieBootstrapConfig.{BOOTSTRAP_BASE_PATH_PROP, BOOTSTRAP_INDEX_CLASS_PROP, DEFAULT_BOOTSTRAP_INDEX_CLASS} @@ -107,13 +107,44 @@ private[hudi] object HoodieSparkSqlWriter { // Handle various save modes handleSaveModes(mode, basePath, tableConfig, tblName, operation, fs) // Create the table if not present - if (!tableExists) { + val dfFull = if (!tableExists) { val archiveLogFolder = parameters.getOrElse( HoodieTableConfig.HOODIE_ARCHIVELOG_FOLDER_PROP_NAME, "archived") val tableMetaClient = HoodieTableMetaClient.initTableType(sparkContext.hadoopConfiguration, path.get, tableType, tblName, archiveLogFolder, parameters(PAYLOAD_CLASS_OPT_KEY), null.asInstanceOf[String]) tableConfig = tableMetaClient.getTableConfig + df + } else { + val tableMetaClient = new HoodieTableMetaClient(sparkContext.hadoopConfiguration, path.get) + val tableSchemaResolver = new TableSchemaResolver(tableMetaClient) + val oldSchema = tableSchemaResolver.getTableAvroSchemaWithoutMetadataFields + val oldStructType = AvroConversionUtils.convertAvroSchemaToStructType(oldSchema) + val dfFields = df.schema.fields.map(sf => sf.name).toList + val missingField = oldStructType.fields.exists(f => !dfFields.contains(f.name)) + + val recordKeyFields = parameters(RECORDKEY_FIELD_OPT_KEY).split(",").map{ f => f.trim }.filter { p => !p.isEmpty }.toList + val partitionPathFields = parameters(PARTITIONPATH_FIELD_OPT_KEY).split(",").map{ f => f.trim }.filter { p => !p.isEmpty }.toList + val precombineField = parameters(PRECOMBINE_FIELD_OPT_KEY).trim + val keyFields = recordKeyFields ++ partitionPathFields :+ precombineField + val allKeysExist = dfFields.containsAll(keyFields) + + val isUpsert = UPSERT_OPERATION_OPT_VAL.equals(parameters(OPERATION_OPT_KEY)) + val isRequiredPayload = classOf[OverwriteNonDefaultsWithLatestAvroPayload].getName.equals(parameters(PAYLOAD_CLASS_OPT_KEY)) + if (isUpsert && isRequiredPayload && allKeysExist && missingField) { + //missing normal fields except key + val selectExprs = oldStructType.fields.map(f => { + if (dfFields.contains(f.name)) + f.name + else + s"cast(${oldSchema.getField(f.name).defaultVal()} as ${f.dataType.typeName}) as ${f.name}" + }).toList + + df.selectExpr(selectExprs:_*) + } else { + //missing key fields fallback to original logic + df + } } val commitActionType = DataSourceUtils.getCommitActionType(operation, tableConfig.getTableType) @@ -122,7 +153,7 @@ private[hudi] object HoodieSparkSqlWriter { // scalastyle:off if (parameters(ENABLE_ROW_WRITER_OPT_KEY).toBoolean && operation == WriteOperationType.BULK_INSERT) { - val (success, commitTime: common.util.Option[String]) = bulkInsertAsRow(sqlContext, parameters, df, tblName, + val (success, commitTime: common.util.Option[String]) = bulkInsertAsRow(sqlContext, parameters, dfFull, tblName, basePath, path, instantTime) return (success, commitTime, common.util.Option.empty(), hoodieWriteClient.orNull, tableConfig) } @@ -135,13 +166,13 @@ private[hudi] object HoodieSparkSqlWriter { sparkContext.getConf.registerKryoClasses( Array(classOf[org.apache.avro.generic.GenericData], classOf[org.apache.avro.Schema])) - val schema = AvroConversionUtils.convertStructTypeToAvroSchema(df.schema, structName, nameSpace) + val schema = AvroConversionUtils.convertStructTypeToAvroSchema(dfFull.schema, structName, nameSpace) sparkContext.getConf.registerAvroSchemas(schema) log.info(s"Registered avro schema : ${schema.toString(true)}") // Convert to RDD[HoodieRecord] val keyGenerator = DataSourceUtils.createKeyGenerator(toProperties(parameters)) - val genericRecords: RDD[GenericRecord] = AvroConversionUtils.createRdd(df, schema, structName, nameSpace) + val genericRecords: RDD[GenericRecord] = AvroConversionUtils.createRdd(dfFull, schema, structName, nameSpace) val shouldCombine = parameters(INSERT_DROP_DUPS_OPT_KEY).toBoolean || operation.equals(WriteOperationType.UPSERT); val hoodieAllIncomingRecords = genericRecords.map(gr => { val hoodieRecord = if (shouldCombine) { @@ -188,7 +219,7 @@ private[hudi] object HoodieSparkSqlWriter { // Convert to RDD[HoodieKey] val keyGenerator = DataSourceUtils.createKeyGenerator(toProperties(parameters)) - val genericRecords: RDD[GenericRecord] = AvroConversionUtils.createRdd(df, structName, nameSpace) + val genericRecords: RDD[GenericRecord] = AvroConversionUtils.createRdd(dfFull, structName, nameSpace) val hoodieKeysToDelete = genericRecords.map(gr => keyGenerator.getKey(gr)).toJavaRDD() if (!tableExists) { diff --git a/hudi-spark/src/test/scala/org/apache/hudi/functional/HoodieSparkSqlWriterSuite.scala b/hudi-spark/src/test/scala/org/apache/hudi/functional/HoodieSparkSqlWriterSuite.scala index 41a45b2b1f951..a0d517da18318 100644 --- a/hudi-spark/src/test/scala/org/apache/hudi/functional/HoodieSparkSqlWriterSuite.scala +++ b/hudi-spark/src/test/scala/org/apache/hudi/functional/HoodieSparkSqlWriterSuite.scala @@ -24,14 +24,14 @@ import java.util.{Collections, Date, UUID} import org.apache.commons.io.FileUtils import org.apache.hudi.DataSourceWriteOptions._ import org.apache.hudi.client.{SparkRDDWriteClient, TestBootstrap} -import org.apache.hudi.common.model.{HoodieRecord, HoodieRecordPayload} +import org.apache.hudi.common.model.{HoodieRecord, HoodieRecordPayload, OverwriteNonDefaultsWithLatestAvroPayload} import org.apache.hudi.common.testutils.HoodieTestDataGenerator import org.apache.hudi.config.{HoodieBootstrapConfig, HoodieWriteConfig} import org.apache.hudi.exception.HoodieException import org.apache.hudi.keygen.{NonpartitionedKeyGenerator, SimpleKeyGenerator} import org.apache.hudi.testutils.DataSourceTestUtils import org.apache.hudi.{AvroConversionUtils, DataSourceUtils, DataSourceWriteOptions, HoodieSparkSqlWriter, HoodieWriterUtils} -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.sql.{Row, SQLContext, SaveMode, SparkSession} import org.mockito.ArgumentMatchers.any @@ -342,6 +342,65 @@ class HoodieSparkSqlWriterSuite extends FunSuite with Matchers { } }) + test("test upsert dataset with specified columns") { + initSparkContext("test_upsert_with_specified_columns") + val path = java.nio.file.Files.createTempDirectory("hoodie_test_path") + try { + + val sqlContext = spark.sqlContext + val hoodieFooTableName = "hoodie_foo_tbl" + val sc = spark.sparkContext + + //create a new table + val fooTableModifier = Map("path" -> path.toAbsolutePath.toString, + HoodieWriteConfig.TABLE_NAME -> hoodieFooTableName, + "hoodie.upsert.shuffle.parallelism" -> "4", + DataSourceWriteOptions.OPERATION_OPT_KEY -> DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL, + DataSourceWriteOptions.RECORDKEY_FIELD_OPT_KEY -> "id", + DataSourceWriteOptions.PARTITIONPATH_FIELD_OPT_KEY -> "dt", + DataSourceWriteOptions.PRECOMBINE_FIELD_OPT_KEY -> "ts", + DataSourceWriteOptions.PAYLOAD_CLASS_OPT_KEY -> classOf[OverwriteNonDefaultsWithLatestAvroPayload].getName, + DataSourceWriteOptions.KEYGENERATOR_CLASS_OPT_KEY -> "org.apache.hudi.keygen.SimpleKeyGenerator") + val fooTableParams = HoodieWriterUtils.parametersWithWriteDefaults(fooTableModifier) + + val data = List( + """{"id" : 1, "name": "Jack", "age" : 10, "ts" : 1, "dt" : "20191212"}""", + """{"id" : 2, "name": "Tom", "age" : 11, "ts" : 1, "dt" : "20191213"}""", + """{"id" : 3, "name": "Bill", "age" : 12, "ts" : 1, "dt" : "20191212"}""") + val df = spark.read.json(sc.parallelize(data, 2)) + + // write to Hudi + HoodieSparkSqlWriter.write(sqlContext, SaveMode.Append, fooTableParams, df) + + val update = List( + """{"id" : 1, "age" : 22, "ts" : 2, "dt" : "20191212"}""") + val dfUpdate = spark.read.json(sc.parallelize(update, 2)) + HoodieSparkSqlWriter.write(sqlContext, SaveMode.Append, fooTableParams, dfUpdate) + + val dfSaved = spark.read.format("org.apache.hudi").load(path.toAbsolutePath.toAbsolutePath + "/*") + assert(1 == dfSaved.filter("name = 'Jack' and age = 22").count()) + + //not upset + var e = intercept[SparkException](HoodieSparkSqlWriter.write(sqlContext, SaveMode.Append, + fooTableParams ++ Map(DataSourceWriteOptions.OPERATION_OPT_KEY->DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL), dfUpdate)) + assert(e.getMessage.contains("Parquet/Avro schema mismatch: Avro field 'name' not found")) + //not required payload + e = intercept[SparkException](HoodieSparkSqlWriter.write(sqlContext, SaveMode.Append, + fooTableParams ++ Map(DataSourceWriteOptions.PAYLOAD_CLASS_OPT_KEY->DataSourceWriteOptions.DEFAULT_PAYLOAD_OPT_VAL), dfUpdate)) + assert(e.getMessage.contains("Parquet/Avro schema mismatch: Avro field 'name' not found")) + //key missing + val update1 = List( + """{"age" : 22, "ts" : 2, "dt" : "20191212"}""") + val dfUpdate1 = spark.read.json(sc.parallelize(update1, 2)) + e = intercept[SparkException](HoodieSparkSqlWriter.write(sqlContext, SaveMode.Append, fooTableParams, dfUpdate1)) + assert(e.getMessage.contains("\"id\" cannot be null or empty")) + } + finally { + spark.stop() + FileUtils.deleteDirectory(path.toFile) + } + } + List(DataSourceWriteOptions.COW_TABLE_TYPE_OPT_VAL, DataSourceWriteOptions.MOR_TABLE_TYPE_OPT_VAL) .foreach(tableType => { test("test HoodieSparkSqlWriter functionality with datasource bootstrap for " + tableType) {