diff --git a/integration_tests/src/main/python/orc_write_test.py b/integration_tests/src/main/python/orc_write_test.py index db36f630981..51dabdcf01c 100644 --- a/integration_tests/src/main/python/orc_write_test.py +++ b/integration_tests/src/main/python/orc_write_test.py @@ -19,6 +19,7 @@ from datetime import date, datetime, timezone from data_gen import * from marks import * +from pyspark.sql.functions import lit from pyspark.sql.types import * pytestmark = pytest.mark.nightly_resource_consuming_test @@ -90,6 +91,31 @@ def test_part_write_round_trip(spark_tmp_path, orc_gen): data_path, conf = {'spark.rapids.sql.format.orc.write.enabled': True}) + +@ignore_order(local=True) +@pytest.mark.parametrize('orc_gen', [int_gen], ids=idfn) +@pytest.mark.parametrize('orc_impl', ["native", "hive"]) +def test_dynamic_partition_write_round_trip(spark_tmp_path, orc_gen, orc_impl): + gen_list = [('_c0', orc_gen)] + data_path = spark_tmp_path + '/ORC_DATA' + def do_writes(spark, path): + df = gen_df(spark, gen_list).withColumn("my_partition", lit("PART")) + # first write finds no partitions, it skips the dynamic partition + # overwrite code + df.write.mode("overwrite").partitionBy("my_partition").orc(path) + # second write actually triggers dynamic partition overwrite + df.write.mode("overwrite").partitionBy("my_partition").orc(path) + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: do_writes(spark, path), + lambda spark, path: spark.read.orc(path), + data_path, + conf={ + 'spark.sql.orc.impl': orc_impl, + 'spark.rapids.sql.format.orc.write.enabled': True, + 'spark.sql.sources.partitionOverwriteMode': 'DYNAMIC' + }) + + orc_write_compress_options = ['none', 'uncompressed', 'snappy'] # zstd is available in spark 3.2.0 and later. if not is_before_spark_320() and not is_spark_cdh(): diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuInsertIntoHadoopFsRelationCommand.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuInsertIntoHadoopFsRelationCommand.scala index d2124a9f4ad..2b7974fd1a6 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuInsertIntoHadoopFsRelationCommand.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuInsertIntoHadoopFsRelationCommand.scala @@ -64,6 +64,32 @@ case class GpuInsertIntoHadoopFsRelationCommand( val fs = outputPath.getFileSystem(hadoopConf) val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + val parameters = CaseInsensitiveMap(options) + + val partitionOverwriteMode = parameters.get("partitionOverwriteMode") + // scalastyle:off caselocale + .map(mode => PartitionOverwriteMode.withName(mode.toUpperCase)) + // scalastyle:on caselocale + .getOrElse(sparkSession.sessionState.conf.partitionOverwriteMode) + + + val enableDynamicOverwrite = partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC + // This config only makes sense when we are overwriting a partitioned dataset with dynamic + // partition columns. + val dynamicPartitionOverwrite = enableDynamicOverwrite && mode == SaveMode.Overwrite && + staticPartitions.size < partitionColumns.length + + val jobId = java.util.UUID.randomUUID().toString + + // For dynamic partition overwrite, FileOutputCommitter's output path is staging path, files + // will be renamed from staging path to final output path during commit job + val committerOutputPath = if (dynamicPartitionOverwrite) { + FileCommitProtocol.getStagingDir(outputPath.toString, jobId) + .makeQualified(fs.getUri, fs.getWorkingDirectory) + } else { + qualifiedOutputPath + } + val partitionsTrackedByCatalog = sparkSession.sessionState.conf.manageFilesourcePartitions && catalogTable.isDefined && catalogTable.get.partitionColumnNames.nonEmpty && @@ -83,22 +109,9 @@ case class GpuInsertIntoHadoopFsRelationCommand( fs, catalogTable.get, qualifiedOutputPath, matchingPartitions) } - val parameters = CaseInsensitiveMap(options) - - val partitionOverwriteMode = parameters.get("partitionOverwriteMode") - // scalastyle:off caselocale - .map(mode => PartitionOverwriteMode.withName(mode.toUpperCase)) - // scalastyle:on caselocale - .getOrElse(sparkSession.sessionState.conf.partitionOverwriteMode) - val enableDynamicOverwrite = partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC - // This config only makes sense when we are overwriting a partitioned dataset with dynamic - // partition columns. - val dynamicPartitionOverwrite = enableDynamicOverwrite && mode == SaveMode.Overwrite && - staticPartitions.size < partitionColumns.length - val committer = FileCommitProtocol.instantiate( sparkSession.sessionState.conf.fileCommitProtocolClass, - jobId = java.util.UUID.randomUUID().toString, + jobId = jobId, outputPath = outputPath.toString, dynamicPartitionOverwrite = dynamicPartitionOverwrite) @@ -160,7 +173,7 @@ case class GpuInsertIntoHadoopFsRelationCommand( fileFormat = fileFormat, committer = committer, outputSpec = FileFormatWriter.OutputSpec( - qualifiedOutputPath.toString, customPartitionLocations, outputColumns), + committerOutputPath.toString, customPartitionLocations, outputColumns), hadoopConf = hadoopConf, partitionColumns = partitionColumns, bucketSpec = bucketSpec,