Skip to content

Commit ebc8fa5

Browse files
yaooqinncloud-fan
authored andcommitted
[SPARK-31527][SQL] date add/subtract interval only allow those day precision in ansi mode
### What changes were proposed in this pull request? To follow ANSI,the expressions - `date + interval`, `interval + date` and `date - interval` should only accept intervals which the `microseconds` part is 0. ### Why are the changes needed? Better ANSI compliance ### Does this PR introduce any user-facing change? No, this PR should target 3.0.0 in which this feature is newly added. ### How was this patch tested? add more unit tests Closes #28310 from yaooqinn/SPARK-31527. Authored-by: Kent Yao <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent a911287 commit ebc8fa5

File tree

12 files changed

+888
-21
lines changed

12 files changed

+888
-21
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ class Analyzer(
246246
ResolveLambdaVariables(conf) ::
247247
ResolveTimeZone(conf) ::
248248
ResolveRandomSeed ::
249-
ResolveBinaryArithmetic(conf) ::
249+
ResolveBinaryArithmetic ::
250250
TypeCoercion.typeCoercionRules(conf) ++
251251
extendedResolutionRules : _*),
252252
Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
@@ -268,17 +268,21 @@ class Analyzer(
268268
/**
269269
* For [[Add]]:
270270
* 1. if both side are interval, stays the same;
271-
* 2. else if one side is interval, turns it to [[TimeAdd]];
272-
* 3. else if one side is date, turns it to [[DateAdd]] ;
273-
* 4. else stays the same.
271+
* 2. else if one side is date and the other is interval,
272+
* turns it to [[DateAddInterval]];
273+
* 3. else if one side is interval, turns it to [[TimeAdd]];
274+
* 4. else if one side is date, turns it to [[DateAdd]] ;
275+
* 5. else stays the same.
274276
*
275277
* For [[Subtract]]:
276278
* 1. if both side are interval, stays the same;
277-
* 2. else if the right side is an interval, turns it to [[TimeSub]];
278-
* 3. else if one side is timestamp, turns it to [[SubtractTimestamps]];
279-
* 4. else if the right side is date, turns it to [[DateDiff]]/[[SubtractDates]];
280-
* 5. else if the left side is date, turns it to [[DateSub]];
281-
* 6. else turns it to stays the same.
279+
* 2. else if the left side is date and the right side is interval,
280+
* turns it to [[DateAddInterval(l, -r)]];
281+
* 3. else if the right side is an interval, turns it to [[TimeSub]];
282+
* 4. else if one side is timestamp, turns it to [[SubtractTimestamps]];
283+
* 5. else if the right side is date, turns it to [[DateDiff]]/[[SubtractDates]];
284+
* 6. else if the left side is date, turns it to [[DateSub]];
285+
* 7. else turns it to stays the same.
282286
*
283287
* For [[Multiply]]:
284288
* 1. If one side is interval, turns it to [[MultiplyInterval]];
@@ -288,19 +292,22 @@ class Analyzer(
288292
* 1. If the left side is interval, turns it to [[DivideInterval]];
289293
* 2. otherwise, stays the same.
290294
*/
291-
case class ResolveBinaryArithmetic(conf: SQLConf) extends Rule[LogicalPlan] {
295+
object ResolveBinaryArithmetic extends Rule[LogicalPlan] {
292296
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
293297
case p: LogicalPlan => p.transformExpressionsUp {
294298
case a @ Add(l, r) if a.childrenResolved => (l.dataType, r.dataType) match {
295299
case (CalendarIntervalType, CalendarIntervalType) => a
300+
case (DateType, CalendarIntervalType) => DateAddInterval(l, r)
296301
case (_, CalendarIntervalType) => Cast(TimeAdd(l, r), l.dataType)
302+
case (CalendarIntervalType, DateType) => DateAddInterval(r, l)
297303
case (CalendarIntervalType, _) => Cast(TimeAdd(r, l), r.dataType)
298304
case (DateType, dt) if dt != StringType => DateAdd(l, r)
299305
case (dt, DateType) if dt != StringType => DateAdd(r, l)
300306
case _ => a
301307
}
302308
case s @ Subtract(l, r) if s.childrenResolved => (l.dataType, r.dataType) match {
303309
case (CalendarIntervalType, CalendarIntervalType) => s
310+
case (DateType, CalendarIntervalType) => DateAddInterval(l, UnaryMinus(r))
304311
case (_, CalendarIntervalType) => Cast(TimeSub(l, r), l.dataType)
305312
case (TimestampType, _) => SubtractTimestamps(l, r)
306313
case (_, TimestampType) => SubtractTimestamps(l, r)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.util.{DateTimeUtils, LegacyDateFormats, Tim
3434
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
3535
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
3636
import org.apache.spark.sql.catalyst.util.LegacyDateFormats.SIMPLE_DATE_FORMAT
37+
import org.apache.spark.sql.internal.SQLConf
3738
import org.apache.spark.sql.types._
3839
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
3940

@@ -1196,6 +1197,68 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S
11961197
}
11971198
}
11981199

1200+
/**
1201+
* Adds date and an interval.
1202+
*
1203+
* When ansi mode is on, the microseconds part of interval needs to be 0, otherwise a runtime
1204+
* [[IllegalArgumentException]] will be raised.
1205+
* When ansi mode is off, if the microseconds part of interval is 0, we perform date + interval
1206+
* for better performance. if the microseconds part is not 0, then the date will be converted to a
1207+
* timestamp to add with the whole interval parts.
1208+
*/
1209+
case class DateAddInterval(
1210+
start: Expression,
1211+
interval: Expression,
1212+
timeZoneId: Option[String] = None,
1213+
ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
1214+
extends BinaryExpression with ExpectsInputTypes with TimeZoneAwareExpression {
1215+
1216+
override def left: Expression = start
1217+
override def right: Expression = interval
1218+
1219+
override def toString: String = s"$left + $right"
1220+
override def sql: String = s"${left.sql} + ${right.sql}"
1221+
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, CalendarIntervalType)
1222+
1223+
override def dataType: DataType = DateType
1224+
1225+
override def nullSafeEval(start: Any, interval: Any): Any = {
1226+
val itvl = interval.asInstanceOf[CalendarInterval]
1227+
if (ansiEnabled || itvl.microseconds == 0) {
1228+
DateTimeUtils.dateAddInterval(start.asInstanceOf[Int], itvl)
1229+
} else {
1230+
val startTs = DateTimeUtils.epochDaysToMicros(start.asInstanceOf[Int], zoneId)
1231+
val resultTs = DateTimeUtils.timestampAddInterval(
1232+
startTs, itvl.months, itvl.days, itvl.microseconds, zoneId)
1233+
DateTimeUtils.microsToEpochDays(resultTs, zoneId)
1234+
}
1235+
}
1236+
1237+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
1238+
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
1239+
nullSafeCodeGen(ctx, ev, (sd, i) => if (ansiEnabled) {
1240+
s"""${ev.value} = $dtu.dateAddInterval($sd, $i);"""
1241+
} else {
1242+
val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
1243+
val startTs = ctx.freshName("startTs")
1244+
val resultTs = ctx.freshName("resultTs")
1245+
s"""
1246+
|if ($i.microseconds == 0) {
1247+
| ${ev.value} = $dtu.dateAddInterval($sd, $i);
1248+
|} else {
1249+
| long $startTs = $dtu.epochDaysToMicros($sd, $zid);
1250+
| long $resultTs =
1251+
| $dtu.timestampAddInterval($startTs, $i.months, $i.days, $i.microseconds, $zid);
1252+
| ${ev.value} = $dtu.microsToEpochDays($resultTs, $zid);
1253+
|}
1254+
|""".stripMargin
1255+
})
1256+
}
1257+
1258+
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
1259+
copy(timeZoneId = Option(timeZoneId))
1260+
}
1261+
11991262
/**
12001263
* This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function
12011264
* takes a timestamp which is timezone-agnostic, and interprets it as a timestamp in UTC, and

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,22 @@ object DateTimeUtils {
618618
instantToMicros(resultTimestamp.toInstant)
619619
}
620620

621+
/**
622+
* Add the date and the interval's months and days.
623+
* Returns a date value, expressed in days since 1.1.1970.
624+
*
625+
* @throws DateTimeException if the result exceeds the supported date range
626+
* @throws IllegalArgumentException if the interval has `microseconds` part
627+
*/
628+
def dateAddInterval(
629+
start: SQLDate,
630+
interval: CalendarInterval): SQLDate = {
631+
require(interval.microseconds == 0,
632+
"Cannot add hours, minutes or seconds, milliseconds, microseconds to a date")
633+
val ld = LocalDate.ofEpochDay(start).plusMonths(interval.months).plusDays(interval.days)
634+
localDateToDays(ld)
635+
}
636+
621637
/**
622638
* Returns number of months between time1 and time2. time1 and time2 are expressed in
623639
* microseconds since 1.1.1970. If time1 is later than time2, the result is positive.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.{SparkFunSuite, SparkUpgradeException}
2727
import org.apache.spark.sql.catalyst.InternalRow
2828
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
2929
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils, TimestampFormatter}
30+
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
3031
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
3132
import org.apache.spark.sql.internal.SQLConf
3233
import org.apache.spark.sql.types._
@@ -358,6 +359,40 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
358359
checkConsistencyBetweenInterpretedAndCodegen(DateAdd, DateType, IntegerType)
359360
}
360361

362+
test("date add interval") {
363+
val d = Date.valueOf("2016-02-28")
364+
Seq("true", "false") foreach { flag =>
365+
withSQLConf((SQLConf.ANSI_ENABLED.key, flag)) {
366+
checkEvaluation(
367+
DateAddInterval(Literal(d), Literal(new CalendarInterval(0, 1, 0))),
368+
DateTimeUtils.fromJavaDate(Date.valueOf("2016-02-29")))
369+
checkEvaluation(
370+
DateAddInterval(Literal(d), Literal(new CalendarInterval(1, 1, 0))),
371+
DateTimeUtils.fromJavaDate(Date.valueOf("2016-03-29")))
372+
checkEvaluation(DateAddInterval(Literal(d), Literal.create(null, CalendarIntervalType)),
373+
null)
374+
checkEvaluation(DateAddInterval(Literal.create(null, DateType),
375+
Literal(new CalendarInterval(1, 1, 0))),
376+
null)
377+
}
378+
}
379+
380+
withSQLConf((SQLConf.ANSI_ENABLED.key, "true")) {
381+
checkExceptionInExpression[IllegalArgumentException](
382+
DateAddInterval(Literal(d), Literal(new CalendarInterval(1, 1, 25 * MICROS_PER_HOUR))),
383+
"Cannot add hours, minutes or seconds, milliseconds, microseconds to a date")
384+
}
385+
386+
withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) {
387+
checkEvaluation(
388+
DateAddInterval(Literal(d), Literal(new CalendarInterval(1, 1, 25))),
389+
DateTimeUtils.fromJavaDate(Date.valueOf("2016-03-29")))
390+
checkEvaluation(
391+
DateAddInterval(Literal(d), Literal(new CalendarInterval(1, 1, 25 * MICROS_PER_HOUR))),
392+
DateTimeUtils.fromJavaDate(Date.valueOf("2016-03-30")))
393+
}
394+
}
395+
361396
test("date_sub") {
362397
checkEvaluation(
363398
DateSub(Literal(Date.valueOf("2015-01-01")), Literal(1.toByte)),

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSQLBuilderSuite.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,5 +176,9 @@ class ExpressionSQLBuilderSuite extends SparkFunSuite {
176176
TimeSub('a, interval),
177177
"`a` - INTERVAL '1 hours'"
178178
)
179+
checkSQL(
180+
DateAddInterval('a, interval),
181+
"`a` + INTERVAL '1 hours'"
182+
)
179183
}
180184
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.SQLHelper
3030
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
3131
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
3232
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
33-
import org.apache.spark.unsafe.types.UTF8String
33+
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
3434

3535
class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper {
3636

@@ -391,6 +391,14 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper {
391391
assert(dateAddMonths(input, -13) === days(1996, 1, 28))
392392
}
393393

394+
test("date add interval with day precision") {
395+
val input = days(1997, 2, 28, 10, 30)
396+
assert(dateAddInterval(input, new CalendarInterval(36, 0, 0)) === days(2000, 2, 28))
397+
assert(dateAddInterval(input, new CalendarInterval(36, 47, 0)) === days(2000, 4, 15))
398+
assert(dateAddInterval(input, new CalendarInterval(-13, 0, 0)) === days(1996, 1, 28))
399+
intercept[IllegalArgumentException](dateAddInterval(input, new CalendarInterval(36, 47, 1)))
400+
}
401+
394402
test("timestamp add months") {
395403
val ts1 = date(1997, 2, 28, 10, 30, 0)
396404
val ts2 = date(2000, 2, 28, 10, 30, 0, 123000)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
--IMPORT datetime.sql

0 commit comments

Comments
 (0)