diff --git a/hudi-spark/src/main/scala/org/apache/hudi/HoodieSparkSqlWriter.scala b/hudi-spark/src/main/scala/org/apache/hudi/HoodieSparkSqlWriter.scala index 8cad2f3d508f9..c7d5c8782db7b 100644 --- a/hudi-spark/src/main/scala/org/apache/hudi/HoodieSparkSqlWriter.scala +++ b/hudi-spark/src/main/scala/org/apache/hudi/HoodieSparkSqlWriter.scala @@ -221,7 +221,8 @@ private[hudi] object HoodieSparkSqlWriter { mode: SaveMode, parameters: Map[String, String], df: DataFrame, - hoodieTableConfigOpt: Option[HoodieTableConfig] = Option.empty): Boolean = { + hoodieTableConfigOpt: Option[HoodieTableConfig] = Option.empty, + hoodieWriteClient: Option[SparkRDDWriteClient[HoodieRecordPayload[Nothing]]] = Option.empty): Boolean = { val sparkContext = sqlContext.sparkContext val path = parameters.getOrElse("path", throw new HoodieException("'path' must be set.")) @@ -263,8 +264,13 @@ private[hudi] object HoodieSparkSqlWriter { } val jsc = new JavaSparkContext(sqlContext.sparkContext) - val writeClient = DataSourceUtils.createHoodieClient(jsc, schema, path, tableName, mapAsJavaMap(parameters)) - writeClient.bootstrap(org.apache.hudi.common.util.Option.empty()) + val writeClient = hoodieWriteClient.getOrElse(DataSourceUtils.createHoodieClient(jsc, + schema, path, tableName, mapAsJavaMap(parameters))) + try { + writeClient.bootstrap(org.apache.hudi.common.util.Option.empty()) + } finally { + writeClient.close() + } val metaSyncSuccess = metaSync(parameters, basePath, jsc.hadoopConfiguration) metaSyncSuccess } diff --git a/hudi-spark/src/test/scala/org/apache/hudi/functional/HoodieSparkSqlWriterSuite.scala b/hudi-spark/src/test/scala/org/apache/hudi/functional/HoodieSparkSqlWriterSuite.scala index e1659048f6b68..41a45b2b1f951 100644 --- a/hudi-spark/src/test/scala/org/apache/hudi/functional/HoodieSparkSqlWriterSuite.scala +++ b/hudi-spark/src/test/scala/org/apache/hudi/functional/HoodieSparkSqlWriterSuite.scala @@ -17,17 +17,18 @@ package org.apache.hudi.functional +import java.time.Instant import java.util -import java.util.{Date, UUID} +import java.util.{Collections, Date, UUID} import org.apache.commons.io.FileUtils import org.apache.hudi.DataSourceWriteOptions._ -import org.apache.hudi.client.SparkRDDWriteClient +import org.apache.hudi.client.{SparkRDDWriteClient, TestBootstrap} import org.apache.hudi.common.model.{HoodieRecord, HoodieRecordPayload} import org.apache.hudi.common.testutils.HoodieTestDataGenerator -import org.apache.hudi.config.HoodieWriteConfig +import org.apache.hudi.config.{HoodieBootstrapConfig, HoodieWriteConfig} import org.apache.hudi.exception.HoodieException -import org.apache.hudi.keygen.SimpleKeyGenerator +import org.apache.hudi.keygen.{NonpartitionedKeyGenerator, SimpleKeyGenerator} import org.apache.hudi.testutils.DataSourceTestUtils import org.apache.hudi.{AvroConversionUtils, DataSourceUtils, DataSourceWriteOptions, HoodieSparkSqlWriter, HoodieWriterUtils} import org.apache.spark.SparkContext @@ -341,6 +342,61 @@ class HoodieSparkSqlWriterSuite extends FunSuite with Matchers { } }) + List(DataSourceWriteOptions.COW_TABLE_TYPE_OPT_VAL, DataSourceWriteOptions.MOR_TABLE_TYPE_OPT_VAL) + .foreach(tableType => { + test("test HoodieSparkSqlWriter functionality with datasource bootstrap for " + tableType) { + initSparkContext("test_bootstrap_datasource") + val path = java.nio.file.Files.createTempDirectory("hoodie_test_path") + val srcPath = java.nio.file.Files.createTempDirectory("hoodie_bootstrap_source_path") + + try { + + val hoodieFooTableName = "hoodie_foo_tbl" + + val sourceDF = TestBootstrap.generateTestRawTripDataset(Instant.now.toEpochMilli, 0, 100, Collections.emptyList(), sc, + spark.sqlContext) + + // Write source data non-partitioned + sourceDF.write + .format("parquet") + .mode(SaveMode.Overwrite) + .save(srcPath.toAbsolutePath.toString) + + val fooTableModifier = Map("path" -> path.toAbsolutePath.toString, + HoodieBootstrapConfig.BOOTSTRAP_BASE_PATH_PROP -> srcPath.toAbsolutePath.toString, + HoodieWriteConfig.TABLE_NAME -> hoodieFooTableName, + DataSourceWriteOptions.TABLE_TYPE_OPT_KEY -> tableType, + HoodieBootstrapConfig.BOOTSTRAP_PARALLELISM -> "4", + DataSourceWriteOptions.OPERATION_OPT_KEY -> DataSourceWriteOptions.BOOTSTRAP_OPERATION_OPT_VAL, + DataSourceWriteOptions.RECORDKEY_FIELD_OPT_KEY -> "_row_key", + DataSourceWriteOptions.PARTITIONPATH_FIELD_OPT_KEY -> "partition", + HoodieBootstrapConfig.BOOTSTRAP_KEYGEN_CLASS -> classOf[NonpartitionedKeyGenerator].getCanonicalName) + val fooTableParams = HoodieWriterUtils.parametersWithWriteDefaults(fooTableModifier) + + val client = spy(DataSourceUtils.createHoodieClient( + new JavaSparkContext(sc), + null, + path.toAbsolutePath.toString, + hoodieFooTableName, + mapAsJavaMap(fooTableParams)).asInstanceOf[SparkRDDWriteClient[HoodieRecordPayload[Nothing]]]) + + HoodieSparkSqlWriter.bootstrap(sqlContext, SaveMode.Append, fooTableParams, spark.emptyDataFrame, Option.empty, + Option(client)) + + // Verify that HoodieWriteClient is closed correctly + verify(client, times(1)).close() + + // fetch all records from parquet files generated from write to hudi + val actualDf = sqlContext.read.parquet(path.toAbsolutePath.toString) + assert(actualDf.count == 100) + } finally { + spark.stop() + FileUtils.deleteDirectory(path.toFile) + FileUtils.deleteDirectory(srcPath.toFile) + } + } + }) + case class Test(uuid: String, ts: Long) import scala.collection.JavaConverters