|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql.catalyst.expressions |
19 | 19 |
|
20 | | -import org.scalatest.Matchers._ |
21 | | - |
22 | 20 | import org.apache.spark.SparkFunSuite |
23 | 21 | import org.apache.spark.sql.catalyst.dsl.expressions._ |
24 | 22 | import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType} |
25 | 23 |
|
26 | 24 |
|
27 | 25 | class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { |
28 | 26 |
|
29 | | - test("arithmetic") { |
30 | | - val row = create_row(1, 2, 3, null) |
31 | | - val c1 = 'a.int.at(0) |
32 | | - val c2 = 'a.int.at(1) |
33 | | - val c3 = 'a.int.at(2) |
34 | | - val c4 = 'a.int.at(3) |
35 | | - |
36 | | - checkEvaluation(UnaryMinus(c1), -1, row) |
37 | | - checkEvaluation(UnaryMinus(Literal.create(100, IntegerType)), -100) |
38 | | - |
39 | | - checkEvaluation(Add(c1, c4), null, row) |
40 | | - checkEvaluation(Add(c1, c2), 3, row) |
41 | | - checkEvaluation(Add(c1, Literal.create(null, IntegerType)), null, row) |
42 | | - checkEvaluation(Add(Literal.create(null, IntegerType), c2), null, row) |
43 | | - checkEvaluation( |
44 | | - Add(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) |
45 | | - |
46 | | - checkEvaluation(-c1, -1, row) |
47 | | - checkEvaluation(c1 + c2, 3, row) |
48 | | - checkEvaluation(c1 - c2, -1, row) |
49 | | - checkEvaluation(c1 * c2, 2, row) |
50 | | - checkEvaluation(c1 / c2, 0, row) |
51 | | - checkEvaluation(c1 % c2, 1, row) |
| 27 | + /** |
| 28 | + * Runs through the testFunc for all numeric data types. |
| 29 | + * |
| 30 | + * @param testFunc a test function that accepts a conversion function to convert an integer |
| 31 | + * into another data type. |
| 32 | + */ |
| 33 | + private def testNumericDataTypes(testFunc: (Int => Any) => Unit): Unit = { |
| 34 | + testFunc(_.toByte) |
| 35 | + testFunc(_.toShort) |
| 36 | + testFunc(identity) |
| 37 | + testFunc(_.toLong) |
| 38 | + testFunc(_.toFloat) |
| 39 | + testFunc(_.toDouble) |
| 40 | + testFunc(Decimal(_)) |
52 | 41 | } |
53 | 42 |
|
54 | | - test("fractional arithmetic") { |
55 | | - val row = create_row(1.1, 2.0, 3.1, null) |
56 | | - val c1 = 'a.double.at(0) |
57 | | - val c2 = 'a.double.at(1) |
58 | | - val c3 = 'a.double.at(2) |
59 | | - val c4 = 'a.double.at(3) |
60 | | - |
61 | | - checkEvaluation(UnaryMinus(c1), -1.1, row) |
62 | | - checkEvaluation(UnaryMinus(Literal.create(100.0, DoubleType)), -100.0) |
63 | | - checkEvaluation(Add(c1, c4), null, row) |
64 | | - checkEvaluation(Add(c1, c2), 3.1, row) |
65 | | - checkEvaluation(Add(c1, Literal.create(null, DoubleType)), null, row) |
66 | | - checkEvaluation(Add(Literal.create(null, DoubleType), c2), null, row) |
67 | | - checkEvaluation( |
68 | | - Add(Literal.create(null, DoubleType), Literal.create(null, DoubleType)), null, row) |
69 | | - |
70 | | - checkEvaluation(-c1, -1.1, row) |
71 | | - checkEvaluation(c1 + c2, 3.1, row) |
72 | | - checkDoubleEvaluation(c1 - c2, (-0.9 +- 0.001), row) |
73 | | - checkDoubleEvaluation(c1 * c2, (2.2 +- 0.001), row) |
74 | | - checkDoubleEvaluation(c1 / c2, (0.55 +- 0.001), row) |
75 | | - checkDoubleEvaluation(c3 % c2, (1.1 +- 0.001), row) |
| 43 | + test("+ (Add)") { |
| 44 | + testNumericDataTypes { convert => |
| 45 | + val left = Literal(convert(1)) |
| 46 | + val right = Literal(convert(2)) |
| 47 | + checkEvaluation(Add(left, right), convert(3)) |
| 48 | + checkEvaluation(Add(Literal.create(null, left.dataType), right), null) |
| 49 | + checkEvaluation(Add(left, Literal.create(null, right.dataType)), null) |
| 50 | + } |
76 | 51 | } |
77 | 52 |
|
78 | | - test("Abs") { |
79 | | - def testAbs(convert: (Int) => Any): Unit = { |
80 | | - checkEvaluation(Abs(Literal(convert(0))), convert(0)) |
81 | | - checkEvaluation(Abs(Literal(convert(1))), convert(1)) |
82 | | - checkEvaluation(Abs(Literal(convert(-1))), convert(1)) |
| 53 | + test("- (UnaryMinus)") { |
| 54 | + testNumericDataTypes { convert => |
| 55 | + val input = Literal(convert(1)) |
| 56 | + val dataType = input.dataType |
| 57 | + checkEvaluation(UnaryMinus(input), convert(-1)) |
| 58 | + checkEvaluation(UnaryMinus(Literal.create(null, dataType)), null) |
83 | 59 | } |
84 | | - testAbs(_.toByte) |
85 | | - testAbs(_.toShort) |
86 | | - testAbs(identity) |
87 | | - testAbs(_.toLong) |
88 | | - testAbs(_.toFloat) |
89 | | - testAbs(_.toDouble) |
90 | | - testAbs(Decimal(_)) |
91 | 60 | } |
92 | 61 |
|
93 | | - test("Divide") { |
94 | | - checkEvaluation(Divide(Literal(2), Literal(1)), 2) |
95 | | - checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5) |
| 62 | + test("- (Minus)") { |
| 63 | + testNumericDataTypes { convert => |
| 64 | + val left = Literal(convert(1)) |
| 65 | + val right = Literal(convert(2)) |
| 66 | + checkEvaluation(Subtract(left, right), convert(-1)) |
| 67 | + checkEvaluation(Subtract(Literal.create(null, left.dataType), right), null) |
| 68 | + checkEvaluation(Subtract(left, Literal.create(null, right.dataType)), null) |
| 69 | + } |
| 70 | + } |
| 71 | + |
| 72 | + test("* (Multiply)") { |
| 73 | + testNumericDataTypes { convert => |
| 74 | + val left = Literal(convert(1)) |
| 75 | + val right = Literal(convert(2)) |
| 76 | + checkEvaluation(Multiply(left, right), convert(2)) |
| 77 | + checkEvaluation(Multiply(Literal.create(null, left.dataType), right), null) |
| 78 | + checkEvaluation(Multiply(left, Literal.create(null, right.dataType)), null) |
| 79 | + } |
| 80 | + } |
| 81 | + |
| 82 | + test("/ (Divide) basic") { |
| 83 | + testNumericDataTypes { convert => |
| 84 | + val left = Literal(convert(2)) |
| 85 | + val right = Literal(convert(1)) |
| 86 | + val dataType = left.dataType |
| 87 | + checkEvaluation(Divide(left, right), convert(2)) |
| 88 | + checkEvaluation(Divide(Literal.create(null, dataType), right), null) |
| 89 | + checkEvaluation(Divide(left, Literal.create(null, right.dataType)), null) |
| 90 | + checkEvaluation(Divide(left, Literal(convert(0))), null) // divide by zero |
| 91 | + } |
| 92 | + } |
| 93 | + |
| 94 | + test("/ (Divide) for integral type") { |
| 95 | + checkEvaluation(Divide(Literal(1.toByte), Literal(2.toByte)), 0.toByte) |
| 96 | + checkEvaluation(Divide(Literal(1.toShort), Literal(2.toShort)), 0.toShort) |
96 | 97 | checkEvaluation(Divide(Literal(1), Literal(2)), 0) |
97 | | - checkEvaluation(Divide(Literal(1), Literal(0)), null) |
98 | | - checkEvaluation(Divide(Literal(1.0), Literal(0.0)), null) |
99 | | - checkEvaluation(Divide(Literal(0.0), Literal(0.0)), null) |
100 | | - checkEvaluation(Divide(Literal(0), Literal.create(null, IntegerType)), null) |
101 | | - checkEvaluation(Divide(Literal(1), Literal.create(null, IntegerType)), null) |
102 | | - checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(0)), null) |
103 | | - checkEvaluation(Divide(Literal.create(null, DoubleType), Literal(0.0)), null) |
104 | | - checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(1)), null) |
105 | | - checkEvaluation(Divide(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), |
106 | | - null) |
| 98 | + checkEvaluation(Divide(Literal(1.toLong), Literal(2.toLong)), 0.toLong) |
107 | 99 | } |
108 | 100 |
|
109 | | - test("Remainder") { |
110 | | - checkEvaluation(Remainder(Literal(2), Literal(1)), 0) |
111 | | - checkEvaluation(Remainder(Literal(1.0), Literal(2.0)), 1.0) |
112 | | - checkEvaluation(Remainder(Literal(1), Literal(2)), 1) |
113 | | - checkEvaluation(Remainder(Literal(1), Literal(0)), null) |
114 | | - checkEvaluation(Remainder(Literal(1.0), Literal(0.0)), null) |
115 | | - checkEvaluation(Remainder(Literal(0.0), Literal(0.0)), null) |
116 | | - checkEvaluation(Remainder(Literal(0), Literal.create(null, IntegerType)), null) |
117 | | - checkEvaluation(Remainder(Literal(1), Literal.create(null, IntegerType)), null) |
118 | | - checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(0)), null) |
119 | | - checkEvaluation(Remainder(Literal.create(null, DoubleType), Literal(0.0)), null) |
120 | | - checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(1)), null) |
121 | | - checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), |
122 | | - null) |
| 101 | + test("/ (Divide) for floating point") { |
| 102 | + checkEvaluation(Divide(Literal(1.0f), Literal(2.0f)), 0.5f) |
| 103 | + checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5) |
| 104 | + checkEvaluation(Divide(Literal(Decimal(1.0)), Literal(Decimal(2.0))), Decimal(0.5)) |
| 105 | + } |
| 106 | + |
| 107 | + test("% (Remainder)") { |
| 108 | + testNumericDataTypes { convert => |
| 109 | + val left = Literal(convert(1)) |
| 110 | + val right = Literal(convert(2)) |
| 111 | + checkEvaluation(Remainder(left, right), convert(1)) |
| 112 | + checkEvaluation(Remainder(Literal.create(null, left.dataType), right), null) |
| 113 | + checkEvaluation(Remainder(left, Literal.create(null, right.dataType)), null) |
| 114 | + checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0 |
| 115 | + } |
| 116 | + } |
| 117 | + |
| 118 | + test("Abs") { |
| 119 | + testNumericDataTypes { convert => |
| 120 | + checkEvaluation(Abs(Literal(convert(0))), convert(0)) |
| 121 | + checkEvaluation(Abs(Literal(convert(1))), convert(1)) |
| 122 | + checkEvaluation(Abs(Literal(convert(-1))), convert(1)) |
| 123 | + } |
123 | 124 | } |
124 | 125 |
|
125 | 126 | test("MaxOf") { |
|
0 commit comments