Skip to content

Commit 3b266ba

Browse files
committed
[SPARK-8362] Add unit tests for +, -, *, /.
1 parent cb7ada1 commit 3b266ba

File tree

2 files changed

+85
-63
lines changed

2 files changed

+85
-63
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.apache.spark.sql.catalyst
2120
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2221
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
2322
import org.apache.spark.sql.catalyst.util.TypeUtils
@@ -52,8 +51,8 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic {
5251
private lazy val numeric = TypeUtils.getNumeric(dataType)
5352

5453
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
55-
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"c.unary_$$minus()")
56-
case dt: NumericType => defineCodeGen(ctx, ev, c => s"-($c)")
54+
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
55+
case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))")
5756
}
5857

5958
protected override def evalInternal(evalE: Any) = numeric.negate(evalE)
@@ -144,8 +143,8 @@ abstract class BinaryArithmetic extends BinaryExpression {
144143
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
145144
// byte and short are casted into int when add, minus, times or divide
146145
case ByteType | ShortType =>
147-
defineCodeGen(ctx, ev, (eval1, eval2) =>
148-
s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
146+
defineCodeGen(ctx, ev,
147+
(eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
149148
case _ =>
150149
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
151150
}
@@ -205,7 +204,7 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
205204

206205
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
207206
override def symbol: String = "/"
208-
override def decimalMethod: String = "$divide"
207+
override def decimalMethod: String = "$div"
209208

210209
override def nullable: Boolean = true
211210

@@ -257,15 +256,16 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
257256
if (${eval1.isNull} || ${eval2.isNull} || $test) {
258257
${ev.isNull} = true;
259258
} else {
260-
${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive});
259+
${ev.primitive} =
260+
(${ctx.javaType(left.dataType)})(${eval1.primitive}$method(${eval2.primitive}));
261261
}
262262
"""
263263
}
264264
}
265265

266266
case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
267267
override def symbol: String = "%"
268-
override def decimalMethod: String = "reminder"
268+
override def decimalMethod: String = "remainder"
269269

270270
override def nullable: Boolean = true
271271

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala

Lines changed: 77 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -17,77 +17,99 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.scalatest.Matchers._
21-
2220
import org.apache.spark.SparkFunSuite
2321
import org.apache.spark.sql.catalyst.dsl.expressions._
2422
import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType}
2523

2624

2725
class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
2826

29-
test("arithmetic") {
30-
val row = create_row(1, 2, 3, null)
31-
val c1 = 'a.int.at(0)
32-
val c2 = 'a.int.at(1)
33-
val c3 = 'a.int.at(2)
34-
val c4 = 'a.int.at(3)
35-
36-
checkEvaluation(UnaryMinus(c1), -1, row)
37-
checkEvaluation(UnaryMinus(Literal.create(100, IntegerType)), -100)
38-
39-
checkEvaluation(Add(c1, c4), null, row)
40-
checkEvaluation(Add(c1, c2), 3, row)
41-
checkEvaluation(Add(c1, Literal.create(null, IntegerType)), null, row)
42-
checkEvaluation(Add(Literal.create(null, IntegerType), c2), null, row)
43-
checkEvaluation(
44-
Add(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row)
45-
46-
checkEvaluation(-c1, -1, row)
47-
checkEvaluation(c1 + c2, 3, row)
48-
checkEvaluation(c1 - c2, -1, row)
49-
checkEvaluation(c1 * c2, 2, row)
50-
checkEvaluation(c1 / c2, 0, row)
51-
checkEvaluation(c1 % c2, 1, row)
27+
/**
28+
* Runs through the testFunc for all numeric data types.
29+
*
30+
* @param testFunc a test function that accepts a conversion function to convert an integer
31+
* into another data type.
32+
*/
33+
private def testNumericDataTypes(testFunc: (Int => Any) => Unit): Unit = {
34+
testFunc(_.toByte)
35+
testFunc(_.toShort)
36+
testFunc(identity)
37+
testFunc(_.toLong)
38+
testFunc(_.toFloat)
39+
testFunc(_.toDouble)
40+
testFunc(Decimal(_))
41+
}
42+
43+
test("+ (Add)") {
44+
testNumericDataTypes { convert =>
45+
val left = Literal(convert(1))
46+
val right = Literal(convert(2))
47+
checkEvaluation(Add(left, right), convert(3))
48+
checkEvaluation(Add(Literal.create(null, left.dataType), right), null)
49+
checkEvaluation(Add(left, Literal.create(null, right.dataType)), null)
50+
}
51+
}
52+
53+
test("- (UnaryMinus)") {
54+
testNumericDataTypes { convert =>
55+
val input = Literal(convert(1))
56+
val dataType = input.dataType
57+
checkEvaluation(UnaryMinus(input), convert(-1))
58+
checkEvaluation(UnaryMinus(Literal.create(null, dataType)), null)
59+
}
60+
}
61+
62+
test("- (Minus)") {
63+
testNumericDataTypes { convert =>
64+
val left = Literal(convert(1))
65+
val right = Literal(convert(2))
66+
checkEvaluation(Subtract(left, right), convert(-1))
67+
checkEvaluation(Subtract(Literal.create(null, left.dataType), right), null)
68+
checkEvaluation(Subtract(left, Literal.create(null, right.dataType)), null)
69+
}
70+
}
71+
72+
test("* (Multiply)") {
73+
testNumericDataTypes { convert =>
74+
val left = Literal(convert(1))
75+
val right = Literal(convert(2))
76+
checkEvaluation(Multiply(left, right), convert(2))
77+
checkEvaluation(Multiply(Literal.create(null, left.dataType), right), null)
78+
checkEvaluation(Multiply(left, Literal.create(null, right.dataType)), null)
79+
}
80+
}
81+
82+
test("/ (Divide) basic") {
83+
testNumericDataTypes { convert =>
84+
val left = Literal(convert(2))
85+
val right = Literal(convert(1))
86+
val dataType = left.dataType
87+
checkEvaluation(Divide(left, right), convert(2))
88+
checkEvaluation(Divide(Literal.create(null, dataType), right), null)
89+
checkEvaluation(Divide(left, Literal.create(null, right.dataType)), null)
90+
checkEvaluation(Divide(left, Literal(convert(0))), null) // divide by zero
91+
}
5292
}
5393

54-
test("fractional arithmetic") {
55-
val row = create_row(1.1, 2.0, 3.1, null)
56-
val c1 = 'a.double.at(0)
57-
val c2 = 'a.double.at(1)
58-
val c3 = 'a.double.at(2)
59-
val c4 = 'a.double.at(3)
60-
61-
checkEvaluation(UnaryMinus(c1), -1.1, row)
62-
checkEvaluation(UnaryMinus(Literal.create(100.0, DoubleType)), -100.0)
63-
checkEvaluation(Add(c1, c4), null, row)
64-
checkEvaluation(Add(c1, c2), 3.1, row)
65-
checkEvaluation(Add(c1, Literal.create(null, DoubleType)), null, row)
66-
checkEvaluation(Add(Literal.create(null, DoubleType), c2), null, row)
67-
checkEvaluation(
68-
Add(Literal.create(null, DoubleType), Literal.create(null, DoubleType)), null, row)
69-
70-
checkEvaluation(-c1, -1.1, row)
71-
checkEvaluation(c1 + c2, 3.1, row)
72-
checkDoubleEvaluation(c1 - c2, (-0.9 +- 0.001), row)
73-
checkDoubleEvaluation(c1 * c2, (2.2 +- 0.001), row)
74-
checkDoubleEvaluation(c1 / c2, (0.55 +- 0.001), row)
75-
checkDoubleEvaluation(c3 % c2, (1.1 +- 0.001), row)
94+
test("/ (Divide) for integral type") {
95+
checkEvaluation(Divide(Literal(1.toByte), Literal(2.toByte)), 0.toByte)
96+
checkEvaluation(Divide(Literal(1.toShort), Literal(2.toShort)), 0.toShort)
97+
checkEvaluation(Divide(Literal(1), Literal(2)), 0)
98+
checkEvaluation(Divide(Literal(1.toLong), Literal(2.toLong)), 0.toLong)
99+
}
100+
101+
test("/ (Divide) for floating point") {
102+
checkEvaluation(Divide(Literal(1.0f), Literal(2.0f)), 0.5f)
103+
checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5)
104+
checkEvaluation(Divide(Literal(Decimal(1.0)), Literal(Decimal(2.0))), Decimal(0.5))
76105
}
77106

78107
test("Abs") {
79-
def testAbs(convert: (Int) => Any): Unit = {
108+
testNumericDataTypes { convert =>
80109
checkEvaluation(Abs(Literal(convert(0))), convert(0))
81110
checkEvaluation(Abs(Literal(convert(1))), convert(1))
82111
checkEvaluation(Abs(Literal(convert(-1))), convert(1))
83112
}
84-
testAbs(_.toByte)
85-
testAbs(_.toShort)
86-
testAbs(identity)
87-
testAbs(_.toLong)
88-
testAbs(_.toFloat)
89-
testAbs(_.toDouble)
90-
testAbs(Decimal(_))
91113
}
92114

93115
test("Divide") {

0 commit comments

Comments
 (0)