Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
c567dcc
Add a test
MaxGekk Nov 8, 2018
2b41eba
Fix decimal parsing
MaxGekk Nov 8, 2018
cf438ae
Add locale option
MaxGekk Nov 8, 2018
f9438c4
Updating the migration guide
MaxGekk Nov 8, 2018
3125c23
Fix imports
MaxGekk Nov 8, 2018
64a97a2
Merge remote-tracking branch 'origin/master' into decimal-parsing-locale
MaxGekk Nov 9, 2018
2f76352
Renaming decimalParser to decimalFormat
MaxGekk Nov 11, 2018
3dfce18
Moving the test to UnivocityParserSuite
MaxGekk Nov 11, 2018
bdca7c4
Support the SQL config spark.sql.legacy.decimalParsing.enabled
MaxGekk Nov 12, 2018
8c5593e
Updating the migration guide.
MaxGekk Nov 12, 2018
18470b0
Refactoring
MaxGekk Nov 12, 2018
c28b79f
Removing internal
MaxGekk Nov 12, 2018
1723da2
Test refactoring
MaxGekk Nov 12, 2018
6cdafa5
Added a test for inferring the decimal type
MaxGekk Nov 13, 2018
14b5109
Inferring decimals from CSV
MaxGekk Nov 14, 2018
bab8fb2
Renaming df to decimalFormat
MaxGekk Nov 22, 2018
5236336
Merge remote-tracking branch 'origin/master' into decimal-parsing-locale
MaxGekk Nov 23, 2018
0d1a4f0
Merge branch 'master' into decimal-parsing-locale
MaxGekk Nov 27, 2018
8b1456c
Merge remote-tracking branch 'origin/master' into decimal-parsing-locale
MaxGekk Nov 28, 2018
0859624
Removing SQL config and special handling of Locale.US
MaxGekk Nov 28, 2018
e989b77
Merge remote-tracking branch 'fork/decimal-parsing-locale' into decim…
MaxGekk Nov 28, 2018
521bd45
Merge remote-tracking branch 'origin/master' into decimal-parsing-locale
MaxGekk Nov 29, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

package org.apache.spark.sql.catalyst.csv

import java.math.BigDecimal
import java.text.{DecimalFormat, DecimalFormatSymbols, ParsePosition}
import java.util.Locale

