Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,12 @@ object TypeCoercion {
SubtractTimestamps(l, Cast(r, TimestampType))
case Subtract(l @ DateType(), r @ TimestampType()) =>
SubtractTimestamps(Cast(l, TimestampType), r)
case Divide(l @ CalendarIntervalType(), r) => IntervalDivide(l, Cast(r, DecimalType(28, 9)))
case Multiply(l @ CalendarIntervalType(), r) =>
IntervalMultiply(l, Cast(r, DecimalType(28, 9)))
case Multiply(l, r @ CalendarIntervalType()) =>
IntervalMultiply(r, Cast(l, DecimalType(28, 9)))

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -2164,3 +2164,66 @@ case class SubtractDates(left: Expression, right: Expression)
}
}

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "expr1 _FUNC_ expr2 - Divide interval value `expr1` by `expr2`. It returns NULL if `expr2` is 0 or NULL.",
examples = """
Examples:
> SELECT interval '1 year 2 month' / 3.0;
interval 4 months 2 weeks 6 days
""",
since = "3.0.0")
// scalastyle:on line.size.limit
case class IntervalDivide(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {

override def dataType: DataType = CalendarIntervalType

override def inputTypes: Seq[AbstractDataType] = Seq(CalendarIntervalType, DecimalType)

override def nullSafeEval(interval: Any, divisor: Any): Any = {
IntervalUtils.divide(interval.asInstanceOf[CalendarInterval],
divisor.asInstanceOf[Decimal])
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (interval, divisor) => {
val iu = IntervalUtils.getClass.getName.stripSuffix("$")
s"$iu.divide($interval, $divisor)"
})
}
}

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "expr1 _FUNC_ expr2 - Multiply interval value `expr1` by `expr2`. It returns NULL if `expr2` is 0 or NULL.",
examples = """
Examples:
> SELECT interval '4 months 2 weeks 6 days' * 3.0;
interval 1 years 8 weeks 4 days
> SELECT 3.0 * interval '4 months 2 weeks 6 days';
interval 1 years 8 weeks 4 days
""",
since = "3.0.0")
// scalastyle:on line.size.limit
case class IntervalMultiply(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {

override def dataType: DataType = CalendarIntervalType

override def inputTypes: Seq[AbstractDataType] = Seq(CalendarIntervalType, DecimalType)

override def nullSafeEval(interval: Any, divisor: Any): Any = {
IntervalUtils.multiply(interval.asInstanceOf[CalendarInterval],
divisor.asInstanceOf[Decimal])
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (interval, divisor) => {
val iu = IntervalUtils.getClass.getName.stripSuffix("$")
s"$iu.multiply($interval, $divisor)"
})
}
}


Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,22 @@ object IntervalUtils {
Decimal(result, 18, 6)
}

def divide(interval: CalendarInterval, divisor: Decimal): CalendarInterval = {
if (divisor == Decimal.ZERO || divisor == null) return null
val months = Decimal(interval.months) / divisor
val milliseconds = (Decimal(interval.microseconds) / divisor +
months.remainder(Decimal.ONE) * Decimal(MICROS_PER_MONTH)).toLong
new CalendarInterval(months.toInt, milliseconds.toLong)
}

def multiply(interval: CalendarInterval, multiplier: Decimal): CalendarInterval = {
if (multiplier == null) return null
val months = Decimal(interval.months) * multiplier
val milliseconds = (Decimal(interval.microseconds) * multiplier +
months.remainder(Decimal.ONE) * Decimal(MICROS_PER_MONTH)).toLong
new CalendarInterval(months.toInt, milliseconds.toLong)
}

/**
* Converts a string to [[CalendarInterval]] case-insensitively.
*
Expand Down
11 changes: 11 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/datetime.sql
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,14 @@ select date '2001-10-01' - 7;
select date '2001-10-01' - date '2001-09-28';
select date'2020-01-01' - timestamp'2019-10-06 10:11:12.345678';
select timestamp'2019-10-06 10:11:12.345678' - date'2020-01-01';

select interval '1 year 2 month' / null;
select interval '1 year 2 month' / 0;
select interval '1 year 2 month' / 3;
select interval '1 year 2 month' / 3.0;

SELECT interval '4 months 2 weeks 6 days' * null;
SELECT interval '4 months 2 weeks 6 days' * 0;
SELECT interval '4 months 2 weeks 6 days' * 3;
SELECT interval '4 months 2 weeks 6 days' * 3.0;
SELECT 3.0 * interval '4 months 2 weeks 6 days';
74 changes: 73 additions & 1 deletion sql/core/src/test/resources/sql-tests/results/datetime.sql.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 17
-- Number of queries: 26


-- !query 0
Expand Down Expand Up @@ -145,3 +145,75 @@ select timestamp'2019-10-06 10:11:12.345678' - date'2020-01-01'
struct<subtracttimestamps(TIMESTAMP('2019-10-06 10:11:12.345678'), CAST(DATE '2020-01-01' AS TIMESTAMP)):interval>
-- !query 16 output
interval -12 weeks -2 days -14 hours -48 minutes -47 seconds -654 milliseconds -322 microseconds


-- !query 17
select interval '1 year 2 month' / null
-- !query 17 schema
struct<intervaldivide(interval 1 years 2 months, CAST(NULL AS DECIMAL(28,9))):interval>
-- !query 17 output
NULL


-- !query 18
select interval '1 year 2 month' / 0
-- !query 18 schema
struct<intervaldivide(interval 1 years 2 months, CAST(0 AS DECIMAL(28,9))):interval>
-- !query 18 output
NULL


-- !query 19
select interval '1 year 2 month' / 3
-- !query 19 schema
struct<intervaldivide(interval 1 years 2 months, CAST(3 AS DECIMAL(28,9))):interval>
-- !query 19 output
interval 4 months 2 weeks 6 days


-- !query 20
select interval '1 year 2 month' / 3.0
-- !query 20 schema
struct<intervaldivide(interval 1 years 2 months, CAST(3.0 AS DECIMAL(28,9))):interval>
-- !query 20 output
interval 4 months 2 weeks 6 days


-- !query 21
SELECT interval '4 months 2 weeks 6 days' * null
-- !query 21 schema
struct<intervalmultiply(interval 4 months 2 weeks 6 days, CAST(NULL AS DECIMAL(28,9))):interval>
-- !query 21 output
NULL


-- !query 22
SELECT interval '4 months 2 weeks 6 days' * 0
-- !query 22 schema
struct<intervalmultiply(interval 4 months 2 weeks 6 days, CAST(0 AS DECIMAL(28,9))):interval>
-- !query 22 output
interval 0 microseconds


-- !query 23
SELECT interval '4 months 2 weeks 6 days' * 3
-- !query 23 schema
struct<intervalmultiply(interval 4 months 2 weeks 6 days, CAST(3 AS DECIMAL(28,9))):interval>
-- !query 23 output
interval 1 years 8 weeks 4 days


-- !query 24
SELECT interval '4 months 2 weeks 6 days' * 3.0
-- !query 24 schema
struct<intervalmultiply(interval 4 months 2 weeks 6 days, CAST(3.0 AS DECIMAL(28,9))):interval>
-- !query 24 output
interval 1 years 8 weeks 4 days


-- !query 25
SELECT 3.0 * interval '4 months 2 weeks 6 days'
-- !query 25 schema
struct<intervalmultiply(interval 4 months 2 weeks 6 days, CAST(3.0 AS DECIMAL(28,9))):interval>
-- !query 25 output
interval 1 years 8 weeks 4 days