Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ import org.apache.hudi.avro.HoodieAvroUtils
import org.apache.hudi.client.HoodieWriteResult
import org.apache.hudi.client.SparkRDDWriteClient
import org.apache.hudi.common.config.TypedProperties
import org.apache.hudi.common.model.{HoodieRecordPayload, HoodieTableType, WriteOperationType}
import org.apache.hudi.common.table.{HoodieTableConfig, HoodieTableMetaClient}
import org.apache.hudi.common.model.{HoodieRecordPayload, HoodieTableType, OverwriteNonDefaultsWithLatestAvroPayload, WriteOperationType}
import org.apache.hudi.common.table.{HoodieTableConfig, HoodieTableMetaClient, TableSchemaResolver}
import org.apache.hudi.common.table.timeline.HoodieActiveTimeline
import org.apache.hudi.common.util.ReflectionUtils
import org.apache.hudi.config.HoodieBootstrapConfig.{BOOTSTRAP_BASE_PATH_PROP, BOOTSTRAP_INDEX_CLASS_PROP, DEFAULT_BOOTSTRAP_INDEX_CLASS}
Expand Down Expand Up @@ -107,13 +107,44 @@ private[hudi] object HoodieSparkSqlWriter {
// Handle various save modes
handleSaveModes(mode, basePath, tableConfig, tblName, operation, fs)
// Create the table if not present
if (!tableExists) {
val dfFull = if (!tableExists) {
val archiveLogFolder = parameters.getOrElse(
HoodieTableConfig.HOODIE_ARCHIVELOG_FOLDER_PROP_NAME, "archived")
val tableMetaClient = HoodieTableMetaClient.initTableType(sparkContext.hadoopConfiguration, path.get,
tableType, tblName, archiveLogFolder, parameters(PAYLOAD_CLASS_OPT_KEY),
null.asInstanceOf[String])
tableConfig = tableMetaClient.getTableConfig
df
} else {
val tableMetaClient = new HoodieTableMetaClient(sparkContext.hadoopConfiguration, path.get)
val tableSchemaResolver = new TableSchemaResolver(tableMetaClient)
val oldSchema = tableSchemaResolver.getTableAvroSchemaWithoutMetadataFields
val oldStructType = AvroConversionUtils.convertAvroSchemaToStructType(oldSchema)
val dfFields = df.schema.fields.map(sf => sf.name).toList
val missingField = oldStructType.fields.exists(f => !dfFields.contains(f.name))

val recordKeyFields = parameters(RECORDKEY_FIELD_OPT_KEY).split(",").map{ f => f.trim }.filter { p => !p.isEmpty }.toList
val partitionPathFields = parameters(PARTITIONPATH_FIELD_OPT_KEY).split(",").map{ f => f.trim }.filter { p => !p.isEmpty }.toList
val precombineField = parameters(PRECOMBINE_FIELD_OPT_KEY).trim
val keyFields = recordKeyFields ++ partitionPathFields :+ precombineField
val allKeysExist = dfFields.containsAll(keyFields)

val isUpsert = UPSERT_OPERATION_OPT_VAL.equals(parameters(OPERATION_OPT_KEY))
val isRequiredPayload = classOf[OverwriteNonDefaultsWithLatestAvroPayload].getName.equals(parameters(PAYLOAD_CLASS_OPT_KEY))
if (isUpsert && isRequiredPayload && allKeysExist && missingField) {
//missing normal fields except key
val selectExprs = oldStructType.fields.map(f => {
if (dfFields.contains(f.name))
f.name
else
s"cast(${oldSchema.getField(f.name).defaultVal()} as ${f.dataType.typeName}) as ${f.name}"
}).toList

df.selectExpr(selectExprs:_*)
} else {
//missing key fields fallback to original logic
df
}
}

val commitActionType = DataSourceUtils.getCommitActionType(operation, tableConfig.getTableType)
Expand All @@ -122,7 +153,7 @@ private[hudi] object HoodieSparkSqlWriter {
// scalastyle:off
if (parameters(ENABLE_ROW_WRITER_OPT_KEY).toBoolean &&
operation == WriteOperationType.BULK_INSERT) {
val (success, commitTime: common.util.Option[String]) = bulkInsertAsRow(sqlContext, parameters, df, tblName,
val (success, commitTime: common.util.Option[String]) = bulkInsertAsRow(sqlContext, parameters, dfFull, tblName,
basePath, path, instantTime)
return (success, commitTime, common.util.Option.empty(), hoodieWriteClient.orNull, tableConfig)
}
Expand All @@ -135,13 +166,13 @@ private[hudi] object HoodieSparkSqlWriter {
sparkContext.getConf.registerKryoClasses(
Array(classOf[org.apache.avro.generic.GenericData],
classOf[org.apache.avro.Schema]))
val schema = AvroConversionUtils.convertStructTypeToAvroSchema(df.schema, structName, nameSpace)
val schema = AvroConversionUtils.convertStructTypeToAvroSchema(dfFull.schema, structName, nameSpace)
sparkContext.getConf.registerAvroSchemas(schema)
log.info(s"Registered avro schema : ${schema.toString(true)}")

// Convert to RDD[HoodieRecord]
val keyGenerator = DataSourceUtils.createKeyGenerator(toProperties(parameters))
val genericRecords: RDD[GenericRecord] = AvroConversionUtils.createRdd(df, schema, structName, nameSpace)
val genericRecords: RDD[GenericRecord] = AvroConversionUtils.createRdd(dfFull, schema, structName, nameSpace)
val shouldCombine = parameters(INSERT_DROP_DUPS_OPT_KEY).toBoolean || operation.equals(WriteOperationType.UPSERT);
val hoodieAllIncomingRecords = genericRecords.map(gr => {
val hoodieRecord = if (shouldCombine) {
Expand Down Expand Up @@ -188,7 +219,7 @@ private[hudi] object HoodieSparkSqlWriter {

// Convert to RDD[HoodieKey]
val keyGenerator = DataSourceUtils.createKeyGenerator(toProperties(parameters))
val genericRecords: RDD[GenericRecord] = AvroConversionUtils.createRdd(df, structName, nameSpace)
val genericRecords: RDD[GenericRecord] = AvroConversionUtils.createRdd(dfFull, structName, nameSpace)
val hoodieKeysToDelete = genericRecords.map(gr => keyGenerator.getKey(gr)).toJavaRDD()

if (!tableExists) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ import java.util.{Collections, Date, UUID}
import org.apache.commons.io.FileUtils
import org.apache.hudi.DataSourceWriteOptions._
import org.apache.hudi.client.{SparkRDDWriteClient, TestBootstrap}
import org.apache.hudi.common.model.{HoodieRecord, HoodieRecordPayload}
import org.apache.hudi.common.model.{HoodieRecord, HoodieRecordPayload, OverwriteNonDefaultsWithLatestAvroPayload}
import org.apache.hudi.common.testutils.HoodieTestDataGenerator
import org.apache.hudi.config.{HoodieBootstrapConfig, HoodieWriteConfig}
import org.apache.hudi.exception.HoodieException
import org.apache.hudi.keygen.{NonpartitionedKeyGenerator, SimpleKeyGenerator}
import org.apache.hudi.testutils.DataSourceTestUtils
import org.apache.hudi.{AvroConversionUtils, DataSourceUtils, DataSourceWriteOptions, HoodieSparkSqlWriter, HoodieWriterUtils}
import org.apache.spark.SparkContext
import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.sql.{Row, SQLContext, SaveMode, SparkSession}
import org.mockito.ArgumentMatchers.any
Expand Down Expand Up @@ -342,6 +342,65 @@ class HoodieSparkSqlWriterSuite extends FunSuite with Matchers {
}
})

test("test upsert dataset with specified columns") {
initSparkContext("test_upsert_with_specified_columns")
val path = java.nio.file.Files.createTempDirectory("hoodie_test_path")
try {

val sqlContext = spark.sqlContext
val hoodieFooTableName = "hoodie_foo_tbl"
val sc = spark.sparkContext

//create a new table
val fooTableModifier = Map("path" -> path.toAbsolutePath.toString,
HoodieWriteConfig.TABLE_NAME -> hoodieFooTableName,
"hoodie.upsert.shuffle.parallelism" -> "4",
DataSourceWriteOptions.OPERATION_OPT_KEY -> DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL,
DataSourceWriteOptions.RECORDKEY_FIELD_OPT_KEY -> "id",
DataSourceWriteOptions.PARTITIONPATH_FIELD_OPT_KEY -> "dt",
DataSourceWriteOptions.PRECOMBINE_FIELD_OPT_KEY -> "ts",
DataSourceWriteOptions.PAYLOAD_CLASS_OPT_KEY -> classOf[OverwriteNonDefaultsWithLatestAvroPayload].getName,
DataSourceWriteOptions.KEYGENERATOR_CLASS_OPT_KEY -> "org.apache.hudi.keygen.SimpleKeyGenerator")
val fooTableParams = HoodieWriterUtils.parametersWithWriteDefaults(fooTableModifier)

val data = List(
"""{"id" : 1, "name": "Jack", "age" : 10, "ts" : 1, "dt" : "20191212"}""",
"""{"id" : 2, "name": "Tom", "age" : 11, "ts" : 1, "dt" : "20191213"}""",
"""{"id" : 3, "name": "Bill", "age" : 12, "ts" : 1, "dt" : "20191212"}""")
val df = spark.read.json(sc.parallelize(data, 2))

// write to Hudi
HoodieSparkSqlWriter.write(sqlContext, SaveMode.Append, fooTableParams, df)

val update = List(
"""{"id" : 1, "age" : 22, "ts" : 2, "dt" : "20191212"}""")
val dfUpdate = spark.read.json(sc.parallelize(update, 2))
HoodieSparkSqlWriter.write(sqlContext, SaveMode.Append, fooTableParams, dfUpdate)

val dfSaved = spark.read.format("org.apache.hudi").load(path.toAbsolutePath.toAbsolutePath + "/*")
assert(1 == dfSaved.filter("name = 'Jack' and age = 22").count())

//not upset
var e = intercept[SparkException](HoodieSparkSqlWriter.write(sqlContext, SaveMode.Append,
fooTableParams ++ Map(DataSourceWriteOptions.OPERATION_OPT_KEY->DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL), dfUpdate))
assert(e.getMessage.contains("Parquet/Avro schema mismatch: Avro field 'name' not found"))
//not required payload
e = intercept[SparkException](HoodieSparkSqlWriter.write(sqlContext, SaveMode.Append,
fooTableParams ++ Map(DataSourceWriteOptions.PAYLOAD_CLASS_OPT_KEY->DataSourceWriteOptions.DEFAULT_PAYLOAD_OPT_VAL), dfUpdate))
assert(e.getMessage.contains("Parquet/Avro schema mismatch: Avro field 'name' not found"))
//key missing
val update1 = List(
"""{"age" : 22, "ts" : 2, "dt" : "20191212"}""")
val dfUpdate1 = spark.read.json(sc.parallelize(update1, 2))
e = intercept[SparkException](HoodieSparkSqlWriter.write(sqlContext, SaveMode.Append, fooTableParams, dfUpdate1))
assert(e.getMessage.contains("\"id\" cannot be null or empty"))
}
finally {
spark.stop()
FileUtils.deleteDirectory(path.toFile)
}
}

List(DataSourceWriteOptions.COW_TABLE_TYPE_OPT_VAL, DataSourceWriteOptions.MOR_TABLE_TYPE_OPT_VAL)
.foreach(tableType => {
test("test HoodieSparkSqlWriter functionality with datasource bootstrap for " + tableType) {
Expand Down