Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ object TypeCoercion {
IfCoercion ::
StackCoercion ::
Division ::
IntegralDivision ::
ImplicitTypeCasts ::
DateTimeOperations ::
WindowFrameCoercion ::
Expand Down Expand Up @@ -684,6 +685,24 @@ object TypeCoercion {
}
}

/**
* Since SPARK-16323 `IntegralDivide` always returns long-type value.
* This rule cast the integral inputs to long type, to avoid overflow during calculation.
*/
object IntegralDivision extends TypeCoercionRule {
override protected def coerceTypes(
plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
case e if !e.childrenResolved => e
case d @ IntegralDivide(left, right) =>
IntegralDivide(castToLong(left), castToLong(right))
}

def castToLong(expr: Expression): Expression = expr.dataType match {
case _: ByteType | _: ShortType | _: IntegerType => Cast(expr, LongType)
case _ => expr
}
}

/**
* Coerces the type of different branches of a CASE WHEN statement to a common type.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ case class IntegralDivide(
left: Expression,
right: Expression) extends DivModLike {

override def inputType: AbstractDataType = TypeCollection(IntegralType, DecimalType)
override def inputType: AbstractDataType = TypeCollection(LongType, DecimalType)

override def dataType: DataType = LongType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3495,6 +3495,14 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
assert(df4.schema.head.name === "randn(1)")
checkIfSeedExistsInExplain(df2)
}

test("SPARK-31761: test byte, short, integer overflow for (Divide) integral type ") {
checkAnswer(sql("Select -2147483648 DIV -1"), Seq(Row(Integer.MIN_VALUE.toLong * -1)))
checkAnswer(sql("select CAST(-128 as Byte) DIV CAST (-1 as Byte)"),
Seq(Row(Byte.MinValue.toLong * -1)))
checkAnswer(sql("select CAST(-32768 as short) DIV CAST (-1 as short)"),
Seq(Row(Short.MinValue.toLong * -1)))
}
}

case class Foo(bar: Option[String])