diff --git a/docs/sql-data-sources-csv.md b/docs/sql-data-sources-csv.md
index 2fe8f77ff667..a6b093f7e413 100644
--- a/docs/sql-data-sources-csv.md
+++ b/docs/sql-data-sources-csv.md
@@ -248,5 +248,11 @@ Data source options of CSV can be set via:
Compression codec to use when saving to file. This can be one of the known case-insensitive shorten names (none, bzip2, gzip, lz4, snappy and deflate). |
write |
+
+ inferDateType |
+ false |
+ Infers dateFormat for the CSV. If this is not set, it uses the default value, false. |
+ read |
+
Other generic options can be found in Generic File Source Options.
diff --git a/docs/sql-data-sources-json.md b/docs/sql-data-sources-json.md
index 041512918e61..cb4d69c64845 100644
--- a/docs/sql-data-sources-json.md
+++ b/docs/sql-data-sources-json.md
@@ -155,6 +155,12 @@ Data source options of JSON can be set via:
Allows leading zeros in numbers (e.g. 00012). If None is set, it uses the default value, false. |
read |
+
+ inferDateType |
+ false |
+ Infers dateFormat for the JSON. If this is not set, it uses the default value, false. |
+ read |
+
allowBackslashEscapingAnyCharacter |
None |
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index f9e37341dcd6..f639a563ec0b 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -353,7 +353,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None, lineSep=None,
pathGlobFilter=None, recursiveFileLookup=None, modifiedBefore=None, modifiedAfter=None,
- unescapedQuoteHandling=None):
+ unescapedQuoteHandling=None, inferDateType=None):
r"""Loads a CSV file and returns the result as a :class:`DataFrame`.
This function will go through the input once to determine the input schema if
@@ -403,7 +403,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale, lineSep=lineSep,
pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup,
modifiedBefore=modifiedBefore, modifiedAfter=modifiedAfter,
- unescapedQuoteHandling=unescapedQuoteHandling)
+ unescapedQuoteHandling=unescapedQuoteHandling, inferDateType=inferDateType)
if isinstance(path, str):
path = [path]
if type(path) == list:
diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
index 08c8934fbf03..9d2f3d2cdc1c 100644
--- a/python/pyspark/sql/streaming.py
+++ b/python/pyspark/sql/streaming.py
@@ -636,7 +636,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
enforceSchema=None, emptyValue=None, locale=None, lineSep=None,
- pathGlobFilter=None, recursiveFileLookup=None, unescapedQuoteHandling=None):
+ pathGlobFilter=None, recursiveFileLookup=None, unescapedQuoteHandling=None,
+ inferDateType=None):
r"""Loads a CSV file stream and returns the result as a :class:`DataFrame`.
This function will go through the input once to determine the input schema if
@@ -686,7 +687,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema,
emptyValue=emptyValue, locale=locale, lineSep=lineSep,
pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup,
- unescapedQuoteHandling=unescapedQuoteHandling)
+ unescapedQuoteHandling=unescapedQuoteHandling, inferDateType=inferDateType)
if isinstance(path, str):
return self._df(self._jreader.csv(path))
else:
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala
index ef3d038cf289..946ff1674d16 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala
@@ -24,8 +24,8 @@ import scala.util.control.Exception.allCatch
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.expressions.ExprUtils
+import org.apache.spark.sql.catalyst.util.{DateFormatter, LegacyFastDateFormatter, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT
-import org.apache.spark.sql.catalyst.util.TimestampFormatter
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
@@ -38,6 +38,12 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable {
legacyFormat = FAST_DATE_FORMAT,
isParsing = true)
+ private val dateFormatter = DateFormatter(
+ options.dateFormat,
+ options.locale,
+ legacyFormat = FAST_DATE_FORMAT,
+ isParsing = true)
+
private val decimalParser = if (options.locale == Locale.US) {
// Special handling the default locale for backward compatibility
s: String => new java.math.BigDecimal(s)
@@ -109,6 +115,7 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable {
case LongType => tryParseLong(field)
case _: DecimalType => tryParseDecimal(field)
case DoubleType => tryParseDouble(field)
+ case DateType => tryParseDateFormat(field)
case TimestampType => tryParseTimestamp(field)
case BooleanType => tryParseBoolean(field)
case StringType => StringType
@@ -160,6 +167,16 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable {
private def tryParseDouble(field: String): DataType = {
if ((allCatch opt field.toDouble).isDefined || isInfOrNan(field)) {
DoubleType
+ } else {
+ tryParseDateFormat(field)
+ }
+ }
+
+ private def tryParseDateFormat(field: String): DataType = {
+ if (options.inferDateType
+ && !dateFormatter.isInstanceOf[LegacyFastDateFormatter]
+ && (allCatch opt dateFormatter.parse(field)).isDefined) {
+ DateType
} else {
tryParseTimestamp(field)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala
index 9aa4bf43898a..f5307d294127 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala
@@ -206,6 +206,12 @@ class CSVOptions(
sep
}
+ /**
+ * option to infer date Type in the schema
+ */
+ val inferDateType =
+ parameters.get("inferDateType").map(_.toBoolean).getOrElse(false)
+
val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep =>
lineSep.getBytes(charset)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
index 47be83a41d61..dd43e9eb2a4e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
@@ -68,6 +68,8 @@ private[sql] class JSONOptions(
parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true)
val allowBackslashEscapingAnyCharacter =
parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false)
+ val inferDateType =
+ parameters.get("inferDateType").map(_.toBoolean).getOrElse(false)
private val allowUnquotedControlChars =
parameters.get("allowUnquotedControlChars").map(_.toBoolean).getOrElse(false)
val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
index eebb9a404257..8cc2c3900dd6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
@@ -45,6 +45,12 @@ private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable {
legacyFormat = FAST_DATE_FORMAT,
isParsing = true)
+ private val dateFormatter = DateFormatter(
+ options.dateFormat,
+ options.locale,
+ legacyFormat = FAST_DATE_FORMAT,
+ isParsing = true)
+
/**
* Infer the type of a collection of json records in three stages:
* 1. Infer the type of each record
@@ -127,6 +133,10 @@ private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable {
}
if (options.prefersDecimal && decimalTry.isDefined) {
decimalTry.get
+ } else if (options.inferDateType
+ && !dateFormatter.isInstanceOf[LegacyFastDateFormatter] &&
+ (allCatch opt dateFormatter.parse(field)).isDefined) {
+ DateType
} else if (options.inferTimestamp &&
(allCatch opt timestampFormatter.parse(field)).isDefined) {
TimestampType
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala
index d268f8c2e721..198461268ede 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala
@@ -192,4 +192,27 @@ class CSVInferSchemaSuite extends SparkFunSuite with SQLHelper {
Seq("en-US").foreach(checkDecimalInfer(_, StringType))
Seq("ko-KR", "ru-RU", "de-DE").foreach(checkDecimalInfer(_, DecimalType(7, 0)))
}
+
+ test("SPARK-34953: DateType should be inferred when user defined format are provided") {
+ Seq(true, false).foreach { inferDateType =>
+ val options = new CSVOptions(Map("dateFormat" -> "dd-MM-yyyy",
+ "inferSchema" -> "true", "inferDateType" -> inferDateType.toString), false, "UTC")
+ val inferSchema = new CSVInferSchema(options)
+
+ val inferredDataType = if (inferDateType) { DateType } else { StringType }
+ assert(inferSchema.inferField(NullType, "21-10-2021") == inferredDataType)
+ assert(inferSchema.inferField(NullType, "03.31.2021") == StringType)
+ }
+
+ // For default type where dateFormat is not present in the option
+ Seq(true, false).foreach { inferDateType =>
+ val options = new CSVOptions(Map("inferSchema" -> "true",
+ "inferDateType" -> inferDateType.toString), false, "UTC")
+ val inferSchema = new CSVInferSchema(options)
+
+ val inferredDataType = if (inferDateType) { DateType } else { StringType }
+ assert(inferSchema.inferField(NullType, "2021-10-05") == inferredDataType)
+ assert(inferSchema.inferField(NullType, "03.31.2021") == StringType)
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala
index 8290b38e3393..7d6dee5cbb47 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala
@@ -112,4 +112,14 @@ class JsonInferSchemaSuite extends SparkFunSuite with SQLHelper {
checkType(Map("inferTimestamp" -> "true"), json, TimestampType)
checkType(Map("inferTimestamp" -> "false"), json, StringType)
}
+
+ test("SPARK-34953: Allow DateType format while inferring") {
+ val json = """{"a": "29-01-2020"}"""
+ Seq(true, false).foreach { inferDateType =>
+ checkType(Map("dateFormat" -> "dd-MM-yyyy", "inferDateType" -> inferDateType.toString),
+ json, dt = if (inferDateType) DateType else StringType)
+ checkType(Map("dateFormat" -> "yyyy.MM.dd", "inferDateType" -> inferDateType.toString),
+ json, StringType)
+ }
+ }
}