diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index b0feaeb84e9f..03faf1ba29fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -20,19 +20,18 @@ package org.apache.spark.sql.execution.datasources.csv import java.nio.charset.{Charset, StandardCharsets} import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce._ import org.apache.spark.TaskContext -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Dataset, Encoders, SparkSession} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Dataset, Encoders, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat -import org.apache.spark.sql.functions.{length, trim} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration @@ -59,63 +58,8 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { val csvOptions = new CSVOptions(options) val paths = files.map(_.getPath.toString) val lines: Dataset[String] = readText(sparkSession, csvOptions, paths) - val firstLine: String = findFirstLine(csvOptions, lines) - val firstRow = new CsvReader(csvOptions).parseLine(firstLine) val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - val header = makeSafeHeader(firstRow, csvOptions, caseSensitive) - - val parsedRdd: RDD[Array[String]] = CSVRelation.univocityTokenizer( - lines, - firstLine = if (csvOptions.headerFlag) firstLine else null, - params = csvOptions) - val schema = if (csvOptions.inferSchemaFlag) { - CSVInferSchema.infer(parsedRdd, header, csvOptions) - } else { - // By default fields are assumed to be StringType - val schemaFields = header.map { fieldName => - StructField(fieldName, StringType, nullable = true) - } - StructType(schemaFields) - } - Some(schema) - } - - /** - * Generates a header from the given row which is null-safe and duplicate-safe. - */ - private def makeSafeHeader( - row: Array[String], - options: CSVOptions, - caseSensitive: Boolean): Array[String] = { - if (options.headerFlag) { - val duplicates = { - val headerNames = row.filter(_ != null) - .map(name => if (caseSensitive) name else name.toLowerCase) - headerNames.diff(headerNames.distinct).distinct - } - - row.zipWithIndex.map { case (value, index) => - if (value == null || value.isEmpty || value == options.nullValue) { - // When there are empty strings or the values set in `nullValue`, put the - // index as the suffix. - s"_c$index" - } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) { - // When there are case-insensitive duplicates, put the index as the suffix. - s"$value$index" - } else if (duplicates.contains(value)) { - // When there are duplicates, put the index as the suffix. - s"$value$index" - } else { - value - } - } - } else { - row.zipWithIndex.map { case (_, index) => - // Uses default column names, "_c#" where # is its position of fields - // when header option is disabled. - s"_c$index" - } - } + Some(CSVInferSchema.infer(lines, caseSensitive, csvOptions)) } override def prepareWrite( @@ -142,14 +86,11 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { val csvOptions = new CSVOptions(options) - val commentPrefix = csvOptions.comment.toString - val headers = requiredSchema.fields.map(_.name) - val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) (file: PartitionedFile) => { - val lineIterator = { + val lines = { val conf = broadcastedHadoopConf.value.value val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) @@ -158,36 +99,15 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { } } - CSVRelation.dropHeaderLine(file, lineIterator, csvOptions) - - val csvParser = new CsvReader(csvOptions) - val tokenizedIterator = lineIterator.filter { line => - line.trim.nonEmpty && !line.startsWith(commentPrefix) - }.map { line => - csvParser.parseLine(line) - } - val parser = CSVRelation.csvParser(dataSchema, requiredSchema.fieldNames, csvOptions) - var numMalformedRecords = 0 - tokenizedIterator.flatMap { recordTokens => - val row = parser(recordTokens, numMalformedRecords) - if (row.isEmpty) { - numMalformedRecords += 1 - } - row + val linesWithoutHeader = if (csvOptions.headerFlag && file.start == 0) { + UnivocityParser.dropHeaderLine(lines, csvOptions) + } else { + lines } - } - } - /** - * Returns the first line of the first non-empty file in path - */ - private def findFirstLine(options: CSVOptions, lines: Dataset[String]): String = { - import lines.sqlContext.implicits._ - val nonEmptyLines = lines.filter(length(trim($"value")) > 0) - if (options.isCommentSet) { - nonEmptyLines.filter(!$"value".startsWith(options.comment.toString)).first() - } else { - nonEmptyLines.first() + val linesFiltered = UnivocityParser.filterCommentAndEmpty(linesWithoutHeader, csvOptions) + val parser = new UnivocityParser(dataSchema, requiredSchema, csvOptions) + linesFiltered.flatMap(parser.parse) } } @@ -228,3 +148,35 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { schema.foreach(field => verifyType(field.dataType)) } } + +private[csv] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new CsvOutputWriter(path, dataSchema, context, params) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + ".csv" + CodecStreams.getCompressionExtension(context) + } +} + +private[csv] class CsvOutputWriter( + path: String, + dataSchema: StructType, + context: TaskAttemptContext, + params: CSVOptions) extends OutputWriter with Logging { + private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path)) + private val gen = new UnivocityGenerator(dataSchema, writer, params) + private var printHeader = params.headerFlag + + override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") + + override protected[sql] def writeInternal(row: InternalRow): Unit = { + gen.write(row, printHeader) + printHeader = false + } + + override def close(): Unit = gen.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index adc92fe5a31e..1d9a27a100e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -18,17 +18,16 @@ package org.apache.spark.sql.execution.datasources.csv import java.math.BigDecimal -import java.text.NumberFormat -import java.util.Locale import scala.util.control.Exception._ -import scala.util.Try -import org.apache.spark.rdd.RDD +import com.univocity.parsers.csv.CsvParser + +import org.apache.spark.sql.Dataset import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.datasources.csv.UnivocityParser._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String private[csv] object CSVInferSchema { @@ -39,22 +38,43 @@ private[csv] object CSVInferSchema { * 3. Replace any null types with string type */ def infer( - tokenRdd: RDD[Array[String]], - header: Array[String], + csv: Dataset[String], + caseSensitive: Boolean, options: CSVOptions): StructType = { - val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) - val rootTypes: Array[DataType] = - tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes) - - val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) => - val dType = rootType match { - case _: NullType => StringType - case other => other + import csv.sqlContext.implicits._ + + val filtered = csv.mapPartitions(filterCommentAndEmpty(_, options)) + val firstLine: String = filtered.first() + val firstRow = new CsvParser(options.asParserSettings).parseLine(firstLine) + val header = makeSafeHeader(firstRow, caseSensitive, options) + + val fields = if (options.inferSchemaFlag) { + val tokenRdd = filtered.mapPartitions { iter => + val parser = new CsvParser(options.asParserSettings) + if (options.headerFlag) { + iter.filterNot(_ == firstLine).map(parser.parseLine) + } else { + iter.map(parser.parseLine) + } + }.rdd + + val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) + val rootTypes: Array[DataType] = + tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes) + + header.zip(rootTypes).map { case (thisHeader, rootType) => + val dType = rootType match { + case _: NullType => StringType + case other => other + } + StructField(thisHeader, dType, nullable = true) } - StructField(thisHeader, dType, nullable = true) + } else { + // By default fields are assumed to be StringType + header.map(fieldName => StructField(fieldName, StringType, nullable = true)) } - StructType(structFields) + StructType(fields) } private def inferRowType(options: CSVOptions) @@ -214,127 +234,47 @@ private[csv] object CSVInferSchema { case _ => None } -} - -private[csv] object CSVTypeCast { - // A `ValueConverter` is responsible for converting the given value to a desired type. - private type ValueConverter = String => Any /** - * Create converters which cast each given string datum to each specified type in given schema. - * Currently, we do not support complex types (`ArrayType`, `MapType`, `StructType`). - * - * For string types, this is simply the datum. - * For other types, this is converted into the value according to the type. - * For other nullable types, returns null if it is null or equals to the value specified - * in `nullValue` option. - * - * @param schema schema that contains data types to cast the given value into. - * @param options CSV options. + * Generates a header from the given row which is null-safe and duplicate-safe. */ - def makeConverters( - schema: StructType, - options: CSVOptions = CSVOptions()): Array[ValueConverter] = { - schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray - } - - /** - * Create a converter which converts the string value to a value according to a desired type. - */ - def makeConverter( - name: String, - dataType: DataType, - nullable: Boolean = true, - options: CSVOptions = CSVOptions()): ValueConverter = dataType match { - case _: ByteType => (d: String) => - nullSafeDatum(d, name, nullable, options)(_.toByte) - - case _: ShortType => (d: String) => - nullSafeDatum(d, name, nullable, options)(_.toShort) - - case _: IntegerType => (d: String) => - nullSafeDatum(d, name, nullable, options)(_.toInt) - - case _: LongType => (d: String) => - nullSafeDatum(d, name, nullable, options)(_.toLong) - - case _: FloatType => (d: String) => - nullSafeDatum(d, name, nullable, options) { - case options.nanValue => Float.NaN - case options.negativeInf => Float.NegativeInfinity - case options.positiveInf => Float.PositiveInfinity - case datum => - Try(datum.toFloat) - .getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).floatValue()) - } - - case _: DoubleType => (d: String) => - nullSafeDatum(d, name, nullable, options) { - case options.nanValue => Double.NaN - case options.negativeInf => Double.NegativeInfinity - case options.positiveInf => Double.PositiveInfinity - case datum => - Try(datum.toDouble) - .getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).doubleValue()) + private def makeSafeHeader( + row: Array[String], + caseSensitive: Boolean, + options: CSVOptions): Array[String] = { + if (options.headerFlag) { + val duplicates = { + val headerNames = row.filter(_ != null) + .map(name => if (caseSensitive) name else name.toLowerCase) + headerNames.diff(headerNames.distinct).distinct } - case _: BooleanType => (d: String) => - nullSafeDatum(d, name, nullable, options)(_.toBoolean) - - case dt: DecimalType => (d: String) => - nullSafeDatum(d, name, nullable, options) { datum => - val value = new BigDecimal(datum.replaceAll(",", "")) - Decimal(value, dt.precision, dt.scale) + row.zipWithIndex.map { case (value, index) => + if (value == null || value.isEmpty || value == options.nullValue) { + // When there are empty strings or the values set in `nullValue`, put the + // index as the suffix. + s"_c$index" + } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) { + // When there are case-insensitive duplicates, put the index as the suffix. + s"$value$index" + } else if (duplicates.contains(value)) { + // When there are duplicates, put the index as the suffix. + s"$value$index" + } else { + value + } } - - case _: TimestampType => (d: String) => - nullSafeDatum(d, name, nullable, options) { datum => - // This one will lose microseconds parts. - // See https://issues.apache.org/jira/browse/SPARK-10681. - Try(options.timestampFormat.parse(datum).getTime * 1000L) - .getOrElse { - // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards - // compatibility. - DateTimeUtils.stringToTime(datum).getTime * 1000L - } - } - - case _: DateType => (d: String) => - nullSafeDatum(d, name, nullable, options) { datum => - // This one will lose microseconds parts. - // See https://issues.apache.org/jira/browse/SPARK-10681.x - Try(DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime)) - .getOrElse { - // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards - // compatibility. - DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime) - } - } - - case _: StringType => (d: String) => - nullSafeDatum(d, name, nullable, options)(UTF8String.fromString(_)) - - case udt: UserDefinedType[_] => (datum: String) => - makeConverter(name, udt.sqlType, nullable, options) - - case _ => throw new RuntimeException(s"Unsupported type: ${dataType.typeName}") - } - - private def nullSafeDatum( - datum: String, - name: String, - nullable: Boolean, - options: CSVOptions)(converter: ValueConverter): Any = { - if (datum == options.nullValue || datum == null) { - if (!nullable) { - throw new RuntimeException(s"null value found but field $name is not nullable.") - } - null } else { - converter.apply(datum) + row.zipWithIndex.map { case (_, index) => + // Uses default column names, "_c#" where # is its position of fields + // when header option is disabled. + s"_c$index" + } } } +} +private[csv] object CSVTypeCast { /** * Helper method that converts string representation of a character to actual character. * It handles some Java escaped strings and throws exception if given string is longer than one diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 21e50307b5ab..140ce23958dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.csv import java.nio.charset.StandardCharsets import java.util.Locale +import com.univocity.parsers.csv.{CsvParserSettings, CsvWriterSettings, UnescapedQuoteHandling} import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging @@ -126,6 +127,39 @@ private[csv] class CSVOptions(@transient private val parameters: CaseInsensitive val inputBufferSize = 128 val isCommentSet = this.comment != '\u0000' + + def asWriterSettings: CsvWriterSettings = { + val writerSettings = new CsvWriterSettings() + val format = writerSettings.getFormat + format.setDelimiter(delimiter) + format.setQuote(quote) + format.setQuoteEscape(escape) + format.setComment(comment) + writerSettings.setNullValue(nullValue) + writerSettings.setEmptyValue(nullValue) + writerSettings.setSkipEmptyLines(true) + writerSettings.setQuoteAllFields(quoteAll) + writerSettings.setQuoteEscapingEnabled(escapeQuotes) + writerSettings + } + + def asParserSettings: CsvParserSettings = { + val settings = new CsvParserSettings() + val format = settings.getFormat + format.setDelimiter(delimiter) + format.setQuote(quote) + format.setQuoteEscape(escape) + format.setComment(comment) + settings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlag) + settings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlag) + settings.setReadInputOnSeparateThread(false) + settings.setInputBufferSize(inputBufferSize) + settings.setMaxColumns(maxColumns) + settings.setNullValue(nullValue) + settings.setMaxCharsPerColumn(maxCharsPerColumn) + settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER) + settings + } } object CSVOptions { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala deleted file mode 100644 index 6239508ec942..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.csv - -import java.io.{CharArrayWriter, OutputStream, StringReader} -import java.nio.charset.StandardCharsets - -import com.univocity.parsers.csv._ - -import org.apache.spark.internal.Logging - -/** - * Read and parse CSV-like input - * - * @param params Parameters object - */ -private[csv] class CsvReader(params: CSVOptions) { - - private val parser: CsvParser = { - val settings = new CsvParserSettings() - val format = settings.getFormat - format.setDelimiter(params.delimiter) - format.setQuote(params.quote) - format.setQuoteEscape(params.escape) - format.setComment(params.comment) - settings.setIgnoreLeadingWhitespaces(params.ignoreLeadingWhiteSpaceFlag) - settings.setIgnoreTrailingWhitespaces(params.ignoreTrailingWhiteSpaceFlag) - settings.setReadInputOnSeparateThread(false) - settings.setInputBufferSize(params.inputBufferSize) - settings.setMaxColumns(params.maxColumns) - settings.setNullValue(params.nullValue) - settings.setMaxCharsPerColumn(params.maxCharsPerColumn) - settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER) - - new CsvParser(settings) - } - - /** - * parse a line - * - * @param line a String with no newline at the end - * @return array of strings where each string is a field in the CSV record - */ - def parseLine(line: String): Array[String] = parser.parseLine(line) -} - -/** - * Converts a sequence of string to CSV string - * - * @param params Parameters object for configuration - * @param headers headers for columns - */ -private[csv] class LineCsvWriter( - params: CSVOptions, - headers: Seq[String], - output: OutputStream) extends Logging { - private val writerSettings = new CsvWriterSettings - private val format = writerSettings.getFormat - - format.setDelimiter(params.delimiter) - format.setQuote(params.quote) - format.setQuoteEscape(params.escape) - format.setComment(params.comment) - - writerSettings.setNullValue(params.nullValue) - writerSettings.setEmptyValue(params.nullValue) - writerSettings.setSkipEmptyLines(true) - writerSettings.setQuoteAllFields(params.quoteAll) - writerSettings.setHeaders(headers: _*) - writerSettings.setQuoteEscapingEnabled(params.escapeQuotes) - - private val writer = new CsvWriter(output, StandardCharsets.UTF_8, writerSettings) - - def writeRow(row: Seq[String], includeHeader: Boolean): Unit = { - if (includeHeader) { - writer.writeHeaders() - } - - writer.writeRow(row: _*) - } - - def close(): Unit = { - writer.close() - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala deleted file mode 100644 index 23c07eb630d3..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ /dev/null @@ -1,235 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.csv - -import scala.util.control.NonFatal - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce.TaskAttemptContext - -import org.apache.spark.internal.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter, OutputWriterFactory, PartitionedFile} -import org.apache.spark.sql.types._ - -object CSVRelation extends Logging { - - def univocityTokenizer( - file: Dataset[String], - firstLine: String, - params: CSVOptions): RDD[Array[String]] = { - // If header is set, make sure firstLine is materialized before sending to executors. - val commentPrefix = params.comment.toString - file.rdd.mapPartitions { iter => - val parser = new CsvReader(params) - val filteredIter = iter.filter { line => - line.trim.nonEmpty && !line.startsWith(commentPrefix) - } - if (params.headerFlag) { - filteredIter.filterNot(_ == firstLine).map { item => - parser.parseLine(item) - } - } else { - filteredIter.map { item => - parser.parseLine(item) - } - } - } - } - - /** - * Returns a function that parses a single CSV record (in the form of an array of strings in which - * each element represents a column) and turns it into either one resulting row or no row (if the - * the record is malformed). - * - * The 2nd argument in the returned function represents the total number of malformed rows - * observed so far. - */ - // This is pretty convoluted and we should probably rewrite the entire CSV parsing soon. - def csvParser( - schema: StructType, - requiredColumns: Array[String], - params: CSVOptions): (Array[String], Int) => Option[InternalRow] = { - val requiredFields = StructType(requiredColumns.map(schema(_))).fields - val safeRequiredFields = if (params.dropMalformed) { - // If `dropMalformed` is enabled, then it needs to parse all the values - // so that we can decide which row is malformed. - requiredFields ++ schema.filterNot(requiredFields.contains(_)) - } else { - requiredFields - } - val safeRequiredIndices = new Array[Int](safeRequiredFields.length) - schema.zipWithIndex.filter { case (field, _) => - safeRequiredFields.contains(field) - }.foreach { case (field, index) => - safeRequiredIndices(safeRequiredFields.indexOf(field)) = index - } - val requiredSize = requiredFields.length - val row = new GenericInternalRow(requiredSize) - val converters = CSVTypeCast.makeConverters(schema, params) - - (tokens: Array[String], numMalformedRows) => { - if (params.dropMalformed && schema.length != tokens.length) { - if (numMalformedRows < params.maxMalformedLogPerPartition) { - logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") - } - if (numMalformedRows == params.maxMalformedLogPerPartition - 1) { - logWarning( - s"More than ${params.maxMalformedLogPerPartition} malformed records have been " + - "found on this partition. Malformed records from now on will not be logged.") - } - None - } else if (params.failFast && schema.length != tokens.length) { - throw new RuntimeException(s"Malformed line in FAILFAST mode: " + - s"${tokens.mkString(params.delimiter.toString)}") - } else { - val indexSafeTokens = if (params.permissive && schema.length > tokens.length) { - tokens ++ new Array[String](schema.length - tokens.length) - } else if (params.permissive && schema.length < tokens.length) { - tokens.take(schema.length) - } else { - tokens - } - try { - var index: Int = 0 - var subIndex: Int = 0 - while (subIndex < safeRequiredIndices.length) { - index = safeRequiredIndices(subIndex) - // It anyway needs to try to parse since it decides if this row is malformed - // or not after trying to cast in `DROPMALFORMED` mode even if the casted - // value is not stored in the row. - val value = converters(index).apply(indexSafeTokens(index)) - if (subIndex < requiredSize) { - row(subIndex) = value - } - subIndex += 1 - } - Some(row) - } catch { - case NonFatal(e) if params.dropMalformed => - if (numMalformedRows < params.maxMalformedLogPerPartition) { - logWarning("Parse exception. " + - s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") - } - if (numMalformedRows == params.maxMalformedLogPerPartition - 1) { - logWarning( - s"More than ${params.maxMalformedLogPerPartition} malformed records have been " + - "found on this partition. Malformed records from now on will not be logged.") - } - None - } - } - } - } - - // Skips the header line of each file if the `header` option is set to true. - def dropHeaderLine( - file: PartitionedFile, lines: Iterator[String], csvOptions: CSVOptions): Unit = { - // TODO What if the first partitioned file consists of only comments and empty lines? - if (csvOptions.headerFlag && file.start == 0) { - val nonEmptyLines = if (csvOptions.isCommentSet) { - val commentPrefix = csvOptions.comment.toString - lines.dropWhile { line => - line.trim.isEmpty || line.trim.startsWith(commentPrefix) - } - } else { - lines.dropWhile(_.trim.isEmpty) - } - - if (nonEmptyLines.hasNext) nonEmptyLines.drop(1) - } - } -} - -private[csv] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory { - override def newInstance( - path: String, - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter = { - new CsvOutputWriter(path, dataSchema, context, params) - } - - override def getFileExtension(context: TaskAttemptContext): String = { - ".csv" + CodecStreams.getCompressionExtension(context) - } -} - -private[csv] class CsvOutputWriter( - path: String, - dataSchema: StructType, - context: TaskAttemptContext, - params: CSVOptions) extends OutputWriter with Logging { - - // A `ValueConverter` is responsible for converting a value of an `InternalRow` to `String`. - // When the value is null, this converter should not be called. - private type ValueConverter = (InternalRow, Int) => String - - // `ValueConverter`s for all values in the fields of the schema - private val valueConverters: Array[ValueConverter] = - dataSchema.map(_.dataType).map(makeConverter).toArray - - private var printHeader: Boolean = params.headerFlag - private val writer = CodecStreams.createOutputStream(context, new Path(path)) - private val csvWriter = new LineCsvWriter(params, dataSchema.fieldNames.toSeq, writer) - - private def rowToString(row: InternalRow): Seq[String] = { - var i = 0 - val values = new Array[String](row.numFields) - while (i < row.numFields) { - if (!row.isNullAt(i)) { - values(i) = valueConverters(i).apply(row, i) - } else { - values(i) = params.nullValue - } - i += 1 - } - values - } - - private def makeConverter(dataType: DataType): ValueConverter = dataType match { - case DateType => - (row: InternalRow, ordinal: Int) => - params.dateFormat.format(DateTimeUtils.toJavaDate(row.getInt(ordinal))) - - case TimestampType => - (row: InternalRow, ordinal: Int) => - params.timestampFormat.format(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal))) - - case udt: UserDefinedType[_] => makeConverter(udt.sqlType) - - case dt: DataType => - (row: InternalRow, ordinal: Int) => - row.get(ordinal, dt).toString - } - - override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") - - override protected[sql] def writeInternal(row: InternalRow): Unit = { - csvWriter.writeRow(rowToString(row), printHeader) - printHeader = false - } - - override def close(): Unit = { - csvWriter.close() - writer.close() - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityGenerator.scala new file mode 100644 index 000000000000..ff93d4589f51 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityGenerator.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.io.Writer + +import com.univocity.parsers.csv.CsvWriter + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ + +private[csv] class UnivocityGenerator( + schema: StructType, + writer: Writer, + options: CSVOptions = new CSVOptions(Map.empty[String, String])) { + private val writerSettings = options.asWriterSettings + writerSettings.setHeaders(schema.fieldNames: _*) + private val gen = new CsvWriter(writer, writerSettings) + + // A `ValueConverter` is responsible for converting a value of an `InternalRow` to `String`. + // When the value is null, this converter should not be called. + private type ValueConverter = (InternalRow, Int) => String + + // `ValueConverter`s for all values in the fields of the schema + private val valueConverters: Array[ValueConverter] = + schema.map(_.dataType).map(makeConverter).toArray + + private def makeConverter(dataType: DataType): ValueConverter = dataType match { + case DateType => + (row: InternalRow, ordinal: Int) => + options.dateFormat.format(DateTimeUtils.toJavaDate(row.getInt(ordinal))) + + case TimestampType => + (row: InternalRow, ordinal: Int) => + options.timestampFormat.format(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal))) + + case udt: UserDefinedType[_] => makeConverter(udt.sqlType) + + case dt: DataType => + (row: InternalRow, ordinal: Int) => + row.get(ordinal, dt).toString + } + + private def convertRow(row: InternalRow): Seq[String] = { + var i = 0 + val values = new Array[String](row.numFields) + while (i < row.numFields) { + if (!row.isNullAt(i)) { + values(i) = valueConverters(i).apply(row, i) + } else { + values(i) = options.nullValue + } + i += 1 + } + values + } + + /** + * Writes a single InternalRow to CSV using Univocity + * + * @param row The row to convert + */ + def write(row: InternalRow, writeHeader: Boolean): Unit = { + if (writeHeader) { + gen.writeHeaders() + } + gen.writeRow(convertRow(row): _*) + } + + def close(): Unit = gen.close() + + def flush(): Unit = gen.flush() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala new file mode 100644 index 000000000000..5256d38929ad --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.math.BigDecimal +import java.text.NumberFormat +import java.util.Locale + +import scala.util.Try +import scala.util.control.NonFatal + +import com.univocity.parsers.csv.CsvParser + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +private[csv] class UnivocityParser( + schema: StructType, + requiredSchema: StructType, + options: CSVOptions) extends Logging { + def this(schema: StructType, options: CSVOptions) = this(schema, schema, options) + + val valueConverters = makeConverters(schema, options) + val parser = new CsvParser(options.asParserSettings) + + // A `ValueConverter` is responsible for converting the given value to a desired type. + private type ValueConverter = String => Any + + var numMalformedRecords = 0 + val row = new GenericInternalRow(requiredSchema.length) + val indexArr: Array[Int] = { + val fields = if (options.dropMalformed) { + // If `dropMalformed` is enabled, then it needs to parse all the values + // so that we can decide which row is malformed. + requiredSchema ++ schema.filterNot(requiredSchema.contains(_)) + } else { + requiredSchema + } + fields.filter(schema.contains).map(schema.indexOf).toArray + } + + /** + * Create converters which cast each given string datum to each specified type in given schema. + * Currently, we do not support complex types (`ArrayType`, `MapType`, `StructType`). + * + * For string types, this is simply the datum. + * For other types, this is converted into the value according to the type. + * For other nullable types, returns null if it is null or equals to the value specified + * in `nullValue` option. + * + * @param schema schema that contains data types to cast the given value into. + * @param options CSV options. + */ + private def makeConverters( + schema: StructType, + options: CSVOptions = CSVOptions()): Array[ValueConverter] = { + schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray + } + + /** + * Create a converter which converts the string value to a value according to a desired type. + */ + def makeConverter( + name: String, + dataType: DataType, + nullable: Boolean = true, + options: CSVOptions = CSVOptions()): ValueConverter = dataType match { + case _: ByteType => (d: String) => + nullSafeDatum(d, name, nullable, options)(_.toByte) + + case _: ShortType => (d: String) => + nullSafeDatum(d, name, nullable, options)(_.toShort) + + case _: IntegerType => (d: String) => + nullSafeDatum(d, name, nullable, options)(_.toInt) + + case _: LongType => (d: String) => + nullSafeDatum(d, name, nullable, options)(_.toLong) + + case _: FloatType => (d: String) => + nullSafeDatum(d, name, nullable, options) { + case options.nanValue => Float.NaN + case options.negativeInf => Float.NegativeInfinity + case options.positiveInf => Float.PositiveInfinity + case datum => + Try(datum.toFloat) + .getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).floatValue()) + } + + case _: DoubleType => (d: String) => + nullSafeDatum(d, name, nullable, options) { + case options.nanValue => Double.NaN + case options.negativeInf => Double.NegativeInfinity + case options.positiveInf => Double.PositiveInfinity + case datum => + Try(datum.toDouble) + .getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).doubleValue()) + } + + case _: BooleanType => (d: String) => + nullSafeDatum(d, name, nullable, options)(_.toBoolean) + + case dt: DecimalType => (d: String) => + nullSafeDatum(d, name, nullable, options) { datum => + val value = new BigDecimal(datum.replaceAll(",", "")) + Decimal(value, dt.precision, dt.scale) + } + + case _: TimestampType => (d: String) => + nullSafeDatum(d, name, nullable, options) { datum => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681. + Try(options.timestampFormat.parse(datum).getTime * 1000L) + .getOrElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + DateTimeUtils.stringToTime(datum).getTime * 1000L + } + } + + case _: DateType => (d: String) => + nullSafeDatum(d, name, nullable, options) { datum => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681.x + Try(DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime)) + .getOrElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime) + } + } + + case _: StringType => (d: String) => + nullSafeDatum(d, name, nullable, options)(UTF8String.fromString(_)) + + case udt: UserDefinedType[_] => (datum: String) => + makeConverter(name, udt.sqlType, nullable, options) + + case _ => throw new RuntimeException(s"Unsupported type: ${dataType.typeName}") + } + + private def nullSafeDatum( + datum: String, + name: String, + nullable: Boolean, + options: CSVOptions)(converter: ValueConverter): Any = { + if (datum == options.nullValue || datum == null) { + if (!nullable) { + throw new RuntimeException(s"null value found but field $name is not nullable.") + } + null + } else { + converter.apply(datum) + } + } + + /** + * Parses a single CSV record (in the form of an array of strings in which + * each element represents a column) and turns it into either one resulting row or no row (if the + * the record is malformed). + */ + def parse(input: String): Option[InternalRow] = { + withParseMode(parser.parseLine(input)) { tokens => + var i: Int = 0 + while (i < indexArr.length) { + val pos = indexArr(i) + // It anyway needs to try to parse since it decides if this row is malformed + // or not after trying to cast in `DROPMALFORMED` mode even if the casted + // value is not stored in the row. + val value = valueConverters(pos).apply(tokens(pos)) + if (i < requiredSchema.length) { + row(i) = value + } + i += 1 + } + row + } + } + + private def withParseMode( + tokens: Array[String])(convert: Array[String] => InternalRow): Option[InternalRow] = { + if (options.dropMalformed && schema.length != tokens.length) { + if (numMalformedRecords < options.maxMalformedLogPerPartition) { + logWarning(s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}") + } + if (numMalformedRecords == options.maxMalformedLogPerPartition - 1) { + logWarning( + s"More than ${options.maxMalformedLogPerPartition} malformed records have been " + + "found on this partition. Malformed records from now on will not be logged.") + } + numMalformedRecords += 1 + None + } else if (options.failFast && schema.length != tokens.length) { + throw new RuntimeException(s"Malformed line in FAILFAST mode: " + + s"${tokens.mkString(options.delimiter.toString)}") + } else { + val checkedTokens = if (options.permissive && schema.length > tokens.length) { + tokens ++ new Array[String](schema.length - tokens.length) + } else if (options.permissive && schema.length < tokens.length) { + tokens.take(schema.length) + } else { + tokens + } + + try { + Some(convert(checkedTokens)) + } catch { + case NonFatal(e) if options.dropMalformed => + if (numMalformedRecords < options.maxMalformedLogPerPartition) { + logWarning("Parse exception. " + + s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}") + } + if (numMalformedRecords == options.maxMalformedLogPerPartition - 1) { + logWarning( + s"More than ${options.maxMalformedLogPerPartition} malformed records have been " + + "found on this partition. Malformed records from now on will not be logged.") + } + numMalformedRecords += 1 + None + } + } + } +} + +private[csv] object UnivocityParser { + /** + * Filter ignorable rows for CSV (lines empty and starting with `comment`). + */ + def filterCommentAndEmpty(iter: Iterator[String], options: CSVOptions): Iterator[String] = { + val commentPrefix = options.comment.toString + iter.filter { line => + line.trim.nonEmpty && !line.startsWith(commentPrefix) + } + } + + /** + * Drop header line so that only data can remain. This drops the first line in a iterator. + */ + def dropHeaderLine(iter: Iterator[String], options: CSVOptions): Iterator[String] = { + val nonEmptyLines = if (options.isCommentSet) { + val commentPrefix = options.comment.toString + iter.dropWhile { line => + line.trim.isEmpty || line.trim.startsWith(commentPrefix) + } + } else { + iter.dropWhile(_.trim.isEmpty) + } + + if (nonEmptyLines.hasNext) nonEmptyLines.drop(1) + iter + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala index ffd3d260bcb4..add3d3f6d3a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala @@ -27,6 +27,9 @@ import org.apache.spark.unsafe.types.UTF8String class CSVTypeCastSuite extends SparkFunSuite { + private val parser = + new UnivocityParser(StructType(Seq.empty), new CSVOptions(Map.empty[String, String])) + private def assertNull(v: Any) = assert(v == null) test("Can parse decimal type values") { @@ -36,7 +39,7 @@ class CSVTypeCastSuite extends SparkFunSuite { stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) => val decimalValue = new BigDecimal(decimalVal.toString) - assert(CSVTypeCast.makeConverter("_1", decimalType).apply(strVal) === + assert(parser.makeConverter("_1", decimalType).apply(strVal) === Decimal(decimalValue, decimalType.precision, decimalType.scale)) } } @@ -73,19 +76,19 @@ class CSVTypeCastSuite extends SparkFunSuite { types.foreach { t => // Tests that a custom nullValue. val converter = - CSVTypeCast.makeConverter("_1", t, nullable = true, CSVOptions("nullValue", "-")) + parser.makeConverter("_1", t, nullable = true, CSVOptions("nullValue", "-")) assertNull(converter.apply("-")) assertNull(converter.apply(null)) // Tests that the default nullValue is empty string. - assertNull(CSVTypeCast.makeConverter("_1", t, nullable = true).apply("")) + assertNull(parser.makeConverter("_1", t, nullable = true).apply("")) } // Not nullable field with nullValue option. types.foreach { t => // Casts a null to not nullable field should throw an exception. val converter = - CSVTypeCast.makeConverter("_1", t, nullable = false, CSVOptions("nullValue", "-")) + parser.makeConverter("_1", t, nullable = false, CSVOptions("nullValue", "-")) var message = intercept[RuntimeException] { converter.apply("-") }.getMessage @@ -100,32 +103,32 @@ class CSVTypeCastSuite extends SparkFunSuite { // null. Seq(true, false).foreach { b => val converter = - CSVTypeCast.makeConverter("_1", StringType, nullable = b, CSVOptions("nullValue", "null")) + parser.makeConverter("_1", StringType, nullable = b, CSVOptions("nullValue", "null")) assert(converter.apply("") == UTF8String.fromString("")) } } test("Throws exception for empty string with non null type") { val exception = intercept[RuntimeException]{ - CSVTypeCast.makeConverter("_1", IntegerType, nullable = false, CSVOptions()).apply("") + parser.makeConverter("_1", IntegerType, nullable = false, CSVOptions()).apply("") } assert(exception.getMessage.contains("null value found but field _1 is not nullable.")) } test("Types are cast correctly") { - assert(CSVTypeCast.makeConverter("_1", ByteType).apply("10") == 10) - assert(CSVTypeCast.makeConverter("_1", ShortType).apply("10") == 10) - assert(CSVTypeCast.makeConverter("_1", IntegerType).apply("10") == 10) - assert(CSVTypeCast.makeConverter("_1", LongType).apply("10") == 10) - assert(CSVTypeCast.makeConverter("_1", FloatType).apply("1.00") == 1.0) - assert(CSVTypeCast.makeConverter("_1", DoubleType).apply("1.00") == 1.0) - assert(CSVTypeCast.makeConverter("_1", BooleanType).apply("true") == true) + assert(parser.makeConverter("_1", ByteType).apply("10") == 10) + assert(parser.makeConverter("_1", ShortType).apply("10") == 10) + assert(parser.makeConverter("_1", IntegerType).apply("10") == 10) + assert(parser.makeConverter("_1", LongType).apply("10") == 10) + assert(parser.makeConverter("_1", FloatType).apply("1.00") == 1.0) + assert(parser.makeConverter("_1", DoubleType).apply("1.00") == 1.0) + assert(parser.makeConverter("_1", BooleanType).apply("true") == true) val timestampsOptions = CSVOptions("timestampFormat", "dd/MM/yyyy hh:mm") val customTimestamp = "31/01/2015 00:00" val expectedTime = timestampsOptions.timestampFormat.parse(customTimestamp).getTime val castedTimestamp = - CSVTypeCast.makeConverter("_1", TimestampType, nullable = true, timestampsOptions) + parser.makeConverter("_1", TimestampType, nullable = true, timestampsOptions) .apply(customTimestamp) assert(castedTimestamp == expectedTime * 1000L) @@ -133,14 +136,14 @@ class CSVTypeCastSuite extends SparkFunSuite { val dateOptions = CSVOptions("dateFormat", "dd/MM/yyyy") val expectedDate = dateOptions.dateFormat.parse(customDate).getTime val castedDate = - CSVTypeCast.makeConverter("_1", DateType, nullable = true, dateOptions) + parser.makeConverter("_1", DateType, nullable = true, dateOptions) .apply(customTimestamp) assert(castedDate == DateTimeUtils.millisToDays(expectedDate)) val timestamp = "2015-01-01 00:00:00" - assert(CSVTypeCast.makeConverter("_1", TimestampType).apply(timestamp) == + assert(parser.makeConverter("_1", TimestampType).apply(timestamp) == DateTimeUtils.stringToTime(timestamp).getTime * 1000L) - assert(CSVTypeCast.makeConverter("_1", DateType).apply("2015-01-01") == + assert(parser.makeConverter("_1", DateType).apply("2015-01-01") == DateTimeUtils.millisToDays(DateTimeUtils.stringToTime("2015-01-01").getTime)) } @@ -149,15 +152,15 @@ class CSVTypeCastSuite extends SparkFunSuite { try { Locale.setDefault(new Locale("fr", "FR")) // Would parse as 1.0 in fr-FR - assert(CSVTypeCast.makeConverter("_1", FloatType).apply("1,00") == 100.0) - assert(CSVTypeCast.makeConverter("_1", DoubleType).apply("1,00") == 100.0) + assert(parser.makeConverter("_1", FloatType).apply("1,00") == 100.0) + assert(parser.makeConverter("_1", DoubleType).apply("1,00") == 100.0) } finally { Locale.setDefault(originalLocale) } } test("Float NaN values are parsed correctly") { - val floatVal: Float = CSVTypeCast.makeConverter( + val floatVal: Float = parser.makeConverter( "_1", FloatType, nullable = true, CSVOptions("nanValue", "nn") ).apply("nn").asInstanceOf[Float] @@ -167,7 +170,7 @@ class CSVTypeCastSuite extends SparkFunSuite { } test("Double NaN values are parsed correctly") { - val doubleVal: Double = CSVTypeCast.makeConverter( + val doubleVal: Double = parser.makeConverter( "_1", DoubleType, nullable = true, CSVOptions("nanValue", "-") ).apply("-").asInstanceOf[Double] @@ -175,13 +178,13 @@ class CSVTypeCastSuite extends SparkFunSuite { } test("Float infinite values can be parsed") { - val floatVal1 = CSVTypeCast.makeConverter( + val floatVal1 = parser.makeConverter( "_1", FloatType, nullable = true, CSVOptions("negativeInf", "max") ).apply("max").asInstanceOf[Float] assert(floatVal1 == Float.NegativeInfinity) - val floatVal2 = CSVTypeCast.makeConverter( + val floatVal2 = parser.makeConverter( "_1", FloatType, nullable = true, CSVOptions("positiveInf", "max") ).apply("max").asInstanceOf[Float] @@ -189,13 +192,13 @@ class CSVTypeCastSuite extends SparkFunSuite { } test("Double infinite values can be parsed") { - val doubleVal1 = CSVTypeCast.makeConverter( + val doubleVal1 = parser.makeConverter( "_1", DoubleType, nullable = true, CSVOptions("negativeInf", "max") ).apply("max").asInstanceOf[Double] assert(doubleVal1 == Double.NegativeInfinity) - val doubleVal2 = CSVTypeCast.makeConverter( + val doubleVal2 = parser.makeConverter( "_1", DoubleType, nullable = true, CSVOptions("positiveInf", "max") ).apply("max").asInstanceOf[Double]