diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index a585cbed2551..3596ff105fd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -70,7 +70,7 @@ private[csv] object CSVInferSchema { def mergeRowTypes(first: Array[DataType], second: Array[DataType]): Array[DataType] = { first.zipAll(second, NullType, NullType).map { case (a, b) => - findTightestCommonType(a, b).getOrElse(NullType) + compatibleType(a, b).getOrElse(NullType) } } @@ -88,7 +88,7 @@ private[csv] object CSVInferSchema { case LongType => tryParseLong(field, options) case _: DecimalType => // DecimalTypes have different precisions and scales, so we try to find the common type. - findTightestCommonType(typeSoFar, tryParseDecimal(field, options)).getOrElse(StringType) + compatibleType(typeSoFar, tryParseDecimal(field, options)).getOrElse(StringType) case DoubleType => tryParseDouble(field, options) case TimestampType => tryParseTimestamp(field, options) case BooleanType => tryParseBoolean(field, options) @@ -172,35 +172,27 @@ private[csv] object CSVInferSchema { StringType } - private val numericPrecedence: IndexedSeq[DataType] = TypeCoercion.numericPrecedence + /** + * Returns the common data type given two input data types so that the return type + * is compatible with both input data types. + */ + private def compatibleType(t1: DataType, t2: DataType): Option[DataType] = { + TypeCoercion.findTightestCommonType(t1, t2).orElse(findCompatibleTypeForCSV(t1, t2)) + } /** - * Copied from internal Spark api - * [[org.apache.spark.sql.catalyst.analysis.TypeCoercion]] + * The following pattern matching represents additional type promotion rules that + * are CSV specific. */ - val findTightestCommonType: (DataType, DataType) => Option[DataType] = { - case (t1, t2) if t1 == t2 => Some(t1) - case (NullType, t1) => Some(t1) - case (t1, NullType) => Some(t1) + private val findCompatibleTypeForCSV: (DataType, DataType) => Option[DataType] = { case (StringType, t2) => Some(StringType) case (t1, StringType) => Some(StringType) - // Promote numeric types to the highest of the two and all numeric types to unlimited decimal - case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) => - val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2) - Some(numericPrecedence(index)) - - // These two cases below deal with when `DecimalType` is larger than `IntegralType`. - case (t1: IntegralType, t2: DecimalType) if t2.isWiderThan(t1) => - Some(t2) - case (t1: DecimalType, t2: IntegralType) if t1.isWiderThan(t2) => - Some(t1) - // These two cases below deal with when `IntegralType` is larger than `DecimalType`. case (t1: IntegralType, t2: DecimalType) => - findTightestCommonType(DecimalType.forType(t1), t2) + compatibleType(DecimalType.forType(t1), t2) case (t1: DecimalType, t2: IntegralType) => - findTightestCommonType(t1, DecimalType.forType(t2)) + compatibleType(t1, DecimalType.forType(t2)) // Double support larger range than fixed decimal, DecimalType.Maximum should be enough // in most case, also have better precision. @@ -216,7 +208,6 @@ private[csv] object CSVInferSchema { } else { Some(DecimalType(range + scale, scale)) } - case _ => None } }