|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql.catalyst.expressions |
19 | 19 |
|
20 | | -import org.scalatest.Matchers._ |
21 | | - |
22 | 20 | import org.apache.spark.SparkFunSuite |
23 | 21 | import org.apache.spark.sql.catalyst.dsl.expressions._ |
24 | 22 | import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType} |
25 | 23 |
|
26 | 24 |
|
27 | 25 | class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { |
28 | 26 |
|
29 | | - test("arithmetic") { |
30 | | - val row = create_row(1, 2, 3, null) |
31 | | - val c1 = 'a.int.at(0) |
32 | | - val c2 = 'a.int.at(1) |
33 | | - val c3 = 'a.int.at(2) |
34 | | - val c4 = 'a.int.at(3) |
35 | | - |
36 | | - checkEvaluation(UnaryMinus(c1), -1, row) |
37 | | - checkEvaluation(UnaryMinus(Literal.create(100, IntegerType)), -100) |
38 | | - |
39 | | - checkEvaluation(Add(c1, c4), null, row) |
40 | | - checkEvaluation(Add(c1, c2), 3, row) |
41 | | - checkEvaluation(Add(c1, Literal.create(null, IntegerType)), null, row) |
42 | | - checkEvaluation(Add(Literal.create(null, IntegerType), c2), null, row) |
43 | | - checkEvaluation( |
44 | | - Add(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) |
45 | | - |
46 | | - checkEvaluation(-c1, -1, row) |
47 | | - checkEvaluation(c1 + c2, 3, row) |
48 | | - checkEvaluation(c1 - c2, -1, row) |
49 | | - checkEvaluation(c1 * c2, 2, row) |
50 | | - checkEvaluation(c1 / c2, 0, row) |
51 | | - checkEvaluation(c1 % c2, 1, row) |
| 27 | + /** |
| 28 | + * Runs through the testFunc for all numeric data types. |
| 29 | + * |
| 30 | + * @param testFunc a test function that accepts a conversion function to convert an integer |
| 31 | + * into another data type. |
| 32 | + */ |
| 33 | + private def testNumericDataTypes(testFunc: (Int => Any) => Unit): Unit = { |
| 34 | + testFunc(_.toByte) |
| 35 | + testFunc(_.toShort) |
| 36 | + testFunc(identity) |
| 37 | + testFunc(_.toLong) |
| 38 | + testFunc(_.toFloat) |
| 39 | + testFunc(_.toDouble) |
| 40 | + testFunc(Decimal(_)) |
| 41 | + } |
| 42 | + |
| 43 | + test("+ (Add)") { |
| 44 | + testNumericDataTypes { convert => |
| 45 | + val left = Literal(convert(1)) |
| 46 | + val right = Literal(convert(2)) |
| 47 | + checkEvaluation(Add(left, right), convert(3)) |
| 48 | + checkEvaluation(Add(Literal.create(null, left.dataType), right), null) |
| 49 | + checkEvaluation(Add(left, Literal.create(null, right.dataType)), null) |
| 50 | + } |
| 51 | + } |
| 52 | + |
| 53 | + test("- (UnaryMinus)") { |
| 54 | + testNumericDataTypes { convert => |
| 55 | + val input = Literal(convert(1)) |
| 56 | + val dataType = input.dataType |
| 57 | + checkEvaluation(UnaryMinus(input), convert(-1)) |
| 58 | + checkEvaluation(UnaryMinus(Literal.create(null, dataType)), null) |
| 59 | + } |
| 60 | + } |
| 61 | + |
| 62 | + test("- (Minus)") { |
| 63 | + testNumericDataTypes { convert => |
| 64 | + val left = Literal(convert(1)) |
| 65 | + val right = Literal(convert(2)) |
| 66 | + checkEvaluation(Subtract(left, right), convert(-1)) |
| 67 | + checkEvaluation(Subtract(Literal.create(null, left.dataType), right), null) |
| 68 | + checkEvaluation(Subtract(left, Literal.create(null, right.dataType)), null) |
| 69 | + } |
| 70 | + } |
| 71 | + |
| 72 | + test("* (Multiply)") { |
| 73 | + testNumericDataTypes { convert => |
| 74 | + val left = Literal(convert(1)) |
| 75 | + val right = Literal(convert(2)) |
| 76 | + checkEvaluation(Multiply(left, right), convert(2)) |
| 77 | + checkEvaluation(Multiply(Literal.create(null, left.dataType), right), null) |
| 78 | + checkEvaluation(Multiply(left, Literal.create(null, right.dataType)), null) |
| 79 | + } |
| 80 | + } |
| 81 | + |
| 82 | + test("/ (Divide) basic") { |
| 83 | + testNumericDataTypes { convert => |
| 84 | + val left = Literal(convert(2)) |
| 85 | + val right = Literal(convert(1)) |
| 86 | + val dataType = left.dataType |
| 87 | + checkEvaluation(Divide(left, right), convert(2)) |
| 88 | + checkEvaluation(Divide(Literal.create(null, dataType), right), null) |
| 89 | + checkEvaluation(Divide(left, Literal.create(null, right.dataType)), null) |
| 90 | + checkEvaluation(Divide(left, Literal(convert(0))), null) // divide by zero |
| 91 | + } |
52 | 92 | } |
53 | 93 |
|
54 | | - test("fractional arithmetic") { |
55 | | - val row = create_row(1.1, 2.0, 3.1, null) |
56 | | - val c1 = 'a.double.at(0) |
57 | | - val c2 = 'a.double.at(1) |
58 | | - val c3 = 'a.double.at(2) |
59 | | - val c4 = 'a.double.at(3) |
60 | | - |
61 | | - checkEvaluation(UnaryMinus(c1), -1.1, row) |
62 | | - checkEvaluation(UnaryMinus(Literal.create(100.0, DoubleType)), -100.0) |
63 | | - checkEvaluation(Add(c1, c4), null, row) |
64 | | - checkEvaluation(Add(c1, c2), 3.1, row) |
65 | | - checkEvaluation(Add(c1, Literal.create(null, DoubleType)), null, row) |
66 | | - checkEvaluation(Add(Literal.create(null, DoubleType), c2), null, row) |
67 | | - checkEvaluation( |
68 | | - Add(Literal.create(null, DoubleType), Literal.create(null, DoubleType)), null, row) |
69 | | - |
70 | | - checkEvaluation(-c1, -1.1, row) |
71 | | - checkEvaluation(c1 + c2, 3.1, row) |
72 | | - checkDoubleEvaluation(c1 - c2, (-0.9 +- 0.001), row) |
73 | | - checkDoubleEvaluation(c1 * c2, (2.2 +- 0.001), row) |
74 | | - checkDoubleEvaluation(c1 / c2, (0.55 +- 0.001), row) |
75 | | - checkDoubleEvaluation(c3 % c2, (1.1 +- 0.001), row) |
| 94 | + test("/ (Divide) for integral type") { |
| 95 | + checkEvaluation(Divide(Literal(1.toByte), Literal(2.toByte)), 0.toByte) |
| 96 | + checkEvaluation(Divide(Literal(1.toShort), Literal(2.toShort)), 0.toShort) |
| 97 | + checkEvaluation(Divide(Literal(1), Literal(2)), 0) |
| 98 | + checkEvaluation(Divide(Literal(1.toLong), Literal(2.toLong)), 0.toLong) |
| 99 | + } |
| 100 | + |
| 101 | + test("/ (Divide) for floating point") { |
| 102 | + checkEvaluation(Divide(Literal(1.0f), Literal(2.0f)), 0.5f) |
| 103 | + checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5) |
| 104 | + checkEvaluation(Divide(Literal(Decimal(1.0)), Literal(Decimal(2.0))), Decimal(0.5)) |
76 | 105 | } |
77 | 106 |
|
78 | 107 | test("Abs") { |
79 | | - def testAbs(convert: (Int) => Any): Unit = { |
| 108 | + testNumericDataTypes { convert => |
80 | 109 | checkEvaluation(Abs(Literal(convert(0))), convert(0)) |
81 | 110 | checkEvaluation(Abs(Literal(convert(1))), convert(1)) |
82 | 111 | checkEvaluation(Abs(Literal(convert(-1))), convert(1)) |
83 | 112 | } |
84 | | - testAbs(_.toByte) |
85 | | - testAbs(_.toShort) |
86 | | - testAbs(identity) |
87 | | - testAbs(_.toLong) |
88 | | - testAbs(_.toFloat) |
89 | | - testAbs(_.toDouble) |
90 | | - testAbs(Decimal(_)) |
91 | 113 | } |
92 | 114 |
|
93 | 115 | test("Divide") { |
|
0 commit comments