diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/DataSourceOptions.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/DataSourceOptions.scala index 5151fe93db5b8..965b35c7d1649 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/DataSourceOptions.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/DataSourceOptions.scala @@ -23,9 +23,11 @@ import org.apache.hudi.common.model.WriteOperationType import org.apache.hudi.config.HoodieWriteConfig import org.apache.hudi.hive.HiveSyncTool import org.apache.hudi.hive.SlashEncodedDayPartitionValueExtractor -import org.apache.hudi.keygen.SimpleKeyGenerator +import org.apache.hudi.keygen.TimestampBasedAvroKeyGenerator.Config +import org.apache.hudi.keygen.{CustomKeyGenerator, SimpleKeyGenerator} import org.apache.hudi.keygen.constant.KeyGeneratorOptions import org.apache.log4j.LogManager +import org.apache.spark.sql.execution.datasources.{DataSourceUtils => SparkDataSourceUtils} /** * List of options that can be passed to the Hoodie datasource, @@ -192,6 +194,42 @@ object DataSourceWriteOptions { } } + /** + * Translate spark parameters to hudi parameters + * + * @param optParams Parameters to be translated + * @return Parameters after translation + */ + def translateSqlOptions(optParams: Map[String, String]): Map[String, String] = { + var translatedOptParams = optParams + // translate the api partitionBy of spark DataFrameWriter to PARTITIONPATH_FIELD_OPT_KEY + if (optParams.contains(SparkDataSourceUtils.PARTITIONING_COLUMNS_KEY)) { + val partitionColumns = optParams.get(SparkDataSourceUtils.PARTITIONING_COLUMNS_KEY) + .map(SparkDataSourceUtils.decodePartitioningColumns) + .getOrElse(Nil) + val keyGeneratorClass = optParams.getOrElse(DataSourceWriteOptions.KEYGENERATOR_CLASS_OPT_KEY, + DataSourceWriteOptions.DEFAULT_KEYGENERATOR_CLASS_OPT_VAL) + + val partitionPathField = + keyGeneratorClass match { + // Only CustomKeyGenerator needs special treatment, because it needs to be specified in a way + // such as "field1:PartitionKeyType1,field2:PartitionKeyType2". + // partitionBy can specify the partition like this: partitionBy("p1", "p2:SIMPLE", "p3:TIMESTAMP") + case c if c == classOf[CustomKeyGenerator].getName => + partitionColumns.map(e => { + if (e.contains(":")) { + e + } else { + s"$e:SIMPLE" + } + }).mkString(",") + case _ => + partitionColumns.mkString(",") + } + translatedOptParams = optParams ++ Map(PARTITIONPATH_FIELD_OPT_KEY -> partitionPathField) + } + translatedOptParams + } /** * Hive table name, to register the table into. diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/DefaultSource.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/DefaultSource.scala index d26390d635a6e..27f922e362575 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/DefaultSource.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/DefaultSource.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.streaming.Sink import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} +import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode, SparkSession} import scala.collection.JavaConverters._ @@ -46,6 +46,14 @@ class DefaultSource extends RelationProvider with StreamSinkProvider with Serializable { + SparkSession.getActiveSession.foreach { spark => + val sparkVersion = spark.version + if (sparkVersion.startsWith("0.") || sparkVersion.startsWith("1.") || sparkVersion.startsWith("2.")) { + // Enable "passPartitionByAsOptions" to support "write.partitionBy(...)" + spark.conf.set("spark.sql.legacy.sources.write.passPartitionByAsOptions", "true") + } + } + private val log = LogManager.getLogger(classOf[DefaultSource]) override def createRelation(sqlContext: SQLContext, @@ -126,12 +134,13 @@ class DefaultSource extends RelationProvider optParams: Map[String, String], df: DataFrame): BaseRelation = { val parameters = HoodieWriterUtils.parametersWithWriteDefaults(optParams) + val translatedOptions = DataSourceWriteOptions.translateSqlOptions(parameters) val dfWithoutMetaCols = df.drop(HoodieRecord.HOODIE_META_COLUMNS.asScala:_*) - if (parameters(OPERATION_OPT_KEY).equals(BOOTSTRAP_OPERATION_OPT_VAL)) { - HoodieSparkSqlWriter.bootstrap(sqlContext, mode, parameters, dfWithoutMetaCols) + if (translatedOptions(OPERATION_OPT_KEY).equals(BOOTSTRAP_OPERATION_OPT_VAL)) { + HoodieSparkSqlWriter.bootstrap(sqlContext, mode, translatedOptions, dfWithoutMetaCols) } else { - HoodieSparkSqlWriter.write(sqlContext, mode, parameters, dfWithoutMetaCols) + HoodieSparkSqlWriter.write(sqlContext, mode, translatedOptions, dfWithoutMetaCols) } new HoodieEmptyRelation(sqlContext, dfWithoutMetaCols.schema) } @@ -141,9 +150,10 @@ class DefaultSource extends RelationProvider partitionColumns: Seq[String], outputMode: OutputMode): Sink = { val parameters = HoodieWriterUtils.parametersWithWriteDefaults(optParams) + val translatedOptions = DataSourceWriteOptions.translateSqlOptions(parameters) new HoodieStreamingSink( sqlContext, - parameters, + translatedOptions, partitionColumns, outputMode) } diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala index b15a7d470a6cf..7203b39e3d83d 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala @@ -25,12 +25,16 @@ import org.apache.hudi.common.table.timeline.HoodieInstant import org.apache.hudi.common.testutils.HoodieTestDataGenerator import org.apache.hudi.common.testutils.RawTripTestPayload.recordsToStrings import org.apache.hudi.config.HoodieWriteConfig +import org.apache.hudi.keygen._ +import org.apache.hudi.keygen.TimestampBasedAvroKeyGenerator.Config import org.apache.hudi.testutils.HoodieClientTestBase import org.apache.hudi.{DataSourceReadOptions, DataSourceWriteOptions, HoodieDataSourceHelpers} import org.apache.spark.sql._ -import org.apache.spark.sql.functions.{col, lit} -import org.apache.spark.sql.types.{DataTypes, DateType, IntegerType, StringType, StructField, StructType, TimestampType} -import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue} +import org.apache.spark.sql.functions.{col, concat, lit, udf} +import org.apache.spark.sql.types._ +import org.joda.time.DateTime +import org.joda.time.format.DateTimeFormat +import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue, fail} import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.ValueSource @@ -348,4 +352,151 @@ class TestCOWDataSource extends HoodieClientTestBase { assertTrue(HoodieDataSourceHelpers.hasNewCommits(fs, basePath, "000")) } + + private def getDataFrameWriter(keyGenerator: String): DataFrameWriter[Row] = { + val records = recordsToStrings(dataGen.generateInserts("000", 100)).toList + val inputDF = spark.read.json(spark.sparkContext.parallelize(records, 2)) + + inputDF.write.format("hudi") + .options(commonOpts) + .option(DataSourceWriteOptions.KEYGENERATOR_CLASS_OPT_KEY, keyGenerator) + .mode(SaveMode.Overwrite) + } + + @Test def testSparkPartitonByWithCustomKeyGenerator(): Unit = { + // Without fieldType, the default is SIMPLE + var writer = getDataFrameWriter(classOf[CustomKeyGenerator].getName) + writer.partitionBy("current_ts") + .save(basePath) + + var recordsReadDF = spark.read.format("org.apache.hudi") + .load(basePath + "/*/*") + + assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= col("current_ts").cast("string")).count() == 0) + + // Specify fieldType as TIMESTAMP + writer = getDataFrameWriter(classOf[CustomKeyGenerator].getName) + writer.partitionBy("current_ts:TIMESTAMP") + .option(Config.TIMESTAMP_TYPE_FIELD_PROP, "EPOCHMILLISECONDS") + .option(Config.TIMESTAMP_OUTPUT_DATE_FORMAT_PROP, "yyyyMMdd") + .save(basePath) + + recordsReadDF = spark.read.format("org.apache.hudi") + .load(basePath + "/*/*") + + val udf_date_format = udf((data: Long) => new DateTime(data).toString(DateTimeFormat.forPattern("yyyyMMdd"))) + assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= udf_date_format(col("current_ts"))).count() == 0) + + // Mixed fieldType + writer = getDataFrameWriter(classOf[CustomKeyGenerator].getName) + writer.partitionBy("driver", "rider:SIMPLE", "current_ts:TIMESTAMP") + .option(Config.TIMESTAMP_TYPE_FIELD_PROP, "EPOCHMILLISECONDS") + .option(Config.TIMESTAMP_OUTPUT_DATE_FORMAT_PROP, "yyyyMMdd") + .save(basePath) + + recordsReadDF = spark.read.format("org.apache.hudi") + .load(basePath + "/*/*/*") + assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= + concat(col("driver"), lit("/"), col("rider"), lit("/"), udf_date_format(col("current_ts")))).count() == 0) + + // Test invalid partitionKeyType + writer = getDataFrameWriter(classOf[CustomKeyGenerator].getName) + writer = writer.partitionBy("current_ts:DUMMY") + .option(Config.TIMESTAMP_TYPE_FIELD_PROP, "EPOCHMILLISECONDS") + .option(Config.TIMESTAMP_OUTPUT_DATE_FORMAT_PROP, "yyyyMMdd") + try { + writer.save(basePath) + fail("should fail when invalid PartitionKeyType is provided!") + } catch { + case e: Exception => + assertTrue(e.getMessage.contains("No enum constant org.apache.hudi.keygen.CustomAvroKeyGenerator.PartitionKeyType.DUMMY")) + } + } + + @Test def testSparkPartitonByWithSimpleKeyGenerator() { + // Use the `driver` field as the partition key + var writer = getDataFrameWriter(classOf[SimpleKeyGenerator].getName) + writer.partitionBy("driver") + .save(basePath) + + var recordsReadDF = spark.read.format("org.apache.hudi") + .load(basePath + "/*/*") + + assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= col("driver")).count() == 0) + + // Use the `driver,rider` field as the partition key, If no such field exists, the default value `default` is used + writer = getDataFrameWriter(classOf[SimpleKeyGenerator].getName) + writer.partitionBy("driver", "rider") + .save(basePath) + + recordsReadDF = spark.read.format("org.apache.hudi") + .load(basePath + "/*/*") + + assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= lit("default")).count() == 0) + } + + @Test def testSparkPartitonByWithComplexKeyGenerator() { + // Use the `driver` field as the partition key + var writer = getDataFrameWriter(classOf[ComplexKeyGenerator].getName) + writer.partitionBy("driver") + .save(basePath) + + var recordsReadDF = spark.read.format("org.apache.hudi") + .load(basePath + "/*/*") + + assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= col("driver")).count() == 0) + + // Use the `driver`,`rider` field as the partition key + writer = getDataFrameWriter(classOf[ComplexKeyGenerator].getName) + writer.partitionBy("driver", "rider") + .save(basePath) + + recordsReadDF = spark.read.format("org.apache.hudi") + .load(basePath + "/*/*") + + assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= concat(col("driver"), lit("/"), col("rider"))).count() == 0) + } + + @Test def testSparkPartitonByWithTimestampBasedKeyGenerator() { + val writer = getDataFrameWriter(classOf[TimestampBasedKeyGenerator].getName) + writer.partitionBy("current_ts") + .option(Config.TIMESTAMP_TYPE_FIELD_PROP, "EPOCHMILLISECONDS") + .option(Config.TIMESTAMP_OUTPUT_DATE_FORMAT_PROP, "yyyyMMdd") + .save(basePath) + + val recordsReadDF = spark.read.format("org.apache.hudi") + .load(basePath + "/*/*") + val udf_date_format = udf((data: Long) => new DateTime(data).toString(DateTimeFormat.forPattern("yyyyMMdd"))) + assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= udf_date_format(col("current_ts"))).count() == 0) + } + + @Test def testSparkPartitonByWithGlobalDeleteKeyGenerator() { + val writer = getDataFrameWriter(classOf[GlobalDeleteKeyGenerator].getName) + writer.partitionBy("driver") + .save(basePath) + + val recordsReadDF = spark.read.format("org.apache.hudi") + .load(basePath + "/*") + assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= lit("")).count() == 0) + } + + @Test def testSparkPartitonByWithNonpartitionedKeyGenerator() { + // Empty string column + var writer = getDataFrameWriter(classOf[NonpartitionedKeyGenerator].getName) + writer.partitionBy("") + .save(basePath) + + var recordsReadDF = spark.read.format("org.apache.hudi") + .load(basePath + "/*") + assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= lit("")).count() == 0) + + // Non-existent column + writer = getDataFrameWriter(classOf[NonpartitionedKeyGenerator].getName) + writer.partitionBy("abc") + .save(basePath) + + recordsReadDF = spark.read.format("org.apache.hudi") + .load(basePath + "/*") + assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!= lit("")).count() == 0) + } }