diff --git a/presto-main/src/main/java/com/facebook/presto/type/IntervalYearMonthOperators.java b/presto-main/src/main/java/com/facebook/presto/type/IntervalYearMonthOperators.java index 7dbd1b197a6ff..ba03ace1645e9 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/IntervalYearMonthOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/IntervalYearMonthOperators.java @@ -17,6 +17,7 @@ import com.facebook.presto.common.block.Block; import com.facebook.presto.common.type.AbstractIntType; import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.function.BlockIndex; import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.IsNull; @@ -42,6 +43,7 @@ import static com.facebook.presto.common.function.OperatorType.NEGATION; import static com.facebook.presto.common.function.OperatorType.NOT_EQUAL; import static com.facebook.presto.common.function.OperatorType.SUBTRACT; +import static com.facebook.presto.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE; import static com.facebook.presto.type.IntervalYearMonthType.INTERVAL_YEAR_MONTH; import static io.airlift.slice.Slices.utf8Slice; import static java.lang.Math.toIntExact; @@ -56,49 +58,81 @@ private IntervalYearMonthOperators() @SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) public static long add(@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long left, @SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long right) { - return left + right; + try { + return Math.addExact((int) left, (int) right); + } + catch (ArithmeticException e) { + throw new PrestoException(NUMERIC_VALUE_OUT_OF_RANGE, "Overflow adding interval year-month values: " + left + " + " + right); + } } @ScalarOperator(SUBTRACT) @SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) public static long subtract(@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long left, @SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long right) { - return left - right; + try { + return Math.subtractExact((int) left, (int) right); + } + catch (ArithmeticException e) { + throw new PrestoException(NUMERIC_VALUE_OUT_OF_RANGE, "Overflow subtracting interval year-month values: " + left + " - " + right); + } } @ScalarOperator(MULTIPLY) @SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) - public static long multiplyByBigint(@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long left, @SqlType(StandardTypes.BIGINT) long right) + public static long multiplyByInteger(@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long left, @SqlType(StandardTypes.INTEGER) long right) { - return left * right; + try { + return Math.multiplyExact((int) left, (int) right); + } + catch (ArithmeticException e) { + throw new PrestoException(NUMERIC_VALUE_OUT_OF_RANGE, "Overflow multiplying interval year-month value by integer: " + left + " * " + right); + } } @ScalarOperator(MULTIPLY) @SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) public static long multiplyByDouble(@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long left, @SqlType(StandardTypes.DOUBLE) double right) { - return (long) (left * right); + long result = (long) (left * right); + if (result < Integer.MIN_VALUE || result > Integer.MAX_VALUE) { + throw new PrestoException(NUMERIC_VALUE_OUT_OF_RANGE, "Overflow multiplying interval year-month value by double: " + left + " * " + right); + } + return result; } @ScalarOperator(MULTIPLY) @SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) - public static long bigintMultiply(@SqlType(StandardTypes.BIGINT) long left, @SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long right) + public static long integerMultiply(@SqlType(StandardTypes.INTEGER) long left, @SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long right) { - return left * right; + try { + return Math.multiplyExact((int) left, (int) right); + } + catch (ArithmeticException e) { + throw new PrestoException(NUMERIC_VALUE_OUT_OF_RANGE, "Overflow multiplying integer by interval year-month value: " + left + " * " + right); + } } @ScalarOperator(MULTIPLY) @SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) public static long doubleMultiply(@SqlType(StandardTypes.DOUBLE) double left, @SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long right) { - return (long) (left * right); + long result = (long) (left * right); + if (result < Integer.MIN_VALUE || result > Integer.MAX_VALUE) { + throw new PrestoException(NUMERIC_VALUE_OUT_OF_RANGE, "Overflow multiplying double by interval year-month value: " + left + " * " + right); + } + return result; } @ScalarOperator(DIVIDE) @SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) public static long divideByDouble(@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long left, @SqlType(StandardTypes.DOUBLE) double right) { - return (long) (left / right); + long result = (long) (left / right); + if (result < Integer.MIN_VALUE || result > Integer.MAX_VALUE) { + throw new PrestoException(NUMERIC_VALUE_OUT_OF_RANGE, "Overflow dividing interval year-month value by double: " + left + " / " + right); + } + return result; } @ScalarOperator(NEGATION) diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestIntervalYearMonth.java b/presto-main/src/test/java/com/facebook/presto/type/TestIntervalYearMonth.java index 66c551519dd5f..2e2f80db353de 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestIntervalYearMonth.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestIntervalYearMonth.java @@ -28,6 +28,7 @@ public class TestIntervalYearMonth extends AbstractTestFunctions { private static final int MAX_SHORT = Short.MAX_VALUE; + private static final long MAX_INT_PLUS_1 = Integer.MAX_VALUE + 1L; @Test public void testObject() @@ -74,6 +75,7 @@ public void testInvalidLiteral() assertInvalidFunction("INTERVAL '124-X' YEAR TO MONTH", "Invalid INTERVAL YEAR TO MONTH value: 124-X"); assertInvalidFunction("INTERVAL '124--30' YEAR TO MONTH", "Invalid INTERVAL YEAR TO MONTH value: 124--30"); assertInvalidFunction("INTERVAL '--124--30' YEAR TO MONTH", "Invalid INTERVAL YEAR TO MONTH value: --124--30"); + assertInvalidFunction(format("INTERVAL '%s' MONTH", MAX_INT_PLUS_1), "Invalid INTERVAL MONTH value: " + MAX_INT_PLUS_1); } @Test @@ -82,6 +84,7 @@ public void testAdd() assertFunction("INTERVAL '3' MONTH + INTERVAL '3' MONTH", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(6)); assertFunction("INTERVAL '6' YEAR + INTERVAL '6' YEAR", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(12 * 12)); assertFunction("INTERVAL '3' MONTH + INTERVAL '6' YEAR", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth((6 * 12) + (3))); + assertNumericOverflow(format("INTERVAL '%s' MONTH + INTERVAL '1' MONTH", Integer.MAX_VALUE), format("Overflow adding interval year-month values: %s + 1", Integer.MAX_VALUE)); } @Test @@ -90,6 +93,7 @@ public void testSubtract() assertFunction("INTERVAL '6' MONTH - INTERVAL '3' MONTH", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(3)); assertFunction("INTERVAL '9' YEAR - INTERVAL '6' YEAR", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(3 * 12)); assertFunction("INTERVAL '3' MONTH - INTERVAL '6' YEAR", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth((3) - (6 * 12))); + assertNumericOverflow(format("-INTERVAL '%s' MONTH - INTERVAL '2' MONTH", Integer.MAX_VALUE), format("Overflow subtracting interval year-month values: -%s - 2", Integer.MAX_VALUE)); } @Test @@ -104,6 +108,13 @@ public void testMultiply() assertFunction("2 * INTERVAL '6' YEAR", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(12 * 12)); assertFunction("INTERVAL '1' YEAR * 2.5", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth((int) (2.5 * 12))); assertFunction("2.5 * INTERVAL '1' YEAR", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth((int) (2.5 * 12))); + + assertNumericOverflow(format("INTERVAL '%s' MONTH * 2", Integer.MAX_VALUE), format("Overflow multiplying interval year-month value by integer: %s * 2", Integer.MAX_VALUE)); + assertNumericOverflow(format("2 * INTERVAL '%s' MONTH", Integer.MAX_VALUE), format("Overflow multiplying integer by interval year-month value: 2 * %s", Integer.MAX_VALUE)); + assertNumericOverflow(format("INTERVAL '%s' MONTH * 2.0", Integer.MAX_VALUE), format("Overflow multiplying interval year-month value by double: %s * 2.0", Integer.MAX_VALUE)); + assertNumericOverflow(format("DOUBLE '2' * INTERVAL '%s' MONTH", Integer.MAX_VALUE), format("Overflow multiplying double by interval year-month value: 2.0 * %s", Integer.MAX_VALUE)); + assertNumericOverflow(format("INTERVAL '2' YEAR * %s", (long) Integer.MAX_VALUE + 1), format("Overflow multiplying interval year-month value by double: 24 * %s", (double) ((long) Integer.MAX_VALUE + 1))); + assertNumericOverflow(format("%s * INTERVAL '2' YEAR", (long) Integer.MAX_VALUE + 1), format("Overflow multiplying double by interval year-month value: %s * 24", (double) ((long) Integer.MAX_VALUE + 1))); } @Test @@ -114,6 +125,8 @@ public void testDivide() assertFunction("INTERVAL '3' YEAR / 2", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(18)); assertFunction("INTERVAL '4' YEAR / 4.8", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(10)); + + assertNumericOverflow(format("INTERVAL '%s' MONTH / 0.5", Integer.MAX_VALUE), format("Overflow dividing interval year-month value by double: %s / 0.5", Integer.MAX_VALUE)); } @Test