diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala index b15a7d470a6cf..3574ee4bc11d2 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala @@ -198,6 +198,31 @@ class TestCOWDataSource extends HoodieClientTestBase { .mode(SaveMode.Append) .save(basePath) + val records2 = recordsToStrings(dataGen.generateInserts("002", 5)).toList + val inputDF2 = spark.read.json(spark.sparkContext.parallelize(records2, 2)) + inputDF2.write.format("org.apache.hudi") + .options(commonOpts) + .option(DataSourceWriteOptions.OPERATION_OPT_KEY, DataSourceWriteOptions.INSERT_OVERWRITE_OPERATION_OPT_VAL) + .mode(SaveMode.Append) + .save(basePath) + + val metaClient = new HoodieTableMetaClient(spark.sparkContext.hadoopConfiguration, basePath, true) + val commits = metaClient.getActiveTimeline.filterCompletedInstants().getInstants.toArray + .map(instant => (instant.asInstanceOf[HoodieInstant]).getAction) + assertEquals(2, commits.size) + assertEquals("commit", commits(0)) + assertEquals("replacecommit", commits(1)) + } + + @Test def testOverWriteTableModeUseReplaceAction(): Unit = { + val records1 = recordsToStrings(dataGen.generateInserts("001", 5)).toList + val inputDF1 = spark.read.json(spark.sparkContext.parallelize(records1, 2)) + inputDF1.write.format("org.apache.hudi") + .options(commonOpts) + .option(DataSourceWriteOptions.OPERATION_OPT_KEY, DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL) + .mode(SaveMode.Append) + .save(basePath) + val records2 = recordsToStrings(dataGen.generateInserts("002", 5)).toList val inputDF2 = spark.read.json(spark.sparkContext.parallelize(records2, 2)) inputDF2.write.format("org.apache.hudi") @@ -207,7 +232,7 @@ class TestCOWDataSource extends HoodieClientTestBase { .save(basePath) val metaClient = new HoodieTableMetaClient(spark.sparkContext.hadoopConfiguration, basePath, true) - val commits = metaClient.getActiveTimeline.filterCompletedInstants().getInstants.toArray + val commits = metaClient.getActiveTimeline.filterCompletedInstants().getInstants.toArray .map(instant => (instant.asInstanceOf[HoodieInstant]).getAction) assertEquals(2, commits.size) assertEquals("commit", commits(0)) @@ -224,7 +249,62 @@ class TestCOWDataSource extends HoodieClientTestBase { .mode(SaveMode.Append) .save(basePath) - // step2: Write 7 more rectestOverWriteModeUseReplaceActionords using SaveMode.Overwrite for partition2 DEFAULT_SECOND_PARTITION_PATH + // step2: Write 7 records to hoodie table for partition2 DEFAULT_SECOND_PARTITION_PATH + val records2 = recordsToStrings(dataGen.generateInsertsForPartition("002", 7, HoodieTestDataGenerator.DEFAULT_SECOND_PARTITION_PATH)).toList + val inputDF2 = spark.read.json(spark.sparkContext.parallelize(records2, 2)) + inputDF2.write.format("org.apache.hudi") + .options(commonOpts) + .option(DataSourceWriteOptions.OPERATION_OPT_KEY, DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL) + .mode(SaveMode.Append) + .save(basePath) + + // step3: Write 6 records to hoodie table for partition1 DEFAULT_FIRST_PARTITION_PATH using INSERT_OVERWRITE_OPERATION_OPT_VAL + val records3 = recordsToStrings(dataGen.generateInsertsForPartition("001", 6, HoodieTestDataGenerator.DEFAULT_FIRST_PARTITION_PATH)).toList + val inputDF3 = spark.read.json(spark.sparkContext.parallelize(records3, 2)) + inputDF3.write.format("org.apache.hudi") + .options(commonOpts) + .option(DataSourceWriteOptions.OPERATION_OPT_KEY, DataSourceWriteOptions.INSERT_OVERWRITE_OPERATION_OPT_VAL) + .mode(SaveMode.Append) + .save(basePath) + + val allRecords = spark.read.format("org.apache.hudi").load(basePath + "/*/*/*") + allRecords.registerTempTable("tmpTable") + + spark.sql(String.format("select count(*) from tmpTable")).show() + + // step4: Query the rows count from hoodie table for partition1 DEFAULT_FIRST_PARTITION_PATH + val recordCountForParititon1 = spark.sql(String.format("select count(*) from tmpTable where partition = '%s'", HoodieTestDataGenerator.DEFAULT_FIRST_PARTITION_PATH)).collect() + assertEquals("6", recordCountForParititon1(0).get(0).toString) + + // step5: Query the rows count from hoodie table for partition2 DEFAULT_SECOND_PARTITION_PATH + val recordCountForParititon2 = spark.sql(String.format("select count(*) from tmpTable where partition = '%s'", HoodieTestDataGenerator.DEFAULT_SECOND_PARTITION_PATH)).collect() + assertEquals("7", recordCountForParititon2(0).get(0).toString) + + // step6: Query the rows count from hoodie table for partition2 DEFAULT_SECOND_PARTITION_PATH using spark.collect and then filter mode + val recordsForPartitionColumn = spark.sql(String.format("select partition from tmpTable")).collect() + val filterSecondPartitionCount = recordsForPartitionColumn.filter(row => row.get(0).equals(HoodieTestDataGenerator.DEFAULT_SECOND_PARTITION_PATH)).size + assertEquals(7, filterSecondPartitionCount) + + val metaClient = new HoodieTableMetaClient(spark.sparkContext.hadoopConfiguration, basePath, true) + val commits = metaClient.getActiveTimeline.filterCompletedInstants().getInstants.toArray + .map(instant => instant.asInstanceOf[HoodieInstant].getAction) + assertEquals(3, commits.size) + assertEquals("commit", commits(0)) + assertEquals("commit", commits(1)) + assertEquals("replacecommit", commits(2)) + } + + @Test def testOverWriteTableModeUseReplaceActionOnDisJointPartitions(): Unit = { + // step1: Write 5 records to hoodie table for partition1 DEFAULT_FIRST_PARTITION_PATH + val records1 = recordsToStrings(dataGen.generateInsertsForPartition("001", 5, HoodieTestDataGenerator.DEFAULT_FIRST_PARTITION_PATH)).toList + val inputDF1 = spark.read.json(spark.sparkContext.parallelize(records1, 2)) + inputDF1.write.format("org.apache.hudi") + .options(commonOpts) + .option(DataSourceWriteOptions.OPERATION_OPT_KEY, DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL) + .mode(SaveMode.Append) + .save(basePath) + + // step2: Write 7 more records using SaveMode.Overwrite for partition2 DEFAULT_SECOND_PARTITION_PATH val records2 = recordsToStrings(dataGen.generateInsertsForPartition("002", 7, HoodieTestDataGenerator.DEFAULT_SECOND_PARTITION_PATH)).toList val inputDF2 = spark.read.json(spark.sparkContext.parallelize(records2, 2)) inputDF2.write.format("org.apache.hudi") @@ -233,30 +313,30 @@ class TestCOWDataSource extends HoodieClientTestBase { .mode(SaveMode.Overwrite) .save(basePath) - val allRecords = spark.read.format("org.apache.hudi").load(basePath + "/*/*/*") + val allRecords = spark.read.format("org.apache.hudi").load(basePath + "/*/*/*") allRecords.registerTempTable("tmpTable") spark.sql(String.format("select count(*) from tmpTable")).show() // step3: Query the rows count from hoodie table for partition1 DEFAULT_FIRST_PARTITION_PATH - val recordCountForParititon1 = spark.sql(String.format("select count(*) from tmpTable where partition = '%s'", HoodieTestDataGenerator.DEFAULT_FIRST_PARTITION_PATH)).collect() + val recordCountForParititon1 = spark.sql(String.format("select count(*) from tmpTable where partition = '%s'", HoodieTestDataGenerator.DEFAULT_FIRST_PARTITION_PATH)).collect() assertEquals("0", recordCountForParititon1(0).get(0).toString) - // step4: Query the rows count from hoodie table for partition1 DEFAULT_SECOND_PARTITION_PATH + // step4: Query the rows count from hoodie table for partition2 DEFAULT_SECOND_PARTITION_PATH val recordCountForParititon2 = spark.sql(String.format("select count(*) from tmpTable where partition = '%s'", HoodieTestDataGenerator.DEFAULT_SECOND_PARTITION_PATH)).collect() assertEquals("7", recordCountForParititon2(0).get(0).toString) // step5: Query the rows count from hoodie table val recordCount = spark.sql(String.format("select count(*) from tmpTable")).collect() - assertEquals("7", recordCountForParititon2(0).get(0).toString) + assertEquals("7", recordCount(0).get(0).toString) - // step6: Query the rows count from hoodie table for partition1 DEFAULT_SECOND_PARTITION_PATH using spark.collect and then filter mode + // step6: Query the rows count from hoodie table for partition2 DEFAULT_SECOND_PARTITION_PATH using spark.collect and then filter mode val recordsForPartitionColumn = spark.sql(String.format("select partition from tmpTable")).collect() - val filterSecondPartitionCount = recordsForPartitionColumn.filter(row => row.get(0).equals(HoodieTestDataGenerator.DEFAULT_SECOND_PARTITION_PATH)).size - assertEquals(7,filterSecondPartitionCount) + val filterSecondPartitionCount = recordsForPartitionColumn.filter(row => row.get(0).equals(HoodieTestDataGenerator.DEFAULT_SECOND_PARTITION_PATH)).size + assertEquals(7, filterSecondPartitionCount) val metaClient = new HoodieTableMetaClient(spark.sparkContext.hadoopConfiguration, basePath, true) - val commits = metaClient.getActiveTimeline.filterCompletedInstants().getInstants.toArray + val commits = metaClient.getActiveTimeline.filterCompletedInstants().getInstants.toArray .map(instant => instant.asInstanceOf[HoodieInstant].getAction) assertEquals(2, commits.size) assertEquals("commit", commits(0))