Skip to content

Commit c386fb4

Browse files
rxinnemccarthy
authored andcommitted
[SPARK-8362] [SQL] Add unit tests for +, -, *, /, %
Added unit tests for all supported data types for: - Add - Subtract - Multiply - Divide - UnaryMinus - Remainder Fixed bugs caught by the unit tests. Author: Reynold Xin <[email protected]> Closes apache#6813 from rxin/SPARK-8362 and squashes the following commits: fb3fe62 [Reynold Xin] Added Remainder. 3b266ba [Reynold Xin] [SPARK-8362] Add unit tests for +, -, *, /.
1 parent 82fb170 commit c386fb4

File tree

2 files changed

+99
-105
lines changed

2 files changed

+99
-105
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.apache.spark.sql.catalyst
2120
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2221
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
2322
import org.apache.spark.sql.catalyst.util.TypeUtils
@@ -52,8 +51,8 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic {
5251
private lazy val numeric = TypeUtils.getNumeric(dataType)
5352

5453
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
55-
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"c.unary_$$minus()")
56-
case dt: NumericType => defineCodeGen(ctx, ev, c => s"-($c)")
54+
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
55+
case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))")
5756
}
5857

5958
protected override def evalInternal(evalE: Any) = numeric.negate(evalE)
@@ -144,8 +143,8 @@ abstract class BinaryArithmetic extends BinaryExpression {
144143
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
145144
// byte and short are casted into int when add, minus, times or divide
146145
case ByteType | ShortType =>
147-
defineCodeGen(ctx, ev, (eval1, eval2) =>
148-
s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
146+
defineCodeGen(ctx, ev,
147+
(eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
149148
case _ =>
150149
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
151150
}
@@ -205,7 +204,7 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
205204

206205
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
207206
override def symbol: String = "/"
208-
override def decimalMethod: String = "$divide"
207+
override def decimalMethod: String = "$div"
209208

210209
override def nullable: Boolean = true
211210

@@ -245,27 +244,24 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
245244
} else {
246245
s"${eval2.primitive} == 0"
247246
}
248-
val method = if (left.dataType.isInstanceOf[DecimalType]) {
249-
s".$decimalMethod"
250-
} else {
251-
s"$symbol"
252-
}
247+
val method = if (left.dataType.isInstanceOf[DecimalType]) s".$decimalMethod" else s" $symbol "
248+
val javaType = ctx.javaType(left.dataType)
253249
eval1.code + eval2.code +
254250
s"""
255251
boolean ${ev.isNull} = false;
256252
${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)};
257253
if (${eval1.isNull} || ${eval2.isNull} || $test) {
258254
${ev.isNull} = true;
259255
} else {
260-
${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive});
256+
${ev.primitive} = ($javaType) (${eval1.primitive}$method(${eval2.primitive}));
261257
}
262258
"""
263259
}
264260
}
265261

266262
case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
267263
override def symbol: String = "%"
268-
override def decimalMethod: String = "reminder"
264+
override def decimalMethod: String = "remainder"
269265

270266
override def nullable: Boolean = true
271267

@@ -305,19 +301,16 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
305301
} else {
306302
s"${eval2.primitive} == 0"
307303
}
308-
val method = if (left.dataType.isInstanceOf[DecimalType]) {
309-
s".$decimalMethod"
310-
} else {
311-
s"$symbol"
312-
}
304+
val method = if (left.dataType.isInstanceOf[DecimalType]) s".$decimalMethod" else s" $symbol "
305+
val javaType = ctx.javaType(left.dataType)
313306
eval1.code + eval2.code +
314307
s"""
315308
boolean ${ev.isNull} = false;
316309
${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)};
317310
if (${eval1.isNull} || ${eval2.isNull} || $test) {
318311
${ev.isNull} = true;
319312
} else {
320-
${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive});
313+
${ev.primitive} = ($javaType) (${eval1.primitive}$method(${eval2.primitive}));
321314
}
322315
"""
323316
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala

Lines changed: 87 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -17,109 +17,110 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.scalatest.Matchers._
21-
2220
import org.apache.spark.SparkFunSuite
2321
import org.apache.spark.sql.catalyst.dsl.expressions._
2422
import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType}
2523

2624

2725
class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
2826

