Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -519,15 +519,15 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
* 1. if both side are interval, stays the same;
* 2. else if one side is date and the other is interval,
* turns it to [[DateAddInterval]];
* 3. else if one side is interval, turns it to [[TimeAdd]];
* 3. else if one side is interval, turns it to [[TimestampAddInterval]];
* 4. else if one side is date, turns it to [[DateAdd]] ;
* 5. else stays the same.
*
* For [[Subtract]]:
* 1. if both side are interval, stays the same;
* 2. else if the left side is date and the right side is interval,
* turns it to [[DateAddInterval(l, -r)]];
* 3. else if the right side is an interval, turns it to [[TimeAdd(l, -r)]];
* 3. else if the right side is an interval, turns it to [[TimestampAddInterval(l, -r)]];
* 4. else if one side is timestamp, turns it to [[SubtractTimestamps]];
* 5. else if the right side is date, turns it to [[DateDiff]]/[[SubtractDates]];
* 6. else if the left side is date, turns it to [[DateSub]];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{
Literal,
SubtractDates,
SubtractTimestamps,
TimeAdd,
TimestampAddInterval,
UnaryMinus,
UnaryPositive
}
Expand Down Expand Up @@ -77,7 +77,7 @@ object AnsiStringPromotionTypeCoercion {
s.copy(left = Cast(s.left, DateType))
case s @ SubtractDates(_, right @ StringTypeExpression(), _) =>
s.copy(right = Cast(s.right, DateType))
case t @ TimeAdd(left @ StringTypeExpression(), _, _) =>
case t @ TimestampAddInterval(left @ StringTypeExpression(), _, _) =>
t.copy(start = Cast(t.start, TimestampType))
case t @ SubtractTimestamps(left @ StringTypeExpression(), _, _, _) =>
t.copy(left = Cast(t.left, t.right.dataType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.{
Subtract,
SubtractDates,
SubtractTimestamps,
TimeAdd,
TimestampAddInterval,
TimestampAddYMInterval,
UnaryMinus
}
Expand All @@ -62,9 +62,9 @@ object BinaryArithmeticWithDatetimeResolver {
case a @ Add(l, r, mode) =>
(l.dataType, r.dataType) match {
case (DateType, DayTimeIntervalType(DAY, DAY)) => DateAdd(l, ExtractANSIIntervalDays(r))
case (DateType, _: DayTimeIntervalType) => TimeAdd(Cast(l, TimestampType), r)
case (DateType, _: DayTimeIntervalType) => TimestampAddInterval(Cast(l, TimestampType), r)
case (DayTimeIntervalType(DAY, DAY), DateType) => DateAdd(r, ExtractANSIIntervalDays(l))
case (_: DayTimeIntervalType, DateType) => TimeAdd(Cast(r, TimestampType), l)
case (_: DayTimeIntervalType, DateType) => TimestampAddInterval(Cast(r, TimestampType), l)
case (DateType, _: YearMonthIntervalType) => DateAddYMInterval(l, r)
case (_: YearMonthIntervalType, DateType) => DateAddYMInterval(r, l)
case (TimestampType | TimestampNTZType, _: YearMonthIntervalType) =>
Expand All @@ -80,10 +80,12 @@ object BinaryArithmeticWithDatetimeResolver {
a.copy(right = Cast(a.right, a.left.dataType))
case (DateType, CalendarIntervalType) =>
DateAddInterval(l, r, ansiEnabled = mode == EvalMode.ANSI)
case (_, CalendarIntervalType | _: DayTimeIntervalType) => Cast(TimeAdd(l, r), l.dataType)
case (_, CalendarIntervalType | _: DayTimeIntervalType) =>
Cast(TimestampAddInterval(l, r), l.dataType)
case (CalendarIntervalType, DateType) =>
DateAddInterval(r, l, ansiEnabled = mode == EvalMode.ANSI)
case (CalendarIntervalType | _: DayTimeIntervalType, _) => Cast(TimeAdd(r, l), r.dataType)
case (CalendarIntervalType | _: DayTimeIntervalType, _) =>
Cast(TimestampAddInterval(r, l), r.dataType)
case (DateType, dt) if dt != StringType => DateAdd(l, r)
case (dt, DateType) if dt != StringType => DateAdd(r, l)
case _ => a
Expand All @@ -93,7 +95,8 @@ object BinaryArithmeticWithDatetimeResolver {
case (DateType, DayTimeIntervalType(DAY, DAY)) =>
DateAdd(l, UnaryMinus(ExtractANSIIntervalDays(r), mode == EvalMode.ANSI))
case (DateType, _: DayTimeIntervalType) =>
DatetimeSub(l, r, TimeAdd(Cast(l, TimestampType), UnaryMinus(r, mode == EvalMode.ANSI)))
DatetimeSub(l, r,
TimestampAddInterval(Cast(l, TimestampType), UnaryMinus(r, mode == EvalMode.ANSI)))
case (DateType, _: YearMonthIntervalType) =>
DatetimeSub(l, r, DateAddYMInterval(l, UnaryMinus(r, mode == EvalMode.ANSI)))
case (TimestampType | TimestampNTZType, _: YearMonthIntervalType) =>
Expand All @@ -116,7 +119,8 @@ object BinaryArithmeticWithDatetimeResolver {
)
)
case (_, CalendarIntervalType | _: DayTimeIntervalType) =>
Cast(DatetimeSub(l, r, TimeAdd(l, UnaryMinus(r, mode == EvalMode.ANSI))), l.dataType)
Cast(DatetimeSub(l, r,
TimestampAddInterval(l, UnaryMinus(r, mode == EvalMode.ANSI))), l.dataType)
case _
if AnyTimestampTypeExpression.unapply(l) ||
AnyTimestampTypeExpression.unapply(r) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ object StreamingJoinHelper extends PredicateHelper with Logging {
collect(left, negate) ++ collect(right, negate)
case Subtract(left, right, _) =>
collect(left, negate) ++ collect(right, !negate)
case TimeAdd(left, right, _) =>
case TimestampAddInterval(left, right, _) =>
collect(left, negate) ++ collect(right, negate)
case DatetimeSub(_, _, child) => collect(child, negate)
case UnaryMinus(child, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ import org.apache.spark.sql.catalyst.expressions.{
SpecialFrameBoundary,
SpecifiedWindowFrame,
SubtractTimestamps,
TimeAdd,
TimestampAddInterval,
WindowSpecDefinition
}
import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum}
Expand Down Expand Up @@ -700,7 +700,8 @@ abstract class TypeCoercionHelper {
val newRight = castIfNotSameType(s.right, TimestampNTZType)
s.copy(left = newLeft, right = newRight)

case t @ TimeAdd(StringTypeExpression(), _, _) => t.copy(start = Cast(t.start, TimestampType))
case t @ TimestampAddInterval(StringTypeExpression(), _, _) =>
t.copy(start = Cast(t.start, TimestampType))

case other => other
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ import org.apache.spark.sql.types.{DateType, StringType}
*
* Cast(
* DatetimeSub(
* TimeAdd(
* TimestampAddInterval(
* Literal('4 11:11', StringType),
* UnaryMinus(
* Literal(Interval('4 22:12' DAY TO MINUTE), DayTimeIntervalType(0,2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class ExpressionResolver(
aggregateExpressionResolver,
binaryArithmeticResolver
)
private val timeAddResolver = new TimeAddResolver(this)
private val timestampAddResolver = new TimestampAddResolver(this)
private val unaryMinusResolver = new UnaryMinusResolver(this)
private val subqueryExpressionResolver = new SubqueryExpressionResolver(this, resolver)
private val ordinalResolver = new OrdinalResolver(this)
Expand Down Expand Up @@ -262,8 +262,8 @@ class ExpressionResolver(
subqueryExpressionResolver.resolveScalarSubquery(unresolvedScalarSubquery)
case unresolvedListQuery: ListQuery =>
subqueryExpressionResolver.resolveListQuery(unresolvedListQuery)
case unresolvedTimeAdd: TimeAdd =>
timeAddResolver.resolve(unresolvedTimeAdd)
case unresolvedTimestampAdd: TimestampAddInterval =>
timestampAddResolver.resolve(unresolvedTimestampAdd)
case unresolvedUnaryMinus: UnaryMinus =>
unaryMinusResolver.resolve(unresolvedUnaryMinus)
case createNamedStruct: CreateNamedStruct =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,38 +23,39 @@ import org.apache.spark.sql.catalyst.analysis.{
StringPromotionTypeCoercion,
TypeCoercion
}
import org.apache.spark.sql.catalyst.expressions.{Expression, TimeAdd}
import org.apache.spark.sql.catalyst.expressions.{Expression, TimestampAddInterval}

/**
* Helper resolver for [[TimeAdd]] which is produced by resolving [[BinaryArithmetic]] nodes.
* Helper resolver for [[TimestampAddInterval]] which is produced by resolving [[BinaryArithmetic]]
* nodes.
*/
class TimeAddResolver(expressionResolver: ExpressionResolver)
extends TreeNodeResolver[TimeAdd, Expression]
class TimestampAddResolver(expressionResolver: ExpressionResolver)
extends TreeNodeResolver[TimestampAddInterval, Expression]
with ResolvesExpressionChildren
with CoercesExpressionTypes {

private val traversals = expressionResolver.getExpressionTreeTraversals

protected override val ansiTransformations: CoercesExpressionTypes.Transformations =
TimeAddResolver.ANSI_TYPE_COERCION_TRANSFORMATIONS
TimestampAddResolver.ANSI_TYPE_COERCION_TRANSFORMATIONS
protected override val nonAnsiTransformations: CoercesExpressionTypes.Transformations =
TimeAddResolver.TYPE_COERCION_TRANSFORMATIONS
TimestampAddResolver.TYPE_COERCION_TRANSFORMATIONS

override def resolve(unresolvedTimeAdd: TimeAdd): Expression = {
val timeAddWithResolvedChildren =
withResolvedChildren(unresolvedTimeAdd, expressionResolver.resolve _)
val timeAddWithTypeCoercion: Expression = coerceExpressionTypes(
expression = timeAddWithResolvedChildren,
override def resolve(unresolvedTimestampAdd: TimestampAddInterval): Expression = {
val timestampAddWithResolvedChildren =
withResolvedChildren(unresolvedTimestampAdd, expressionResolver.resolve _)
val timestampAddWithTypeCoercion: Expression = coerceExpressionTypes(
expression = timestampAddWithResolvedChildren,
expressionTreeTraversal = traversals.current
)
TimezoneAwareExpressionResolver.resolveTimezone(
timeAddWithTypeCoercion,
timestampAddWithTypeCoercion,
traversals.current.sessionLocalTimeZone
)
}
}

object TimeAddResolver {
object TimestampAddResolver {
// Ordering in the list of type coercions should be in sync with the list in [[TypeCoercion]].
private val TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq(
StringPromotionTypeCoercion.apply,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1647,7 +1647,10 @@ case class NextDay(
/**
* Adds an interval to timestamp.
*/
case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[String] = None)
case class TimestampAddInterval(
start: Expression,
interval: Expression,
timeZoneId: Option[String] = None)
extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes {
override def nullIntolerant: Boolean = true

Expand Down Expand Up @@ -1690,7 +1693,7 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S
}

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): TimeAdd =
newLeft: Expression, newRight: Expression): TimestampAddInterval =
copy(start = newLeft, interval = newRight)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ object SupportedBinaryExpr {
case _: BinaryArithmetic => Some(expr, expr.children.head, expr.children.last)
case _: BinaryMathExpression => Some(expr, expr.children.head, expr.children.last)
case _: AddMonths | _: DateAdd | _: DateAddInterval | _: DateDiff | _: DateSub |
_: DateAddYMInterval | _: TimestampAddYMInterval | _: TimeAdd =>
_: DateAddYMInterval | _: TimestampAddYMInterval | _: TimestampAddInterval =>
Some(expr, expr.children.head, expr.children.last)
case _: FindInSet | _: RoundBase => Some(expr, expr.children.head, expr.children.last)
case BinaryPredicate(expr) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.{
GreaterThan,
Literal,
NamedExpression,
TimeAdd
TimestampAddInterval
}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.types.{
Expand Down Expand Up @@ -188,7 +188,7 @@ class ResolutionValidatorSuite extends SparkFunSuite with SQLConfHelper {
Project(
projectList = Seq(
Alias(
child = TimeAdd(
child = TimestampAddInterval(
start = Cast(
child = Literal("2024-10-01"),
dataType = TimestampType,
Expand All @@ -205,7 +205,7 @@ class ResolutionValidatorSuite extends SparkFunSuite with SQLConfHelper {
),
child = LocalRelation(output = colInteger)
),
error = Some("TimezoneId is not set for TimeAdd")
error = Some("TimezoneId is not set for TimestampAddInterval")
)
}

Expand Down
Loading