Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions integration_tests/src/main/python/orc_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from datetime import date, datetime, timezone
from data_gen import *
from marks import *
from pyspark.sql.functions import lit
from pyspark.sql.types import *

pytestmark = pytest.mark.nightly_resource_consuming_test
Expand Down Expand Up @@ -90,6 +91,31 @@ def test_part_write_round_trip(spark_tmp_path, orc_gen):
data_path,
conf = {'spark.rapids.sql.format.orc.write.enabled': True})


@ignore_order(local=True)
@pytest.mark.parametrize('orc_gen', [int_gen], ids=idfn)
@pytest.mark.parametrize('orc_impl', ["native", "hive"])
def test_dynamic_partition_write_round_trip(spark_tmp_path, orc_gen, orc_impl):
gen_list = [('_c0', orc_gen)]
data_path = spark_tmp_path + '/ORC_DATA'
def do_writes(spark, path):
df = gen_df(spark, gen_list).withColumn("my_partition", lit("PART"))
# first write finds no partitions, it skips the dynamic partition
# overwrite code
df.write.mode("overwrite").partitionBy("my_partition").orc(path)
# second write actually triggers dynamic partition overwrite
df.write.mode("overwrite").partitionBy("my_partition").orc(path)
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: do_writes(spark, path),
lambda spark, path: spark.read.orc(path),
data_path,
conf={
'spark.sql.orc.impl': orc_impl,
'spark.rapids.sql.format.orc.write.enabled': True,
'spark.sql.sources.partitionOverwriteMode': 'DYNAMIC'
})


orc_write_compress_options = ['none', 'uncompressed', 'snappy']
# zstd is available in spark 3.2.0 and later.
if not is_before_spark_320() and not is_spark_cdh():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,32 @@ case class GpuInsertIntoHadoopFsRelationCommand(
val fs = outputPath.getFileSystem(hadoopConf)
val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)

val parameters = CaseInsensitiveMap(options)

val partitionOverwriteMode = parameters.get("partitionOverwriteMode")
// scalastyle:off caselocale
.map(mode => PartitionOverwriteMode.withName(mode.toUpperCase))
// scalastyle:on caselocale
.getOrElse(sparkSession.sessionState.conf.partitionOverwriteMode)


val enableDynamicOverwrite = partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC
// This config only makes sense when we are overwriting a partitioned dataset with dynamic
// partition columns.
val dynamicPartitionOverwrite = enableDynamicOverwrite && mode == SaveMode.Overwrite &&
staticPartitions.size < partitionColumns.length

val jobId = java.util.UUID.randomUUID().toString

// For dynamic partition overwrite, FileOutputCommitter's output path is staging path, files
// will be renamed from staging path to final output path during commit job
val committerOutputPath = if (dynamicPartitionOverwrite) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, might be nice to pull the comment over:
// For dynamic partition overwrite, FileOutputCommitter's output path is staging path, files
// will be renamed from staging path to final output path during commit job

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, will do

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.d8f2a1e

FileCommitProtocol.getStagingDir(outputPath.toString, jobId)
.makeQualified(fs.getUri, fs.getWorkingDirectory)
} else {
qualifiedOutputPath
}

val partitionsTrackedByCatalog = sparkSession.sessionState.conf.manageFilesourcePartitions &&
catalogTable.isDefined &&
catalogTable.get.partitionColumnNames.nonEmpty &&
Expand All @@ -83,22 +109,9 @@ case class GpuInsertIntoHadoopFsRelationCommand(
fs, catalogTable.get, qualifiedOutputPath, matchingPartitions)
}

val parameters = CaseInsensitiveMap(options)

val partitionOverwriteMode = parameters.get("partitionOverwriteMode")
// scalastyle:off caselocale
.map(mode => PartitionOverwriteMode.withName(mode.toUpperCase))
// scalastyle:on caselocale
.getOrElse(sparkSession.sessionState.conf.partitionOverwriteMode)
val enableDynamicOverwrite = partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC
// This config only makes sense when we are overwriting a partitioned dataset with dynamic
// partition columns.
val dynamicPartitionOverwrite = enableDynamicOverwrite && mode == SaveMode.Overwrite &&
staticPartitions.size < partitionColumns.length

val committer = FileCommitProtocol.instantiate(
sparkSession.sessionState.conf.fileCommitProtocolClass,
jobId = java.util.UUID.randomUUID().toString,
jobId = jobId,
outputPath = outputPath.toString,
dynamicPartitionOverwrite = dynamicPartitionOverwrite)

Expand Down Expand Up @@ -160,7 +173,7 @@ case class GpuInsertIntoHadoopFsRelationCommand(
fileFormat = fileFormat,
committer = committer,
outputSpec = FileFormatWriter.OutputSpec(
qualifiedOutputPath.toString, customPartitionLocations, outputColumns),
committerOutputPath.toString, customPartitionLocations, outputColumns),
hadoopConf = hadoopConf,
partitionColumns = partitionColumns,
bucketSpec = bucketSpec,
Expand Down