diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index d970bf466fb8..3484108a5503 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -61,6 +61,7 @@ object TypeCoercion { IfCoercion :: StackCoercion :: Division :: + IntegralDivision :: ImplicitTypeCasts :: DateTimeOperations :: WindowFrameCoercion :: @@ -684,6 +685,23 @@ object TypeCoercion { } } + /** + * The DIV operator always returns long-type value. + * This rule cast the integral inputs to long type, to avoid overflow during calculation. + */ + object IntegralDivision extends TypeCoercionRule { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + case e if !e.childrenResolved => e + case d @ IntegralDivide(left, right) => + IntegralDivide(mayCastToLong(left), mayCastToLong(right)) + } + + private def mayCastToLong(expr: Expression): Expression = expr.dataType match { + case _: ByteType | _: ShortType | _: IntegerType => Cast(expr, LongType) + case _ => expr + } + } + /** * Coerces the type of different branches of a CASE WHEN statement to a common type. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 354845d5ccd8..7c521838447d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -412,7 +412,7 @@ case class IntegralDivide( left: Expression, right: Expression) extends DivModLike { - override def inputType: AbstractDataType = TypeCollection(IntegralType, DecimalType) + override def inputType: AbstractDataType = TypeCollection(LongType, DecimalType) override def dataType: DataType = LongType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index e37555f1c0ec..1ea1ddb8bbd0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1559,6 +1559,30 @@ class TypeCoercionSuite extends AnalysisTest { Literal.create(null, DecimalType.SYSTEM_DEFAULT))) } } + + test("SPARK-31761: byte, short and int should be cast to long for IntegralDivide's datatype") { + val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts) + // Casts Byte to Long + ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2.toByte, 1.toByte), + IntegralDivide(Cast(2.toByte, LongType), Cast(1.toByte, LongType))) + // Casts Short to Long + ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2.toShort, 1.toShort), + IntegralDivide(Cast(2.toShort, LongType), Cast(1.toShort, LongType))) + // Casts Integer to Long + ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2, 1), + IntegralDivide(Cast(2, LongType), Cast(1, LongType))) + // should not be any change for Long data types + ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2L, 1L), IntegralDivide(2L, 1L)) + // one of the operand is byte + ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2L, 1.toByte), + IntegralDivide(2L, Cast(1.toByte, LongType))) + // one of the operand is short + ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2.toShort, 1L), + IntegralDivide(Cast(2.toShort, LongType), 1L)) + // one of the operand is int + ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2, 1L), + IntegralDivide(Cast(2, LongType), 1L)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 675f85f9e82e..f05598aeb535 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -173,13 +173,8 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } - test("/ (Divide) for integral type") { - checkEvaluation(IntegralDivide(Literal(1.toByte), Literal(2.toByte)), 0L) - checkEvaluation(IntegralDivide(Literal(1.toShort), Literal(2.toShort)), 0L) - checkEvaluation(IntegralDivide(Literal(1), Literal(2)), 0L) + test("/ (Divide) for Long type") { checkEvaluation(IntegralDivide(Literal(1.toLong), Literal(2.toLong)), 0L) - checkEvaluation(IntegralDivide(positiveShortLit, negativeShortLit), 0L) - checkEvaluation(IntegralDivide(positiveIntLit, negativeIntLit), 0L) checkEvaluation(IntegralDivide(positiveLongLit, negativeLongLit), 0L) } diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 3570fb61e288..3420bc521993 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -136,7 +136,7 @@ | org.apache.spark.sql.catalyst.expressions.InputFileBlockLength | input_file_block_length | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.InputFileBlockStart | input_file_block_start | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.InputFileName | input_file_name | N/A | N/A | -| org.apache.spark.sql.catalyst.expressions.IntegralDivide | div | SELECT 3 div 2 | struct<(3 div 2):bigint> | +| org.apache.spark.sql.catalyst.expressions.IntegralDivide | div | SELECT 3 div 2 | struct<(CAST(3 AS BIGINT) div CAST(2 AS BIGINT)):bigint> | | org.apache.spark.sql.catalyst.expressions.IsNaN | isnan | SELECT isnan(cast('NaN' as double)) | struct | | org.apache.spark.sql.catalyst.expressions.IsNotNull | isnotnull | SELECT isnotnull(1) | struct<(1 IS NOT NULL):boolean> | | org.apache.spark.sql.catalyst.expressions.IsNull | isnull | SELECT isnull(1) | struct<(1 IS NULL):boolean> | diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index a94a123b1b8a..9accc57d0bf6 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -157,7 +157,7 @@ NULL -- !query select 5 div 2 -- !query schema -struct<(5 div 2):bigint> +struct<(CAST(5 AS BIGINT) div CAST(2 AS BIGINT)):bigint> -- !query output 2 @@ -165,7 +165,7 @@ struct<(5 div 2):bigint> -- !query select 5 div 0 -- !query schema -struct<(5 div 0):bigint> +struct<(CAST(5 AS BIGINT) div CAST(0 AS BIGINT)):bigint> -- !query output NULL @@ -173,7 +173,7 @@ NULL -- !query select 5 div null -- !query schema -struct<(5 div CAST(NULL AS INT)):bigint> +struct<(CAST(5 AS BIGINT) div CAST(NULL AS BIGINT)):bigint> -- !query output NULL @@ -181,7 +181,7 @@ NULL -- !query select null div 5 -- !query schema -struct<(CAST(NULL AS INT) div 5):bigint> +struct<(CAST(NULL AS BIGINT) div CAST(5 AS BIGINT)):bigint> -- !query output NULL diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 788a07370195..f7a904169d6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3495,6 +3495,14 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark assert(df4.schema.head.name === "randn(1)") checkIfSeedExistsInExplain(df2) } + + test("SPARK-31761: test byte, short, integer overflow for (Divide) integral type") { + checkAnswer(sql("Select -2147483648 DIV -1"), Seq(Row(Integer.MIN_VALUE.toLong * -1))) + checkAnswer(sql("select CAST(-128 as Byte) DIV CAST (-1 as Byte)"), + Seq(Row(Byte.MinValue.toLong * -1))) + checkAnswer(sql("select CAST(-32768 as short) DIV CAST (-1 as short)"), + Seq(Row(Short.MinValue.toLong * -1))) + } } case class Foo(bar: Option[String])