Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import javax.annotation.Nullable

import scala.annotation.tailrec
import scala.collection.mutable
import scala.math._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -89,10 +90,10 @@ object TypeCoercion {
case (NullType, t1) => Some(t1)
case (t1, NullType) => Some(t1)

case (t1: IntegralType, t2: DecimalType) if t2.isWiderThan(t1) =>
Some(t2)
case (t1: DecimalType, t2: IntegralType) if t1.isWiderThan(t2) =>
Some(t1)
case (t1: IntegralType, t2: DecimalType) =>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan Would we need a guard here to be safe ? like
case (t1: IntegralType, t2: DecimalType) if findWiderDecimalType(DecimalType.forType(t1), t2)).isdefined =>

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. We don't need to handle integral and decimal again after it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan ok... thanks !!

findWiderDecimalType(DecimalType.forType(t1), t2)
case (t1: DecimalType, t2: IntegralType) =>
findWiderDecimalType(t1, DecimalType.forType(t2))

// Promote numeric types to the highest of the two
case (t1: NumericType, t2: NumericType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we handle 2 decimals as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan Yeah.. Do you think it may conflict with DecimalConversion rule in anyway ? Let me run the tests first and see how it goes ..

Expand All @@ -106,6 +107,22 @@ object TypeCoercion {
case (t1, t2) => findTypeForComplex(t1, t2, findTightestCommonType)
}

/**
* Finds a wider decimal type between the two supplied decimal types without
* any loss of precision.
*/
def findWiderDecimalType(d1: DecimalType, d2: DecimalType): Option[DecimalType] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reference for this implementation? I'm worried about corner cases like negative scale.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan Actually the bounded version is in DecimalPrecision::widerDecimalType. Thats the function i looked at as reference.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. Then can we add some more tests with negative scale?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan Added tests with -ve scale. Thanks !!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to rename it to findTightestDecimalType, and add document to say what's the difference between this and findWiderTypeForDecimal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan Sure.. will do.

val scale = max(d1.scale, d2.scale)
val range = max(d1.precision - d1.scale, d2.precision - d2.scale)

// Check the resultant decimal type does not exceed the allowable limits.
if (range + scale <= DecimalType.MAX_PRECISION && scale <= DecimalType.MAX_SCALE) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need scale <= DecimalType.MAX_SCALE? DecimalType.scale has been already validated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maropu OK.. i will remove this check.

Some(DecimalType(range + scale, scale))
} else {
None
}
}

/** Promotes all the way to StringType. */
private def stringPromotion(dt1: DataType, dt2: DataType): Option[DataType] = (dt1, dt2) match {
case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,29 @@ class TypeCoercionSuite extends AnalysisTest {
// No up-casting for fixed-precision decimal (this is handled by arithmetic rules)
widenTest(DecimalType(2, 1), DecimalType(3, 2), None)
widenTest(DecimalType(2, 1), DoubleType, None)
widenTest(DecimalType(2, 1), IntegerType, None)
widenTest(DecimalType(2, 1), IntegerType, Some(DecimalType(11, 1)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have one and only one positive and negative test case for each integral type(byte, short, int, long), and another positive and negative test case for negative scale with int type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

widenTest(DecimalType(2, 1), IntegerType, Some(DecimalType(11, 1)))
widenTest(DecimalType(2, 2), IntegerType, Some(DecimalType(12, 2)))
widenTest(DecimalType(4, 2), IntegerType, Some(DecimalType(12, 2)))
widenTest(DecimalType(10, 2), IntegerType, Some(DecimalType(12, 2)))
widenTest(DecimalType(38, 18), IntegerType, Some(DecimalType(38, 18)))
widenTest(DecimalType(38, 28), IntegerType, Some(DecimalType(38, 28)))
widenTest(DecimalType(38, 29), IntegerType, None)
widenTest(DecimalType(38, 38), IntegerType, None)
widenTest(DecimalType(3, -2), IntegerType, Some(DecimalType(10, 0)))
widenTest(DecimalType(11, -2), IntegerType, Some(DecimalType(13, 0)))
widenTest(DecimalType(36, -2), IntegerType, Some(DecimalType(38, 0)))
widenTest(DecimalType(37, -2), IntegerType, None)
widenTest(DecimalType(1, -38), IntegerType, None)
widenTest(DecimalType(2, 1), LongType, Some(DecimalType(21, 1)))
widenTest(DecimalType(2, 1), LongType, Some(DecimalType(21, 1)))
widenTest(DecimalType(2, 2), LongType, Some(DecimalType(22, 2)))
widenTest(DecimalType(4, 2), LongType, Some(DecimalType(22, 2)))
widenTest(DecimalType(10, 2), LongType, Some(DecimalType(22, 2)))
widenTest(DecimalType(38, 18), LongType, Some(DecimalType(38, 18)))
widenTest(DecimalType(38, 18), LongType, Some(DecimalType(38, 18)))
widenTest(DecimalType(38, 19), LongType, None)
widenTest(DecimalType(38, 38), LongType, None)
widenTest(DoubleType, DecimalType(2, 1), None)

// StringType
Expand Down