Skip to content

Commit 3214e0a

Browse files
committed
add test case
1 parent b4985a2 commit 3214e0a

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -361,19 +361,21 @@ trait HiveTypeCoercion {
361361
DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
362362
)
363363

364-
// Cast is no need for logical operator
365-
case LessThan(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
364+
// Cast is not needed for binary comparison
365+
case LessThan(e1 @ DecimalType.Expression(p1, s1),
366+
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
366367
LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
367368

368369
case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
369-
e2 @ DecimalType.Expression(p2, s2)) =>
370+
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
370371
LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
371372

372-
case GreaterThan(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
373+
case GreaterThan(e1 @ DecimalType.Expression(p1, s1),
374+
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
373375
GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
374376

375377
case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
376-
e2 @ DecimalType.Expression(p2, s2)) =>
378+
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
377379
GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
378380

379381
// Promote integers inside a binary expression with fixed-precision decimals to decimals,

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,13 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
4949
assert(analyzer(plan).schema.fields(0).dataType === expectedType)
5050
}
5151

52+
private def checkComparison(expression: Expression, expectedType: DataType): Unit = {
53+
val plan = Project(Seq(Alias(expression, "c")()), relation)
54+
val comparison = analyzer(plan).expressions(0).children(0).asInstanceOf[BinaryComparison]
55+
assert(comparison.left.dataType === expectedType)
56+
assert(comparison.right.dataType === expectedType)
57+
}
58+
5259
test("basic operations") {
5360
checkType(Add(d1, d2), DecimalType(6, 2))
5461
checkType(Subtract(d1, d2), DecimalType(6, 2))
@@ -65,6 +72,14 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
6572
checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2))
6673
}
6774

75+
test("Comparison operations") {
76+
checkComparison(LessThan(i, d1), DecimalType.Unlimited)
77+
checkComparison(LessThanOrEqual(d1, d2), DecimalType.Unlimited)
78+
checkComparison(GreaterThan(d2, u), DecimalType.Unlimited)
79+
checkComparison(GreaterThanOrEqual(d1, f), DoubleType)
80+
checkComparison(GreaterThan(d2, d2), DecimalType(5, 2))
81+
}
82+
6883
test("bringing in primitive types") {
6984
checkType(Add(d1, i), DecimalType(12, 1))
7085
checkType(Add(d1, f), DoubleType)

0 commit comments

Comments
 (0)