Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ object TypeCoercion {
IfCoercion ::
StackCoercion ::
Division ::
IntegralDivision ::
ImplicitTypeCasts ::
DateTimeOperations ::
WindowFrameCoercion ::
Expand Down Expand Up @@ -684,6 +685,23 @@ object TypeCoercion {
}
}

/**
* The DIV operator always returns long-type value.
* This rule cast the integral inputs to long type, to avoid overflow during calculation.
*/
object IntegralDivision extends TypeCoercionRule {
override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
case e if !e.childrenResolved => e
case d @ IntegralDivide(left, right) =>
IntegralDivide(mayCastToLong(left), mayCastToLong(right))
}

private def mayCastToLong(expr: Expression): Expression = expr.dataType match {
case _: ByteType | _: ShortType | _: IntegerType => Cast(expr, LongType)
case _ => expr
}
}

/**
* Coerces the type of different branches of a CASE WHEN statement to a common type.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ case class IntegralDivide(
left: Expression,
right: Expression) extends DivModLike {

override def inputType: AbstractDataType = TypeCollection(IntegralType, DecimalType)
override def inputType: AbstractDataType = TypeCollection(LongType, DecimalType)

override def dataType: DataType = LongType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1559,6 +1559,30 @@ class TypeCoercionSuite extends AnalysisTest {
Literal.create(null, DecimalType.SYSTEM_DEFAULT)))
}
}

test("SPARK-31761: byte, short and int should be cast to long for IntegralDivide's datatype") {
val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts)
// Casts Byte to Long
ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2.toByte, 1.toByte),
IntegralDivide(Cast(2.toByte, LongType), Cast(1.toByte, LongType)))
// Casts Short to Long
ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2.toShort, 1.toShort),
IntegralDivide(Cast(2.toShort, LongType), Cast(1.toShort, LongType)))
// Casts Integer to Long
ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2, 1),
IntegralDivide(Cast(2, LongType), Cast(1, LongType)))
// should not be any change for Long data types
ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2L, 1L), IntegralDivide(2L, 1L))
// one of the operand is byte
ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2L, 1.toByte),
IntegralDivide(2L, Cast(1.toByte, LongType)))
// one of the operand is short
ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2.toShort, 1L),
IntegralDivide(Cast(2.toShort, LongType), 1L))
// one of the operand is int
ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2, 1L),
IntegralDivide(Cast(2, LongType), 1L))
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,8 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
}
}

test("/ (Divide) for integral type") {
checkEvaluation(IntegralDivide(Literal(1.toByte), Literal(2.toByte)), 0L)
checkEvaluation(IntegralDivide(Literal(1.toShort), Literal(2.toShort)), 0L)
checkEvaluation(IntegralDivide(Literal(1), Literal(2)), 0L)
test("/ (Divide) for Long type") {
checkEvaluation(IntegralDivide(Literal(1.toLong), Literal(2.toLong)), 0L)
checkEvaluation(IntegralDivide(positiveShortLit, negativeShortLit), 0L)
checkEvaluation(IntegralDivide(positiveIntLit, negativeIntLit), 0L)
checkEvaluation(IntegralDivide(positiveLongLit, negativeLongLit), 0L)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@
| org.apache.spark.sql.catalyst.expressions.InputFileBlockLength | input_file_block_length | N/A | N/A |
| org.apache.spark.sql.catalyst.expressions.InputFileBlockStart | input_file_block_start | N/A | N/A |
| org.apache.spark.sql.catalyst.expressions.InputFileName | input_file_name | N/A | N/A |
| org.apache.spark.sql.catalyst.expressions.IntegralDivide | div | SELECT 3 div 2 | struct<(3 div 2):bigint> |
| org.apache.spark.sql.catalyst.expressions.IntegralDivide | div | SELECT 3 div 2 | struct<(CAST(3 AS BIGINT) div CAST(2 AS BIGINT)):bigint> |
| org.apache.spark.sql.catalyst.expressions.IsNaN | isnan | SELECT isnan(cast('NaN' as double)) | struct<isnan(CAST(NaN AS DOUBLE)):boolean> |
| org.apache.spark.sql.catalyst.expressions.IsNotNull | isnotnull | SELECT isnotnull(1) | struct<(1 IS NOT NULL):boolean> |
| org.apache.spark.sql.catalyst.expressions.IsNull | isnull | SELECT isnull(1) | struct<(1 IS NULL):boolean> |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,31 +157,31 @@ NULL
-- !query
select 5 div 2
-- !query schema
struct<(5 div 2):bigint>
struct<(CAST(5 AS BIGINT) div CAST(2 AS BIGINT)):bigint>
-- !query output
2


-- !query
select 5 div 0
-- !query schema
struct<(5 div 0):bigint>
struct<(CAST(5 AS BIGINT) div CAST(0 AS BIGINT)):bigint>
-- !query output
NULL


-- !query
select 5 div null
-- !query schema
struct<(5 div CAST(NULL AS INT)):bigint>
struct<(CAST(5 AS BIGINT) div CAST(NULL AS BIGINT)):bigint>
-- !query output
NULL


-- !query
select null div 5
-- !query schema
struct<(CAST(NULL AS INT) div 5):bigint>
struct<(CAST(NULL AS BIGINT) div CAST(5 AS BIGINT)):bigint>
-- !query output
NULL

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3495,6 +3495,14 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
assert(df4.schema.head.name === "randn(1)")
checkIfSeedExistsInExplain(df2)
}

test("SPARK-31761: test byte, short, integer overflow for (Divide) integral type") {
checkAnswer(sql("Select -2147483648 DIV -1"), Seq(Row(Integer.MIN_VALUE.toLong * -1)))
checkAnswer(sql("select CAST(-128 as Byte) DIV CAST (-1 as Byte)"),
Seq(Row(Byte.MinValue.toLong * -1)))
checkAnswer(sql("select CAST(-32768 as short) DIV CAST (-1 as short)"),
Seq(Row(Short.MinValue.toLong * -1)))
}
}

case class Foo(bar: Option[String])