Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,18 @@ package org.apache.spark.sql.execution.datasources.csv
import java.nio.charset.{Charset, StandardCharsets}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FileStatus
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce._

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Dataset, Encoders, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.functions.{length, trim}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableConfiguration
Expand All @@ -59,63 +58,8 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
val csvOptions = new CSVOptions(options)
val paths = files.map(_.getPath.toString)
val lines: Dataset[String] = readText(sparkSession, csvOptions, paths)
val firstLine: String = findFirstLine(csvOptions, lines)
val firstRow = new CsvReader(csvOptions).parseLine(firstLine)
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
val header = makeSafeHeader(firstRow, csvOptions, caseSensitive)

val parsedRdd: RDD[Array[String]] = CSVRelation.univocityTokenizer(
lines,
firstLine = if (csvOptions.headerFlag) firstLine else null,
params = csvOptions)
val schema = if (csvOptions.inferSchemaFlag) {
CSVInferSchema.infer(parsedRdd, header, csvOptions)
} else {
// By default fields are assumed to be StringType
val schemaFields = header.map { fieldName =>
StructField(fieldName, StringType, nullable = true)
}
StructType(schemaFields)
}
Some(schema)
}

/**
* Generates a header from the given row which is null-safe and duplicate-safe.
*/
private def makeSafeHeader(
row: Array[String],
options: CSVOptions,
caseSensitive: Boolean): Array[String] = {
if (options.headerFlag) {
val duplicates = {
val headerNames = row.filter(_ != null)
.map(name => if (caseSensitive) name else name.toLowerCase)
headerNames.diff(headerNames.distinct).distinct
}

row.zipWithIndex.map { case (value, index) =>
if (value == null || value.isEmpty || value == options.nullValue) {
// When there are empty strings or the values set in `nullValue`, put the
// index as the suffix.
s"_c$index"
} else if (!caseSensitive && duplicates.contains(value.toLowerCase)) {
// When there are case-insensitive duplicates, put the index as the suffix.
s"$value$index"
} else if (duplicates.contains(value)) {
// When there are duplicates, put the index as the suffix.
s"$value$index"
} else {
value
}
}
} else {
row.zipWithIndex.map { case (_, index) =>
// Uses default column names, "_c#" where # is its position of fields
// when header option is disabled.
s"_c$index"
}
}
Some(CSVInferSchema.infer(lines, caseSensitive, csvOptions))
}

override def prepareWrite(
Expand All @@ -142,14 +86,11 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
options: Map[String, String],
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
val csvOptions = new CSVOptions(options)
val commentPrefix = csvOptions.comment.toString
val headers = requiredSchema.fields.map(_.name)

val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))

(file: PartitionedFile) => {
val lineIterator = {
val lines = {
val conf = broadcastedHadoopConf.value.value
val linesReader = new HadoopFileLinesReader(file, conf)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
Expand All @@ -158,36 +99,15 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
}
}

CSVRelation.dropHeaderLine(file, lineIterator, csvOptions)

val csvParser = new CsvReader(csvOptions)
val tokenizedIterator = lineIterator.filter { line =>
line.trim.nonEmpty && !line.startsWith(commentPrefix)
}.map { line =>
csvParser.parseLine(line)
}
val parser = CSVRelation.csvParser(dataSchema, requiredSchema.fieldNames, csvOptions)
var numMalformedRecords = 0
tokenizedIterator.flatMap { recordTokens =>
val row = parser(recordTokens, numMalformedRecords)
if (row.isEmpty) {
numMalformedRecords += 1
}
row
val linesWithoutHeader = if (csvOptions.headerFlag && file.start == 0) {
UnivocityParser.dropHeaderLine(lines, csvOptions)
} else {
lines
}
}
}

/**
* Returns the first line of the first non-empty file in path
*/
private def findFirstLine(options: CSVOptions, lines: Dataset[String]): String = {
import lines.sqlContext.implicits._
val nonEmptyLines = lines.filter(length(trim($"value")) > 0)
if (options.isCommentSet) {
nonEmptyLines.filter(!$"value".startsWith(options.comment.toString)).first()
} else {
nonEmptyLines.first()
val linesFiltered = UnivocityParser.filterCommentAndEmpty(linesWithoutHeader, csvOptions)
val parser = new UnivocityParser(dataSchema, requiredSchema, csvOptions)
linesFiltered.flatMap(parser.parse)
}
}

Expand Down Expand Up @@ -228,3 +148,35 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
schema.foreach(field => verifyType(field.dataType))
}
}

Copy link
Member Author

@HyukjinKwon HyukjinKwon Jan 5, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These below just came from CSVRelation.

private[csv] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory {
override def newInstance(
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new CsvOutputWriter(path, dataSchema, context, params)
}

override def getFileExtension(context: TaskAttemptContext): String = {
".csv" + CodecStreams.getCompressionExtension(context)
}
}

private[csv] class CsvOutputWriter(
path: String,
dataSchema: StructType,
context: TaskAttemptContext,
params: CSVOptions) extends OutputWriter with Logging {
private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path))
private val gen = new UnivocityGenerator(dataSchema, writer, params)
private var printHeader = params.headerFlag

override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")

override protected[sql] def writeInternal(row: InternalRow): Unit = {
gen.write(row, printHeader)
printHeader = false
}

override def close(): Unit = gen.close()
}
Loading