Skip to content

Commit 54268b4

Browse files
HyukjinKwoncloud-fan
authored andcommitted
[SPARK-16101][SQL] Refactoring CSV write path to be consistent with JSON data source
## What changes were proposed in this pull request? This PR refactors CSV write path to be consistent with JSON data source. This PR makes the methods in classes have consistent arguments with JSON ones. - `UnivocityGenerator` and `JacksonGenerator` ``` scala private[csv] class UnivocityGenerator( schema: StructType, writer: Writer, options: CSVOptions = new CSVOptions(Map.empty[String, String])) { ... def write ... def close ... def flush ... ``` ``` scala private[sql] class JacksonGenerator( schema: StructType, writer: Writer, options: JSONOptions = new JSONOptions(Map.empty[String, String])) { ... def write ... def close ... def flush ... ``` - This PR also makes the classes put in together in a consistent manner with JSON. - `CsvFileFormat` ``` scala CsvFileFormat CsvOutputWriter ``` - `JsonFileFormat` ``` scala JsonFileFormat JsonOutputWriter ``` ## How was this patch tested? Existing tests should cover this. Author: hyukjinkwon <[email protected]> Closes #16496 from HyukjinKwon/SPARK-16101-write.
1 parent ea31f92 commit 54268b4

File tree

5 files changed

+135
-115
lines changed

5 files changed

+135
-115
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@ package org.apache.spark.sql.execution.datasources.csv
2020
import java.nio.charset.{Charset, StandardCharsets}
2121

2222
import org.apache.hadoop.conf.Configuration
23-
import org.apache.hadoop.fs.FileStatus
23+
import org.apache.hadoop.fs.{FileStatus, Path}
2424
import org.apache.hadoop.io.{LongWritable, Text}
2525
import org.apache.hadoop.mapred.TextInputFormat
2626
import org.apache.hadoop.mapreduce._
2727

2828
import org.apache.spark.TaskContext
29+
import org.apache.spark.internal.Logging
2930
import org.apache.spark.rdd.RDD
30-
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
31+
import org.apache.spark.sql.{Dataset, Encoders, Row, SparkSession}
3132
import org.apache.spark.sql.catalyst.InternalRow
3233
import org.apache.spark.sql.catalyst.util.CompressionCodecs
3334
import org.apache.spark.sql.execution.datasources._
@@ -130,7 +131,18 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
130131
CompressionCodecs.setCodecConfiguration(conf, codec)
131132
}
132133

133-
new CSVOutputWriterFactory(csvOptions)
134+
new OutputWriterFactory {
135+
override def newInstance(
136+
path: String,
137+
dataSchema: StructType,
138+
context: TaskAttemptContext): OutputWriter = {
139+
new CsvOutputWriter(path, dataSchema, context, csvOptions)
140+
}
141+
142+
override def getFileExtension(context: TaskAttemptContext): String = {
143+
".csv" + CodecStreams.getCompressionExtension(context)
144+
}
145+
}
134146
}
135147

136148
override def buildReader(
@@ -228,3 +240,18 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
228240
schema.foreach(field => verifyType(field.dataType))
229241
}
230242
}
243+
244+
private[csv] class CsvOutputWriter(
245+
path: String,
246+
dataSchema: StructType,
247+
context: TaskAttemptContext,
248+
params: CSVOptions) extends OutputWriter with Logging {
249+
250+
private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path))
251+
252+
private val gen = new UnivocityGenerator(dataSchema, writer, params)
253+
254+
override def write(row: InternalRow): Unit = gen.write(row)
255+
256+
override def close(): Unit = gen.close()
257+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.csv
2020
import java.nio.charset.StandardCharsets
2121
import java.util.Locale
2222

23+
import com.univocity.parsers.csv.CsvWriterSettings
2324
import org.apache.commons.lang3.time.FastDateFormat
2425

