Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ import org.apache.spark.sql.types.DayTimeIntervalType.DAY

object BinaryArithmeticWithDatetimeResolver {
def resolve(expr: Expression): Expression = expr match {
case a @ Add(l, r, mode) =>
case a @ Add(l, r, context) =>
(l.dataType, r.dataType) match {
case (DateType, DayTimeIntervalType(DAY, DAY)) => DateAdd(l, ExtractANSIIntervalDays(r))
case (DateType, _: DayTimeIntervalType) => TimestampAddInterval(Cast(l, TimestampType), r)
Expand All @@ -83,7 +83,7 @@ object BinaryArithmeticWithDatetimeResolver {
case (_: AnsiIntervalType, _: NullType) =>
a.copy(right = Cast(a.right, a.left.dataType))
case (DateType, CalendarIntervalType) =>
DateAddInterval(l, r, ansiEnabled = mode == EvalMode.ANSI)
DateAddInterval(l, r, ansiEnabled = context.evalMode == EvalMode.ANSI)
case (_: TimeType, _: DayTimeIntervalType) => TimeAddInterval(l, r)
case (_: DatetimeType, _: NullType) =>
a.copy(right = Cast(a.right, DayTimeIntervalType.DEFAULT))
Expand All @@ -93,24 +93,26 @@ object BinaryArithmeticWithDatetimeResolver {
case (_, CalendarIntervalType | _: DayTimeIntervalType) =>
Cast(TimestampAddInterval(l, r), l.dataType)
case (CalendarIntervalType, DateType) =>
DateAddInterval(r, l, ansiEnabled = mode == EvalMode.ANSI)
DateAddInterval(r, l, ansiEnabled = context.evalMode == EvalMode.ANSI)
case (CalendarIntervalType | _: DayTimeIntervalType, _) =>
Cast(TimestampAddInterval(r, l), r.dataType)
case (DateType, dt) if dt != StringType => DateAdd(l, r)
case (dt, DateType) if dt != StringType => DateAdd(r, l)
case _ => a
}
case s @ Subtract(l, r, mode) =>
case s @ Subtract(l, r, context) =>
(l.dataType, r.dataType) match {
case (DateType, DayTimeIntervalType(DAY, DAY)) =>
DateAdd(l, UnaryMinus(ExtractANSIIntervalDays(r), mode == EvalMode.ANSI))
DateAdd(l, UnaryMinus(ExtractANSIIntervalDays(r), context.evalMode == EvalMode.ANSI))
case (DateType, _: DayTimeIntervalType) =>
DatetimeSub(l, r,
TimestampAddInterval(Cast(l, TimestampType), UnaryMinus(r, mode == EvalMode.ANSI)))
TimestampAddInterval(Cast(l, TimestampType),
UnaryMinus(r, context.evalMode == EvalMode.ANSI)))
case (DateType, _: YearMonthIntervalType) =>
DatetimeSub(l, r, DateAddYMInterval(l, UnaryMinus(r, mode == EvalMode.ANSI)))
DatetimeSub(l, r, DateAddYMInterval(l, UnaryMinus(r, context.evalMode == EvalMode.ANSI)))
case (TimestampType | TimestampNTZType, _: YearMonthIntervalType) =>
DatetimeSub(l, r, TimestampAddYMInterval(l, UnaryMinus(r, mode == EvalMode.ANSI)))
DatetimeSub(l, r, TimestampAddYMInterval(l,
UnaryMinus(r, context.evalMode == EvalMode.ANSI)))
case (CalendarIntervalType, CalendarIntervalType) |
(_: DayTimeIntervalType, _: DayTimeIntervalType) =>
s
Expand All @@ -124,15 +126,15 @@ object BinaryArithmeticWithDatetimeResolver {
r,
DateAddInterval(
l,
UnaryMinus(r, mode == EvalMode.ANSI),
ansiEnabled = mode == EvalMode.ANSI
UnaryMinus(r, context.evalMode == EvalMode.ANSI),
ansiEnabled = context.evalMode == EvalMode.ANSI
)
)
case (_: TimeType, _: DayTimeIntervalType) =>
DatetimeSub(l, r, TimeAddInterval(l, UnaryMinus(r, mode == EvalMode.ANSI)))
DatetimeSub(l, r, TimeAddInterval(l, UnaryMinus(r, context.evalMode == EvalMode.ANSI)))
case (_, CalendarIntervalType | _: DayTimeIntervalType) =>
Cast(DatetimeSub(l, r,
TimestampAddInterval(l, UnaryMinus(r, mode == EvalMode.ANSI))), l.dataType)
TimestampAddInterval(l, UnaryMinus(r, context.evalMode == EvalMode.ANSI))), l.dataType)
case _
if AnyTimestampTypeExpression.unapply(l) ||
AnyTimestampTypeExpression.unapply(r) =>
Expand All @@ -142,19 +144,19 @@ object BinaryArithmeticWithDatetimeResolver {
case (_: TimeType, _: TimeType) => SubtractTimes(l, r)
case _ => s
}
case m @ Multiply(l, r, mode) =>
case m @ Multiply(l, r, context) =>
(l.dataType, r.dataType) match {
case (CalendarIntervalType, _) => MultiplyInterval(l, r, mode == EvalMode.ANSI)
case (_, CalendarIntervalType) => MultiplyInterval(r, l, mode == EvalMode.ANSI)
case (CalendarIntervalType, _) => MultiplyInterval(l, r, context.evalMode == EvalMode.ANSI)
case (_, CalendarIntervalType) => MultiplyInterval(r, l, context.evalMode == EvalMode.ANSI)
case (_: YearMonthIntervalType, _) => MultiplyYMInterval(l, r)
case (_, _: YearMonthIntervalType) => MultiplyYMInterval(r, l)
case (_: DayTimeIntervalType, _) => MultiplyDTInterval(l, r)
case (_, _: DayTimeIntervalType) => MultiplyDTInterval(r, l)
case _ => m
}
case d @ Divide(l, r, mode) =>
case d @ Divide(l, r, context) =>
(l.dataType, r.dataType) match {
case (CalendarIntervalType, _) => DivideInterval(l, r, mode == EvalMode.ANSI)
case (CalendarIntervalType, _) => DivideInterval(l, r, context.evalMode == EvalMode.ANSI)
case (_: YearMonthIntervalType, _) => DivideYMInterval(l, r)
case (_: DayTimeIntervalType, _) => DivideDTInterval(l, r)
case _ => d
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1423,13 +1423,13 @@ trait CommutativeExpression extends Expression {
protected def buildCanonicalizedPlan(
collectOperands: PartialFunction[Expression, Seq[Expression]],
buildBinaryOp: (Expression, Expression) => Expression,
evalMode: Option[EvalMode.Value] = None): Expression = {
evalContext: Option[NumericEvalContext] = None): Expression = {
val operands = orderCommutative(collectOperands)
val reorderResult =
if (operands.length < SQLConf.get.getConf(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD)) {
operands.reduce(buildBinaryOp)
} else {
MultiCommutativeOp(operands, this.getClass, evalMode)(this)
MultiCommutativeOp(operands, this.getClass, evalContext)(this)
}
reorderResult
}
Expand All @@ -1446,15 +1446,15 @@ trait CommutativeExpression extends Expression {
* Add, Multiply, And, Or, BitwiseAnd, BitwiseOr, BitwiseXor.
* @param operands A sequence of operands that produces a commutative expression tree.
* @param opCls The class of the root operator of the expression tree.
* @param evalMode The optional expression evaluation mode.
* @param evalContext The optional expression evaluation context.
* @param originalRoot Root operator of the commutative expression tree before canonicalization.
* This object reference is used to deduce the return dataType of Add and
* Multiply operations when the input datatype is decimal.
*/
case class MultiCommutativeOp(
operands: Seq[Expression],
opCls: Class[_],
evalMode: Option[EvalMode.Value])(originalRoot: Expression) extends Unevaluable {
evalContext: Option[NumericEvalContext])(originalRoot: Expression) extends Unevaluable {
// Helper method to deduce the data type of a single operation.
private def singleOpDataType(lType: DataType, rType: DataType): DataType = {
originalRoot match {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.internal.SQLConf

/**
* Encapsulates the evaluation context for expressions, capturing SQL configuration
* state at expression construction time.
*
* This context must be stored as part of the expression's state to ensure deterministic
* evaluation. Without it, copying an expression or evaluating it in a different context
* (e.g., inside a view) could produce different results due to changed SQL configuration
* values.
*
* @param evalMode The error handling mode (LEGACY, ANSI, or TRY) that determines
* overflow behavior and exception handling for operations like
* arithmetic and casts.
* @param allowDecimalPrecisionLoss Whether decimal operations are allowed to lose precision
* when the result type cannot represent the full precision.
* Corresponds to
* spark.sql.decimalOperations.allowPrecisionLoss.
*/
case class NumericEvalContext private(
evalMode: EvalMode.Value,
allowDecimalPrecisionLoss: Boolean
)

case object NumericEvalContext {

def apply(
evalMode: EvalMode.Value,
allowDecimalPrecisionLoss: Boolean = SQLConf.get.decimalOperationsAllowPrecisionLoss
): NumericEvalContext = {
new NumericEvalContext(evalMode, allowDecimalPrecisionLoss)
}

def fromSQLConf(conf: SQLConf): NumericEvalContext = {
NumericEvalContext(
EvalMode.fromSQLConf(conf),
conf.decimalOperationsAllowPrecisionLoss)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,19 @@ import org.apache.spark.sql.types._
since = "1.0.0")
case class Sum(
child: Expression,
evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get))
evalContext: NumericEvalContext = NumericEvalContext.fromSQLConf(SQLConf.get))
extends DeclarativeAggregate
with ImplicitCastInputTypes
with UnaryLike[Expression]
with SupportQueryContext {

def this(child: Expression) = this(child, EvalMode.fromSQLConf(SQLConf.get))
def this(child: Expression) = this(child, NumericEvalContext.fromSQLConf(SQLConf.get))

private def shouldTrackIsEmpty: Boolean = resultType match {
case _: DecimalType => true
// For try_sum(), the result of following data types can be null on overflow.
// Thus we need additional buffer to keep track of whether overflow happens.
case _: IntegralType | _: AnsiIntervalType if evalMode == EvalMode.TRY => true
case _: IntegralType | _: AnsiIntervalType if evalContext.evalMode == EvalMode.TRY => true
case _ => false
}

Expand Down Expand Up @@ -89,7 +89,7 @@ case class Sum(

private def add(left: Expression, right: Expression): Expression = left.dataType match {
case _: DecimalType => DecimalAddNoOverflowCheck(left, right, left.dataType)
case _ => Add(left, right, evalMode)
case _ => Add(left, right, evalContext)
}

override lazy val aggBufferAttributes = if (shouldTrackIsEmpty) {
Expand Down Expand Up @@ -176,7 +176,7 @@ case class Sum(
resultType match {
case d: DecimalType =>
val checkOverflowInSum =
CheckOverflowInSum(sum, d, evalMode != EvalMode.ANSI, getContextOrNull())
CheckOverflowInSum(sum, d, evalContext.evalMode != EvalMode.ANSI, getContextOrNull())
If(isEmpty, Literal.create(null, resultType), checkOverflowInSum)
case _ if shouldTrackIsEmpty =>
If(isEmpty, Literal.create(null, resultType), sum)
Expand All @@ -187,11 +187,12 @@ case class Sum(
// The flag `evalMode` won't be shown in the `toString` or `toAggString` methods
override def flatArguments: Iterator[Any] = Iterator(child)

override def initQueryContext(): Option[QueryContext] = if (evalMode == EvalMode.ANSI) {
Some(origin.context)
} else {
None
}
override def initQueryContext(): Option[QueryContext] =
if (evalContext.evalMode == EvalMode.ANSI) {
Some(origin.context)
} else {
None
}

override protected def withNewChildInternal(newChild: Expression): Expression =
copy(child = newChild)
Expand All @@ -218,7 +219,7 @@ object TrySumExpressionBuilder extends ExpressionBuilder {
override def build(funcName: String, expressions: Seq[Expression]): Expression = {
val numArgs = expressions.length
if (numArgs == 1) {
Sum(expressions.head, EvalMode.TRY)
Sum(expressions.head, NumericEvalContext(EvalMode.TRY))
} else {
throw QueryCompilationErrors.wrongNumArgsError(funcName, Seq(1, 2), numArgs)
}
Expand Down
Loading