Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE
import org.apache.spark.sql.catalyst.util.IntervalUtils.{dayTimeIntervalToByte, dayTimeIntervalToInt, dayTimeIntervalToLong, dayTimeIntervalToShort, yearMonthIntervalToByte, yearMonthIntervalToInt, yearMonthIntervalToShort}
import org.apache.spark.sql.catalyst.util.IntervalUtils.{dayTimeIntervalToByte, dayTimeIntervalToDecimal, dayTimeIntervalToInt, dayTimeIntervalToLong, dayTimeIntervalToShort, yearMonthIntervalToByte, yearMonthIntervalToInt, yearMonthIntervalToShort}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -110,7 +110,7 @@ object Cast {
case (StringType, _: CalendarIntervalType) => true
case (StringType, _: AnsiIntervalType) => true

case (_: AnsiIntervalType, _: IntegralType) => true
case (_: AnsiIntervalType, _: IntegralType | _: DecimalType) => true

case (_: DayTimeIntervalType, _: DayTimeIntervalType) => true
case (_: YearMonthIntervalType, _: YearMonthIntervalType) => true
Expand Down Expand Up @@ -194,8 +194,7 @@ object Cast {

case (_: DayTimeIntervalType, _: DayTimeIntervalType) => true
case (_: YearMonthIntervalType, _: YearMonthIntervalType) => true
case (_: DayTimeIntervalType, _: IntegralType) => true
case (_: YearMonthIntervalType, _: IntegralType) => true
case (_: AnsiIntervalType, _: IntegralType | _: DecimalType) => true

case (StringType, _: NumericType) => true
case (BooleanType, _: NumericType) => true
Expand Down Expand Up @@ -967,10 +966,17 @@ case class Cast(
* NOTE: this modifies `value` in-place, so don't call it on external data.
*/
private[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = {
changePrecision(value, decimalType, !ansiEnabled)
}

private[this] def changePrecision(
value: Decimal,
decimalType: DecimalType,
nullOnOverflow: Boolean): Decimal = {
if (value.changePrecision(decimalType.precision, decimalType.scale)) {
value
} else {
if (!ansiEnabled) {
if (nullOnOverflow) {
null
} else {
throw QueryExecutionErrors.cannotChangeDecimalPrecisionError(
Expand Down Expand Up @@ -1015,6 +1021,18 @@ case class Cast(
} catch {
case _: NumberFormatException => null
}
case x: DayTimeIntervalType =>
buildCast[Long](_, dt =>
changePrecision(
value = dayTimeIntervalToDecimal(dt, x.endField),
decimalType = target,
nullOnOverflow = false))
case x: YearMonthIntervalType =>
buildCast[Int](_, ym =>
changePrecision(
value = Decimal(yearMonthIntervalToInt(ym, x.startField, x.endField)),
decimalType = target,
nullOnOverflow = false))
}

// DoubleConverter
Expand Down Expand Up @@ -1515,14 +1533,15 @@ case class Cast(
evPrim: ExprValue,
evNull: ExprValue,
canNullSafeCast: Boolean,
ctx: CodegenContext): Block = {
ctx: CodegenContext,
nullOnOverflow: Boolean): Block = {
if (canNullSafeCast) {
code"""
|$d.changePrecision(${decimalType.precision}, ${decimalType.scale});
|$evPrim = $d;
""".stripMargin
} else {
val overflowCode = if (!ansiEnabled) {
val overflowCode = if (nullOnOverflow) {
s"$evNull = true;"
} else {
s"""
Expand All @@ -1540,6 +1559,16 @@ case class Cast(
}
}

private[this] def changePrecision(
d: ExprValue,
decimalType: DecimalType,
evPrim: ExprValue,
evNull: ExprValue,
canNullSafeCast: Boolean,
ctx: CodegenContext): Block = {
changePrecision(d, decimalType, evPrim, evNull, canNullSafeCast, ctx, !ansiEnabled)
}

private[this] def castToDecimalCode(
from: DataType,
target: DecimalType,
Expand Down Expand Up @@ -1605,6 +1634,22 @@ case class Cast(
$evNull = true;
}
"""
case x: DayTimeIntervalType =>
(c, evPrim, evNull) =>
val u = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
code"""
Decimal $tmp = $u.dayTimeIntervalToDecimal($c, (byte)${x.endField});
${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast, ctx, false)}
"""
case x: YearMonthIntervalType =>
(c, evPrim, evNull) =>
val u = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
val tmpYm = ctx.freshVariable("tmpYm", classOf[Int])
code"""
int $tmpYm = $u.yearMonthIntervalToInt($c, (byte)${x.startField}, (byte)${x.endField});
Decimal $tmp = Decimal.apply($tmpYm);
${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast, ctx, false)}
"""
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,15 @@ object IntervalUtils {
}
}

def dayTimeIntervalToDecimal(v: Long, endField: Byte): Decimal = {
endField match {
case DAY => Decimal(v / MICROS_PER_DAY)
case HOUR => Decimal(v / MICROS_PER_HOUR)
case MINUTE => Decimal(v / MICROS_PER_MINUTE)
case SECOND => Decimal(v, Decimal.MAX_LONG_DIGITS, 6)
}
}

def dayTimeIntervalToInt(v: Long, startField: Byte, endField: Byte): Int = {
val vLong = dayTimeIntervalToLong(v, startField, endField)
val vInt = vLong.toInt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1272,4 +1272,37 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
"to restore the behavior before Spark 3.0."))
}
}

test("cast ANSI intervals to decimals") {
Seq(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need examples with rounding:
INTERVAL '12.123' SECOND AS DECIMAL(3, 1) => 12.1
INTERVAL '12.005' SECOND AS DECIMAL(4, 2) => 12.01

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we don't round and don't loose info. See the last check that I added:

select cast(interval '10.123' second as decimal(1, 0))
-- !query schema
struct<>
-- !query output
org.apache.spark.SparkArithmeticException
[CANNOT_CHANGE_DECIMAL_PRECISION] Decimal(compact, 10, 18, 6) cannot be represented as Decimal(1, 0). If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
== SQL(line 1, position 8) ==
select cast(interval '10.123' second as decimal(1, 0))
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Both of your cases are the same, actually - total number of digits is greater than 3 or 4.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is different. My examples test sucess. Yours tests a failure. Obviously any number>= 10 and <100 cannot be even approximated in a single digit.
But any number between 0 and 9 can be approximated in a digit.
I’m curious why you see truncation rather than rounding given that decimal to decimal rounds.

(Duration.ZERO, DayTimeIntervalType(DAY), DecimalType(10, 3)) -> Decimal(0, 10, 3),
(Duration.ofHours(-1), DayTimeIntervalType(HOUR), DecimalType(10, 1)) -> Decimal(-10, 10, 1),
(Duration.ofMinutes(1), DayTimeIntervalType(MINUTE), DecimalType(8, 2)) -> Decimal(100, 8, 2),
(Duration.ofSeconds(59), DayTimeIntervalType(SECOND), DecimalType(6, 0)) -> Decimal(59, 6, 0),
(Duration.ofSeconds(-60).minusMillis(1), DayTimeIntervalType(SECOND),
DecimalType(10, 3)) -> Decimal(-60.001, 10, 3),
(Duration.ZERO, DayTimeIntervalType(DAY, SECOND), DecimalType(10, 6)) -> Decimal(0, 10, 6),
(Duration.ofHours(-23).minusMinutes(59).minusSeconds(59).minusNanos(123456000),
DayTimeIntervalType(HOUR, SECOND), DecimalType(18, 6)) -> Decimal(-86399.123456, 18, 6),
(Period.ZERO, YearMonthIntervalType(YEAR), DecimalType(5, 2)) -> Decimal(0, 5, 2),
(Period.ofMonths(-1), YearMonthIntervalType(MONTH),
DecimalType(8, 0)) -> Decimal(-1, 8, 0),
(Period.ofYears(-1).minusMonths(1), YearMonthIntervalType(YEAR, MONTH),
DecimalType(8, 3)) -> Decimal(-13000, 8, 3)
).foreach { case ((duration, intervalType, targetType), expected) =>
checkEvaluation(
Cast(Literal.create(duration, intervalType), targetType),
expected)
}

dayTimeIntervalTypes.foreach { it =>
checkConsistencyBetweenInterpretedAndCodegenAllowingException((child: Expression) =>
Cast(child, DecimalType.USER_DEFAULT), it)
}

yearMonthIntervalTypes.foreach { it =>
checkConsistencyBetweenInterpretedAndCodegenAllowingException((child: Expression) =>
Cast(child, DecimalType.USER_DEFAULT), it)
}
}
}
10 changes: 10 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/cast.sql
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,13 @@ select cast(interval '10' day as bigint);

select cast(interval '-1000' month as tinyint);
select cast(interval '1000000' second as smallint);

-- cast ANSI intervals to decimals
select cast(interval '-1' year as decimal(10, 0));
select cast(interval '1.000001' second as decimal(10, 6));
select cast(interval '08:11:10.001' hour to second as decimal(10, 4));
select cast(interval '1 01:02:03.1' day to second as decimal(8, 1));
select cast(interval '10.123' second as decimal(4, 2));
select cast(interval '10.005' second as decimal(4, 2));
select cast(interval '10.123' second as decimal(5, 2));
select cast(interval '10.123' second as decimal(1, 0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about decimal(4, 2)? will we fail or truncate?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a check. The result is truncated. Please, take a look.

68 changes: 68 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -838,3 +838,71 @@ struct<>
-- !query output
org.apache.spark.SparkArithmeticException
[CAST_OVERFLOW] The value INTERVAL '1000000' SECOND of the type "INTERVAL SECOND" cannot be cast to "SMALLINT" due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.


-- !query
select cast(interval '-1' year as decimal(10, 0))
-- !query schema
struct<CAST(INTERVAL '-1' YEAR AS DECIMAL(10,0)):decimal(10,0)>
-- !query output
-1


-- !query
select cast(interval '1.000001' second as decimal(10, 6))
-- !query schema
struct<CAST(INTERVAL '01.000001' SECOND AS DECIMAL(10,6)):decimal(10,6)>
-- !query output
1.000001


-- !query
select cast(interval '08:11:10.001' hour to second as decimal(10, 4))
-- !query schema
struct<CAST(INTERVAL '08:11:10.001' HOUR TO SECOND AS DECIMAL(10,4)):decimal(10,4)>
-- !query output
29470.0010


-- !query
select cast(interval '1 01:02:03.1' day to second as decimal(8, 1))
-- !query schema
struct<CAST(INTERVAL '1 01:02:03.1' DAY TO SECOND AS DECIMAL(8,1)):decimal(8,1)>
-- !query output
90123.1


-- !query
select cast(interval '10.123' second as decimal(4, 2))
-- !query schema
struct<CAST(INTERVAL '10.123' SECOND AS DECIMAL(4,2)):decimal(4,2)>
-- !query output
10.12


-- !query
select cast(interval '10.005' second as decimal(4, 2))
-- !query schema
struct<CAST(INTERVAL '10.005' SECOND AS DECIMAL(4,2)):decimal(4,2)>
-- !query output
10.01
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@srielau I think the rounding behavior is correct.



-- !query
select cast(interval '10.123' second as decimal(5, 2))
-- !query schema
struct<CAST(INTERVAL '10.123' SECOND AS DECIMAL(5,2)):decimal(5,2)>
-- !query output
10.12


-- !query
select cast(interval '10.123' second as decimal(1, 0))
-- !query schema
struct<>
-- !query output
org.apache.spark.SparkArithmeticException
[CANNOT_CHANGE_DECIMAL_PRECISION] Decimal(compact, 10, 18, 6) cannot be represented as Decimal(1, 0). If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
== SQL(line 1, position 8) ==
select cast(interval '10.123' second as decimal(1, 0))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
65 changes: 65 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/cast.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -666,3 +666,68 @@ struct<>
-- !query output
org.apache.spark.SparkArithmeticException
[CAST_OVERFLOW] The value INTERVAL '1000000' SECOND of the type "INTERVAL SECOND" cannot be cast to "SMALLINT" due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.


-- !query
select cast(interval '-1' year as decimal(10, 0))
-- !query schema
struct<CAST(INTERVAL '-1' YEAR AS DECIMAL(10,0)):decimal(10,0)>
-- !query output
-1


-- !query
select cast(interval '1.000001' second as decimal(10, 6))
-- !query schema
struct<CAST(INTERVAL '01.000001' SECOND AS DECIMAL(10,6)):decimal(10,6)>
-- !query output
1.000001


-- !query
select cast(interval '08:11:10.001' hour to second as decimal(10, 4))
-- !query schema
struct<CAST(INTERVAL '08:11:10.001' HOUR TO SECOND AS DECIMAL(10,4)):decimal(10,4)>
-- !query output
29470.0010


-- !query
select cast(interval '1 01:02:03.1' day to second as decimal(8, 1))
-- !query schema
struct<CAST(INTERVAL '1 01:02:03.1' DAY TO SECOND AS DECIMAL(8,1)):decimal(8,1)>
-- !query output
90123.1


-- !query
select cast(interval '10.123' second as decimal(4, 2))
-- !query schema
struct<CAST(INTERVAL '10.123' SECOND AS DECIMAL(4,2)):decimal(4,2)>
-- !query output
10.12


-- !query
select cast(interval '10.005' second as decimal(4, 2))
-- !query schema
struct<CAST(INTERVAL '10.005' SECOND AS DECIMAL(4,2)):decimal(4,2)>
-- !query output
10.01


-- !query
select cast(interval '10.123' second as decimal(5, 2))
-- !query schema
struct<CAST(INTERVAL '10.123' SECOND AS DECIMAL(5,2)):decimal(5,2)>
-- !query output
10.12


-- !query
select cast(interval '10.123' second as decimal(1, 0))
-- !query schema
struct<>
-- !query output
org.apache.spark.SparkArithmeticException
[CANNOT_CHANGE_DECIMAL_PRECISION] Decimal(compact, 10, 18, 6) cannot be represented as Decimal(1, 0). If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
Expand Down