Skip to content

Commit cd5d93c

Browse files
MaxGekkHyukjinKwon
authored andcommitted
[SPARK-24854][SQL] Gathering all Avro options into the AvroOptions class
## What changes were proposed in this pull request? In the PR, I propose to put all `Avro` options in new class `AvroOptions` in the same way as for other datasources `JSON` and `CSV`. ## How was this patch tested? It was tested by `AvroSuite` Author: Maxim Gekk <[email protected]> Closes #21810 from MaxGekk/avro-options.
1 parent 753f115 commit cd5d93c

File tree

3 files changed

+58
-9
lines changed

3 files changed

+58
-9
lines changed

external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister {
5858
options: Map[String, String],
5959
files: Seq[FileStatus]): Option[StructType] = {
6060
val conf = spark.sparkContext.hadoopConfiguration
61+
val parsedOptions = new AvroOptions(options)
6162

6263
// Schema evolution is not supported yet. Here we only pick a single random sample file to
6364
// figure out the schema of the whole dataset.
@@ -76,7 +77,7 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister {
7677
}
7778

7879
// User can specify an optional avro json schema.
79-
val avroSchema = options.get(AvroFileFormat.AvroSchema)
80+
val avroSchema = parsedOptions.schema
8081
.map(new Schema.Parser().parse)
8182
.getOrElse {
8283
val in = new FsInput(sampleFile.getPath, conf)
@@ -114,10 +115,9 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister {
114115
job: Job,
115116
options: Map[String, String],
116117
dataSchema: StructType): OutputWriterFactory = {
117-
val recordName = options.getOrElse("recordName", "topLevelRecord")
118-
val recordNamespace = options.getOrElse("recordNamespace", "")
118+
val parsedOptions = new AvroOptions(options)
119119
val outputAvroSchema = SchemaConverters.toAvroType(
120-
dataSchema, nullable = false, recordName, recordNamespace)
120+
dataSchema, nullable = false, parsedOptions.recordName, parsedOptions.recordNamespace)
121121

122122
AvroJob.setOutputKeySchema(job, outputAvroSchema)
123123
val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec"
@@ -160,11 +160,12 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister {
160160

161161
val broadcastedConf =
162162
spark.sparkContext.broadcast(new AvroFileFormat.SerializableConfiguration(hadoopConf))
163+
val parsedOptions = new AvroOptions(options)
163164

164165
(file: PartitionedFile) => {
165166
val log = LoggerFactory.getLogger(classOf[AvroFileFormat])
166167
val conf = broadcastedConf.value.value
167-
val userProvidedSchema = options.get(AvroFileFormat.AvroSchema).map(new Schema.Parser().parse)
168+
val userProvidedSchema = parsedOptions.schema.map(new Schema.Parser().parse)
168169

169170
// TODO Removes this check once `FileFormat` gets a general file filtering interface method.
170171
// Doing input file filtering is improper because we may generate empty tasks that process no
@@ -235,8 +236,6 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister {
235236
private[avro] object AvroFileFormat {
236237
val IgnoreFilesWithoutExtensionProperty = "avro.mapred.ignore.inputs.without.extension"
237238

238-
val AvroSchema = "avroSchema"
239-
240239
class SerializableConfiguration(@transient var value: Configuration)
241240
extends Serializable with KryoSerializable {
242241
@transient private[avro] lazy val log = LoggerFactory.getLogger(getClass)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.avro
19+
20+
import org.apache.spark.internal.Logging
21+
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
22+
23+
/**
24+
* Options for Avro Reader and Writer stored in case insensitive manner.
25+
*/
26+
class AvroOptions(@transient val parameters: CaseInsensitiveMap[String])
27+
extends Logging with Serializable {
28+
29+
def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters))
30+
31+
/**
32+
* Optional schema provided by an user in JSON format.
33+
*/
34+
val schema: Option[String] = parameters.get("avroSchema")
35+
36+
/**
37+
* Top level record name in write result, which is required in Avro spec.
38+
* See https://avro.apache.org/docs/1.8.2/spec.html#schema_record .
39+
* Default value is "topLevelRecord"
40+
*/
41+
val recordName: String = parameters.getOrElse("recordName", "topLevelRecord")
42+
43+
/**
44+
* Record namespace in write result. Default value is "".
45+
* See Avro spec for details: https://avro.apache.org/docs/1.8.2/spec.html#schema_record .
46+
*/
47+
val recordNamespace: String = parameters.getOrElse("recordNamespace", "")
48+
}

external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
578578
""".stripMargin
579579
val result = spark
580580
.read
581-
.option(AvroFileFormat.AvroSchema, avroSchema)
581+
.option("avroSchema", avroSchema)
582582
.avro(testAvro)
583583
.collect()
584584
val expected = spark.read.avro(testAvro).select("string").collect()
@@ -598,7 +598,9 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
598598
| }]
599599
|}
600600
""".stripMargin
601-
val result = spark.read.option(AvroFileFormat.AvroSchema, avroSchema)
601+
val result = spark
602+
.read
603+
.option("avroSchema", avroSchema)
602604
.avro(testAvro).select("missingField").first
603605
assert(result === Row("foo"))
604606
}

0 commit comments

Comments
 (0)