From 829cfe7a41f243cbef48746e13f2cc2b9cd6053e Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 24 Dec 2019 11:39:34 +0800 Subject: [PATCH 01/19] [SPARK-30341][SQL] Overflow check for interval arithmetic operations --- .../sql/catalyst/expressions/arithmetic.scala | 63 +++++++- .../expressions/intervalExpressions.scala | 15 +- .../sql/catalyst/util/IntervalUtils.scala | 19 ++- .../catalyst/util/IntervalUtilsSuite.scala | 20 +++ .../resources/sql-tests/inputs/interval.sql | 7 + .../sql-tests/results/ansi/interval.sql.out | 145 ++++++++++++------ .../sql-tests/results/interval.sql.out | 42 ++++- 7 files changed, 241 insertions(+), 70 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 82a8e6d80a0b..190903b2c158 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -37,6 +37,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes with NullIntolerant { private val checkOverflow = SQLConf.get.ansiEnabled + override def nullable: Boolean = true + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) override def dataType: DataType = child.dataType @@ -75,12 +77,29 @@ case class UnaryMinus(child: Expression) extends UnaryExpression """}) case _: CalendarIntervalType => val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") - defineCodeGen(ctx, ev, c => s"$iu.negate($c)") + nullSafeCodeGen(ctx, ev, interval => s""" + try { + ${ev.value} = $iu.negate($interval); + } catch (ArithmeticException e) { + if ($checkOverflow) { + throw new ArithmeticException("-($interval) caused interval overflow."); + } else { + ${ev.isNull} = true; + } + } + """) } protected override def nullSafeEval(input: Any): Any = dataType match { - case CalendarIntervalType => IntervalUtils.negate(input.asInstanceOf[CalendarInterval]) - case _ => numeric.negate(input) + case CalendarIntervalType => + try { + IntervalUtils.negate(input.asInstanceOf[CalendarInterval]) + } catch { + case _: ArithmeticException if checkOverflow => + throw new ArithmeticException(s"$sql caused interval overflow") + case _: ArithmeticException => null + } + case _ => numeric.negate(input) } override def sql: String = s"(- ${child.sql})" @@ -139,6 +158,8 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { override def dataType: DataType = left.dataType + override def nullable: Boolean = true + override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess /** Name of the function for this expression on a [[Decimal]] type. */ @@ -160,7 +181,19 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") case CalendarIntervalType => val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") - defineCodeGen(ctx, ev, (eval1, eval2) => s"$iu.$calendarIntervalMethod($eval1, $eval2)") + nullSafeCodeGen(ctx, ev, (eval1, eval2) => + s""" + |try { + | ${ev.value} = $iu.$calendarIntervalMethod($eval1, $eval2); + |} catch (ArithmeticException e) { + | if ($checkOverflow) { + | throw new ArithmeticException( + | "$eval1 $calendarIntervalMethod $eval2 caused interval overflow."); + | } else { + | ${ev.isNull} = true; + | } + |} + |""".stripMargin) // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { @@ -229,8 +262,15 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow) protected override def nullSafeEval(input1: Any, input2: Any): Any = dataType match { - case CalendarIntervalType => IntervalUtils.add( - input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval]) + case CalendarIntervalType => + try { + IntervalUtils.add( + input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval]) + } catch { + case _: ArithmeticException if checkOverflow => + throw new ArithmeticException(s"$sql causes interval overflow") + case _: ArithmeticException => null + } case _ => numeric.plus(input1, input2) } @@ -257,8 +297,15 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow) protected override def nullSafeEval(input1: Any, input2: Any): Any = dataType match { - case CalendarIntervalType => IntervalUtils.subtract( - input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval]) + case CalendarIntervalType => + try { + IntervalUtils.subtract( + input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval]) + } catch { + case _: ArithmeticException if checkOverflow => + throw new ArithmeticException(s"$sql caused interval overflow") + case _: ArithmeticException => null + } case _ => numeric.minus(input1, input2) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index c8a40d0435a5..d911365dc827 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -24,6 +24,7 @@ import scala.util.control.NonFatal import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.catalyst.util.IntervalUtils._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -118,6 +119,8 @@ abstract class IntervalNumOperation( operation: (CalendarInterval, Double) => CalendarInterval, operationName: String) extends BinaryExpression with ImplicitCastInputTypes with Serializable { + private val checkOverflow = SQLConf.get.ansiEnabled + override def left: Expression = interval override def right: Expression = num @@ -130,7 +133,9 @@ abstract class IntervalNumOperation( try { operation(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) } catch { - case _: java.lang.ArithmeticException => null + case _: ArithmeticException if checkOverflow => + throw new ArithmeticException(s"$sql caused interval overflow.") + case _: ArithmeticException => null } } @@ -140,8 +145,12 @@ abstract class IntervalNumOperation( s""" try { ${ev.value} = $iu.$operationName($interval, $num); - } catch (java.lang.ArithmeticException e) { - ${ev.isNull} = true; + } catch (ArithmeticException e) { + if ($checkOverflow) { + throw new ArithmeticException("$prettyName($interval, $num) caused interval overflow."); + } else { + ${ev.isNull} = true; + } } """ }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index 2a60cfd52ca9..88119a2c5318 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -420,16 +420,19 @@ object IntervalUtils { * @return a new calendar interval instance with all it parameters negated from the origin one. */ def negate(interval: CalendarInterval): CalendarInterval = { - new CalendarInterval(-interval.months, -interval.days, -interval.microseconds) + val months = Math.negateExact(interval.months) + val days = Math.negateExact(interval.days) + val microseconds = Math.negateExact(interval.microseconds) + new CalendarInterval(months, days, microseconds) } /** * Return a new calendar interval instance of the sum of two intervals. */ def add(left: CalendarInterval, right: CalendarInterval): CalendarInterval = { - val months = left.months + right.months - val days = left.days + right.days - val microseconds = left.microseconds + right.microseconds + val months = Math.addExact(left.months, right.months) + val days = Math.addExact(left.days, right.days) + val microseconds = Math.addExact(left.microseconds, right.microseconds) new CalendarInterval(months, days, microseconds) } @@ -437,9 +440,9 @@ object IntervalUtils { * Return a new calendar interval instance of the left intervals minus the right one. */ def subtract(left: CalendarInterval, right: CalendarInterval): CalendarInterval = { - val months = left.months - right.months - val days = left.days - right.days - val microseconds = left.microseconds - right.microseconds + val months = Math.subtractExact(left.months, right.months) + val days = Math.subtractExact(left.days, right.days) + val microseconds = Math.subtractExact(left.microseconds, right.microseconds) new CalendarInterval(months, days, microseconds) } @@ -448,7 +451,7 @@ object IntervalUtils { } def divide(interval: CalendarInterval, num: Double): CalendarInterval = { - if (num == 0) throw new java.lang.ArithmeticException("divide by zero") + if (num == 0) throw new ArithmeticException("divide by zero") fromDoubles(interval.months / num, interval.days / num, interval.microseconds / num) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala index 15ba5f03d050..6fc377005770 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala @@ -445,4 +445,24 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { checkFail("5 30-12", DAY, SECOND, "must match day-time format") checkFail("5 1:12:20", HOUR, MICROSECOND, "Cannot support (interval") } + + test("interval overflow check") { + intercept[ArithmeticException](negate(new CalendarInterval(Int.MinValue, 0, 0))) + intercept[ArithmeticException](negate(CalendarInterval.MIN_VALUE)) + + intercept[ArithmeticException](add(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, 1))) + intercept[ArithmeticException](add(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 1, 0))) + intercept[ArithmeticException](add(CalendarInterval.MAX_VALUE, new CalendarInterval(1, 0, 0))) + + intercept[ArithmeticException](subtract(CalendarInterval.MAX_VALUE, + new CalendarInterval(0, 0, -1))) + intercept[ArithmeticException](subtract(CalendarInterval.MAX_VALUE, + new CalendarInterval(0, -1, 0))) + intercept[ArithmeticException](subtract(CalendarInterval.MAX_VALUE, + new CalendarInterval(-1, 0, 0))) + + intercept[ArithmeticException](multiply(CalendarInterval.MAX_VALUE, 2)) + + intercept[ArithmeticException](divide(CalendarInterval.MAX_VALUE, 0.5)) + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/interval.sql b/sql/core/src/test/resources/sql-tests/inputs/interval.sql index 5b5eaab225d2..794d7cdb9a5f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/interval.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/interval.sql @@ -269,3 +269,10 @@ select interval 'interval \t 1\tday'; select interval 'interval\t1\tday'; select interval '1\t' day; select interval '1 ' day; + +-- interval overflow if (ansi) exception else NULL +select -(a) from values (interval '-2147483648 months', interval '2147483647 months') t(a, b); +select a - b from values (interval '-2147483648 months', interval '2147483647 months') t(a, b); +select b + interval '1 month' from values (interval '-2147483648 months', interval '2147483647 months') t(a, b); +select a * 2 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b); +select a / 0.5 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index 8c65c5ece7e5..7314174d141f 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 138 +-- Number of queries: 143 -- !query 0 @@ -207,9 +207,10 @@ struct +struct<> -- !query 25 output -NULL +java.lang.ArithmeticException +divide_interval(INTERVAL '2 seconds', CAST(0 AS DOUBLE)) caused interval overflow. -- !query 26 @@ -994,17 +995,19 @@ struct +struct<> -- !query 105 output -NULL +java.lang.ArithmeticException +divide_interval(agg_mutableStateArray_0[0], agg_value_4) caused interval overflow. -- !query 106 select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where 1=0 -- !query 106 schema -struct +struct<> -- !query 106 output -NULL +java.lang.ArithmeticException +divide_interval(agg_mutableStateArray_0[0], agg_value_4) caused interval overflow. -- !query 107 @@ -1046,11 +1049,10 @@ select from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v) group by i -- !query 111 schema -struct +struct<> -- !query 111 output -1 -1 days -2 2 seconds -3 NULL +java.lang.ArithmeticException +divide_interval(agg_value_9, agg_value_13) caused interval overflow. -- !query 112 @@ -1070,12 +1072,10 @@ SELECT avg(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM VALUES (1,'1 seconds'), (1,'2 seconds'), (2,NULL), (2,NULL) t(i,v) -- !query 113 schema -struct +struct<> -- !query 113 output -1 1.5 seconds -1 2 seconds -2 NULL -2 NULL +java.lang.ArithmeticException +divide_interval(value_1, value_2) caused interval overflow. -- !query 114 @@ -1216,34 +1216,79 @@ struct -- !query 126 -select 1 year 2 days +select -(a) from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) -- !query 126 schema -struct +struct<> -- !query 126 output -1 years 2 days +java.lang.ArithmeticException +-(localtablescan_value_0) caused interval overflow. -- !query 127 -select '10-9' year to month +select a - b from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) -- !query 127 schema -struct +struct<> -- !query 127 output -10 years 9 months +java.lang.ArithmeticException +localtablescan_value_0 subtract localtablescan_value_1 caused interval overflow. -- !query 128 -select '20 15:40:32.99899999' day to second +select b + interval '1 month' from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) -- !query 128 schema -struct +struct<> -- !query 128 output -20 days 15 hours 40 minutes 32.998999 seconds +java.lang.ArithmeticException +localtablescan_value_1 add ((CalendarInterval) references[1] /* literal */) caused interval overflow. -- !query 129 -select 30 day day +select a * 2 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) -- !query 129 schema struct<> -- !query 129 output +java.lang.ArithmeticException +multiply_interval(localtablescan_value_0, 2.0D) caused interval overflow. + + +-- !query 130 +select a / 0.5 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) +-- !query 130 schema +struct<> +-- !query 130 output +java.lang.ArithmeticException +divide_interval(localtablescan_value_0, 0.5D) caused interval overflow. + + +-- !query 131 +select 1 year 2 days +-- !query 131 schema +struct +-- !query 131 output +1 years 2 days + + +-- !query 132 +select '10-9' year to month +-- !query 132 schema +struct +-- !query 132 output +10 years 9 months + + +-- !query 133 +select '20 15:40:32.99899999' day to second +-- !query 133 schema +struct +-- !query 133 output +20 days 15 hours 40 minutes 32.998999 seconds + + +-- !query 134 +select 30 day day +-- !query 134 schema +struct<> +-- !query 134 output org.apache.spark.sql.catalyst.parser.ParseException no viable alternative at input 'day'(line 1, pos 14) @@ -1253,27 +1298,27 @@ select 30 day day --------------^^^ --- !query 130 +-- !query 135 select date'2012-01-01' - '2-2' year to month --- !query 130 schema +-- !query 135 schema struct --- !query 130 output +-- !query 135 output 2009-11-01 --- !query 131 +-- !query 136 select 1 month - 1 day --- !query 131 schema +-- !query 136 schema struct --- !query 131 output +-- !query 136 output 1 months -1 days --- !query 132 +-- !query 137 select 1 year to month --- !query 132 schema +-- !query 137 schema struct<> --- !query 132 output +-- !query 137 output org.apache.spark.sql.catalyst.parser.ParseException The value of from-to unit must be a string(line 1, pos 7) @@ -1283,11 +1328,11 @@ select 1 year to month -------^^^ --- !query 133 +-- !query 138 select '1' year to second --- !query 133 schema +-- !query 138 schema struct<> --- !query 133 output +-- !query 138 output org.apache.spark.sql.catalyst.parser.ParseException Intervals FROM year TO second are not supported.(line 1, pos 7) @@ -1297,11 +1342,11 @@ select '1' year to second -------^^^ --- !query 134 +-- !query 139 select 1 year '2-1' year to month --- !query 134 schema +-- !query 139 schema struct<> --- !query 134 output +-- !query 139 output org.apache.spark.sql.catalyst.parser.ParseException Can only have a single from-to unit in the interval literal syntax(line 1, pos 14) @@ -1311,11 +1356,11 @@ select 1 year '2-1' year to month --------------^^^ --- !query 135 +-- !query 140 select (-30) day --- !query 135 schema +-- !query 140 schema struct<> --- !query 135 output +-- !query 140 output org.apache.spark.sql.catalyst.parser.ParseException no viable alternative at input 'day'(line 1, pos 13) @@ -1325,11 +1370,11 @@ select (-30) day -------------^^^ --- !query 136 +-- !query 141 select (a + 1) day --- !query 136 schema +-- !query 141 schema struct<> --- !query 136 output +-- !query 141 output org.apache.spark.sql.catalyst.parser.ParseException no viable alternative at input 'day'(line 1, pos 15) @@ -1339,11 +1384,11 @@ select (a + 1) day ---------------^^^ --- !query 137 +-- !query 142 select 30 day day day --- !query 137 schema +-- !query 142 schema struct<> --- !query 137 output +-- !query 142 output org.apache.spark.sql.catalyst.parser.ParseException no viable alternative at input 'day'(line 1, pos 14) diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out index ff0a3ff74f1e..13722d942ec2 100644 --- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 126 +-- Number of queries: 131 -- !query 0 @@ -1197,3 +1197,43 @@ select interval '1 ' day struct -- !query 125 output 1 days + + +-- !query 126 +select -(a) from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) +-- !query 126 schema +struct<(- a):interval> +-- !query 126 output +NULL + + +-- !query 127 +select a - b from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) +-- !query 127 schema +struct<(a - b):interval> +-- !query 127 output +NULL + + +-- !query 128 +select b + interval '1 month' from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) +-- !query 128 schema +struct<(b + INTERVAL '1 months'):interval> +-- !query 128 output +NULL + + +-- !query 129 +select a * 2 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) +-- !query 129 schema +struct +-- !query 129 output +NULL + + +-- !query 130 +select a / 0.5 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) +-- !query 130 schema +struct +-- !query 130 output +NULL From 516080cbea19905fbfa177abced0f5853792e9fe Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 24 Dec 2019 13:34:49 +0800 Subject: [PATCH 02/19] override nullable, fix tests --- .../spark/sql/catalyst/expressions/arithmetic.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 190903b2c158..1574cf8e5f98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -158,8 +158,6 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { override def dataType: DataType = left.dataType - override def nullable: Boolean = true - override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess /** Name of the function for this expression on a [[Decimal]] type. */ @@ -251,6 +249,11 @@ object BinaryArithmetic { """) case class Add(left: Expression, right: Expression) extends BinaryArithmetic { + override def nullable: Boolean = dataType match { + case CalendarIntervalType if !checkOverflow => true + case _ => super.nullable + } + override def inputType: AbstractDataType = TypeCollection.NumericAndInterval override def symbol: String = "+" @@ -286,6 +289,11 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { """) case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { + override def nullable: Boolean = dataType match { + case CalendarIntervalType if !checkOverflow => true + case _ => false + } + override def inputType: AbstractDataType = TypeCollection.NumericAndInterval override def symbol: String = "-" From 67767c043671a2c0960494ab06e022317cec12d1 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 24 Dec 2019 13:39:21 +0800 Subject: [PATCH 03/19] fix --- .../apache/spark/sql/catalyst/expressions/arithmetic.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 1574cf8e5f98..68d94237c639 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -37,7 +37,10 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes with NullIntolerant { private val checkOverflow = SQLConf.get.ansiEnabled - override def nullable: Boolean = true + override def nullable: Boolean = dataType match { + case CalendarIntervalType if !checkOverflow => true + case _ => super.nullable + } override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) @@ -291,7 +294,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti override def nullable: Boolean = dataType match { case CalendarIntervalType if !checkOverflow => true - case _ => false + case _ => super.nullable } override def inputType: AbstractDataType = TypeCollection.NumericAndInterval From 729337756261a26a2dc376ede9db05a52f238c76 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 24 Dec 2019 15:32:12 +0800 Subject: [PATCH 04/19] fix tests --- .../sql/catalyst/expressions/arithmetic.scala | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 68d94237c639..6f35d0ef7172 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -38,8 +38,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression private val checkOverflow = SQLConf.get.ansiEnabled override def nullable: Boolean = dataType match { - case CalendarIntervalType if !checkOverflow => true - case _ => super.nullable + case CalendarIntervalType => true + case _ => child.nullable } override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) @@ -97,11 +97,11 @@ case class UnaryMinus(child: Expression) extends UnaryExpression case CalendarIntervalType => try { IntervalUtils.negate(input.asInstanceOf[CalendarInterval]) - } catch { - case _: ArithmeticException if checkOverflow => - throw new ArithmeticException(s"$sql caused interval overflow") - case _: ArithmeticException => null - } + } catch { + case _: ArithmeticException if checkOverflow => + throw new ArithmeticException(s"$sql caused interval overflow") + case _: ArithmeticException => null + } case _ => numeric.negate(input) } @@ -253,7 +253,7 @@ object BinaryArithmetic { case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = dataType match { - case CalendarIntervalType if !checkOverflow => true + case CalendarIntervalType => true case _ => super.nullable } @@ -293,7 +293,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = dataType match { - case CalendarIntervalType if !checkOverflow => true + case CalendarIntervalType => true case _ => super.nullable } From f100d8825684885ccb9518da3d252a38fdd51517 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 24 Dec 2019 19:36:14 +0800 Subject: [PATCH 05/19] use old overflow style --- .../sql/catalyst/expressions/arithmetic.scala | 84 ++++------------ .../expressions/intervalExpressions.scala | 39 +++++--- .../sql/catalyst/util/IntervalUtils.scala | 97 ++++++++++++++++++- .../catalyst/util/IntervalUtilsSuite.scala | 93 +++++++++++------- .../resources/sql-tests/inputs/interval.sql | 2 +- .../sql-tests/results/ansi/interval.sql.out | 22 ++--- .../sql-tests/results/interval.sql.out | 30 +++--- 7 files changed, 229 insertions(+), 138 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 6f35d0ef7172..f12892f41175 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -37,11 +37,6 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes with NullIntolerant { private val checkOverflow = SQLConf.get.ansiEnabled - override def nullable: Boolean = dataType match { - case CalendarIntervalType => true - case _ => child.nullable - } - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) override def dataType: DataType = child.dataType @@ -80,28 +75,19 @@ case class UnaryMinus(child: Expression) extends UnaryExpression """}) case _: CalendarIntervalType => val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") - nullSafeCodeGen(ctx, ev, interval => s""" - try { - ${ev.value} = $iu.negate($interval); - } catch (ArithmeticException e) { - if ($checkOverflow) { - throw new ArithmeticException("-($interval) caused interval overflow."); - } else { - ${ev.isNull} = true; - } + defineCodeGen(ctx, ev, + interval => if (checkOverflow) { + s"$iu.negate($interval)" + } else { + s"$iu.safeNegate($interval)" } - """) + ) } protected override def nullSafeEval(input: Any): Any = dataType match { - case CalendarIntervalType => - try { + case CalendarIntervalType if checkOverflow => IntervalUtils.negate(input.asInstanceOf[CalendarInterval]) - } catch { - case _: ArithmeticException if checkOverflow => - throw new ArithmeticException(s"$sql caused interval overflow") - case _: ArithmeticException => null - } + case CalendarIntervalType => IntervalUtils.safeNegate(input.asInstanceOf[CalendarInterval]) case _ => numeric.negate(input) } @@ -182,19 +168,7 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") case CalendarIntervalType => val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") - nullSafeCodeGen(ctx, ev, (eval1, eval2) => - s""" - |try { - | ${ev.value} = $iu.$calendarIntervalMethod($eval1, $eval2); - |} catch (ArithmeticException e) { - | if ($checkOverflow) { - | throw new ArithmeticException( - | "$eval1 $calendarIntervalMethod $eval2 caused interval overflow."); - | } else { - | ${ev.isNull} = true; - | } - |} - |""".stripMargin) + defineCodeGen(ctx, ev, (eval1, eval2) => s"$iu.$calendarIntervalMethod($eval1, $eval2)") // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { @@ -252,31 +226,23 @@ object BinaryArithmetic { """) case class Add(left: Expression, right: Expression) extends BinaryArithmetic { - override def nullable: Boolean = dataType match { - case CalendarIntervalType => true - case _ => super.nullable - } - override def inputType: AbstractDataType = TypeCollection.NumericAndInterval override def symbol: String = "+" override def decimalMethod: String = "$plus" - override def calendarIntervalMethod: String = "add" + override def calendarIntervalMethod: String = if (checkOverflow) "add" else "safeAdd" private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow) protected override def nullSafeEval(input1: Any, input2: Any): Any = dataType match { + case CalendarIntervalType if checkOverflow => + IntervalUtils.add( + input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval]) case CalendarIntervalType => - try { - IntervalUtils.add( - input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval]) - } catch { - case _: ArithmeticException if checkOverflow => - throw new ArithmeticException(s"$sql causes interval overflow") - case _: ArithmeticException => null - } + IntervalUtils.safeAdd( + input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval]) case _ => numeric.plus(input1, input2) } @@ -292,31 +258,23 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { """) case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { - override def nullable: Boolean = dataType match { - case CalendarIntervalType => true - case _ => super.nullable - } - override def inputType: AbstractDataType = TypeCollection.NumericAndInterval override def symbol: String = "-" override def decimalMethod: String = "$minus" - override def calendarIntervalMethod: String = "subtract" + override def calendarIntervalMethod: String = if (checkOverflow) "subtract" else "safeSubtract" private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow) protected override def nullSafeEval(input1: Any, input2: Any): Any = dataType match { + case CalendarIntervalType if checkOverflow => + IntervalUtils.subtract( + input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval]) case CalendarIntervalType => - try { - IntervalUtils.subtract( - input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval]) - } catch { - case _: ArithmeticException if checkOverflow => - throw new ArithmeticException(s"$sql caused interval overflow") - case _: ArithmeticException => null - } + IntervalUtils.safeSubtract( + input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval]) case _ => numeric.minus(input1, input2) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index d911365dc827..643027dc24b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -113,13 +113,14 @@ object ExtractIntervalPart { } } -abstract class IntervalNumOperation( - interval: Expression, - num: Expression, - operation: (CalendarInterval, Double) => CalendarInterval, - operationName: String) +abstract class IntervalNumOperation(interval: Expression, num: Expression) extends BinaryExpression with ImplicitCastInputTypes with Serializable { - private val checkOverflow = SQLConf.get.ansiEnabled + + protected val checkOverflow: Boolean = SQLConf.get.ansiEnabled + + protected def operation(interval: CalendarInterval, num: Double): CalendarInterval + + protected val operationName: String override def left: Expression = interval override def right: Expression = num @@ -133,9 +134,7 @@ abstract class IntervalNumOperation( try { operation(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) } catch { - case _: ArithmeticException if checkOverflow => - throw new ArithmeticException(s"$sql caused interval overflow.") - case _: ArithmeticException => null + case _: ArithmeticException if (!checkOverflow) => null } } @@ -147,7 +146,7 @@ abstract class IntervalNumOperation( ${ev.value} = $iu.$operationName($interval, $num); } catch (ArithmeticException e) { if ($checkOverflow) { - throw new ArithmeticException("$prettyName($interval, $num) caused interval overflow."); + throw e; } else { ${ev.isNull} = true; } @@ -156,14 +155,28 @@ abstract class IntervalNumOperation( }) } - override def prettyName: String = operationName + "_interval" + override def prettyName: String = operationName.stripPrefix("safe").toLowerCase() + "_interval" } case class MultiplyInterval(interval: Expression, num: Expression) - extends IntervalNumOperation(interval, num, multiply, "multiply") + extends IntervalNumOperation(interval, num) { + + override protected def operation(interval: CalendarInterval, num: Double): CalendarInterval = { + if (checkOverflow) multiply(interval, num) else safeMultiply(interval, num) + } + + override protected val operationName: String = if (checkOverflow) "multiply" else "safeMultiply" +} case class DivideInterval(interval: Expression, num: Expression) - extends IntervalNumOperation(interval, num, divide, "divide") + extends IntervalNumOperation(interval, num) { + + override protected def operation(interval: CalendarInterval, num: Double): CalendarInterval = { + if (checkOverflow) divide(interval, num) else safeDivide(interval, num) + } + + override protected val operationName: String = if (checkOverflow) "divide" else "safeDivide" +} // scalastyle:off line.size.limit @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index 88119a2c5318..b1eec68a4395 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -401,6 +401,8 @@ object IntervalUtils { /** * Makes an interval from months, days and micros with the fractional part by * adding the month fraction to days and the days fraction to micros. + * + * @throws ArithmeticException if the result overflows any field value */ private def fromDoubles( monthsWithFraction: Double, @@ -413,11 +415,41 @@ object IntervalUtils { new CalendarInterval(truncatedMonths, truncatedDays, micros.round) } + /** + * Makes an interval from months, days and micros with the fractional part by + * adding the month fraction to days and the days fraction to micros. + */ + private def safeFromDoubles( + monthsWithFraction: Double, + daysWithFraction: Double, + microsWithFraction: Double): CalendarInterval = { + val monthInLong = monthsWithFraction.toLong + val truncatedMonths = if (monthInLong > Int.MaxValue) { + Int.MaxValue + } else if (monthInLong < Int.MinValue) { + Int.MinValue + } else { + monthInLong.toInt + } + val days = daysWithFraction + DAYS_PER_MONTH * (monthsWithFraction - truncatedMonths) + val dayInLong = days.toLong + val truncatedDays = if (dayInLong > Int.MaxValue) { + Int.MaxValue + } else if (monthInLong < Int.MinValue) { + Int.MinValue + } else { + dayInLong.toInt + } + val micros = microsWithFraction + MICROS_PER_DAY * (days - truncatedDays) + new CalendarInterval(truncatedMonths, truncatedDays.toInt, micros.round) + } + /** * Unary minus, return the negated the calendar interval value. * * @param interval the interval to be negated * @return a new calendar interval instance with all it parameters negated from the origin one. + * @throws ArithmeticException if the result overflows any field value */ def negate(interval: CalendarInterval): CalendarInterval = { val months = Math.negateExact(interval.months) @@ -426,8 +458,21 @@ object IntervalUtils { new CalendarInterval(months, days, microseconds) } + /** + * Unary minus, return the negated the calendar interval value. + * + * @param interval the interval to be negated + * @return a new calendar interval instance with all it parameters negated from the origin one. + */ + def safeNegate(interval: CalendarInterval): CalendarInterval = { + new CalendarInterval(-interval.months, -interval.days, -interval.microseconds) + } + /** * Return a new calendar interval instance of the sum of two intervals. + * + * @throws ArithmeticException if the result overflows any field value + * */ def add(left: CalendarInterval, right: CalendarInterval): CalendarInterval = { val months = Math.addExact(left.months, right.months) @@ -437,7 +482,20 @@ object IntervalUtils { } /** - * Return a new calendar interval instance of the left intervals minus the right one. + * Return a new calendar interval instance of the sum of two intervals. + */ + def safeAdd(left: CalendarInterval, right: CalendarInterval): CalendarInterval = { + val months = left.months + right.months + val days = left.days + right.days + val microseconds = left.microseconds + right.microseconds + new CalendarInterval(months, days, microseconds) + } + + /** + * Return a new calendar interval instance of the left interval minus the right one. + * + * @throws ArithmeticException if the result overflows any field value + * */ def subtract(left: CalendarInterval, right: CalendarInterval): CalendarInterval = { val months = Math.subtractExact(left.months, right.months) @@ -446,15 +504,52 @@ object IntervalUtils { new CalendarInterval(months, days, microseconds) } + /** + * Return a new calendar interval instance of the left interval minus the right one. + */ + def safeSubtract(left: CalendarInterval, right: CalendarInterval): CalendarInterval = { + val months = left.months - right.months + val days = left.days - right.days + val microseconds = left.microseconds - right.microseconds + new CalendarInterval(months, days, microseconds) + } + + /** + * Return a new calendar interval instance of the left interval times a multiplier. + * + * @throws ArithmeticException if the result overflows any field value + */ def multiply(interval: CalendarInterval, num: Double): CalendarInterval = { fromDoubles(num * interval.months, num * interval.days, num * interval.microseconds) } + /** + * Return a new calendar interval instance of the left interval times a multiplier. + */ + def safeMultiply(interval: CalendarInterval, num: Double): CalendarInterval = { + safeFromDoubles(num * interval.months, num * interval.days, num * interval.microseconds) + } + + /** + * Return a new calendar interval instance of the left interval divides by a dividend. + * + * @throws ArithmeticException if the result overflows any field value or divided by zero + */ def divide(interval: CalendarInterval, num: Double): CalendarInterval = { if (num == 0) throw new ArithmeticException("divide by zero") fromDoubles(interval.months / num, interval.days / num, interval.microseconds / num) } + /** + * Return a new calendar interval instance of the left interval divides by a dividend. + * + * @throws ArithmeticException if divided by zero + */ + def safeDivide(interval: CalendarInterval, num: Double): CalendarInterval = { + if (num == 0) throw new ArithmeticException("divide by zero") + safeFromDoubles(interval.months / num, interval.days / num, interval.microseconds / num) + } + // `toString` implementation in CalendarInterval is the multi-units format currently. def toMultiUnitsString(interval: CalendarInterval): String = interval.toString diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala index 6fc377005770..2669406cc348 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala @@ -240,63 +240,73 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { test("negate") { assert(negate(new CalendarInterval(1, 2, 3)) === new CalendarInterval(-1, -2, -3)) + assert(safeNegate(new CalendarInterval(1, 2, 3)) === new CalendarInterval(-1, -2, -3)) } test("subtract one interval by another") { val input1 = new CalendarInterval(3, 1, 1 * MICROS_PER_HOUR) val input2 = new CalendarInterval(2, 4, 100 * MICROS_PER_HOUR) - assert(new CalendarInterval(1, -3, -99 * MICROS_PER_HOUR) === subtract(input1, input2)) val input3 = new CalendarInterval(-10, -30, -81 * MICROS_PER_HOUR) val input4 = new CalendarInterval(75, 150, 200 * MICROS_PER_HOUR) - assert(new CalendarInterval(-85, -180, -281 * MICROS_PER_HOUR) === subtract(input3, input4)) + Seq[(CalendarInterval, CalendarInterval) => CalendarInterval](subtract, safeSubtract) + .foreach { func => + assert(new CalendarInterval(1, -3, -99 * MICROS_PER_HOUR) === func(input1, input2)) + assert(new CalendarInterval(-85, -180, -281 * MICROS_PER_HOUR) === func(input3, input4)) + } } test("add two intervals") { val input1 = new CalendarInterval(3, 1, 1 * MICROS_PER_HOUR) val input2 = new CalendarInterval(2, 4, 100 * MICROS_PER_HOUR) - assert(new CalendarInterval(5, 5, 101 * MICROS_PER_HOUR) === add(input1, input2)) - val input3 = new CalendarInterval(-10, -30, -81 * MICROS_PER_HOUR) val input4 = new CalendarInterval(75, 150, 200 * MICROS_PER_HOUR) - assert(new CalendarInterval(65, 120, 119 * MICROS_PER_HOUR) === add(input3, input4)) + Seq[(CalendarInterval, CalendarInterval) => CalendarInterval](add, safeAdd).foreach { func => + assert(new CalendarInterval(5, 5, 101 * MICROS_PER_HOUR) === func(input1, input2)) + assert(new CalendarInterval(65, 120, 119 * MICROS_PER_HOUR) === func(input3, input4)) + } } test("multiply by num") { - var interval = new CalendarInterval(0, 0, 0) - assert(interval === multiply(interval, 0)) - interval = new CalendarInterval(123, 456, 789) - assert(new CalendarInterval(123 * 42, 456 * 42, 789 * 42) === multiply(interval, 42)) - interval = new CalendarInterval(-123, -456, -789) - assert(new CalendarInterval(-123 * 42, -456 * 42, -789 * 42) === multiply(interval, 42)) - assert(new CalendarInterval(1, 22, 12 * MICROS_PER_HOUR) === - multiply(new CalendarInterval(1, 5, 0), 1.5)) - assert(new CalendarInterval(2, 14, 12 * MICROS_PER_HOUR) === - multiply(new CalendarInterval(2, 2, 2 * MICROS_PER_HOUR), 1.2)) + Seq[(CalendarInterval, Double) => CalendarInterval](multiply, safeMultiply).foreach { func => + var interval = new CalendarInterval(0, 0, 0) + assert(interval === func(interval, 0)) + interval = new CalendarInterval(123, 456, 789) + assert(new CalendarInterval(123 * 42, 456 * 42, 789 * 42) === func(interval, 42)) + interval = new CalendarInterval(-123, -456, -789) + assert(new CalendarInterval(-123 * 42, -456 * 42, -789 * 42) === func(interval, 42)) + assert(new CalendarInterval(1, 22, 12 * MICROS_PER_HOUR) === + func(new CalendarInterval(1, 5, 0), 1.5)) + assert(new CalendarInterval(2, 14, 12 * MICROS_PER_HOUR) === + func(new CalendarInterval(2, 2, 2 * MICROS_PER_HOUR), 1.2)) + } + + assert(CalendarInterval.MAX_VALUE === + safeMultiply(new CalendarInterval(2, 0, 0), Integer.MAX_VALUE)) try { multiply(new CalendarInterval(2, 0, 0), Integer.MAX_VALUE) fail("Expected to throw an exception on months overflow") } catch { - case e: ArithmeticException => - assert(e.getMessage.contains("overflow")) + case e: ArithmeticException => assert(e.getMessage.contains("overflow")) } } test("divide by num") { - var interval = new CalendarInterval(0, 0, 0) - assert(interval === divide(interval, 10)) - interval = new CalendarInterval(1, 3, 30 * MICROS_PER_SECOND) - assert(new CalendarInterval(0, 16, 12 * MICROS_PER_HOUR + 15 * MICROS_PER_SECOND) === - divide(interval, 2)) - assert(new CalendarInterval(2, 6, MICROS_PER_MINUTE) === divide(interval, 0.5)) - interval = new CalendarInterval(-1, 0, -30 * MICROS_PER_SECOND) - assert(new CalendarInterval(0, -15, -15 * MICROS_PER_SECOND) === divide(interval, 2)) - assert(new CalendarInterval(-2, 0, -1 * MICROS_PER_MINUTE) === divide(interval, 0.5)) - try { - divide(new CalendarInterval(123, 456, 789), 0) - fail("Expected to throw an exception on divide by zero") - } catch { - case e: ArithmeticException => - assert(e.getMessage.contains("divide by zero")) + Seq[(CalendarInterval, Double) => CalendarInterval](divide, safeDivide).foreach { func => + var interval = new CalendarInterval(0, 0, 0) + assert(interval === func(interval, 10)) + interval = new CalendarInterval(1, 3, 30 * MICROS_PER_SECOND) + assert(new CalendarInterval(0, 16, 12 * MICROS_PER_HOUR + 15 * MICROS_PER_SECOND) === + func(interval, 2)) + assert(new CalendarInterval(2, 6, MICROS_PER_MINUTE) === func(interval, 0.5)) + interval = new CalendarInterval(-1, 0, -30 * MICROS_PER_SECOND) + assert(new CalendarInterval(0, -15, -15 * MICROS_PER_SECOND) === func(interval, 2)) + assert(new CalendarInterval(-2, 0, -1 * MICROS_PER_MINUTE) === func(interval, 0.5)) + try { + func(new CalendarInterval(123, 456, 789), 0) + fail("Expected to throw an exception on divide by zero") + } catch { + case e: ArithmeticException => assert(e.getMessage.contains("divide by zero")) + } } } @@ -448,11 +458,19 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { test("interval overflow check") { intercept[ArithmeticException](negate(new CalendarInterval(Int.MinValue, 0, 0))) + assert(safeNegate(new CalendarInterval(Int.MinValue, 0, 0)) === + new CalendarInterval(Int.MinValue, 0, 0)) intercept[ArithmeticException](negate(CalendarInterval.MIN_VALUE)) - + assert(safeNegate(CalendarInterval.MIN_VALUE) === CalendarInterval.MIN_VALUE) intercept[ArithmeticException](add(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, 1))) intercept[ArithmeticException](add(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 1, 0))) intercept[ArithmeticException](add(CalendarInterval.MAX_VALUE, new CalendarInterval(1, 0, 0))) + assert(safeAdd(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, 1)) === + new CalendarInterval(Int.MaxValue, Int.MaxValue, Long.MinValue)) + assert(safeAdd(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 1, 0)) === + new CalendarInterval(Int.MaxValue, Int.MinValue, Long.MaxValue)) + assert(safeAdd(CalendarInterval.MAX_VALUE, new CalendarInterval(1, 0, 0)) === + new CalendarInterval(Int.MinValue, Int.MaxValue, Long.MaxValue)) intercept[ArithmeticException](subtract(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, -1))) @@ -460,9 +478,16 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { new CalendarInterval(0, -1, 0))) intercept[ArithmeticException](subtract(CalendarInterval.MAX_VALUE, new CalendarInterval(-1, 0, 0))) + assert(safeSubtract(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, -1)) === + new CalendarInterval(Int.MaxValue, Int.MaxValue, Long.MinValue)) + assert(safeSubtract(CalendarInterval.MAX_VALUE, new CalendarInterval(0, -1, 0)) === + new CalendarInterval(Int.MaxValue, Int.MinValue, Long.MaxValue)) + assert(safeSubtract(CalendarInterval.MAX_VALUE, new CalendarInterval(-1, 0, 0)) === + new CalendarInterval(Int.MinValue, Int.MaxValue, Long.MaxValue)) intercept[ArithmeticException](multiply(CalendarInterval.MAX_VALUE, 2)) - + assert(safeMultiply(CalendarInterval.MAX_VALUE, 2) === CalendarInterval.MAX_VALUE) intercept[ArithmeticException](divide(CalendarInterval.MAX_VALUE, 0.5)) + assert(safeDivide(CalendarInterval.MAX_VALUE, 0.5) === CalendarInterval.MAX_VALUE) } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/interval.sql b/sql/core/src/test/resources/sql-tests/inputs/interval.sql index 794d7cdb9a5f..6a98d3ecad13 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/interval.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/interval.sql @@ -274,5 +274,5 @@ select interval '1 ' day; select -(a) from values (interval '-2147483648 months', interval '2147483647 months') t(a, b); select a - b from values (interval '-2147483648 months', interval '2147483647 months') t(a, b); select b + interval '1 month' from values (interval '-2147483648 months', interval '2147483647 months') t(a, b); -select a * 2 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b); +select a * 1.1 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b); select a / 0.5 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index 7314174d141f..eadb45012edc 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -210,7 +210,7 @@ select interval '2 seconds' / 0 struct<> -- !query 25 output java.lang.ArithmeticException -divide_interval(INTERVAL '2 seconds', CAST(0 AS DOUBLE)) caused interval overflow. +divide by zero -- !query 26 @@ -998,7 +998,7 @@ select avg(cast(v as interval)) from VALUES (null) t(v) struct<> -- !query 105 output java.lang.ArithmeticException -divide_interval(agg_mutableStateArray_0[0], agg_value_4) caused interval overflow. +divide by zero -- !query 106 @@ -1007,7 +1007,7 @@ select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) struct<> -- !query 106 output java.lang.ArithmeticException -divide_interval(agg_mutableStateArray_0[0], agg_value_4) caused interval overflow. +divide by zero -- !query 107 @@ -1052,7 +1052,7 @@ group by i struct<> -- !query 111 output java.lang.ArithmeticException -divide_interval(agg_value_9, agg_value_13) caused interval overflow. +divide by zero -- !query 112 @@ -1075,7 +1075,7 @@ FROM VALUES (1,'1 seconds'), (1,'2 seconds'), (2,NULL), (2,NULL) t(i,v) struct<> -- !query 113 output java.lang.ArithmeticException -divide_interval(value_1, value_2) caused interval overflow. +divide by zero -- !query 114 @@ -1221,7 +1221,7 @@ select -(a) from values (interval '-2147483648 months', interval '2147483647 mon struct<> -- !query 126 output java.lang.ArithmeticException --(localtablescan_value_0) caused interval overflow. +integer overflow -- !query 127 @@ -1230,7 +1230,7 @@ select a - b from values (interval '-2147483648 months', interval '2147483647 mo struct<> -- !query 127 output java.lang.ArithmeticException -localtablescan_value_0 subtract localtablescan_value_1 caused interval overflow. +integer overflow -- !query 128 @@ -1239,16 +1239,16 @@ select b + interval '1 month' from values (interval '-2147483648 months', interv struct<> -- !query 128 output java.lang.ArithmeticException -localtablescan_value_1 add ((CalendarInterval) references[1] /* literal */) caused interval overflow. +integer overflow -- !query 129 -select a * 2 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) +select a * 1.1 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) -- !query 129 schema struct<> -- !query 129 output java.lang.ArithmeticException -multiply_interval(localtablescan_value_0, 2.0D) caused interval overflow. +integer overflow -- !query 130 @@ -1257,7 +1257,7 @@ select a / 0.5 from values (interval '-2147483648 months', interval '2147483647 struct<> -- !query 130 output java.lang.ArithmeticException -divide_interval(localtablescan_value_0, 0.5D) caused interval overflow. +integer overflow -- !query 131 diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out index 13722d942ec2..664cc85ef85e 100644 --- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out @@ -183,7 +183,7 @@ struct -- !query 22 select 3 * (timestamp'2019-10-15 10:11:12.001002' - date'2019-10-15') -- !query 22 schema -struct +struct -- !query 22 output 30 hours 33 minutes 36.003006 seconds @@ -191,7 +191,7 @@ struct +struct -- !query 23 output 6 months 21 days 0.000005 seconds @@ -199,7 +199,7 @@ struct +struct -- !query 24 output 16 hours @@ -207,7 +207,7 @@ struct +struct -- !query 25 output NULL @@ -215,7 +215,7 @@ NULL -- !query 26 select interval '2 seconds' / null -- !query 26 schema -struct +struct -- !query 26 output NULL @@ -223,7 +223,7 @@ NULL -- !query 27 select interval '2 seconds' * null -- !query 27 schema -struct +struct -- !query 27 output NULL @@ -231,7 +231,7 @@ NULL -- !query 28 select null * interval '2 seconds' -- !query 28 schema -struct +struct -- !query 28 output NULL @@ -1204,7 +1204,7 @@ select -(a) from values (interval '-2147483648 months', interval '2147483647 mon -- !query 126 schema struct<(- a):interval> -- !query 126 output -NULL +-178956970 years -8 months -- !query 127 @@ -1212,7 +1212,7 @@ select a - b from values (interval '-2147483648 months', interval '2147483647 mo -- !query 127 schema struct<(a - b):interval> -- !query 127 output -NULL +1 months -- !query 128 @@ -1220,20 +1220,20 @@ select b + interval '1 month' from values (interval '-2147483648 months', interv -- !query 128 schema struct<(b + INTERVAL '1 months'):interval> -- !query 128 output -NULL +-178956970 years -8 months -- !query 129 -select a * 2 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) +select a * 1.1 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) -- !query 129 schema -struct +struct -- !query 129 output -NULL +-178956970 years -8 months -2147483648 days -2562047788 hours -54.775808 seconds -- !query 130 select a / 0.5 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) -- !query 130 schema -struct +struct -- !query 130 output -NULL +-178956970 years -8 months -2147483648 days -2562047788 hours -54.775808 seconds From ba44c5a7a5cee3e46464103fb31e36aaa2a2c12d Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 24 Dec 2019 20:17:31 +0800 Subject: [PATCH 06/19] workaround for avg --- .../expressions/aggregate/Average.scala | 2 +- .../expressions/intervalExpressions.scala | 7 ++++-- .../sql-tests/results/ansi/interval.sql.out | 25 ++++++++++--------- .../sql-tests/results/interval.sql.out | 18 ++++++------- 4 files changed, 28 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 9bb048a9851e..afbbac293ca8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -81,7 +81,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit case _: DecimalType => DecimalPrecision.decimalAndDecimal(sum / count.cast(DecimalType.LongDecimal)).cast(resultType) case CalendarIntervalType => - DivideInterval(sum.cast(resultType), count.cast(DoubleType)) + DivideInterval(sum.cast(resultType), count.cast(DoubleType), false) case _ => sum.cast(resultType) / count.cast(resultType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 643027dc24b9..e436ebdbf5c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -134,7 +134,7 @@ abstract class IntervalNumOperation(interval: Expression, num: Expression) try { operation(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) } catch { - case _: ArithmeticException if (!checkOverflow) => null + case _: ArithmeticException if !checkOverflow => null } } @@ -168,7 +168,10 @@ case class MultiplyInterval(interval: Expression, num: Expression) override protected val operationName: String = if (checkOverflow) "multiply" else "safeMultiply" } -case class DivideInterval(interval: Expression, num: Expression) +case class DivideInterval( + interval: Expression, + num: Expression, + override val checkOverflow: Boolean = SQLConf.get.ansiEnabled) extends IntervalNumOperation(interval, num) { override protected def operation(interval: CalendarInterval, num: Double): CalendarInterval = { diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index eadb45012edc..0fcabc8e977b 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -995,19 +995,17 @@ struct +struct -- !query 105 output -java.lang.ArithmeticException -divide by zero +NULL -- !query 106 select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where 1=0 -- !query 106 schema -struct<> +struct -- !query 106 output -java.lang.ArithmeticException -divide by zero +NULL -- !query 107 @@ -1049,10 +1047,11 @@ select from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v) group by i -- !query 111 schema -struct<> +struct -- !query 111 output -java.lang.ArithmeticException -divide by zero +1 -1 days +2 2 seconds +3 NULL -- !query 112 @@ -1072,10 +1071,12 @@ SELECT avg(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) FROM VALUES (1,'1 seconds'), (1,'2 seconds'), (2,NULL), (2,NULL) t(i,v) -- !query 113 schema -struct<> +struct -- !query 113 output -java.lang.ArithmeticException -divide by zero +1 1.5 seconds +1 2 seconds +2 NULL +2 NULL -- !query 114 diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out index 664cc85ef85e..178d462dc3ad 100644 --- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out @@ -183,7 +183,7 @@ struct -- !query 22 select 3 * (timestamp'2019-10-15 10:11:12.001002' - date'2019-10-15') -- !query 22 schema -struct +struct -- !query 22 output 30 hours 33 minutes 36.003006 seconds @@ -191,7 +191,7 @@ struct +struct -- !query 23 output 6 months 21 days 0.000005 seconds @@ -199,7 +199,7 @@ struct +struct -- !query 24 output 16 hours @@ -207,7 +207,7 @@ struct +struct -- !query 25 output NULL @@ -215,7 +215,7 @@ NULL -- !query 26 select interval '2 seconds' / null -- !query 26 schema -struct +struct -- !query 26 output NULL @@ -223,7 +223,7 @@ NULL -- !query 27 select interval '2 seconds' * null -- !query 27 schema -struct +struct -- !query 27 output NULL @@ -231,7 +231,7 @@ NULL -- !query 28 select null * interval '2 seconds' -- !query 28 schema -struct +struct -- !query 28 output NULL @@ -1226,7 +1226,7 @@ struct<(b + INTERVAL '1 months'):interval> -- !query 129 select a * 1.1 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) -- !query 129 schema -struct +struct -- !query 129 output -178956970 years -8 months -2147483648 days -2562047788 hours -54.775808 seconds @@ -1234,6 +1234,6 @@ struct -- !query 130 select a / 0.5 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) -- !query 130 schema -struct +struct -- !query 130 output -178956970 years -8 months -2147483648 days -2562047788 hours -54.775808 seconds From 67645d4d3415735fcd464338ed6d0dcd52c32088 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 24 Dec 2019 20:46:12 +0800 Subject: [PATCH 07/19] divide 0 return null as other types --- .../expressions/aggregate/Average.scala | 2 +- .../expressions/intervalExpressions.scala | 69 ++++++++++++------- .../sql-tests/results/ansi/interval.sql.out | 5 +- 3 files changed, 49 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index afbbac293ca8..9bb048a9851e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -81,7 +81,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit case _: DecimalType => DecimalPrecision.decimalAndDecimal(sum / count.cast(DecimalType.LongDecimal)).cast(resultType) case CalendarIntervalType => - DivideInterval(sum.cast(resultType), count.cast(DoubleType), false) + DivideInterval(sum.cast(resultType), count.cast(DoubleType)) case _ => sum.cast(resultType) / count.cast(resultType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index e436ebdbf5c0..f0793ccb78dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -118,10 +118,6 @@ abstract class IntervalNumOperation(interval: Expression, num: Expression) protected val checkOverflow: Boolean = SQLConf.get.ansiEnabled - protected def operation(interval: CalendarInterval, num: Double): CalendarInterval - - protected val operationName: String - override def left: Expression = interval override def right: Expression = num @@ -129,10 +125,20 @@ abstract class IntervalNumOperation(interval: Expression, num: Expression) override def dataType: DataType = CalendarIntervalType override def nullable: Boolean = true +} + +case class MultiplyInterval(interval: Expression, num: Expression) + extends IntervalNumOperation(interval, num) { + + override def prettyName: String = "multiply_interval" override def nullSafeEval(interval: Any, num: Any): Any = { try { - operation(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) + if (checkOverflow) { + multiply(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) + } else { + safeMultiply(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) + } } catch { case _: ArithmeticException if !checkOverflow => null } @@ -141,6 +147,7 @@ abstract class IntervalNumOperation(interval: Expression, num: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (interval, num) => { val iu = IntervalUtils.getClass.getName.stripSuffix("$") + val operationName = if (checkOverflow) "multiply" else "safeMultiply" s""" try { ${ev.value} = $iu.$operationName($interval, $num); @@ -154,31 +161,47 @@ abstract class IntervalNumOperation(interval: Expression, num: Expression) """ }) } - - override def prettyName: String = operationName.stripPrefix("safe").toLowerCase() + "_interval" } -case class MultiplyInterval(interval: Expression, num: Expression) +case class DivideInterval(interval: Expression, num: Expression) extends IntervalNumOperation(interval, num) { - override protected def operation(interval: CalendarInterval, num: Double): CalendarInterval = { - if (checkOverflow) multiply(interval, num) else safeMultiply(interval, num) - } + override def prettyName: String = "divide_interval" - override protected val operationName: String = if (checkOverflow) "multiply" else "safeMultiply" -} - -case class DivideInterval( - interval: Expression, - num: Expression, - override val checkOverflow: Boolean = SQLConf.get.ansiEnabled) - extends IntervalNumOperation(interval, num) { - - override protected def operation(interval: CalendarInterval, num: Double): CalendarInterval = { - if (checkOverflow) divide(interval, num) else safeDivide(interval, num) + override def nullSafeEval(interval: Any, num: Any): Any = { + try { + if (num == 0) return null + if (checkOverflow) { + divide(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) + } else { + safeDivide(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) + } + } catch { + case _: ArithmeticException if !checkOverflow => null + } } - override protected val operationName: String = if (checkOverflow) "divide" else "safeDivide" + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (interval, num) => { + val iu = IntervalUtils.getClass.getName.stripSuffix("$") + val operationName = if (checkOverflow) "divide" else "safeDivide" + s""" + try { + if ($num == 0) { + ${ev.isNull} = true; + } else { + ${ev.value} = $iu.$operationName($interval, $num); + } + } catch (ArithmeticException e) { + if ($checkOverflow) { + throw e; + } else { + ${ev.isNull} = true; + } + } + """ + }) + } } // scalastyle:off line.size.limit diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index 0fcabc8e977b..fe8e4cfe78ff 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -207,10 +207,9 @@ struct +struct -- !query 25 output -java.lang.ArithmeticException -divide by zero +NULL -- !query 26 From 7671d83bb1ccc67fc2beef4a3a26fa44e06da8e7 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Wed, 25 Dec 2019 01:11:22 +0800 Subject: [PATCH 08/19] fix tests --- .../IntervalExpressionsSuite.scala | 29 ++++++++++++++----- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala index ddcb6a66832a..780744ba8ae4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala @@ -22,6 +22,7 @@ import scala.language.implicitConversions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.IntervalUtils.stringToInterval +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.Decimal import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -197,10 +198,14 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("multiply") { - def check(interval: String, num: Double, expected: String): Unit = { - checkEvaluation( - MultiplyInterval(Literal(stringToInterval(interval)), Literal(num)), - if (expected == null) null else stringToInterval(expected)) + def check(interval: String, num: Double, expected: String, + modes: Seq[String] = Seq("true", "false")): Unit = { + modes.foreach { v => + withSQLConf(SQLConf.ANSI_ENABLED.key -> v) { + checkEvaluation(MultiplyInterval(Literal(stringToInterval(interval)), Literal(num)), + if (expected == null) null else stringToInterval(expected)) + } + } } check("0 seconds", 10, "0 seconds") @@ -211,14 +216,22 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { check("-100 years -1 millisecond", 0.5, "-50 years -500 microseconds") check("2 months 4 seconds", -0.5, "-1 months -2 seconds") check("1 month 2 microseconds", 1.5, "1 months 15 days 3 microseconds") - check("2 months", Int.MaxValue, null) + check("2 months", Int.MaxValue, CalendarInterval.MAX_VALUE.toString, Seq("false")) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + checkExceptionInExpression[ArithmeticException]( + MultiplyInterval(Literal(stringToInterval("2 months")), Literal(Int.MaxValue.toDouble)), + "integer overflow") + } } test("divide") { def check(interval: String, num: Double, expected: String): Unit = { - checkEvaluation( - DivideInterval(Literal(stringToInterval(interval)), Literal(num)), - if (expected == null) null else stringToInterval(expected)) + Seq("true", "false").foreach { v => + withSQLConf(SQLConf.ANSI_ENABLED.key -> v) { + checkEvaluation(DivideInterval(Literal(stringToInterval(interval)), Literal(num)), + if (expected == null) null else stringToInterval(expected)) + } + } } check("0 seconds", 10, "0 seconds") From b6793816c16da1eea33dbb9d63e90a9b25394d45 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 27 Dec 2019 15:17:39 +0800 Subject: [PATCH 09/19] name with exact --- .../sql/catalyst/expressions/arithmetic.scala | 20 +++---- .../expressions/intervalExpressions.scala | 12 ++-- .../sql/catalyst/util/IntervalUtils.scala | 22 +++---- .../CollectionExpressionsSuite.scala | 12 ++-- .../catalyst/util/IntervalUtilsSuite.scala | 57 ++++++++++--------- 5 files changed, 63 insertions(+), 60 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index f12892f41175..cb3b16d751e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -77,17 +77,17 @@ case class UnaryMinus(child: Expression) extends UnaryExpression val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") defineCodeGen(ctx, ev, interval => if (checkOverflow) { - s"$iu.negate($interval)" + s"$iu.negateExact($interval)" } else { - s"$iu.safeNegate($interval)" + s"$iu.negate($interval)" } ) } protected override def nullSafeEval(input: Any): Any = dataType match { case CalendarIntervalType if checkOverflow => - IntervalUtils.negate(input.asInstanceOf[CalendarInterval]) - case CalendarIntervalType => IntervalUtils.safeNegate(input.asInstanceOf[CalendarInterval]) + IntervalUtils.negateExact(input.asInstanceOf[CalendarInterval]) + case CalendarIntervalType => IntervalUtils.negate(input.asInstanceOf[CalendarInterval]) case _ => numeric.negate(input) } @@ -232,16 +232,16 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def decimalMethod: String = "$plus" - override def calendarIntervalMethod: String = if (checkOverflow) "add" else "safeAdd" + override def calendarIntervalMethod: String = if (checkOverflow) "addExact" else "add" private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow) protected override def nullSafeEval(input1: Any, input2: Any): Any = dataType match { case CalendarIntervalType if checkOverflow => - IntervalUtils.add( + IntervalUtils.addExact( input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval]) case CalendarIntervalType => - IntervalUtils.safeAdd( + IntervalUtils.add( input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval]) case _ => numeric.plus(input1, input2) } @@ -264,16 +264,16 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti override def decimalMethod: String = "$minus" - override def calendarIntervalMethod: String = if (checkOverflow) "subtract" else "safeSubtract" + override def calendarIntervalMethod: String = if (checkOverflow) "subtractExact" else "subtract" private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow) protected override def nullSafeEval(input1: Any, input2: Any): Any = dataType match { case CalendarIntervalType if checkOverflow => - IntervalUtils.subtract( + IntervalUtils.subtractExact( input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval]) case CalendarIntervalType => - IntervalUtils.safeSubtract( + IntervalUtils.subtract( input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval]) case _ => numeric.minus(input1, input2) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index f0793ccb78dd..7a38c9c76f31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -135,9 +135,9 @@ case class MultiplyInterval(interval: Expression, num: Expression) override def nullSafeEval(interval: Any, num: Any): Any = { try { if (checkOverflow) { - multiply(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) + multiplyExact(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) } else { - safeMultiply(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) + multiply(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) } } catch { case _: ArithmeticException if !checkOverflow => null @@ -147,7 +147,7 @@ case class MultiplyInterval(interval: Expression, num: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (interval, num) => { val iu = IntervalUtils.getClass.getName.stripSuffix("$") - val operationName = if (checkOverflow) "multiply" else "safeMultiply" + val operationName = if (checkOverflow) "multiplyExact" else "multiply" s""" try { ${ev.value} = $iu.$operationName($interval, $num); @@ -172,9 +172,9 @@ case class DivideInterval(interval: Expression, num: Expression) try { if (num == 0) return null if (checkOverflow) { - divide(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) + divideExact(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) } else { - safeDivide(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) + divide(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) } } catch { case _: ArithmeticException if !checkOverflow => null @@ -184,7 +184,7 @@ case class DivideInterval(interval: Expression, num: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (interval, num) => { val iu = IntervalUtils.getClass.getName.stripSuffix("$") - val operationName = if (checkOverflow) "divide" else "safeDivide" + val operationName = if (checkOverflow) "divideExact" else "divide" s""" try { if ($num == 0) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index b1eec68a4395..757fded9395c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -138,7 +138,7 @@ object IntervalUtils { assert(input.length == input.trim.length) input match { case yearMonthPattern("-", yearStr, monthStr) => - negate(toInterval(yearStr, monthStr)) + negateExact(toInterval(yearStr, monthStr)) case yearMonthPattern(_, yearStr, monthStr) => toInterval(yearStr, monthStr) case _ => @@ -451,7 +451,7 @@ object IntervalUtils { * @return a new calendar interval instance with all it parameters negated from the origin one. * @throws ArithmeticException if the result overflows any field value */ - def negate(interval: CalendarInterval): CalendarInterval = { + def negateExact(interval: CalendarInterval): CalendarInterval = { val months = Math.negateExact(interval.months) val days = Math.negateExact(interval.days) val microseconds = Math.negateExact(interval.microseconds) @@ -464,7 +464,7 @@ object IntervalUtils { * @param interval the interval to be negated * @return a new calendar interval instance with all it parameters negated from the origin one. */ - def safeNegate(interval: CalendarInterval): CalendarInterval = { + def negate(interval: CalendarInterval): CalendarInterval = { new CalendarInterval(-interval.months, -interval.days, -interval.microseconds) } @@ -474,7 +474,7 @@ object IntervalUtils { * @throws ArithmeticException if the result overflows any field value * */ - def add(left: CalendarInterval, right: CalendarInterval): CalendarInterval = { + def addExact(left: CalendarInterval, right: CalendarInterval): CalendarInterval = { val months = Math.addExact(left.months, right.months) val days = Math.addExact(left.days, right.days) val microseconds = Math.addExact(left.microseconds, right.microseconds) @@ -484,7 +484,7 @@ object IntervalUtils { /** * Return a new calendar interval instance of the sum of two intervals. */ - def safeAdd(left: CalendarInterval, right: CalendarInterval): CalendarInterval = { + def add(left: CalendarInterval, right: CalendarInterval): CalendarInterval = { val months = left.months + right.months val days = left.days + right.days val microseconds = left.microseconds + right.microseconds @@ -497,7 +497,7 @@ object IntervalUtils { * @throws ArithmeticException if the result overflows any field value * */ - def subtract(left: CalendarInterval, right: CalendarInterval): CalendarInterval = { + def subtractExact(left: CalendarInterval, right: CalendarInterval): CalendarInterval = { val months = Math.subtractExact(left.months, right.months) val days = Math.subtractExact(left.days, right.days) val microseconds = Math.subtractExact(left.microseconds, right.microseconds) @@ -507,7 +507,7 @@ object IntervalUtils { /** * Return a new calendar interval instance of the left interval minus the right one. */ - def safeSubtract(left: CalendarInterval, right: CalendarInterval): CalendarInterval = { + def subtract(left: CalendarInterval, right: CalendarInterval): CalendarInterval = { val months = left.months - right.months val days = left.days - right.days val microseconds = left.microseconds - right.microseconds @@ -519,14 +519,14 @@ object IntervalUtils { * * @throws ArithmeticException if the result overflows any field value */ - def multiply(interval: CalendarInterval, num: Double): CalendarInterval = { + def multiplyExact(interval: CalendarInterval, num: Double): CalendarInterval = { fromDoubles(num * interval.months, num * interval.days, num * interval.microseconds) } /** * Return a new calendar interval instance of the left interval times a multiplier. */ - def safeMultiply(interval: CalendarInterval, num: Double): CalendarInterval = { + def multiply(interval: CalendarInterval, num: Double): CalendarInterval = { safeFromDoubles(num * interval.months, num * interval.days, num * interval.microseconds) } @@ -535,7 +535,7 @@ object IntervalUtils { * * @throws ArithmeticException if the result overflows any field value or divided by zero */ - def divide(interval: CalendarInterval, num: Double): CalendarInterval = { + def divideExact(interval: CalendarInterval, num: Double): CalendarInterval = { if (num == 0) throw new ArithmeticException("divide by zero") fromDoubles(interval.months / num, interval.days / num, interval.microseconds / num) } @@ -545,7 +545,7 @@ object IntervalUtils { * * @throws ArithmeticException if divided by zero */ - def safeDivide(interval: CalendarInterval, num: Double): CalendarInterval = { + def divide(interval: CalendarInterval, num: Double): CalendarInterval = { if (num == 0) throw new ArithmeticException("divide by zero") safeFromDoubles(interval.months / num, interval.days / num, interval.microseconds / num) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index cc9ebfe40942..9e98e146c7a0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -733,7 +733,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Timestamp.valueOf("2018-01-02 00:00:00")), Literal(Timestamp.valueOf("2018-01-01 00:00:00")), - Literal(negate(stringToInterval("interval 12 hours")))), + Literal(negateExact(stringToInterval("interval 12 hours")))), Seq( Timestamp.valueOf("2018-01-02 00:00:00"), Timestamp.valueOf("2018-01-01 12:00:00"), @@ -742,7 +742,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Timestamp.valueOf("2018-01-02 00:00:00")), Literal(Timestamp.valueOf("2017-12-31 23:59:59")), - Literal(negate(stringToInterval("interval 12 hours")))), + Literal(negateExact(stringToInterval("interval 12 hours")))), Seq( Timestamp.valueOf("2018-01-02 00:00:00"), Timestamp.valueOf("2018-01-01 12:00:00"), @@ -760,7 +760,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Timestamp.valueOf("2018-03-01 00:00:00")), Literal(Timestamp.valueOf("2018-01-01 00:00:00")), - Literal(negate(stringToInterval("interval 1 month")))), + Literal(negateExact(stringToInterval("interval 1 month")))), Seq( Timestamp.valueOf("2018-03-01 00:00:00"), Timestamp.valueOf("2018-02-01 00:00:00"), @@ -769,7 +769,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Timestamp.valueOf("2018-03-03 00:00:00")), Literal(Timestamp.valueOf("2018-01-01 00:00:00")), - Literal(negate(stringToInterval("interval 1 month 1 day")))), + Literal(negateExact(stringToInterval("interval 1 month 1 day")))), Seq( Timestamp.valueOf("2018-03-03 00:00:00"), Timestamp.valueOf("2018-02-02 00:00:00"), @@ -815,7 +815,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Timestamp.valueOf("2022-04-01 00:00:00")), Literal(Timestamp.valueOf("2017-01-01 00:00:00")), - Literal(negate(fromYearMonthString("1-5")))), + Literal(negateExact(fromYearMonthString("1-5")))), Seq( Timestamp.valueOf("2022-04-01 00:00:00.000"), Timestamp.valueOf("2020-11-01 00:00:00.000"), @@ -907,7 +907,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper new Sequence( Literal(Date.valueOf("1970-01-01")), Literal(Date.valueOf("1970-02-01")), - Literal(negate(stringToInterval("interval 1 month")))), + Literal(negateExact(stringToInterval("interval 1 month")))), EmptyRow, s"sequence boundaries: 0 to 2678400000000 by -${28 * MICROS_PER_DAY}") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala index 2669406cc348..7aba12b1c9fb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala @@ -239,8 +239,8 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { } test("negate") { + assert(negateExact(new CalendarInterval(1, 2, 3)) === new CalendarInterval(-1, -2, -3)) assert(negate(new CalendarInterval(1, 2, 3)) === new CalendarInterval(-1, -2, -3)) - assert(safeNegate(new CalendarInterval(1, 2, 3)) === new CalendarInterval(-1, -2, -3)) } test("subtract one interval by another") { @@ -248,7 +248,7 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { val input2 = new CalendarInterval(2, 4, 100 * MICROS_PER_HOUR) val input3 = new CalendarInterval(-10, -30, -81 * MICROS_PER_HOUR) val input4 = new CalendarInterval(75, 150, 200 * MICROS_PER_HOUR) - Seq[(CalendarInterval, CalendarInterval) => CalendarInterval](subtract, safeSubtract) + Seq[(CalendarInterval, CalendarInterval) => CalendarInterval](subtractExact, subtract) .foreach { func => assert(new CalendarInterval(1, -3, -99 * MICROS_PER_HOUR) === func(input1, input2)) assert(new CalendarInterval(-85, -180, -281 * MICROS_PER_HOUR) === func(input3, input4)) @@ -260,14 +260,14 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { val input2 = new CalendarInterval(2, 4, 100 * MICROS_PER_HOUR) val input3 = new CalendarInterval(-10, -30, -81 * MICROS_PER_HOUR) val input4 = new CalendarInterval(75, 150, 200 * MICROS_PER_HOUR) - Seq[(CalendarInterval, CalendarInterval) => CalendarInterval](add, safeAdd).foreach { func => + Seq[(CalendarInterval, CalendarInterval) => CalendarInterval](addExact, add).foreach { func => assert(new CalendarInterval(5, 5, 101 * MICROS_PER_HOUR) === func(input1, input2)) assert(new CalendarInterval(65, 120, 119 * MICROS_PER_HOUR) === func(input3, input4)) } } test("multiply by num") { - Seq[(CalendarInterval, Double) => CalendarInterval](multiply, safeMultiply).foreach { func => + Seq[(CalendarInterval, Double) => CalendarInterval](multiplyExact, multiply).foreach { func => var interval = new CalendarInterval(0, 0, 0) assert(interval === func(interval, 0)) interval = new CalendarInterval(123, 456, 789) @@ -281,9 +281,9 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { } assert(CalendarInterval.MAX_VALUE === - safeMultiply(new CalendarInterval(2, 0, 0), Integer.MAX_VALUE)) + multiply(new CalendarInterval(2, 0, 0), Integer.MAX_VALUE)) try { - multiply(new CalendarInterval(2, 0, 0), Integer.MAX_VALUE) + multiplyExact(new CalendarInterval(2, 0, 0), Integer.MAX_VALUE) fail("Expected to throw an exception on months overflow") } catch { case e: ArithmeticException => assert(e.getMessage.contains("overflow")) @@ -291,7 +291,7 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { } test("divide by num") { - Seq[(CalendarInterval, Double) => CalendarInterval](divide, safeDivide).foreach { func => + Seq[(CalendarInterval, Double) => CalendarInterval](divideExact, divide).foreach { func => var interval = new CalendarInterval(0, 0, 0) assert(interval === func(interval, 10)) interval = new CalendarInterval(1, 3, 30 * MICROS_PER_SECOND) @@ -457,37 +457,40 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { } test("interval overflow check") { - intercept[ArithmeticException](negate(new CalendarInterval(Int.MinValue, 0, 0))) - assert(safeNegate(new CalendarInterval(Int.MinValue, 0, 0)) === + intercept[ArithmeticException](negateExact(new CalendarInterval(Int.MinValue, 0, 0))) + assert(negate(new CalendarInterval(Int.MinValue, 0, 0)) === new CalendarInterval(Int.MinValue, 0, 0)) - intercept[ArithmeticException](negate(CalendarInterval.MIN_VALUE)) - assert(safeNegate(CalendarInterval.MIN_VALUE) === CalendarInterval.MIN_VALUE) - intercept[ArithmeticException](add(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, 1))) - intercept[ArithmeticException](add(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 1, 0))) - intercept[ArithmeticException](add(CalendarInterval.MAX_VALUE, new CalendarInterval(1, 0, 0))) - assert(safeAdd(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, 1)) === + intercept[ArithmeticException](negateExact(CalendarInterval.MIN_VALUE)) + assert(negate(CalendarInterval.MIN_VALUE) === CalendarInterval.MIN_VALUE) + intercept[ArithmeticException](addExact(CalendarInterval.MAX_VALUE, + new CalendarInterval(0, 0, 1))) + intercept[ArithmeticException](addExact(CalendarInterval.MAX_VALUE, + new CalendarInterval(0, 1, 0))) + intercept[ArithmeticException](addExact(CalendarInterval.MAX_VALUE, + new CalendarInterval(1, 0, 0))) + assert(add(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, 1)) === new CalendarInterval(Int.MaxValue, Int.MaxValue, Long.MinValue)) - assert(safeAdd(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 1, 0)) === + assert(add(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 1, 0)) === new CalendarInterval(Int.MaxValue, Int.MinValue, Long.MaxValue)) - assert(safeAdd(CalendarInterval.MAX_VALUE, new CalendarInterval(1, 0, 0)) === + assert(add(CalendarInterval.MAX_VALUE, new CalendarInterval(1, 0, 0)) === new CalendarInterval(Int.MinValue, Int.MaxValue, Long.MaxValue)) - intercept[ArithmeticException](subtract(CalendarInterval.MAX_VALUE, + intercept[ArithmeticException](subtractExact(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, -1))) - intercept[ArithmeticException](subtract(CalendarInterval.MAX_VALUE, + intercept[ArithmeticException](subtractExact(CalendarInterval.MAX_VALUE, new CalendarInterval(0, -1, 0))) - intercept[ArithmeticException](subtract(CalendarInterval.MAX_VALUE, + intercept[ArithmeticException](subtractExact(CalendarInterval.MAX_VALUE, new CalendarInterval(-1, 0, 0))) - assert(safeSubtract(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, -1)) === + assert(subtract(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, -1)) === new CalendarInterval(Int.MaxValue, Int.MaxValue, Long.MinValue)) - assert(safeSubtract(CalendarInterval.MAX_VALUE, new CalendarInterval(0, -1, 0)) === + assert(subtract(CalendarInterval.MAX_VALUE, new CalendarInterval(0, -1, 0)) === new CalendarInterval(Int.MaxValue, Int.MinValue, Long.MaxValue)) - assert(safeSubtract(CalendarInterval.MAX_VALUE, new CalendarInterval(-1, 0, 0)) === + assert(subtract(CalendarInterval.MAX_VALUE, new CalendarInterval(-1, 0, 0)) === new CalendarInterval(Int.MinValue, Int.MaxValue, Long.MaxValue)) - intercept[ArithmeticException](multiply(CalendarInterval.MAX_VALUE, 2)) - assert(safeMultiply(CalendarInterval.MAX_VALUE, 2) === CalendarInterval.MAX_VALUE) - intercept[ArithmeticException](divide(CalendarInterval.MAX_VALUE, 0.5)) - assert(safeDivide(CalendarInterval.MAX_VALUE, 0.5) === CalendarInterval.MAX_VALUE) + intercept[ArithmeticException](multiplyExact(CalendarInterval.MAX_VALUE, 2)) + assert(multiply(CalendarInterval.MAX_VALUE, 2) === CalendarInterval.MAX_VALUE) + intercept[ArithmeticException](divideExact(CalendarInterval.MAX_VALUE, 0.5)) + assert(divide(CalendarInterval.MAX_VALUE, 0.5) === CalendarInterval.MAX_VALUE) } } From a8d29b6eaa67a4471c74a0d0b652b001d66633bc Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 27 Dec 2019 20:49:41 +0800 Subject: [PATCH 10/19] rm try-catch --- .../expressions/intervalExpressions.scala | 82 ++++++------------- 1 file changed, 27 insertions(+), 55 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 7a38c9c76f31..78964803f3a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -113,93 +113,65 @@ object ExtractIntervalPart { } } -abstract class IntervalNumOperation(interval: Expression, num: Expression) +abstract class IntervalNumOperation(interval: Expression, num: Expression, operationName: String) extends BinaryExpression with ImplicitCastInputTypes with Serializable { protected val checkOverflow: Boolean = SQLConf.get.ansiEnabled + protected val operation: String = + IntervalUtils.getClass.getName.stripSuffix("$") + "." + { + if (checkOverflow) operationName + "Exact" else operationName + } + override def left: Expression = interval override def right: Expression = num override def inputTypes: Seq[AbstractDataType] = Seq(CalendarIntervalType, DoubleType) + override def dataType: DataType = CalendarIntervalType override def nullable: Boolean = true + + override def prettyName: String = operationName + "_interval" } case class MultiplyInterval(interval: Expression, num: Expression) - extends IntervalNumOperation(interval, num) { - - override def prettyName: String = "multiply_interval" + extends IntervalNumOperation(interval, num, "multiply") { override def nullSafeEval(interval: Any, num: Any): Any = { - try { - if (checkOverflow) { - multiplyExact(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) - } else { - multiply(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) - } - } catch { - case _: ArithmeticException if !checkOverflow => null + if (checkOverflow) { + multiplyExact(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) + } else { + multiply(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (interval, num) => { - val iu = IntervalUtils.getClass.getName.stripSuffix("$") - val operationName = if (checkOverflow) "multiplyExact" else "multiply" - s""" - try { - ${ev.value} = $iu.$operationName($interval, $num); - } catch (ArithmeticException e) { - if ($checkOverflow) { - throw e; - } else { - ${ev.isNull} = true; - } - } - """ - }) + defineCodeGen(ctx, ev, (interval, num) => s"$operation($interval, $num)") } } case class DivideInterval(interval: Expression, num: Expression) - extends IntervalNumOperation(interval, num) { - - override def prettyName: String = "divide_interval" + extends IntervalNumOperation(interval, num, "divide") { override def nullSafeEval(interval: Any, num: Any): Any = { - try { - if (num == 0) return null - if (checkOverflow) { - divideExact(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) - } else { - divide(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) - } - } catch { - case _: ArithmeticException if !checkOverflow => null + if (num == 0) return null + if (checkOverflow) { + divideExact(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) + } else { + divide(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (interval, num) => { - val iu = IntervalUtils.getClass.getName.stripSuffix("$") - val operationName = if (checkOverflow) "divideExact" else "divide" s""" - try { - if ($num == 0) { - ${ev.isNull} = true; - } else { - ${ev.value} = $iu.$operationName($interval, $num); - } - } catch (ArithmeticException e) { - if ($checkOverflow) { - throw e; - } else { - ${ev.isNull} = true; - } - } - """ + |if ($num == 0) { + | ${ev.isNull} = true; + |} else { + | ${ev.value} = $operation($interval, $num); + |} + |""".stripMargin }) } } From ccb1fd50faa4620e1d9f227da044f74c44472c14 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Mon, 30 Dec 2019 16:09:33 +0800 Subject: [PATCH 11/19] regen g f --- .../sql-tests/results/ansi/interval.sql.out | 115 ++++++++++++------ .../sql-tests/results/interval.sql.out | 42 ++++++- 2 files changed, 121 insertions(+), 36 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index 3da83c1ff6a7..5697a5a6e882 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 126 +-- Number of queries: 131 -- !query 0 @@ -1120,34 +1120,79 @@ struct -- !query 114 -select 1 year 2 days +select -(a) from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) -- !query 114 schema -struct +struct<> -- !query 114 output -1 years 2 days +java.lang.ArithmeticException +integer overflow -- !query 115 -select '10-9' year to month +select a - b from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) -- !query 115 schema -struct +struct<> -- !query 115 output -10 years 9 months +java.lang.ArithmeticException +integer overflow -- !query 116 -select '20 15:40:32.99899999' day to second +select b + interval '1 month' from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) -- !query 116 schema -struct +struct<> -- !query 116 output -20 days 15 hours 40 minutes 32.998999 seconds +java.lang.ArithmeticException +integer overflow -- !query 117 -select 30 day day +select a * 1.1 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) -- !query 117 schema struct<> -- !query 117 output +java.lang.ArithmeticException +integer overflow + + +-- !query 118 +select a / 0.5 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) +-- !query 118 schema +struct<> +-- !query 118 output +java.lang.ArithmeticException +integer overflow + + +-- !query 119 +select 1 year 2 days +-- !query 119 schema +struct +-- !query 119 output +1 years 2 days + + +-- !query 120 +select '10-9' year to month +-- !query 120 schema +struct +-- !query 120 output +10 years 9 months + + +-- !query 121 +select '20 15:40:32.99899999' day to second +-- !query 121 schema +struct +-- !query 121 output +20 days 15 hours 40 minutes 32.998999 seconds + + +-- !query 122 +select 30 day day +-- !query 122 schema +struct<> +-- !query 122 output org.apache.spark.sql.catalyst.parser.ParseException no viable alternative at input 'day'(line 1, pos 14) @@ -1157,27 +1202,27 @@ select 30 day day --------------^^^ --- !query 118 +-- !query 123 select date'2012-01-01' - '2-2' year to month --- !query 118 schema +-- !query 123 schema struct --- !query 118 output +-- !query 123 output 2009-11-01 --- !query 119 +-- !query 124 select 1 month - 1 day --- !query 119 schema +-- !query 124 schema struct --- !query 119 output +-- !query 124 output 1 months -1 days --- !query 120 +-- !query 125 select 1 year to month --- !query 120 schema +-- !query 125 schema struct<> --- !query 120 output +-- !query 125 output org.apache.spark.sql.catalyst.parser.ParseException The value of from-to unit must be a string(line 1, pos 7) @@ -1187,11 +1232,11 @@ select 1 year to month -------^^^ --- !query 121 +-- !query 126 select '1' year to second --- !query 121 schema +-- !query 126 schema struct<> --- !query 121 output +-- !query 126 output org.apache.spark.sql.catalyst.parser.ParseException Intervals FROM year TO second are not supported.(line 1, pos 7) @@ -1201,11 +1246,11 @@ select '1' year to second -------^^^ --- !query 122 +-- !query 127 select 1 year '2-1' year to month --- !query 122 schema +-- !query 127 schema struct<> --- !query 122 output +-- !query 127 output org.apache.spark.sql.catalyst.parser.ParseException Can only have a single from-to unit in the interval literal syntax(line 1, pos 14) @@ -1215,11 +1260,11 @@ select 1 year '2-1' year to month --------------^^^ --- !query 123 +-- !query 128 select (-30) day --- !query 123 schema +-- !query 128 schema struct<> --- !query 123 output +-- !query 128 output org.apache.spark.sql.catalyst.parser.ParseException no viable alternative at input 'day'(line 1, pos 13) @@ -1229,11 +1274,11 @@ select (-30) day -------------^^^ --- !query 124 +-- !query 129 select (a + 1) day --- !query 124 schema +-- !query 129 schema struct<> --- !query 124 output +-- !query 129 output org.apache.spark.sql.catalyst.parser.ParseException no viable alternative at input 'day'(line 1, pos 15) @@ -1243,11 +1288,11 @@ select (a + 1) day ---------------^^^ --- !query 125 +-- !query 130 select 30 day day day --- !query 125 schema +-- !query 130 schema struct<> --- !query 125 output +-- !query 130 output org.apache.spark.sql.catalyst.parser.ParseException no viable alternative at input 'day'(line 1, pos 14) diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out index b178c18af77c..1d5cfa1687d7 100644 --- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 114 +-- Number of queries: 119 -- !query 0 @@ -1101,3 +1101,43 @@ select interval '1 ' day struct -- !query 113 output 1 days + + +-- !query 114 +select -(a) from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) +-- !query 114 schema +struct<(- a):interval> +-- !query 114 output +-178956970 years -8 months + + +-- !query 115 +select a - b from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) +-- !query 115 schema +struct<(a - b):interval> +-- !query 115 output +1 months + + +-- !query 116 +select b + interval '1 month' from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) +-- !query 116 schema +struct<(b + INTERVAL '1 months'):interval> +-- !query 116 output +-178956970 years -8 months + + +-- !query 117 +select a * 1.1 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) +-- !query 117 schema +struct +-- !query 117 output +-178956970 years -8 months -2147483648 days -2562047788 hours -54.775808 seconds + + +-- !query 118 +select a / 0.5 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) +-- !query 118 schema +struct +-- !query 118 output +-178956970 years -8 months -2147483648 days -2562047788 hours -54.775808 seconds From d704c747c6548b9cfcb64a0f7fd97f890332aa1c Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Mon, 30 Dec 2019 17:21:45 +0800 Subject: [PATCH 12/19] ..Exact for ansi off --- .../expressions/intervalExpressions.scala | 22 ++------ .../sql/catalyst/util/IntervalUtils.scala | 17 ------ .../IntervalExpressionsSuite.scala | 19 ++++--- .../catalyst/util/IntervalUtilsSuite.scala | 56 ++++++++----------- 4 files changed, 40 insertions(+), 74 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 93cfcedf64c4..3bcdc0411459 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -116,10 +116,8 @@ abstract class IntervalNumOperation(interval: Expression, num: Expression, opera protected val checkOverflow: Boolean = SQLConf.get.ansiEnabled - protected val operation: String = - IntervalUtils.getClass.getName.stripSuffix("$") + "." + { - if (checkOverflow) operationName + "Exact" else operationName - } + protected val methodStr: String = + IntervalUtils.getClass.getName.stripSuffix("$") + "." + operationName + "Exact" override def left: Expression = interval override def right: Expression = num @@ -137,15 +135,11 @@ case class MultiplyInterval(interval: Expression, num: Expression) extends IntervalNumOperation(interval, num, "multiply") { override def nullSafeEval(interval: Any, num: Any): Any = { - if (checkOverflow) { - multiplyExact(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) - } else { - multiply(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) - } + multiplyExact(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (interval, num) => s"$operation($interval, $num)") + defineCodeGen(ctx, ev, (interval, num) => s"$methodStr($interval, $num)") } } @@ -154,11 +148,7 @@ case class DivideInterval(interval: Expression, num: Expression) override def nullSafeEval(interval: Any, num: Any): Any = { if (num == 0) return null - if (checkOverflow) { - divideExact(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) - } else { - divide(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) - } + divideExact(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -167,7 +157,7 @@ case class DivideInterval(interval: Expression, num: Expression) |if ($num == 0) { | ${ev.isNull} = true; |} else { - | ${ev.value} = $operation($interval, $num); + | ${ev.value} = $methodStr($interval, $num); |} |""".stripMargin }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index 6c08ae410814..9a4c54700da8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -523,13 +523,6 @@ object IntervalUtils { fromDoubles(num * interval.months, num * interval.days, num * interval.microseconds) } - /** - * Return a new calendar interval instance of the left interval times a multiplier. - */ - def multiply(interval: CalendarInterval, num: Double): CalendarInterval = { - safeFromDoubles(num * interval.months, num * interval.days, num * interval.microseconds) - } - /** * Return a new calendar interval instance of the left interval divides by a dividend. * @@ -540,16 +533,6 @@ object IntervalUtils { fromDoubles(interval.months / num, interval.days / num, interval.microseconds / num) } - /** - * Return a new calendar interval instance of the left interval divides by a dividend. - * - * @throws ArithmeticException if divided by zero - */ - def divide(interval: CalendarInterval, num: Double): CalendarInterval = { - if (num == 0) throw new ArithmeticException("divide by zero") - safeFromDoubles(interval.months / num, interval.days / num, interval.microseconds / num) - } - // `toString` implementation in CalendarInterval is the multi-units format currently. def toMultiUnitsString(interval: CalendarInterval): String = interval.toString diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala index 780744ba8ae4..6a5c1591cf3f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala @@ -199,11 +199,17 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("multiply") { def check(interval: String, num: Double, expected: String, - modes: Seq[String] = Seq("true", "false")): Unit = { + modes: Seq[String] = Seq("true", "false"), checkException: Boolean = false): Unit = { modes.foreach { v => withSQLConf(SQLConf.ANSI_ENABLED.key -> v) { - checkEvaluation(MultiplyInterval(Literal(stringToInterval(interval)), Literal(num)), - if (expected == null) null else stringToInterval(expected)) + if (checkException) { + checkExceptionInExpression[ArithmeticException]( + MultiplyInterval(Literal(stringToInterval(interval)), Literal(num)), + "integer overflow") + } else { + checkEvaluation(MultiplyInterval(Literal(stringToInterval(interval)), Literal(num)), + if (expected == null) null else stringToInterval(expected)) + } } } } @@ -216,12 +222,7 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { check("-100 years -1 millisecond", 0.5, "-50 years -500 microseconds") check("2 months 4 seconds", -0.5, "-1 months -2 seconds") check("1 month 2 microseconds", 1.5, "1 months 15 days 3 microseconds") - check("2 months", Int.MaxValue, CalendarInterval.MAX_VALUE.toString, Seq("false")) - withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { - checkExceptionInExpression[ArithmeticException]( - MultiplyInterval(Literal(stringToInterval("2 months")), Literal(Int.MaxValue.toDouble)), - "integer overflow") - } + check("2 months", Int.MaxValue, CalendarInterval.MAX_VALUE.toString, checkException = true) } test("divide") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala index d5449b4d8a34..47b7d402a202 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala @@ -267,21 +267,17 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { } test("multiply by num") { - Seq[(CalendarInterval, Double) => CalendarInterval](multiplyExact, multiply).foreach { func => - var interval = new CalendarInterval(0, 0, 0) - assert(interval === func(interval, 0)) - interval = new CalendarInterval(123, 456, 789) - assert(new CalendarInterval(123 * 42, 456 * 42, 789 * 42) === func(interval, 42)) - interval = new CalendarInterval(-123, -456, -789) - assert(new CalendarInterval(-123 * 42, -456 * 42, -789 * 42) === func(interval, 42)) - assert(new CalendarInterval(1, 22, 12 * MICROS_PER_HOUR) === - func(new CalendarInterval(1, 5, 0), 1.5)) - assert(new CalendarInterval(2, 14, 12 * MICROS_PER_HOUR) === - func(new CalendarInterval(2, 2, 2 * MICROS_PER_HOUR), 1.2)) - } + var interval = new CalendarInterval(0, 0, 0) + assert(interval === multiplyExact(interval, 0)) + interval = new CalendarInterval(123, 456, 789) + assert(new CalendarInterval(123 * 42, 456 * 42, 789 * 42) === multiplyExact(interval, 42)) + interval = new CalendarInterval(-123, -456, -789) + assert(new CalendarInterval(-123 * 42, -456 * 42, -789 * 42) === multiplyExact(interval, 42)) + assert(new CalendarInterval(1, 22, 12 * MICROS_PER_HOUR) === + multiplyExact(new CalendarInterval(1, 5, 0), 1.5)) + assert(new CalendarInterval(2, 14, 12 * MICROS_PER_HOUR) === + multiplyExact(new CalendarInterval(2, 2, 2 * MICROS_PER_HOUR), 1.2)) - assert(CalendarInterval.MAX_VALUE === - multiply(new CalendarInterval(2, 0, 0), Integer.MAX_VALUE)) try { multiplyExact(new CalendarInterval(2, 0, 0), Integer.MAX_VALUE) fail("Expected to throw an exception on months overflow") @@ -291,22 +287,20 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { } test("divide by num") { - Seq[(CalendarInterval, Double) => CalendarInterval](divideExact, divide).foreach { func => - var interval = new CalendarInterval(0, 0, 0) - assert(interval === func(interval, 10)) - interval = new CalendarInterval(1, 3, 30 * MICROS_PER_SECOND) - assert(new CalendarInterval(0, 16, 12 * MICROS_PER_HOUR + 15 * MICROS_PER_SECOND) === - func(interval, 2)) - assert(new CalendarInterval(2, 6, MICROS_PER_MINUTE) === func(interval, 0.5)) - interval = new CalendarInterval(-1, 0, -30 * MICROS_PER_SECOND) - assert(new CalendarInterval(0, -15, -15 * MICROS_PER_SECOND) === func(interval, 2)) - assert(new CalendarInterval(-2, 0, -1 * MICROS_PER_MINUTE) === func(interval, 0.5)) - try { - func(new CalendarInterval(123, 456, 789), 0) - fail("Expected to throw an exception on divide by zero") - } catch { - case e: ArithmeticException => assert(e.getMessage.contains("divide by zero")) - } + var interval = new CalendarInterval(0, 0, 0) + assert(interval === divideExact(interval, 10)) + interval = new CalendarInterval(1, 3, 30 * MICROS_PER_SECOND) + assert(new CalendarInterval(0, 16, 12 * MICROS_PER_HOUR + 15 * MICROS_PER_SECOND) === + divideExact(interval, 2)) + assert(new CalendarInterval(2, 6, MICROS_PER_MINUTE) === divideExact(interval, 0.5)) + interval = new CalendarInterval(-1, 0, -30 * MICROS_PER_SECOND) + assert(new CalendarInterval(0, -15, -15 * MICROS_PER_SECOND) === divideExact(interval, 2)) + assert(new CalendarInterval(-2, 0, -1 * MICROS_PER_MINUTE) === divideExact(interval, 0.5)) + try { + divideExact(new CalendarInterval(123, 456, 789), 0) + fail("Expected to throw an exception on divide by zero") + } catch { + case e: ArithmeticException => assert(e.getMessage.contains("divide by zero")) } } @@ -464,8 +458,6 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { new CalendarInterval(Int.MinValue, Int.MaxValue, Long.MaxValue)) intercept[ArithmeticException](multiplyExact(CalendarInterval.MAX_VALUE, 2)) - assert(multiply(CalendarInterval.MAX_VALUE, 2) === CalendarInterval.MAX_VALUE) intercept[ArithmeticException](divideExact(CalendarInterval.MAX_VALUE, 0.5)) - assert(divide(CalendarInterval.MAX_VALUE, 0.5) === CalendarInterval.MAX_VALUE) } } From 512570eb0b06189e5407ad3cd1b52ff538bc2ee6 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Mon, 30 Dec 2019 17:23:52 +0800 Subject: [PATCH 13/19] regen g f --- .../sql/catalyst/expressions/intervalExpressions.scala | 2 -- .../test/resources/sql-tests/results/interval.sql.out | 10 ++++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 3bcdc0411459..24dc88b366b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -114,8 +114,6 @@ object ExtractIntervalPart { abstract class IntervalNumOperation(interval: Expression, num: Expression, operationName: String) extends BinaryExpression with ImplicitCastInputTypes with Serializable { - protected val checkOverflow: Boolean = SQLConf.get.ansiEnabled - protected val methodStr: String = IntervalUtils.getClass.getName.stripSuffix("$") + "." + operationName + "Exact" diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out index 1d5cfa1687d7..a98ada7eb2b6 100644 --- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out @@ -1130,14 +1130,16 @@ struct<(b + INTERVAL '1 months'):interval> -- !query 117 select a * 1.1 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) -- !query 117 schema -struct +struct<> -- !query 117 output --178956970 years -8 months -2147483648 days -2562047788 hours -54.775808 seconds +java.lang.ArithmeticException +integer overflow -- !query 118 select a / 0.5 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b) -- !query 118 schema -struct +struct<> -- !query 118 output --178956970 years -8 months -2147483648 days -2562047788 hours -54.775808 seconds +java.lang.ArithmeticException +integer overflow From 56519596e4ebc886fae5be302bd56bab14514b21 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Mon, 30 Dec 2019 17:28:31 +0800 Subject: [PATCH 14/19] import --- .../spark/sql/catalyst/expressions/intervalExpressions.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 24dc88b366b0..51ff474bcca3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -22,7 +22,6 @@ import java.util.Locale import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.catalyst.util.IntervalUtils._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval From 92e2668c11f4d26205690d972b28b3a2d92eddf5 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 31 Dec 2019 10:10:22 +0800 Subject: [PATCH 15/19] clean up --- .../sql/catalyst/util/IntervalUtils.scala | 33 ------------------- 1 file changed, 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index 9a4c54700da8..bd82137a5cac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -415,40 +415,9 @@ object IntervalUtils { new CalendarInterval(truncatedMonths, truncatedDays, micros.round) } - /** - * Makes an interval from months, days and micros with the fractional part by - * adding the month fraction to days and the days fraction to micros. - */ - private def safeFromDoubles( - monthsWithFraction: Double, - daysWithFraction: Double, - microsWithFraction: Double): CalendarInterval = { - val monthInLong = monthsWithFraction.toLong - val truncatedMonths = if (monthInLong > Int.MaxValue) { - Int.MaxValue - } else if (monthInLong < Int.MinValue) { - Int.MinValue - } else { - monthInLong.toInt - } - val days = daysWithFraction + DAYS_PER_MONTH * (monthsWithFraction - truncatedMonths) - val dayInLong = days.toLong - val truncatedDays = if (dayInLong > Int.MaxValue) { - Int.MaxValue - } else if (monthInLong < Int.MinValue) { - Int.MinValue - } else { - dayInLong.toInt - } - val micros = microsWithFraction + MICROS_PER_DAY * (days - truncatedDays) - new CalendarInterval(truncatedMonths, truncatedDays.toInt, micros.round) - } - /** * Unary minus, return the negated the calendar interval value. * - * @param interval the interval to be negated - * @return a new calendar interval instance with all it parameters negated from the origin one. * @throws ArithmeticException if the result overflows any field value */ def negateExact(interval: CalendarInterval): CalendarInterval = { @@ -461,8 +430,6 @@ object IntervalUtils { /** * Unary minus, return the negated the calendar interval value. * - * @param interval the interval to be negated - * @return a new calendar interval instance with all it parameters negated from the origin one. */ def negate(interval: CalendarInterval): CalendarInterval = { new CalendarInterval(-interval.months, -interval.days, -interval.microseconds) From aba10b22bc3928dd88cb20ae0db340413ec0b3b6 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 2 Jan 2020 13:46:57 +0800 Subject: [PATCH 16/19] address comments --- .../expressions/intervalExpressions.scala | 57 +++++++++---------- .../sql/catalyst/util/IntervalUtils.scala | 2 - .../IntervalExpressionsSuite.scala | 28 ++++++--- 3 files changed, 47 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 51ff474bcca3..a60f516af613 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -113,9 +113,6 @@ object ExtractIntervalPart { abstract class IntervalNumOperation(interval: Expression, num: Expression, operationName: String) extends BinaryExpression with ImplicitCastInputTypes with Serializable { - protected val methodStr: String = - IntervalUtils.getClass.getName.stripSuffix("$") + "." + operationName + "Exact" - override def left: Expression = interval override def right: Expression = num @@ -125,41 +122,39 @@ abstract class IntervalNumOperation(interval: Expression, num: Expression, opera override def nullable: Boolean = true - override def prettyName: String = operationName + "_interval" -} - -case class MultiplyInterval(interval: Expression, num: Expression) - extends IntervalNumOperation(interval, num, "multiply") { - - override def nullSafeEval(interval: Any, num: Any): Any = { - multiplyExact(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) + override def nullSafeEval(interval: Any, num: Any): Any = operationName match { + case "multiply" => + multiplyExact(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) + case "divide" => + if (num == 0) return null + divideExact(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (interval, num) => s"$methodStr($interval, $num)") + val methodStr: String = + IntervalUtils.getClass.getName.stripSuffix("$") + "." + operationName + "Exact" + operationName match { + case "multiply" => defineCodeGen(ctx, ev, (interval, num) => s"$methodStr($interval, $num)") + case "divide" => nullSafeCodeGen(ctx, ev, (interval, num) => { + s""" + |if ($num == 0) { + | ${ev.isNull} = true; + |} else { + | ${ev.value} = $methodStr($interval, $num); + |} + |""".stripMargin + }) + } } -} -case class DivideInterval(interval: Expression, num: Expression) - extends IntervalNumOperation(interval, num, "divide") { + override def prettyName: String = operationName + "_interval" +} - override def nullSafeEval(interval: Any, num: Any): Any = { - if (num == 0) return null - divideExact(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) - } +case class MultiplyInterval(interval: Expression, num: Expression) + extends IntervalNumOperation(interval, num, "multiply") - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (interval, num) => { - s""" - |if ($num == 0) { - | ${ev.isNull} = true; - |} else { - | ${ev.value} = $methodStr($interval, $num); - |} - |""".stripMargin - }) - } -} +case class DivideInterval(interval: Expression, num: Expression) + extends IntervalNumOperation(interval, num, "divide") // scalastyle:off line.size.limit @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index bd82137a5cac..e4dc4ad399ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -429,7 +429,6 @@ object IntervalUtils { /** * Unary minus, return the negated the calendar interval value. - * */ def negate(interval: CalendarInterval): CalendarInterval = { new CalendarInterval(-interval.months, -interval.days, -interval.microseconds) @@ -462,7 +461,6 @@ object IntervalUtils { * Return a new calendar interval instance of the left interval minus the right one. * * @throws ArithmeticException if the result overflows any field value - * */ def subtractExact(left: CalendarInterval, right: CalendarInterval): CalendarInterval = { val months = Math.subtractExact(left.months, right.months) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala index 6a5c1591cf3f..838781325741 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala @@ -198,9 +198,12 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("multiply") { - def check(interval: String, num: Double, expected: String, - modes: Seq[String] = Seq("true", "false"), checkException: Boolean = false): Unit = { - modes.foreach { v => + def check( + interval: String, + num: Double, + expected: String, + checkException: Boolean = false): Unit = { + Seq("true", "false").foreach { v => withSQLConf(SQLConf.ANSI_ENABLED.key -> v) { if (checkException) { checkExceptionInExpression[ArithmeticException]( @@ -222,15 +225,25 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { check("-100 years -1 millisecond", 0.5, "-50 years -500 microseconds") check("2 months 4 seconds", -0.5, "-1 months -2 seconds") check("1 month 2 microseconds", 1.5, "1 months 15 days 3 microseconds") - check("2 months", Int.MaxValue, CalendarInterval.MAX_VALUE.toString, checkException = true) + check("2 months", Int.MaxValue, null, checkException = true) } test("divide") { - def check(interval: String, num: Double, expected: String): Unit = { + def check( + interval: String, + num: Double, + expected: String, + checkException: Boolean = false): Unit = { Seq("true", "false").foreach { v => withSQLConf(SQLConf.ANSI_ENABLED.key -> v) { - checkEvaluation(DivideInterval(Literal(stringToInterval(interval)), Literal(num)), - if (expected == null) null else stringToInterval(expected)) + if (checkException) { + checkExceptionInExpression[ArithmeticException]( + DivideInterval(Literal(stringToInterval(interval)), Literal(num)), + "integer overflow") + } else { + checkEvaluation(DivideInterval(Literal(stringToInterval(interval)), Literal(num)), + if (expected == null) null else stringToInterval(expected)) + } } } } @@ -243,6 +256,7 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { check("-1 month 2 microseconds", -0.25, "4 months -8 microseconds") check("1 month 3 microsecond", 1.5, "20 days 2 microseconds") check("1 second", 0, null) + check(s"${Int.MaxValue} months", 0.9, null, checkException = true) } test("make interval") { From e37caaf4115da2cdb7843b4595ec779bad9342b2 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 2 Jan 2020 16:35:42 +0800 Subject: [PATCH 17/19] unify --- .../expressions/aggregate/Average.scala | 6 ++-- .../sql/catalyst/expressions/arithmetic.scala | 9 ++--- .../expressions/intervalExpressions.scala | 36 +++++++------------ .../IntervalExpressionsSuite.scala | 12 +++---- .../sql-tests/results/ansi/interval.sql.out | 5 +-- .../sql-tests/results/interval.sql.out | 5 +-- 6 files changed, 28 insertions(+), 45 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 9bb048a9851e..996c548e1329 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, FunctionRegistry, TypeCheckResult} +import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, FunctionRegistry} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @ExpressionDescription( @@ -81,7 +80,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit case _: DecimalType => DecimalPrecision.decimalAndDecimal(sum / count.cast(DecimalType.LongDecimal)).cast(resultType) case CalendarIntervalType => - DivideInterval(sum.cast(resultType), count.cast(DoubleType)) + val newCount = If(EqualTo(count, Literal(0L)), Literal(null, LongType), count) + DivideInterval(sum.cast(resultType), newCount.cast(DoubleType)) case _ => sum.cast(resultType) / count.cast(resultType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index cb3b16d751e4..debd7c89adb9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -75,13 +75,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression """}) case _: CalendarIntervalType => val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") - defineCodeGen(ctx, ev, - interval => if (checkOverflow) { - s"$iu.negateExact($interval)" - } else { - s"$iu.negate($interval)" - } - ) + val method = if (checkOverflow) "negateExact" else "negate" + defineCodeGen(ctx, ev, c => s"$iu.$method($c)") } protected override def nullSafeEval(input: Any): Any = dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index a60f516af613..bbeeb2aa9c1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -110,7 +110,11 @@ object ExtractIntervalPart { } } -abstract class IntervalNumOperation(interval: Expression, num: Expression, operationName: String) +abstract class IntervalNumOperation( + interval: Expression, + num: Expression, + operation: (CalendarInterval, Double) => CalendarInterval, + operationName: String) extends BinaryExpression with ImplicitCastInputTypes with Serializable { override def left: Expression = interval @@ -122,39 +126,23 @@ abstract class IntervalNumOperation(interval: Expression, num: Expression, opera override def nullable: Boolean = true - override def nullSafeEval(interval: Any, num: Any): Any = operationName match { - case "multiply" => - multiplyExact(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) - case "divide" => - if (num == 0) return null - divideExact(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) + override def nullSafeEval(interval: Any, num: Any): Any = { + operation(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double]) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val methodStr: String = - IntervalUtils.getClass.getName.stripSuffix("$") + "." + operationName + "Exact" - operationName match { - case "multiply" => defineCodeGen(ctx, ev, (interval, num) => s"$methodStr($interval, $num)") - case "divide" => nullSafeCodeGen(ctx, ev, (interval, num) => { - s""" - |if ($num == 0) { - | ${ev.isNull} = true; - |} else { - | ${ev.value} = $methodStr($interval, $num); - |} - |""".stripMargin - }) - } + val iu = IntervalUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (interval, num) => s"$iu.$operationName($interval, $num)") } - override def prettyName: String = operationName + "_interval" + override def prettyName: String = operationName.stripSuffix("Exact") + "_interval" } case class MultiplyInterval(interval: Expression, num: Expression) - extends IntervalNumOperation(interval, num, "multiply") + extends IntervalNumOperation(interval, num, multiplyExact, "multiplyExact") case class DivideInterval(interval: Expression, num: Expression) - extends IntervalNumOperation(interval, num, "divide") + extends IntervalNumOperation(interval, num, divideExact, "divideExact") // scalastyle:off line.size.limit @ExpressionDescription( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala index 838781325741..4bad1e7fc76c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala @@ -207,8 +207,7 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { withSQLConf(SQLConf.ANSI_ENABLED.key -> v) { if (checkException) { checkExceptionInExpression[ArithmeticException]( - MultiplyInterval(Literal(stringToInterval(interval)), Literal(num)), - "integer overflow") + MultiplyInterval(Literal(stringToInterval(interval)), Literal(num)), expected) } else { checkEvaluation(MultiplyInterval(Literal(stringToInterval(interval)), Literal(num)), if (expected == null) null else stringToInterval(expected)) @@ -225,7 +224,7 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { check("-100 years -1 millisecond", 0.5, "-50 years -500 microseconds") check("2 months 4 seconds", -0.5, "-1 months -2 seconds") check("1 month 2 microseconds", 1.5, "1 months 15 days 3 microseconds") - check("2 months", Int.MaxValue, null, checkException = true) + check("2 months", Int.MaxValue, "integer overflow", checkException = true) } test("divide") { @@ -238,8 +237,7 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { withSQLConf(SQLConf.ANSI_ENABLED.key -> v) { if (checkException) { checkExceptionInExpression[ArithmeticException]( - DivideInterval(Literal(stringToInterval(interval)), Literal(num)), - "integer overflow") + DivideInterval(Literal(stringToInterval(interval)), Literal(num)), expected) } else { checkEvaluation(DivideInterval(Literal(stringToInterval(interval)), Literal(num)), if (expected == null) null else stringToInterval(expected)) @@ -255,8 +253,8 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { check("2 years -8 seconds", 0.5, "4 years -16 seconds") check("-1 month 2 microseconds", -0.25, "4 months -8 microseconds") check("1 month 3 microsecond", 1.5, "20 days 2 microseconds") - check("1 second", 0, null) - check(s"${Int.MaxValue} months", 0.9, null, checkException = true) + check("1 second", 0, "divide by zero", checkException = true) + check(s"${Int.MaxValue} months", 0.9, "integer overflow", checkException = true) } test("make interval") { diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index 5697a5a6e882..4fceb6b255b0 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -207,9 +207,10 @@ struct +struct<> -- !query 25 output -NULL +java.lang.ArithmeticException +divide by zero -- !query 26 diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out index a98ada7eb2b6..1c84bb4502f0 100644 --- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out @@ -207,9 +207,10 @@ struct +struct<> -- !query 25 output -NULL +java.lang.ArithmeticException +divide by zero -- !query 26 From 988b51c27fcc32ebfa8e8379ac4ae2913626c66d Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 2 Jan 2020 16:38:53 +0800 Subject: [PATCH 18/19] blank lines --- .../spark/sql/catalyst/expressions/intervalExpressions.scala | 2 -- .../org/apache/spark/sql/catalyst/util/IntervalUtils.scala | 1 - 2 files changed, 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index bbeeb2aa9c1f..831510e7f0f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -116,12 +116,10 @@ abstract class IntervalNumOperation( operation: (CalendarInterval, Double) => CalendarInterval, operationName: String) extends BinaryExpression with ImplicitCastInputTypes with Serializable { - override def left: Expression = interval override def right: Expression = num override def inputTypes: Seq[AbstractDataType] = Seq(CalendarIntervalType, DoubleType) - override def dataType: DataType = CalendarIntervalType override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index e4dc4ad399ae..8763f24b05ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -438,7 +438,6 @@ object IntervalUtils { * Return a new calendar interval instance of the sum of two intervals. * * @throws ArithmeticException if the result overflows any field value - * */ def addExact(left: CalendarInterval, right: CalendarInterval): CalendarInterval = { val months = Math.addExact(left.months, right.months) From f80f0f37c0e26eba797ee9ab047e7df3ec4ffb1f Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 2 Jan 2020 17:09:13 +0800 Subject: [PATCH 19/19] refine tests --- .../IntervalExpressionsSuite.scala | 40 ++++++++----------- 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala index 4bad1e7fc76c..d31a0e210552 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala @@ -21,7 +21,7 @@ import scala.language.implicitConversions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.DateTimeConstants._ -import org.apache.spark.sql.catalyst.util.IntervalUtils.stringToInterval +import org.apache.spark.sql.catalyst.util.IntervalUtils.{safeStringToInterval, stringToInterval} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.Decimal import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -198,19 +198,15 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("multiply") { - def check( - interval: String, - num: Double, - expected: String, - checkException: Boolean = false): Unit = { + def check(interval: String, num: Double, expected: String): Unit = { + val expr = MultiplyInterval(Literal(stringToInterval(interval)), Literal(num)) + val expectedRes = safeStringToInterval(expected) Seq("true", "false").foreach { v => withSQLConf(SQLConf.ANSI_ENABLED.key -> v) { - if (checkException) { - checkExceptionInExpression[ArithmeticException]( - MultiplyInterval(Literal(stringToInterval(interval)), Literal(num)), expected) + if (expectedRes == null) { + checkExceptionInExpression[ArithmeticException](expr, expected) } else { - checkEvaluation(MultiplyInterval(Literal(stringToInterval(interval)), Literal(num)), - if (expected == null) null else stringToInterval(expected)) + checkEvaluation(expr, expectedRes) } } } @@ -224,23 +220,19 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { check("-100 years -1 millisecond", 0.5, "-50 years -500 microseconds") check("2 months 4 seconds", -0.5, "-1 months -2 seconds") check("1 month 2 microseconds", 1.5, "1 months 15 days 3 microseconds") - check("2 months", Int.MaxValue, "integer overflow", checkException = true) + check("2 months", Int.MaxValue, "integer overflow") } test("divide") { - def check( - interval: String, - num: Double, - expected: String, - checkException: Boolean = false): Unit = { + def check(interval: String, num: Double, expected: String): Unit = { + val expr = DivideInterval(Literal(stringToInterval(interval)), Literal(num)) + val expectedRes = safeStringToInterval(expected) Seq("true", "false").foreach { v => withSQLConf(SQLConf.ANSI_ENABLED.key -> v) { - if (checkException) { - checkExceptionInExpression[ArithmeticException]( - DivideInterval(Literal(stringToInterval(interval)), Literal(num)), expected) + if (expectedRes == null) { + checkExceptionInExpression[ArithmeticException](expr, expected) } else { - checkEvaluation(DivideInterval(Literal(stringToInterval(interval)), Literal(num)), - if (expected == null) null else stringToInterval(expected)) + checkEvaluation(expr, expectedRes) } } } @@ -253,8 +245,8 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { check("2 years -8 seconds", 0.5, "4 years -16 seconds") check("-1 month 2 microseconds", -0.25, "4 months -8 microseconds") check("1 month 3 microsecond", 1.5, "20 days 2 microseconds") - check("1 second", 0, "divide by zero", checkException = true) - check(s"${Int.MaxValue} months", 0.9, "integer overflow", checkException = true) + check("1 second", 0, "divide by zero") + check(s"${Int.MaxValue} months", 0.9, "integer overflow") } test("make interval") {