2526
import org.apache.spark.internal.Logging
@@ -126,6 +127,21 @@ private[csv] class CSVOptions(@transient private val parameters: CaseInsensitive
126127
val inputBufferSize = 128
127128

128129
val isCommentSet = this.comment != '\u0000'
130+
131+
def asWriterSettings: CsvWriterSettings = {
132+
val writerSettings = new CsvWriterSettings()
133+
val format = writerSettings.getFormat
134+
format.setDelimiter(delimiter)
135+
format.setQuote(quote)
136+
format.setQuoteEscape(escape)
137+
format.setComment(comment)
138+
writerSettings.setNullValue(nullValue)
139+
writerSettings.setEmptyValue(nullValue)
140+
writerSettings.setSkipEmptyLines(true)
141+
writerSettings.setQuoteAllFields(quoteAll)
142+
writerSettings.setQuoteEscapingEnabled(escapeQuotes)
143+
writerSettings
144+
}
129145
}
130146

131147
object CSVOptions {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -58,43 +58,3 @@ private[csv] class CsvReader(params: CSVOptions) {
5858
*/
5959
def parseLine(line: String): Array[String] = parser.parseLine(line)
6060
}
61-
62-
/**
63-
* Converts a sequence of string to CSV string
64-
*
65-
* @param params Parameters object for configuration
66-
* @param headers headers for columns
67-
*/
68-
private[csv] class LineCsvWriter(
69-
params: CSVOptions,
70-
headers: Seq[String],
71-
output: OutputStream) extends Logging {
72-
private val writerSettings = new CsvWriterSettings
73-
private val format = writerSettings.getFormat
74-
75-
format.setDelimiter(params.delimiter)
76-
format.setQuote(params.quote)
77-
format.setQuoteEscape(params.escape)
78-
format.setComment(params.comment)
79-
80-
writerSettings.setNullValue(params.nullValue)
81-
writerSettings.setEmptyValue(params.nullValue)
82-
writerSettings.setSkipEmptyLines(true)
83-
writerSettings.setQuoteAllFields(params.quoteAll)
84-
writerSettings.setHeaders(headers: _*)
85-
writerSettings.setQuoteEscapingEnabled(params.escapeQuotes)
86-
87-
private val writer = new CsvWriter(output, StandardCharsets.UTF_8, writerSettings)
88-
89-
def writeRow(row: Seq[String], includeHeader: Boolean): Unit = {
90-
if (includeHeader) {
91-
writer.writeHeaders()
92-
}
93-
94-
writer.writeRow(row: _*)
95-
}
96-
97-
def close(): Unit = {
98-
writer.close()
99-
}
100-
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -159,75 +159,3 @@ object CSVRelation extends Logging {
159159
}
160160
}
161161
}
162-
163-
private[csv] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory {
164-
override def newInstance(
165-
path: String,
166-
dataSchema: StructType,
167-
context: TaskAttemptContext): OutputWriter = {
168-
new CsvOutputWriter(path, dataSchema, context, params)
169-
}
170-
171-
override def getFileExtension(context: TaskAttemptContext): String = {
172-
".csv" + CodecStreams.getCompressionExtension(context)
173-
}
174-
}
175-
176-
private[csv] class CsvOutputWriter(
177-
path: String,
178-
dataSchema: StructType,
179-
context: TaskAttemptContext,
180-
params: CSVOptions) extends OutputWriter with Logging {
181-
182-
// A `ValueConverter` is responsible for converting a value of an `InternalRow` to `String`.
183-
// When the value is null, this converter should not be called.
184-
private type ValueConverter = (InternalRow, Int) => String
185-
186-
// `ValueConverter`s for all values in the fields of the schema
187-
private val valueConverters: Array[ValueConverter] =
188-
dataSchema.map(_.dataType).map(makeConverter).toArray
189-
190-
private var printHeader: Boolean = params.headerFlag
191-
private val writer = CodecStreams.createOutputStream(context, new Path(path))
192-
private val csvWriter = new LineCsvWriter(params, dataSchema.fieldNames.toSeq, writer)
193-
194-
private def rowToString(row: InternalRow): Seq[String] = {
195-
var i = 0
196-
val values = new Array[String](row.numFields)
197-
while (i < row.numFields) {
198-
if (!row.isNullAt(i)) {
199-
values(i) = valueConverters(i).apply(row, i)
200-
} else {
201-
values(i) = params.nullValue
202-
}
203-
i += 1
204-
}
205-
values
206-
}
207-
208-
private def makeConverter(dataType: DataType): ValueConverter = dataType match {
209-
case DateType =>
210-
(row: InternalRow, ordinal: Int) =>
211-
params.dateFormat.format(DateTimeUtils.toJavaDate(row.getInt(ordinal)))
212-
213-
case TimestampType =>
214-
(row: InternalRow, ordinal: Int) =>
215-
params.timestampFormat.format(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal)))
216-
217-
case udt: UserDefinedType[_] => makeConverter(udt.sqlType)
218-
219-
case dt: DataType =>
220-
(row: InternalRow, ordinal: Int) =>
221-
row.get(ordinal, dt).toString
222-
}
223-
224-
override def write(row: InternalRow): Unit = {
225-
csvWriter.writeRow(rowToString(row), printHeader)
226-
printHeader = false
227-
}
228-
229-
override def close(): Unit = {
230-
csvWriter.close()
231-
writer.close()
232-
}
233-
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.datasources.csv
19+
20+
import java.io.Writer
21+
22+
import com.univocity.parsers.csv.CsvWriter
23+
24+
import org.apache.spark.sql.catalyst.InternalRow
25+
import org.apache.spark.sql.catalyst.util.DateTimeUtils
26+
import org.apache.spark.sql.types._
27+
28+
private[csv] class UnivocityGenerator(
29+
schema: StructType,
30+
writer: Writer,
31+
options: CSVOptions = new CSVOptions(Map.empty[String, String])) {
32+
private val writerSettings = options.asWriterSettings
33+
writerSettings.setHeaders(schema.fieldNames: _*)
34+
private val gen = new CsvWriter(writer, writerSettings)
35+
private var printHeader = options.headerFlag
36+
37+
// A `ValueConverter` is responsible for converting a value of an `InternalRow` to `String`.
38+
// When the value is null, this converter should not be called.
39+
private type ValueConverter = (InternalRow, Int) => String
40+
41+
// `ValueConverter`s for all values in the fields of the schema
42+
private val valueConverters: Array[ValueConverter] =
43+
schema.map(_.dataType).map(makeConverter).toArray
44+
45+
private def makeConverter(dataType: DataType): ValueConverter = dataType match {
46+
case DateType =>
47+
(row: InternalRow, ordinal: Int) =>
48+
options.dateFormat.format(DateTimeUtils.toJavaDate(row.getInt(ordinal)))
49+
50+
case TimestampType =>
51+
(row: InternalRow, ordinal: Int) =>
52+
options.timestampFormat.format(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal)))
53+
54+
case udt: UserDefinedType[_] => makeConverter(udt.sqlType)
55+
56+
case dt: DataType =>
57+
(row: InternalRow, ordinal: Int) =>
58+
row.get(ordinal, dt).toString
59+
}
60+
61+
private def convertRow(row: InternalRow): Seq[String] = {
62+
var i = 0
63+
val values = new Array[String](row.numFields)
64+
while (i < row.numFields) {
65+
if (!row.isNullAt(i)) {
66+
values(i) = valueConverters(i).apply(row, i)
67+
} else {
68+
values(i) = options.nullValue
69+
}
70+
i += 1
71+
}
72+
values
73+
}
74+
75+
/**
76+
* Writes a single InternalRow to CSV using Univocity.
77+
*/
78+
def write(row: InternalRow): Unit = {
79+
if (printHeader) {
80+
gen.writeHeaders()
81+
}
82+
gen.writeRow(convertRow(row): _*)
83+
printHeader = false
84+
}
85+
86+
def close(): Unit = gen.close()
87+
88+
def flush(): Unit = gen.flush()
89+
}

0 commit comments

Comments
 (0)