object CSVExprUtils {
/**
* Filter ignorable rows for CSV iterator (lines empty and starting with `comment`).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@

package org.apache.spark.sql.catalyst.csv

import java.math.BigDecimal

import scala.util.control.Exception.allCatch

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.expressions.ExprUtils
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._

object CSVInferSchema {
class CSVInferSchema(options: CSVOptions) extends Serializable {

private val decimalParser = {
ExprUtils.getDecimalParser(options.locale)
}

/**
* Similar to the JSON schema inference
Expand All @@ -36,14 +39,13 @@ object CSVInferSchema {
*/
def infer(
tokenRDD: RDD[Array[String]],
header: Array[String],
options: CSVOptions): StructType = {
header: Array[String]): StructType = {
val fields = if (options.inferSchemaFlag) {
val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
val rootTypes: Array[DataType] =
tokenRDD.aggregate(startType)(inferRowType(options), mergeRowTypes)
tokenRDD.aggregate(startType)(inferRowType, mergeRowTypes)

toStructFields(rootTypes, header, options)
toStructFields(rootTypes, header)
} else {
// By default fields are assumed to be StringType
header.map(fieldName => StructField(fieldName, StringType, nullable = true))
Expand All @@ -54,8 +56,7 @@ object CSVInferSchema {

def toStructFields(
fieldTypes: Array[DataType],
header: Array[String],
options: CSVOptions): Array[StructField] = {
header: Array[String]): Array[StructField] = {
header.zip(fieldTypes).map { case (thisHeader, rootType) =>
val dType = rootType match {
case _: NullType => StringType
Expand All @@ -65,11 +66,10 @@ object CSVInferSchema {
}
}

def inferRowType(options: CSVOptions)
(rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = {
def inferRowType(rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = {
var i = 0
while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing.
rowSoFar(i) = inferField(rowSoFar(i), next(i), options)
rowSoFar(i) = inferField(rowSoFar(i), next(i))
i+=1
}
rowSoFar
Expand All @@ -85,51 +85,51 @@ object CSVInferSchema {
* Infer type of string field. Given known type Double, and a string "1", there is no
* point checking if it is an Int, as the final type must be Double or higher.
*/
def inferField(typeSoFar: DataType, field: String, options: CSVOptions): DataType = {
def inferField(typeSoFar: DataType, field: String): DataType = {
if (field == null || field.isEmpty || field == options.nullValue) {
typeSoFar
} else {
typeSoFar match {
case NullType => tryParseInteger(field, options)
case IntegerType => tryParseInteger(field, options)
case LongType => tryParseLong(field, options)
case NullType => tryParseInteger(field)
case IntegerType => tryParseInteger(field)
case LongType => tryParseLong(field)
case _: DecimalType =>
// DecimalTypes have different precisions and scales, so we try to find the common type.
compatibleType(typeSoFar, tryParseDecimal(field, options)).getOrElse(StringType)
case DoubleType => tryParseDouble(field, options)
case TimestampType => tryParseTimestamp(field, options)
case BooleanType => tryParseBoolean(field, options)
compatibleType(typeSoFar, tryParseDecimal(field)).getOrElse(StringType)
case DoubleType => tryParseDouble(field)
case TimestampType => tryParseTimestamp(field)
case BooleanType => tryParseBoolean(field)
case StringType => StringType
case other: DataType =>
throw new UnsupportedOperationException(s"Unexpected data type $other")
}
}
}

private def isInfOrNan(field: String, options: CSVOptions): Boolean = {
private def isInfOrNan(field: String): Boolean = {
field == options.nanValue || field == options.negativeInf || field == options.positiveInf
}

private def tryParseInteger(field: String, options: CSVOptions): DataType = {
private def tryParseInteger(field: String): DataType = {
if ((allCatch opt field.toInt).isDefined) {
IntegerType
} else {
tryParseLong(field, options)
tryParseLong(field)
}
}

private def tryParseLong(field: String, options: CSVOptions): DataType = {
private def tryParseLong(field: String): DataType = {
if ((allCatch opt field.toLong).isDefined) {
LongType
} else {
tryParseDecimal(field, options)
tryParseDecimal(field)
}
}

private def tryParseDecimal(field: String, options: CSVOptions): DataType = {
private def tryParseDecimal(field: String): DataType = {
val decimalTry = allCatch opt {
// `BigDecimal` conversion can fail when the `field` is not a form of number.
val bigDecimal = new BigDecimal(field)
// The conversion can fail when the `field` is not a form of number.
val bigDecimal = decimalParser(field)
// Because many other formats do not support decimal, it reduces the cases for
// decimals by disallowing values having scale (eg. `1.1`).
if (bigDecimal.scale <= 0) {
Expand All @@ -138,33 +138,33 @@ object CSVInferSchema {
// 2. scale is bigger than precision.
DecimalType(bigDecimal.precision, bigDecimal.scale)
} else {
tryParseDouble(field, options)
tryParseDouble(field)
}
}
decimalTry.getOrElse(tryParseDouble(field, options))
decimalTry.getOrElse(tryParseDouble(field))
}

private def tryParseDouble(field: String, options: CSVOptions): DataType = {
if ((allCatch opt field.toDouble).isDefined || isInfOrNan(field, options)) {
private def tryParseDouble(field: String): DataType = {
if ((allCatch opt field.toDouble).isDefined || isInfOrNan(field)) {
DoubleType
} else {
tryParseTimestamp(field, options)
tryParseTimestamp(field)
}
}

private def tryParseTimestamp(field: String, options: CSVOptions): DataType = {
private def tryParseTimestamp(field: String): DataType = {
// This case infers a custom `dataFormat` is set.
if ((allCatch opt options.timestampFormat.parse(field)).isDefined) {
TimestampType
} else if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) {
// We keep this for backwards compatibility.
TimestampType
} else {
tryParseBoolean(field, options)
tryParseBoolean(field)
}
}

private def tryParseBoolean(field: String, options: CSVOptions): DataType = {
private def tryParseBoolean(field: String): DataType = {
if ((allCatch opt field.toBoolean).isDefined) {
BooleanType
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql.catalyst.csv

import java.io.InputStream
import java.math.BigDecimal

import scala.util.Try
import scala.util.control.NonFatal
Expand All @@ -27,7 +26,7 @@ import com.univocity.parsers.csv.CsvParser

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericInternalRow}
import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils, FailureSafeParser}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -104,6 +103,8 @@ class UnivocityParser(
requiredSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray
}

private val decimalParser = ExprUtils.getDecimalParser(options.locale)

/**
* Create a converter which converts the string value to a value according to a desired type.
* Currently, we do not support complex types (`ArrayType`, `MapType`, `StructType`).
Expand Down Expand Up @@ -149,8 +150,7 @@ class UnivocityParser(

case dt: DecimalType => (d: String) =>
nullSafeDatum(d, name, nullable, options) { datum =>
val value = new BigDecimal(datum.replaceAll(",", ""))
Decimal(value, dt.precision, dt.scale)
Decimal(decimalParser(datum), dt.precision, dt.scale)
}

case _: TimestampType => (d: String) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,9 @@ case class SchemaOfCsv(

val header = row.zipWithIndex.map { case (_, index) => s"_c$index" }
val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
val fieldTypes = CSVInferSchema.inferRowType(parsedOptions)(startType, row)
val st = StructType(CSVInferSchema.toStructFields(fieldTypes, header, parsedOptions))
val inferSchema = new CSVInferSchema(parsedOptions)
val fieldTypes = inferSchema.inferRowType(startType, row)
val st = StructType(inferSchema.toStructFields(fieldTypes, header))
UTF8String.fromString(st.catalogString)
}

Expand Down
Loading