diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index fff86686b550..5e9e6ff1a569 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.text.TextOutputWriter import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration @@ -46,12 +47,17 @@ private[libsvm] class LibSVMOutputWriter( context: TaskAttemptContext) extends OutputWriter { + override val path: String = { + val compressionExtension = TextOutputWriter.getCompressionExtension(context) + new Path(stagingDir, fileNamePrefix + ".libsvm" + compressionExtension).toString + } + private[this] val buffer = new Text() private val recordWriter: RecordWriter[NullWritable, Text] = { new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - new Path(stagingDir, fileNamePrefix + extension) + new Path(path) } }.getRecordWriter(context) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala index f4cefdab077e..fbf6e96d3f85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala @@ -42,11 +42,12 @@ abstract class OutputWriterFactory extends Serializable { * @param fileNamePrefix Prefix of the file name. The returned OutputWriter must make sure this * prefix is used in the actual file name. For example, if the prefix is * "part-1-2-3", then the file name must start with "part_1_2_3" but can - * end in arbitrary extension. + * end in arbitrary extension that is deterministic given the configuration + * (i.e. the suffix extension should not depend on any task id, attempt id, + * or partition id). * @param dataSchema Schema of the rows to be written. Partition columns are not included in the * schema if the relation being written is partitioned. * @param context The Hadoop MapReduce task context. - * @since 1.4.0 */ def newInstance( stagingDir: String, @@ -62,7 +63,6 @@ abstract class OutputWriterFactory extends Serializable { * and not modify it (do not add subdirectories, extensions, etc.). All other * file-format-specific information needed to create the writer must be passed * through the [[OutputWriterFactory]] implementation. - * @since 2.0.0 */ def newWriter(path: String): OutputWriter = { throw new UnsupportedOperationException("newInstance with just path not supported") @@ -77,19 +77,22 @@ abstract class OutputWriterFactory extends Serializable { * executor side. This instance is used to persist rows to this single output file. */ abstract class OutputWriter { + + /** + * The path of the file to be written out. This path should include the staging directory and + * the file name prefix passed into the associated createOutputWriter function. + */ + def path: String + /** * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned * tables, dynamic partition columns are not included in rows to be written. - * - * @since 1.4.0 */ def write(row: Row): Unit /** * Closes the [[OutputWriter]]. Invoked on the executor side after all rows are persisted, before * the task output is committed. - * - * @since 1.4.0 */ def close(): Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index eefacbf05ba0..a35cfdb2c234 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory, PartitionedFile} +import org.apache.spark.sql.execution.datasources.text.TextOutputWriter import org.apache.spark.sql.types._ object CSVRelation extends Logging { @@ -185,6 +186,11 @@ private[csv] class CsvOutputWriter( context: TaskAttemptContext, params: CSVOptions) extends OutputWriter with Logging { + override val path: String = { + val compressionExtension = TextOutputWriter.getCompressionExtension(context) + new Path(stagingDir, fileNamePrefix + ".csv" + compressionExtension).toString + } + // create the Generator without separator inserted between 2 records private[this] val text = new Text() @@ -199,7 +205,7 @@ private[csv] class CsvOutputWriter( private val recordWriter: RecordWriter[NullWritable, Text] = { new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - new Path(stagingDir, s"$fileNamePrefix.csv$extension") + new Path(path) } }.getRecordWriter(context) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index cdbb2f729261..651fa78a4e92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.text.TextOutputWriter import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration @@ -160,6 +161,11 @@ private[json] class JsonOutputWriter( context: TaskAttemptContext) extends OutputWriter with Logging { + override val path: String = { + val compressionExtension = TextOutputWriter.getCompressionExtension(context) + new Path(stagingDir, fileNamePrefix + ".json" + compressionExtension).toString + } + private[this] val writer = new CharArrayWriter() // create the Generator without separator inserted between 2 records private[this] val gen = new JacksonGenerator(dataSchema, writer, options) @@ -168,7 +174,7 @@ private[json] class JsonOutputWriter( private val recordWriter: RecordWriter[NullWritable, Text] = { new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - new Path(stagingDir, s"$fileNamePrefix.json$extension") + new Path(path) } }.getRecordWriter(context) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 87b944ba523c..502dd0e8d4cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -121,7 +121,7 @@ class ParquetFileFormat sparkSession.sessionState.conf.writeLegacyParquetFormat.toString) // Sets compression scheme - conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodec) + conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) // SPARK-15719: Disables writing Parquet summary files by default. if (conf.get(ParquetOutputFormat.ENABLE_JOB_SUMMARY) == null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index 615731889dfa..d0fd23605bea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -35,7 +35,7 @@ private[parquet] class ParquetOptions( * Compression codec to use. By default use the value specified in SQLConf. * Acceptable values are defined in [[shortParquetCompressionCodecNames]]. */ - val compressionCodec: String = { + val compressionCodecClassName: String = { val codecName = parameters.getOrElse("compression", sqlConf.parquetCompressionCodec).toLowerCase if (!shortParquetCompressionCodecNames.contains(codecName)) { val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala index 39c199784cd6..1300069c42b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala @@ -22,6 +22,7 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.parquet.hadoop.{ParquetOutputFormat, ParquetRecordWriter} +import org.apache.parquet.hadoop.codec.CodecConfig import org.apache.parquet.hadoop.util.ContextUtil import org.apache.spark.sql.Row @@ -80,7 +81,7 @@ private[parquet] class ParquetOutputWriterFactory( sqlConf.writeLegacyParquetFormat.toString) // Sets compression scheme - conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodec) + conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) new SerializableConfiguration(conf) } @@ -88,7 +89,7 @@ private[parquet] class ParquetOutputWriterFactory( * Returns a [[OutputWriter]] that writes data to the give path without using * [[OutputCommitter]]. */ - override def newWriter(path: String): OutputWriter = new OutputWriter { + override def newWriter(path1: String): OutputWriter = new OutputWriter { // Create TaskAttemptContext that is used to pass on Configuration to the ParquetRecordWriter private val hadoopTaskAttemptId = new TaskAttemptID(new TaskID(new JobID, TaskType.MAP, 0), 0) @@ -98,6 +99,8 @@ private[parquet] class ParquetOutputWriterFactory( // Instance of ParquetRecordWriter that does not use OutputCommitter private val recordWriter = createNoCommitterRecordWriter(path, hadoopAttemptContext) + override def path: String = path1 + override def write(row: Row): Unit = { throw new UnsupportedOperationException("call writeInternal") } @@ -140,16 +143,17 @@ private[parquet] class ParquetOutputWriter( context: TaskAttemptContext) extends OutputWriter { + override val path: String = { + val filename = fileNamePrefix + CodecConfig.from(context).getCodec.getExtension + ".parquet" + new Path(stagingDir, filename).toString + } + private val recordWriter: RecordWriter[Void, InternalRow] = { - val outputFormat = { - new ParquetOutputFormat[InternalRow]() { - override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - new Path(stagingDir, fileNamePrefix + extension) - } + new ParquetOutputFormat[InternalRow]() { + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + new Path(path) } - } - - outputFormat.getRecordWriter(context) + }.getRecordWriter(context) } override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 6cd2351c5749..d40b5725199a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.execution.datasources.text import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{NullWritable, Text} +import org.apache.hadoop.io.compress.GzipCodec import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} -import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat +import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat} +import org.apache.hadoop.util.ReflectionUtils import org.apache.spark.TaskContext import org.apache.spark.sql.{AnalysisException, Row, SparkSession} @@ -128,12 +130,17 @@ class TextOutputWriter( context: TaskAttemptContext) extends OutputWriter { + override val path: String = { + val compressionExtension = TextOutputWriter.getCompressionExtension(context) + new Path(stagingDir, fileNamePrefix + ".txt" + compressionExtension).toString + } + private[this] val buffer = new Text() private val recordWriter: RecordWriter[NullWritable, Text] = { new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - new Path(stagingDir, s"$fileNamePrefix.txt$extension") + new Path(path) } }.getRecordWriter(context) } @@ -150,3 +157,17 @@ class TextOutputWriter( recordWriter.close(context) } } + + +object TextOutputWriter { + /** Returns the compression codec extension to be used in a file name, e.g. ".gzip"). */ + def getCompressionExtension(context: TaskAttemptContext): String = { + // Set the compression extension, similar to code in TextOutputFormat.getDefaultWorkFile + if (FileOutputFormat.getCompressOutput(context)) { + val codecClass = FileOutputFormat.getOutputCompressorClass(context, classOf[GzipCodec]) + ReflectionUtils.newInstance(codecClass, context.getConfiguration).getDefaultExtension + } else { + "" + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 1ceacb458ae6..eba7aa386ade 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -216,9 +216,18 @@ private[orc] class OrcOutputWriter( context: TaskAttemptContext) extends OutputWriter { - private[this] val conf = context.getConfiguration + override val path: String = { + val compressionExtension: String = { + val name = context.getConfiguration.get(OrcRelation.ORC_COMPRESSION) + OrcRelation.extensionsForCompressionCodecNames.getOrElse(name, "") + } + // It has the `.orc` extension at the end because (de)compression tools + // such as gunzip would not be able to decompress this as the compression + // is not applied on this whole file but on each "stream" in ORC format. + new Path(stagingDir, fileNamePrefix + compressionExtension + ".orc").toString + } - private[this] val serializer = new OrcSerializer(dataSchema, conf) + private[this] val serializer = new OrcSerializer(dataSchema, context.getConfiguration) // `OrcRecordWriter.close()` creates an empty file if no rows are written at all. We use this // flag to decide whether `OrcRecordWriter.close()` needs to be called. @@ -226,20 +235,10 @@ private[orc] class OrcOutputWriter( private lazy val recordWriter: RecordWriter[NullWritable, Writable] = { recordWriterInstantiated = true - - val compressionExtension = { - val name = conf.get(OrcRelation.ORC_COMPRESSION) - OrcRelation.extensionsForCompressionCodecNames.getOrElse(name, "") - } - // It has the `.orc` extension at the end because (de)compression tools - // such as gunzip would not be able to decompress this as the compression - // is not applied on this whole file but on each "stream" in ORC format. - val filename = s"$fileNamePrefix$compressionExtension.orc" - new OrcOutputFormat().getRecordWriter( - new Path(stagingDir, filename).getFileSystem(conf), - conf.asInstanceOf[JobConf], - new Path(stagingDir, filename).toString, + new Path(path).getFileSystem(context.getConfiguration), + context.getConfiguration.asInstanceOf[JobConf], + path, Reporter.NULL ).asInstanceOf[RecordWriter[NullWritable, Writable]] } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala index d5044684020e..731540db17ee 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.sources +import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.TaskContext @@ -50,6 +51,8 @@ class CommitFailureTestSource extends SimpleTextSource { SimpleTextRelation.callbackCalled = true } + override val path: String = new Path(stagingDir, fileNamePrefix).toString + override def write(row: Row): Unit = { if (SimpleTextRelation.failWriter) { sys.error("Intentional task writer failure for testing purpose.") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 9e13b217ec30..9896b9bde99c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -123,6 +123,9 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { class SimpleTextOutputWriter( stagingDir: String, fileNamePrefix: String, context: TaskAttemptContext) extends OutputWriter { + + override val path: String = new Path(stagingDir, fileNamePrefix).toString + private val recordWriter: RecordWriter[NullWritable, Text] = new AppendingTextOutputFormat(new Path(stagingDir), fileNamePrefix).getRecordWriter(context)