diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index c6b3c13be514..1df1c8b4af2e 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -117,11 +117,9 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { dataSchema, nullable = false, parsedOptions.recordName, parsedOptions.recordNamespace) AvroJob.setOutputKeySchema(job, outputAvroSchema) - val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" - val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level" val COMPRESS_KEY = "mapred.output.compress" - spark.conf.get(AVRO_COMPRESSION_CODEC, "snappy") match { + parsedOptions.compression match { case "uncompressed" => log.info("writing uncompressed Avro records") job.getConfiguration.setBoolean(COMPRESS_KEY, false) @@ -132,8 +130,7 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, DataFileConstants.SNAPPY_CODEC) case "deflate" => - val deflateLevel = spark.conf.get( - AVRO_DEFLATE_LEVEL, Deflater.DEFAULT_COMPRESSION.toString).toInt + val deflateLevel = spark.sessionState.conf.avroDeflateLevel log.info(s"compressing Avro output using deflate (level=$deflateLevel)") job.getConfiguration.setBoolean(COMPRESS_KEY, true) job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, DataFileConstants.DEFLATE_CODEC) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala index cd9a911a14bf..0f59007e7f72 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.internal.SQLConf /** * Options for Avro Reader and Writer stored in case insensitive manner. @@ -68,4 +69,14 @@ class AvroOptions( .map(_.toBoolean) .getOrElse(!ignoreFilesWithoutExtension) } + + /** + * The `compression` option allows to specify a compression codec used in write. + * Currently supported codecs are `uncompressed`, `snappy` and `deflate`. + * If the option is not set, the `spark.sql.avro.compression.codec` config is taken into + * account. If the former one is not set too, the `snappy` codec is used by default. + */ + val compression: String = { + parameters.get("compression").getOrElse(SQLConf.get.avroCompressionCodec) + } } diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 865a14509485..cbf33ea96807 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -27,13 +27,14 @@ import scala.collection.JavaConverters._ import org.apache.avro.Schema import org.apache.avro.Schema.{Field, Type} -import org.apache.avro.file.DataFileWriter -import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} +import org.apache.avro.file.{DataFileReader, DataFileWriter} +import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord} import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.commons.io.FileUtils import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ @@ -364,21 +365,19 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } - test("write with compression") { + test("write with compression - sql configs") { withTempPath { dir => - val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" - val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level" val uncompressDir = s"$dir/uncompress" val deflateDir = s"$dir/deflate" val snappyDir = s"$dir/snappy" val df = spark.read.format("avro").load(testAvro) - spark.conf.set(AVRO_COMPRESSION_CODEC, "uncompressed") + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "uncompressed") df.write.format("avro").save(uncompressDir) - spark.conf.set(AVRO_COMPRESSION_CODEC, "deflate") - spark.conf.set(AVRO_DEFLATE_LEVEL, "9") + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "deflate") + spark.conf.set(SQLConf.AVRO_DEFLATE_LEVEL.key, "9") df.write.format("avro").save(deflateDir) - spark.conf.set(AVRO_COMPRESSION_CODEC, "snappy") + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "snappy") df.write.format("avro").save(snappyDir) val uncompressSize = FileUtils.sizeOfDirectory(new File(uncompressDir)) @@ -904,4 +903,31 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(count == 8) } } + + test("SPARK-24881: write with compression - avro options") { + def getCodec(dir: String): Option[String] = { + val files = new File(dir) + .listFiles() + .filter(_.isFile) + .filter(_.getName.endsWith("avro")) + files.map { file => + val reader = new DataFileReader(file, new GenericDatumReader[Any]()) + val r = reader.getMetaString("avro.codec") + r + }.map(v => if (v == "null") "uncompressed" else v).headOption + } + def checkCodec(df: DataFrame, dir: String, codec: String): Unit = { + val subdir = s"$dir/$codec" + df.write.option("compression", codec).format("avro").save(subdir) + assert(getCodec(subdir) == Some(codec)) + } + withTempPath { dir => + val path = dir.toString + val df = spark.read.format("avro").load(testAvro) + + checkCodec(df, path, "uncompressed") + checkCodec(df, path, "deflate") + checkCodec(df, path, "snappy") + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 53423e03b6b2..a269e218c4ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.internal import java.util.{Locale, NoSuchElementException, Properties, TimeZone} import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicReference +import java.util.zip.Deflater import scala.collection.JavaConverters._ import scala.collection.immutable @@ -1434,6 +1435,20 @@ object SQLConf { "This only takes effect when spark.sql.repl.eagerEval.enabled is set to true.") .intConf .createWithDefault(20) + + val AVRO_COMPRESSION_CODEC = buildConf("spark.sql.avro.compression.codec") + .doc("Compression codec used in writing of AVRO files. Default codec is snappy.") + .stringConf + .checkValues(Set("uncompressed", "deflate", "snappy")) + .createWithDefault("snappy") + + val AVRO_DEFLATE_LEVEL = buildConf("spark.sql.avro.deflate.level") + .doc("Compression level for the deflate codec used in writing of AVRO files. " + + "Valid value must be in the range of from 1 to 9 inclusive or -1. " + + "The default value is -1 which corresponds to 6 level in the current implementation.") + .intConf + .checkValues((1 to 9).toSet + Deflater.DEFAULT_COMPRESSION) + .createWithDefault(Deflater.DEFAULT_COMPRESSION) } /** @@ -1820,6 +1835,10 @@ class SQLConf extends Serializable with Logging { def replEagerEvalTruncate: Int = getConf(SQLConf.REPL_EAGER_EVAL_TRUNCATE) + def avroCompressionCodec: String = getConf(SQLConf.AVRO_COMPRESSION_CODEC) + + def avroDeflateLevel: Int = getConf(SQLConf.AVRO_DEFLATE_LEVEL) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */