diff --git a/docs/sql-data-sources-csv.md b/docs/sql-data-sources-csv.md index 2fe8f77ff667..a6b093f7e413 100644 --- a/docs/sql-data-sources-csv.md +++ b/docs/sql-data-sources-csv.md @@ -248,5 +248,11 @@ Data source options of CSV can be set via: Compression codec to use when saving to file. This can be one of the known case-insensitive shorten names (none, bzip2, gzip, lz4, snappy and deflate). write + + inferDateType + false + Infers dateFormat for the CSV. If this is not set, it uses the default value, false. + read + Other generic options can be found in Generic File Source Options. diff --git a/docs/sql-data-sources-json.md b/docs/sql-data-sources-json.md index 041512918e61..cb4d69c64845 100644 --- a/docs/sql-data-sources-json.md +++ b/docs/sql-data-sources-json.md @@ -155,6 +155,12 @@ Data source options of JSON can be set via: Allows leading zeros in numbers (e.g. 00012). If None is set, it uses the default value, false. read + + inferDateType + false + Infers dateFormat for the JSON. If this is not set, it uses the default value, false. + read + allowBackslashEscapingAnyCharacter None diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index f9e37341dcd6..f639a563ec0b 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -353,7 +353,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None, lineSep=None, pathGlobFilter=None, recursiveFileLookup=None, modifiedBefore=None, modifiedAfter=None, - unescapedQuoteHandling=None): + unescapedQuoteHandling=None, inferDateType=None): r"""Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -403,7 +403,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale, lineSep=lineSep, pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup, modifiedBefore=modifiedBefore, modifiedAfter=modifiedAfter, - unescapedQuoteHandling=unescapedQuoteHandling) + unescapedQuoteHandling=unescapedQuoteHandling, inferDateType=inferDateType) if isinstance(path, str): path = [path] if type(path) == list: diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 08c8934fbf03..9d2f3d2cdc1c 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -636,7 +636,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, enforceSchema=None, emptyValue=None, locale=None, lineSep=None, - pathGlobFilter=None, recursiveFileLookup=None, unescapedQuoteHandling=None): + pathGlobFilter=None, recursiveFileLookup=None, unescapedQuoteHandling=None, + inferDateType=None): r"""Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -686,7 +687,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale, lineSep=lineSep, pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup, - unescapedQuoteHandling=unescapedQuoteHandling) + unescapedQuoteHandling=unescapedQuoteHandling, inferDateType=inferDateType) if isinstance(path, str): return self._df(self._jreader.csv(path)) else: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index ef3d038cf289..946ff1674d16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -24,8 +24,8 @@ import scala.util.control.Exception.allCatch import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.expressions.ExprUtils +import org.apache.spark.sql.catalyst.util.{DateFormatter, LegacyFastDateFormatter, TimestampFormatter} import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT -import org.apache.spark.sql.catalyst.util.TimestampFormatter import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ @@ -38,6 +38,12 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { legacyFormat = FAST_DATE_FORMAT, isParsing = true) + private val dateFormatter = DateFormatter( + options.dateFormat, + options.locale, + legacyFormat = FAST_DATE_FORMAT, + isParsing = true) + private val decimalParser = if (options.locale == Locale.US) { // Special handling the default locale for backward compatibility s: String => new java.math.BigDecimal(s) @@ -109,6 +115,7 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { case LongType => tryParseLong(field) case _: DecimalType => tryParseDecimal(field) case DoubleType => tryParseDouble(field) + case DateType => tryParseDateFormat(field) case TimestampType => tryParseTimestamp(field) case BooleanType => tryParseBoolean(field) case StringType => StringType @@ -160,6 +167,16 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { private def tryParseDouble(field: String): DataType = { if ((allCatch opt field.toDouble).isDefined || isInfOrNan(field)) { DoubleType + } else { + tryParseDateFormat(field) + } + } + + private def tryParseDateFormat(field: String): DataType = { + if (options.inferDateType + && !dateFormatter.isInstanceOf[LegacyFastDateFormatter] + && (allCatch opt dateFormatter.parse(field)).isDefined) { + DateType } else { tryParseTimestamp(field) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala index 9aa4bf43898a..f5307d294127 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala @@ -206,6 +206,12 @@ class CSVOptions( sep } + /** + * option to infer date Type in the schema + */ + val inferDateType = + parameters.get("inferDateType").map(_.toBoolean).getOrElse(false) + val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => lineSep.getBytes(charset) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 47be83a41d61..dd43e9eb2a4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -68,6 +68,8 @@ private[sql] class JSONOptions( parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true) val allowBackslashEscapingAnyCharacter = parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) + val inferDateType = + parameters.get("inferDateType").map(_.toBoolean).getOrElse(false) private val allowUnquotedControlChars = parameters.get("allowUnquotedControlChars").map(_.toBoolean).getOrElse(false) val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index eebb9a404257..8cc2c3900dd6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -45,6 +45,12 @@ private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable { legacyFormat = FAST_DATE_FORMAT, isParsing = true) + private val dateFormatter = DateFormatter( + options.dateFormat, + options.locale, + legacyFormat = FAST_DATE_FORMAT, + isParsing = true) + /** * Infer the type of a collection of json records in three stages: * 1. Infer the type of each record @@ -127,6 +133,10 @@ private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable { } if (options.prefersDecimal && decimalTry.isDefined) { decimalTry.get + } else if (options.inferDateType + && !dateFormatter.isInstanceOf[LegacyFastDateFormatter] && + (allCatch opt dateFormatter.parse(field)).isDefined) { + DateType } else if (options.inferTimestamp && (allCatch opt timestampFormatter.parse(field)).isDefined) { TimestampType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala index d268f8c2e721..198461268ede 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala @@ -192,4 +192,27 @@ class CSVInferSchemaSuite extends SparkFunSuite with SQLHelper { Seq("en-US").foreach(checkDecimalInfer(_, StringType)) Seq("ko-KR", "ru-RU", "de-DE").foreach(checkDecimalInfer(_, DecimalType(7, 0))) } + + test("SPARK-34953: DateType should be inferred when user defined format are provided") { + Seq(true, false).foreach { inferDateType => + val options = new CSVOptions(Map("dateFormat" -> "dd-MM-yyyy", + "inferSchema" -> "true", "inferDateType" -> inferDateType.toString), false, "UTC") + val inferSchema = new CSVInferSchema(options) + + val inferredDataType = if (inferDateType) { DateType } else { StringType } + assert(inferSchema.inferField(NullType, "21-10-2021") == inferredDataType) + assert(inferSchema.inferField(NullType, "03.31.2021") == StringType) + } + + // For default type where dateFormat is not present in the option + Seq(true, false).foreach { inferDateType => + val options = new CSVOptions(Map("inferSchema" -> "true", + "inferDateType" -> inferDateType.toString), false, "UTC") + val inferSchema = new CSVInferSchema(options) + + val inferredDataType = if (inferDateType) { DateType } else { StringType } + assert(inferSchema.inferField(NullType, "2021-10-05") == inferredDataType) + assert(inferSchema.inferField(NullType, "03.31.2021") == StringType) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala index 8290b38e3393..7d6dee5cbb47 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala @@ -112,4 +112,14 @@ class JsonInferSchemaSuite extends SparkFunSuite with SQLHelper { checkType(Map("inferTimestamp" -> "true"), json, TimestampType) checkType(Map("inferTimestamp" -> "false"), json, StringType) } + + test("SPARK-34953: Allow DateType format while inferring") { + val json = """{"a": "29-01-2020"}""" + Seq(true, false).foreach { inferDateType => + checkType(Map("dateFormat" -> "dd-MM-yyyy", "inferDateType" -> inferDateType.toString), + json, dt = if (inferDateType) DateType else StringType) + checkType(Map("dateFormat" -> "yyyy.MM.dd", "inferDateType" -> inferDateType.toString), + json, StringType) + } + } }