29-
test("arithmetic") {
30-
val row = create_row(1, 2, 3, null)
31-
val c1 = 'a.int.at(0)
32-
val c2 = 'a.int.at(1)
33-
val c3 = 'a.int.at(2)
34-
val c4 = 'a.int.at(3)
35-
36-
checkEvaluation(UnaryMinus(c1), -1, row)
37-
checkEvaluation(UnaryMinus(Literal.create(100, IntegerType)), -100)
38-
39-
checkEvaluation(Add(c1, c4), null, row)
40-
checkEvaluation(Add(c1, c2), 3, row)
41-
checkEvaluation(Add(c1, Literal.create(null, IntegerType)), null, row)
42-
checkEvaluation(Add(Literal.create(null, IntegerType), c2), null, row)
43-
checkEvaluation(
44-
Add(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row)
45-
46-
checkEvaluation(-c1, -1, row)
47-
checkEvaluation(c1 + c2, 3, row)
48-
checkEvaluation(c1 - c2, -1, row)
49-
checkEvaluation(c1 * c2, 2, row)
50-
checkEvaluation(c1 / c2, 0, row)
51-
checkEvaluation(c1 % c2, 1, row)
27+
/**
28+
* Runs through the testFunc for all numeric data types.
29+
*
30+
* @param testFunc a test function that accepts a conversion function to convert an integer
31+
* into another data type.
32+
*/
33+
private def testNumericDataTypes(testFunc: (Int => Any) => Unit): Unit = {
34+
testFunc(_.toByte)
35+
testFunc(_.toShort)
36+
testFunc(identity)
37+
testFunc(_.toLong)
38+
testFunc(_.toFloat)
39+
testFunc(_.toDouble)
40+
testFunc(Decimal(_))
5241
}
5342

54-
test("fractional arithmetic") {
55-
val row = create_row(1.1, 2.0, 3.1, null)
56-
val c1 = 'a.double.at(0)
57-
val c2 = 'a.double.at(1)
58-
val c3 = 'a.double.at(2)
59-
val c4 = 'a.double.at(3)
60-
61-
checkEvaluation(UnaryMinus(c1), -1.1, row)
62-
checkEvaluation(UnaryMinus(Literal.create(100.0, DoubleType)), -100.0)
63-
checkEvaluation(Add(c1, c4), null, row)
64-
checkEvaluation(Add(c1, c2), 3.1, row)
65-
checkEvaluation(Add(c1, Literal.create(null, DoubleType)), null, row)
66-
checkEvaluation(Add(Literal.create(null, DoubleType), c2), null, row)
67-
checkEvaluation(
68-
Add(Literal.create(null, DoubleType), Literal.create(null, DoubleType)), null, row)
69-
70-
checkEvaluation(-c1, -1.1, row)
71-
checkEvaluation(c1 + c2, 3.1, row)
72-
checkDoubleEvaluation(c1 - c2, (-0.9 +- 0.001), row)
73-
checkDoubleEvaluation(c1 * c2, (2.2 +- 0.001), row)
74-
checkDoubleEvaluation(c1 / c2, (0.55 +- 0.001), row)
75-
checkDoubleEvaluation(c3 % c2, (1.1 +- 0.001), row)
43+
test("+ (Add)") {
44+
testNumericDataTypes { convert =>
45+
val left = Literal(convert(1))
46+
val right = Literal(convert(2))
47+
checkEvaluation(Add(left, right), convert(3))
48+
checkEvaluation(Add(Literal.create(null, left.dataType), right), null)
49+
checkEvaluation(Add(left, Literal.create(null, right.dataType)), null)
50+
}
7651
}
7752

78-
test("Abs") {
79-
def testAbs(convert: (Int) => Any): Unit = {
80-
checkEvaluation(Abs(Literal(convert(0))), convert(0))
81-
checkEvaluation(Abs(Literal(convert(1))), convert(1))
82-
checkEvaluation(Abs(Literal(convert(-1))), convert(1))
53+
test("- (UnaryMinus)") {
54+
testNumericDataTypes { convert =>
55+
val input = Literal(convert(1))
56+
val dataType = input.dataType
57+
checkEvaluation(UnaryMinus(input), convert(-1))
58+
checkEvaluation(UnaryMinus(Literal.create(null, dataType)), null)
8359
}
84-
testAbs(_.toByte)
85-
testAbs(_.toShort)
86-
testAbs(identity)
87-
testAbs(_.toLong)
88-
testAbs(_.toFloat)
89-
testAbs(_.toDouble)
90-
testAbs(Decimal(_))
9160
}
9261

93-
test("Divide") {
94-
checkEvaluation(Divide(Literal(2), Literal(1)), 2)
95-
checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5)
62+
test("- (Minus)") {
63+
testNumericDataTypes { convert =>
64+
val left = Literal(convert(1))
65+
val right = Literal(convert(2))
66+
checkEvaluation(Subtract(left, right), convert(-1))
67+
checkEvaluation(Subtract(Literal.create(null, left.dataType), right), null)
68+
checkEvaluation(Subtract(left, Literal.create(null, right.dataType)), null)
69+
}
70+
}
71+
72+
test("* (Multiply)") {
73+
testNumericDataTypes { convert =>
74+
val left = Literal(convert(1))
75+
val right = Literal(convert(2))
76+
checkEvaluation(Multiply(left, right), convert(2))
77+
checkEvaluation(Multiply(Literal.create(null, left.dataType), right), null)
78+
checkEvaluation(Multiply(left, Literal.create(null, right.dataType)), null)
79+
}
80+
}
81+
82+
test("/ (Divide) basic") {
83+
testNumericDataTypes { convert =>
84+
val left = Literal(convert(2))
85+
val right = Literal(convert(1))
86+
val dataType = left.dataType
87+
checkEvaluation(Divide(left, right), convert(2))
88+
checkEvaluation(Divide(Literal.create(null, dataType), right), null)
89+
checkEvaluation(Divide(left, Literal.create(null, right.dataType)), null)
90+
checkEvaluation(Divide(left, Literal(convert(0))), null) // divide by zero
91+
}
92+
}
93+
94+
test("/ (Divide) for integral type") {
95+
checkEvaluation(Divide(Literal(1.toByte), Literal(2.toByte)), 0.toByte)
96+
checkEvaluation(Divide(Literal(1.toShort), Literal(2.toShort)), 0.toShort)
9697
checkEvaluation(Divide(Literal(1), Literal(2)), 0)
97-
checkEvaluation(Divide(Literal(1), Literal(0)), null)
98-
checkEvaluation(Divide(Literal(1.0), Literal(0.0)), null)
99-
checkEvaluation(Divide(Literal(0.0), Literal(0.0)), null)
100-
checkEvaluation(Divide(Literal(0), Literal.create(null, IntegerType)), null)
101-
checkEvaluation(Divide(Literal(1), Literal.create(null, IntegerType)), null)
102-
checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(0)), null)
103-
checkEvaluation(Divide(Literal.create(null, DoubleType), Literal(0.0)), null)
104-
checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(1)), null)
105-
checkEvaluation(Divide(Literal.create(null, IntegerType), Literal.create(null, IntegerType)),
106-
null)
98+
checkEvaluation(Divide(Literal(1.toLong), Literal(2.toLong)), 0.toLong)
10799
}
108100

109-
test("Remainder") {
110-
checkEvaluation(Remainder(Literal(2), Literal(1)), 0)
111-
checkEvaluation(Remainder(Literal(1.0), Literal(2.0)), 1.0)
112-
checkEvaluation(Remainder(Literal(1), Literal(2)), 1)
113-
checkEvaluation(Remainder(Literal(1), Literal(0)), null)
114-
checkEvaluation(Remainder(Literal(1.0), Literal(0.0)), null)
115-
checkEvaluation(Remainder(Literal(0.0), Literal(0.0)), null)
116-
checkEvaluation(Remainder(Literal(0), Literal.create(null, IntegerType)), null)
117-
checkEvaluation(Remainder(Literal(1), Literal.create(null, IntegerType)), null)
118-
checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(0)), null)
119-
checkEvaluation(Remainder(Literal.create(null, DoubleType), Literal(0.0)), null)
120-
checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(1)), null)
121-
checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal.create(null, IntegerType)),
122-
null)
101+
test("/ (Divide) for floating point") {
102+
checkEvaluation(Divide(Literal(1.0f), Literal(2.0f)), 0.5f)
103+
checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5)
104+
checkEvaluation(Divide(Literal(Decimal(1.0)), Literal(Decimal(2.0))), Decimal(0.5))
105+
}
106+
107+
test("% (Remainder)") {
108+
testNumericDataTypes { convert =>
109+
val left = Literal(convert(1))
110+
val right = Literal(convert(2))
111+
checkEvaluation(Remainder(left, right), convert(1))
112+
checkEvaluation(Remainder(Literal.create(null, left.dataType), right), null)
113+
checkEvaluation(Remainder(left, Literal.create(null, right.dataType)), null)
114+
checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0
115+
}
116+
}
117+
118+
test("Abs") {
119+
testNumericDataTypes { convert =>
120+
checkEvaluation(Abs(Literal(convert(0))), convert(0))
121+
checkEvaluation(Abs(Literal(convert(1))), convert(1))
122+
checkEvaluation(Abs(Literal(convert(-1))), convert(1))
123+
}
123124
}
124125

125126
test("MaxOf") {

0 commit comments

Comments
 (0)