Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,11 @@ object TypeCoercion {
case _ => None
}

/** Similar to [[findTightestCommonType]], but can promote all the way to StringType. */
def findTightestCommonTypeToString(left: DataType, right: DataType): Option[DataType] = {
findTightestCommonType(left, right).orElse((left, right) match {
case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType)
case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType)
case _ => None
})
/** Promotes all the way to StringType. */
private def stringPromotion(dt1: DataType, dt2: DataType): Option[DataType] = (dt1, dt2) match {
case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType)
case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType)
case _ => None
}

/**
Expand All @@ -117,49 +115,67 @@ object TypeCoercion {
* loss of precision when widening decimal and double, and promotion to string.
*/
private[analysis] def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = {
(t1, t2) match {
case (t1: DecimalType, t2: DecimalType) =>
Some(DecimalPrecision.widerDecimalType(t1, t2))
case (t: IntegralType, d: DecimalType) =>
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
case (d: DecimalType, t: IntegralType) =>
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) =>
Some(DoubleType)
case _ =>
findTightestCommonTypeToString(t1, t2)
}
findTightestCommonType(t1, t2)
.orElse(findWiderTypeForDecimal(t1, t2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea we changed the order, but looks like it won't change the result. findWiderTypeForDecimal will always return a result for decimal type and numeric type, and if findTightestCommonType can return a result, findWiderTypeForDecimal will return the same result. So it doesn't matter if we run findTightestCommonType before or after it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Integer will be promoted to a wider Decimal anyway.

.orElse(stringPromotion(t1, t2))
.orElse((t1, t2) match {
case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2))
case _ => None
})
}

private def findWiderCommonType(types: Seq[DataType]) = {
private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = {
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case Some(d) => findWiderTypeForTwo(d, c)
case None => None
})
}

/**
* Similar to [[findWiderCommonType]] that can handle decimal types, but can't promote to
* Similar to [[findWiderTypeForTwo]] that can handle decimal types, but can't promote to
* string. If the wider decimal type exceeds system limitation, this rule will truncate
* the decimal type before return it.
*/
def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = {
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case Some(d) => findTightestCommonType(d, c).orElse((d, c) match {
case (t1: DecimalType, t2: DecimalType) =>
Some(DecimalPrecision.widerDecimalType(t1, t2))
case (t: IntegralType, d: DecimalType) =>
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
case (d: DecimalType, t: IntegralType) =>
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) =>
Some(DoubleType)
private[analysis] def findWiderTypeWithoutStringPromotionForTwo(
t1: DataType,
t2: DataType): Option[DataType] = {
findTightestCommonType(t1, t2)
.orElse(findWiderTypeForDecimal(t1, t2))
.orElse((t1, t2) match {
case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
findWiderTypeWithoutStringPromotionForTwo(et1, et2)
.map(ArrayType(_, containsNull1 || containsNull2))
case _ => None
})
}

def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = {
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c)
case None => None
})
}

/**
* Finds a wider type when one or both types are decimals. If the wider decimal type exceeds
* system limitation, this rule will truncate the decimal type. If a decimal and other fractional
* types are compared, returns a double type.
*/
private def findWiderTypeForDecimal(dt1: DataType, dt2: DataType): Option[DataType] = {
(dt1, dt2) match {
case (t1: DecimalType, t2: DecimalType) =>
Some(DecimalPrecision.widerDecimalType(t1, t2))
case (t: IntegralType, d: DecimalType) =>
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
case (d: DecimalType, t: IntegralType) =>
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) =>
Some(DoubleType)
case _ => None
}
}

private def haveSameType(exprs: Seq[Expression]): Boolean =
exprs.map(_.dataType).distinct.length == 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ class TypeCoercionSuite extends PlanTest {
// | NullType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType(38, 18) | DoubleType | IntegerType |
// | CalendarIntervalType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | CalendarIntervalType | X | X | X |
// +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+
// Note: ArrayType*, MapType*, StructType* are castable only when the internal child types also match; otherwise, not castable
// Note: MapType*, StructType* are castable only when the internal child types also match; otherwise, not castable.
// Note: ArrayType* is castable when the element type is castable according to the table.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we now support implicit cast of ArrayType via https://issues.apache.org/jira/browse/SPARK-18624.

// scalastyle:on line.size.limit

private def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = {
Expand Down Expand Up @@ -125,6 +126,20 @@ class TypeCoercionSuite extends PlanTest {
}
}

private def checkWidenType(
widenFunc: (DataType, DataType) => Option[DataType],
t1: DataType,
t2: DataType,
expected: Option[DataType]): Unit = {
var found = widenFunc(t1, t2)
assert(found == expected,
s"Expected $expected as wider common type for $t1 and $t2, found $found")
// Test both directions to make sure the widening is symmetric.
found = widenFunc(t2, t1)
assert(found == expected,
s"Expected $expected as wider common type for $t2 and $t1, found $found")
}

test("implicit type cast - ByteType") {
val checkedType = ByteType
checkTypeCasting(checkedType, castableTypes = numericTypes ++ Seq(StringType))
Expand Down Expand Up @@ -308,15 +323,8 @@ class TypeCoercionSuite extends PlanTest {
}

test("tightest common bound for types") {
def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) {
var found = TypeCoercion.findTightestCommonType(t1, t2)
assert(found == tightestCommon,
s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found")
// Test both directions to make sure the widening is symmetric.
found = TypeCoercion.findTightestCommonType(t2, t1)
assert(found == tightestCommon,
s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found")
}
def widenTest(t1: DataType, t2: DataType, expected: Option[DataType]): Unit =
checkWidenType(TypeCoercion.findTightestCommonType, t1, t2, expected)

// Null
widenTest(NullType, NullType, Some(NullType))
Expand Down Expand Up @@ -355,7 +363,6 @@ class TypeCoercionSuite extends PlanTest {
widenTest(DecimalType(2, 1), DoubleType, None)
widenTest(DecimalType(2, 1), IntegerType, None)
widenTest(DoubleType, DecimalType(2, 1), None)
widenTest(IntegerType, DecimalType(2, 1), None)

// StringType
widenTest(NullType, StringType, Some(StringType))
Expand All @@ -379,6 +386,60 @@ class TypeCoercionSuite extends PlanTest {
widenTest(ArrayType(IntegerType), StructType(Seq()), None)
}

test("wider common type for decimal and array") {
def widenTestWithStringPromotion(
t1: DataType,
t2: DataType,
expected: Option[DataType]): Unit = {
checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected)
}

def widenTestWithoutStringPromotion(
t1: DataType,
t2: DataType,
expected: Option[DataType]): Unit = {
checkWidenType(TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected)
}

// Decimal
widenTestWithStringPromotion(
DecimalType(2, 1), DecimalType(3, 2), Some(DecimalType(3, 2)))
widenTestWithStringPromotion(
DecimalType(2, 1), DoubleType, Some(DoubleType))
widenTestWithStringPromotion(
DecimalType(2, 1), IntegerType, Some(DecimalType(11, 1)))
widenTestWithStringPromotion(
DecimalType(2, 1), LongType, Some(DecimalType(21, 1)))

// ArrayType
widenTestWithStringPromotion(
ArrayType(ShortType, containsNull = true),
ArrayType(DoubleType, containsNull = false),
Some(ArrayType(DoubleType, containsNull = true)))
widenTestWithStringPromotion(
ArrayType(TimestampType, containsNull = false),
ArrayType(StringType, containsNull = true),
Some(ArrayType(StringType, containsNull = true)))
widenTestWithStringPromotion(
ArrayType(ArrayType(IntegerType), containsNull = false),
ArrayType(ArrayType(LongType), containsNull = false),
Some(ArrayType(ArrayType(LongType), containsNull = false)))

// Without string promotion
widenTestWithoutStringPromotion(IntegerType, StringType, None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just test int or long, not both

Copy link
Member Author

@HyukjinKwon HyukjinKwon Feb 12, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LongType test is removed.

widenTestWithoutStringPromotion(StringType, TimestampType, None)
widenTestWithoutStringPromotion(ArrayType(LongType), ArrayType(StringType), None)
widenTestWithoutStringPromotion(ArrayType(StringType), ArrayType(TimestampType), None)

// String promotion
widenTestWithStringPromotion(IntegerType, StringType, Some(StringType))
widenTestWithStringPromotion(StringType, TimestampType, Some(StringType))
widenTestWithStringPromotion(
ArrayType(LongType), ArrayType(StringType), Some(ArrayType(StringType)))
widenTestWithStringPromotion(
ArrayType(StringType), ArrayType(TimestampType), Some(ArrayType(StringType)))
}

private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) {
ruleTest(Seq(rule), initial, transformed)
}
Expand Down