diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/BinaryArithmeticWithDatetimeResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/BinaryArithmeticWithDatetimeResolver.scala index 08407bbe96cc..cdfd942ca09a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/BinaryArithmeticWithDatetimeResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/BinaryArithmeticWithDatetimeResolver.scala @@ -63,7 +63,7 @@ import org.apache.spark.sql.types.DayTimeIntervalType.DAY object BinaryArithmeticWithDatetimeResolver { def resolve(expr: Expression): Expression = expr match { - case a @ Add(l, r, mode) => + case a @ Add(l, r, context) => (l.dataType, r.dataType) match { case (DateType, DayTimeIntervalType(DAY, DAY)) => DateAdd(l, ExtractANSIIntervalDays(r)) case (DateType, _: DayTimeIntervalType) => TimestampAddInterval(Cast(l, TimestampType), r) @@ -83,7 +83,7 @@ object BinaryArithmeticWithDatetimeResolver { case (_: AnsiIntervalType, _: NullType) => a.copy(right = Cast(a.right, a.left.dataType)) case (DateType, CalendarIntervalType) => - DateAddInterval(l, r, ansiEnabled = mode == EvalMode.ANSI) + DateAddInterval(l, r, ansiEnabled = context.evalMode == EvalMode.ANSI) case (_: TimeType, _: DayTimeIntervalType) => TimeAddInterval(l, r) case (_: DatetimeType, _: NullType) => a.copy(right = Cast(a.right, DayTimeIntervalType.DEFAULT)) @@ -93,24 +93,26 @@ object BinaryArithmeticWithDatetimeResolver { case (_, CalendarIntervalType | _: DayTimeIntervalType) => Cast(TimestampAddInterval(l, r), l.dataType) case (CalendarIntervalType, DateType) => - DateAddInterval(r, l, ansiEnabled = mode == EvalMode.ANSI) + DateAddInterval(r, l, ansiEnabled = context.evalMode == EvalMode.ANSI) case (CalendarIntervalType | _: DayTimeIntervalType, _) => Cast(TimestampAddInterval(r, l), r.dataType) case (DateType, dt) if dt != StringType => DateAdd(l, r) case (dt, DateType) if dt != StringType => DateAdd(r, l) case _ => a } - case s @ Subtract(l, r, mode) => + case s @ Subtract(l, r, context) => (l.dataType, r.dataType) match { case (DateType, DayTimeIntervalType(DAY, DAY)) => - DateAdd(l, UnaryMinus(ExtractANSIIntervalDays(r), mode == EvalMode.ANSI)) + DateAdd(l, UnaryMinus(ExtractANSIIntervalDays(r), context.evalMode == EvalMode.ANSI)) case (DateType, _: DayTimeIntervalType) => DatetimeSub(l, r, - TimestampAddInterval(Cast(l, TimestampType), UnaryMinus(r, mode == EvalMode.ANSI))) + TimestampAddInterval(Cast(l, TimestampType), + UnaryMinus(r, context.evalMode == EvalMode.ANSI))) case (DateType, _: YearMonthIntervalType) => - DatetimeSub(l, r, DateAddYMInterval(l, UnaryMinus(r, mode == EvalMode.ANSI))) + DatetimeSub(l, r, DateAddYMInterval(l, UnaryMinus(r, context.evalMode == EvalMode.ANSI))) case (TimestampType | TimestampNTZType, _: YearMonthIntervalType) => - DatetimeSub(l, r, TimestampAddYMInterval(l, UnaryMinus(r, mode == EvalMode.ANSI))) + DatetimeSub(l, r, TimestampAddYMInterval(l, + UnaryMinus(r, context.evalMode == EvalMode.ANSI))) case (CalendarIntervalType, CalendarIntervalType) | (_: DayTimeIntervalType, _: DayTimeIntervalType) => s @@ -124,15 +126,15 @@ object BinaryArithmeticWithDatetimeResolver { r, DateAddInterval( l, - UnaryMinus(r, mode == EvalMode.ANSI), - ansiEnabled = mode == EvalMode.ANSI + UnaryMinus(r, context.evalMode == EvalMode.ANSI), + ansiEnabled = context.evalMode == EvalMode.ANSI ) ) case (_: TimeType, _: DayTimeIntervalType) => - DatetimeSub(l, r, TimeAddInterval(l, UnaryMinus(r, mode == EvalMode.ANSI))) + DatetimeSub(l, r, TimeAddInterval(l, UnaryMinus(r, context.evalMode == EvalMode.ANSI))) case (_, CalendarIntervalType | _: DayTimeIntervalType) => Cast(DatetimeSub(l, r, - TimestampAddInterval(l, UnaryMinus(r, mode == EvalMode.ANSI))), l.dataType) + TimestampAddInterval(l, UnaryMinus(r, context.evalMode == EvalMode.ANSI))), l.dataType) case _ if AnyTimestampTypeExpression.unapply(l) || AnyTimestampTypeExpression.unapply(r) => @@ -142,19 +144,19 @@ object BinaryArithmeticWithDatetimeResolver { case (_: TimeType, _: TimeType) => SubtractTimes(l, r) case _ => s } - case m @ Multiply(l, r, mode) => + case m @ Multiply(l, r, context) => (l.dataType, r.dataType) match { - case (CalendarIntervalType, _) => MultiplyInterval(l, r, mode == EvalMode.ANSI) - case (_, CalendarIntervalType) => MultiplyInterval(r, l, mode == EvalMode.ANSI) + case (CalendarIntervalType, _) => MultiplyInterval(l, r, context.evalMode == EvalMode.ANSI) + case (_, CalendarIntervalType) => MultiplyInterval(r, l, context.evalMode == EvalMode.ANSI) case (_: YearMonthIntervalType, _) => MultiplyYMInterval(l, r) case (_, _: YearMonthIntervalType) => MultiplyYMInterval(r, l) case (_: DayTimeIntervalType, _) => MultiplyDTInterval(l, r) case (_, _: DayTimeIntervalType) => MultiplyDTInterval(r, l) case _ => m } - case d @ Divide(l, r, mode) => + case d @ Divide(l, r, context) => (l.dataType, r.dataType) match { - case (CalendarIntervalType, _) => DivideInterval(l, r, mode == EvalMode.ANSI) + case (CalendarIntervalType, _) => DivideInterval(l, r, context.evalMode == EvalMode.ANSI) case (_: YearMonthIntervalType, _) => DivideYMInterval(l, r) case (_: DayTimeIntervalType, _) => DivideDTInterval(l, r) case _ => d diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index f706741fc98c..b61f7ee0ee16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -1423,13 +1423,13 @@ trait CommutativeExpression extends Expression { protected def buildCanonicalizedPlan( collectOperands: PartialFunction[Expression, Seq[Expression]], buildBinaryOp: (Expression, Expression) => Expression, - evalMode: Option[EvalMode.Value] = None): Expression = { + evalContext: Option[NumericEvalContext] = None): Expression = { val operands = orderCommutative(collectOperands) val reorderResult = if (operands.length < SQLConf.get.getConf(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD)) { operands.reduce(buildBinaryOp) } else { - MultiCommutativeOp(operands, this.getClass, evalMode)(this) + MultiCommutativeOp(operands, this.getClass, evalContext)(this) } reorderResult } @@ -1446,7 +1446,7 @@ trait CommutativeExpression extends Expression { * Add, Multiply, And, Or, BitwiseAnd, BitwiseOr, BitwiseXor. * @param operands A sequence of operands that produces a commutative expression tree. * @param opCls The class of the root operator of the expression tree. - * @param evalMode The optional expression evaluation mode. + * @param evalContext The optional expression evaluation context. * @param originalRoot Root operator of the commutative expression tree before canonicalization. * This object reference is used to deduce the return dataType of Add and * Multiply operations when the input datatype is decimal. @@ -1454,7 +1454,7 @@ trait CommutativeExpression extends Expression { case class MultiCommutativeOp( operands: Seq[Expression], opCls: Class[_], - evalMode: Option[EvalMode.Value])(originalRoot: Expression) extends Unevaluable { + evalContext: Option[NumericEvalContext])(originalRoot: Expression) extends Unevaluable { // Helper method to deduce the data type of a single operation. private def singleOpDataType(lType: DataType, rType: DataType): DataType = { originalRoot match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/NumericEvalContext.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/NumericEvalContext.scala new file mode 100644 index 000000000000..28ec58d42475 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/NumericEvalContext.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.internal.SQLConf + +/** + * Encapsulates the evaluation context for expressions, capturing SQL configuration + * state at expression construction time. + * + * This context must be stored as part of the expression's state to ensure deterministic + * evaluation. Without it, copying an expression or evaluating it in a different context + * (e.g., inside a view) could produce different results due to changed SQL configuration + * values. + * + * @param evalMode The error handling mode (LEGACY, ANSI, or TRY) that determines + * overflow behavior and exception handling for operations like + * arithmetic and casts. + * @param allowDecimalPrecisionLoss Whether decimal operations are allowed to lose precision + * when the result type cannot represent the full precision. + * Corresponds to + * spark.sql.decimalOperations.allowPrecisionLoss. + */ +case class NumericEvalContext private( + evalMode: EvalMode.Value, + allowDecimalPrecisionLoss: Boolean +) + +case object NumericEvalContext { + + def apply( + evalMode: EvalMode.Value, + allowDecimalPrecisionLoss: Boolean = SQLConf.get.decimalOperationsAllowPrecisionLoss + ): NumericEvalContext = { + new NumericEvalContext(evalMode, allowDecimalPrecisionLoss) + } + + def fromSQLConf(conf: SQLConf): NumericEvalContext = { + NumericEvalContext( + EvalMode.fromSQLConf(conf), + conf.decimalOperationsAllowPrecisionLoss) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index dfd41ad12a28..d066a87fc791 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -43,19 +43,19 @@ import org.apache.spark.sql.types._ since = "1.0.0") case class Sum( child: Expression, - evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) + evalContext: NumericEvalContext = NumericEvalContext.fromSQLConf(SQLConf.get)) extends DeclarativeAggregate with ImplicitCastInputTypes with UnaryLike[Expression] with SupportQueryContext { - def this(child: Expression) = this(child, EvalMode.fromSQLConf(SQLConf.get)) + def this(child: Expression) = this(child, NumericEvalContext.fromSQLConf(SQLConf.get)) private def shouldTrackIsEmpty: Boolean = resultType match { case _: DecimalType => true // For try_sum(), the result of following data types can be null on overflow. // Thus we need additional buffer to keep track of whether overflow happens. - case _: IntegralType | _: AnsiIntervalType if evalMode == EvalMode.TRY => true + case _: IntegralType | _: AnsiIntervalType if evalContext.evalMode == EvalMode.TRY => true case _ => false } @@ -89,7 +89,7 @@ case class Sum( private def add(left: Expression, right: Expression): Expression = left.dataType match { case _: DecimalType => DecimalAddNoOverflowCheck(left, right, left.dataType) - case _ => Add(left, right, evalMode) + case _ => Add(left, right, evalContext) } override lazy val aggBufferAttributes = if (shouldTrackIsEmpty) { @@ -176,7 +176,7 @@ case class Sum( resultType match { case d: DecimalType => val checkOverflowInSum = - CheckOverflowInSum(sum, d, evalMode != EvalMode.ANSI, getContextOrNull()) + CheckOverflowInSum(sum, d, evalContext.evalMode != EvalMode.ANSI, getContextOrNull()) If(isEmpty, Literal.create(null, resultType), checkOverflowInSum) case _ if shouldTrackIsEmpty => If(isEmpty, Literal.create(null, resultType), sum) @@ -187,11 +187,12 @@ case class Sum( // The flag `evalMode` won't be shown in the `toString` or `toAggString` methods override def flatArguments: Iterator[Any] = Iterator(child) - override def initQueryContext(): Option[QueryContext] = if (evalMode == EvalMode.ANSI) { - Some(origin.context) - } else { - None - } + override def initQueryContext(): Option[QueryContext] = + if (evalContext.evalMode == EvalMode.ANSI) { + Some(origin.context) + } else { + None + } override protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild) @@ -218,7 +219,7 @@ object TrySumExpressionBuilder extends ExpressionBuilder { override def build(funcName: String, expressions: Seq[Expression]): Expression = { val numArgs = expressions.length if (numArgs == 1) { - Sum(expressions.head, EvalMode.TRY) + Sum(expressions.head, NumericEvalContext(EvalMode.TRY)) } else { throw QueryCompilationErrors.wrongNumArgsError(funcName, Seq(1, 2), numArgs) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 032e04ce2cdd..1c93a6586761 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -196,7 +196,9 @@ abstract class BinaryArithmetic extends BinaryOperator with SupportQueryContext override def contextIndependentFoldable: Boolean = left.contextIndependentFoldable && right.contextIndependentFoldable - protected val evalMode: EvalMode.Value + val evalContext: NumericEvalContext + + def evalMode: EvalMode.Value = evalContext.evalMode private lazy val internalDataType: DataType = (left.dataType, right.dataType) match { case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) => @@ -224,7 +226,7 @@ abstract class BinaryArithmetic extends BinaryOperator with SupportQueryContext // When `spark.sql.decimalOperations.allowPrecisionLoss` is set to true, if the precision / scale // needed are out of the range of available values, the scale is reduced up to 6, in order to // prevent the truncation of the integer part of the decimals. - protected def allowPrecisionLoss: Boolean = SQLConf.get.decimalOperationsAllowPrecisionLoss + protected def allowPrecisionLoss: Boolean = evalContext.allowDecimalPrecisionLoss protected def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = { throw SparkException.internalError( @@ -405,11 +407,12 @@ object BinaryArithmetic { case class Add( left: Expression, right: Expression, - evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends BinaryArithmetic + evalContext: NumericEvalContext = NumericEvalContext.fromSQLConf(SQLConf.get)) + extends BinaryArithmetic with CommutativeExpression { def this(left: Expression, right: Expression) = - this(left, right, EvalMode.fromSQLConf(SQLConf.get)) + this(left, right, NumericEvalContext.fromSQLConf(SQLConf.get)) override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -465,9 +468,9 @@ case class Add( override lazy val canonicalized: Expression = { val reorderResult = buildCanonicalizedPlan( - { case Add(l, r, em) if em == evalMode => Seq(l, r) }, - { case (l: Expression, r: Expression) => Add(l, r, evalMode)}, - Some(evalMode) + { case Add(l, r, em) if em == evalContext => Seq(l, r) }, + { case (l: Expression, r: Expression) => Add(l, r, evalContext)}, + Some(evalContext) ) if (resolved && reorderResult.resolved && reorderResult.dataType == dataType) { reorderResult @@ -479,6 +482,11 @@ case class Add( } } +object Add { + def apply(left: Expression, right: Expression, evalMode: EvalMode.Value): Add = + new Add(left, right, NumericEvalContext(evalMode)) +} + @ExpressionDescription( usage = "expr1 _FUNC_ expr2 - Returns `expr1`-`expr2`.", examples = """ @@ -491,10 +499,11 @@ case class Add( case class Subtract( left: Expression, right: Expression, - evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends BinaryArithmetic { + evalContext: NumericEvalContext = NumericEvalContext.fromSQLConf(SQLConf.get)) + extends BinaryArithmetic { def this(left: Expression, right: Expression) = - this(left, right, EvalMode.fromSQLConf(SQLConf.get)) + this(left, right, NumericEvalContext.fromSQLConf(SQLConf.get)) override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -555,6 +564,11 @@ case class Subtract( newLeft: Expression, newRight: Expression): Subtract = copy(left = newLeft, right = newRight) } +object Subtract { + def apply(left: Expression, right: Expression, evalMode: EvalMode.Value): Subtract = + new Subtract(left, right, NumericEvalContext(evalMode)) +} + @ExpressionDescription( usage = "expr1 _FUNC_ expr2 - Returns `expr1`*`expr2`.", examples = """ @@ -567,11 +581,12 @@ case class Subtract( case class Multiply( left: Expression, right: Expression, - evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends BinaryArithmetic + evalContext: NumericEvalContext = NumericEvalContext.fromSQLConf(SQLConf.get)) + extends BinaryArithmetic with CommutativeExpression { def this(left: Expression, right: Expression) = - this(left, right, EvalMode.fromSQLConf(SQLConf.get)) + this(left, right, NumericEvalContext.fromSQLConf(SQLConf.get)) override def inputType: AbstractDataType = NumericType @@ -620,13 +635,18 @@ case class Multiply( override lazy val canonicalized: Expression = { buildCanonicalizedPlan( - { case Multiply(l, r, em) if em == evalMode => Seq(l, r) }, - { case (l: Expression, r: Expression) => Multiply(l, r, evalMode) }, - Some(evalMode) + { case Multiply(l, r, ec) if ec == evalContext => Seq(l, r) }, + { case (l: Expression, r: Expression) => Multiply(l, r, evalContext) }, + Some(evalContext) ) } } +object Multiply { + def apply(left: Expression, right: Expression, evalMode: EvalMode.Value): Multiply = + new Multiply(left, right, NumericEvalContext(evalMode)) +} + // Common base trait for Divide and Remainder, since these two classes are almost identical trait DivModLike extends BinaryArithmetic { @@ -779,10 +799,11 @@ trait DivModLike extends BinaryArithmetic { case class Divide( left: Expression, right: Expression, - evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends DivModLike { + evalContext: NumericEvalContext = NumericEvalContext.fromSQLConf(SQLConf.get)) + extends DivModLike { def this(left: Expression, right: Expression) = - this(left, right, EvalMode.fromSQLConf(SQLConf.get)) + this(left, right, NumericEvalContext.fromSQLConf(SQLConf.get)) // `try_divide` has exactly the same behavior as the legacy divide, so here it only executes // the error code path when `evalMode` is `ANSI`. @@ -834,6 +855,11 @@ case class Divide( newLeft: Expression, newRight: Expression): Divide = copy(left = newLeft, right = newRight) } +object Divide { + def apply(left: Expression, right: Expression, evalMode: EvalMode.Value): Divide = + new Divide(left, right, NumericEvalContext(evalMode)) +} + // scalastyle:off line.size.limit @ExpressionDescription( usage = "expr1 _FUNC_ expr2 - Divide `expr1` by `expr2`. It returns NULL if an operand is NULL or `expr2` is 0. The result is casted to long.", @@ -850,10 +876,11 @@ case class Divide( case class IntegralDivide( left: Expression, right: Expression, - evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends DivModLike { + evalContext: NumericEvalContext = NumericEvalContext.fromSQLConf(SQLConf.get)) + extends DivModLike { def this(left: Expression, right: Expression) = this(left, right, - EvalMode.fromSQLConf(SQLConf.get)) + NumericEvalContext.fromSQLConf(SQLConf.get)) override def checkDivideOverflow: Boolean = left.dataType match { case LongType if failOnError => true @@ -912,6 +939,11 @@ case class IntegralDivide( copy(left = newLeft, right = newRight) } +object IntegralDivide { + def apply(left: Expression, right: Expression, evalMode: EvalMode.Value): IntegralDivide = + new IntegralDivide(left, right, NumericEvalContext(evalMode)) +} + @ExpressionDescription( usage = "expr1 % expr2, or mod(expr1, expr2) - Returns the remainder after `expr1`/`expr2`.", examples = """ @@ -926,10 +958,11 @@ case class IntegralDivide( case class Remainder( left: Expression, right: Expression, - evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends DivModLike { + evalContext: NumericEvalContext = NumericEvalContext.fromSQLConf(SQLConf.get)) + extends DivModLike { def this(left: Expression, right: Expression) = - this(left, right, EvalMode.fromSQLConf(SQLConf.get)) + this(left, right, NumericEvalContext.fromSQLConf(SQLConf.get)) override def inputType: AbstractDataType = NumericType @@ -994,6 +1027,11 @@ case class Remainder( newLeft: Expression, newRight: Expression): Remainder = copy(left = newLeft, right = newRight) } +object Remainder { + def apply(left: Expression, right: Expression, evalMode: EvalMode.Value): Remainder = + new Remainder(left, right, NumericEvalContext(evalMode)) +} + @ExpressionDescription( usage = "_FUNC_(expr1, expr2) - Returns the positive value of `expr1` mod `expr2`.", examples = """ @@ -1008,10 +1046,11 @@ case class Remainder( case class Pmod( left: Expression, right: Expression, - evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends BinaryArithmetic { + evalContext: NumericEvalContext = NumericEvalContext.fromSQLConf(SQLConf.get)) + extends BinaryArithmetic { def this(left: Expression, right: Expression) = - this(left, right, EvalMode.fromSQLConf(SQLConf.get)) + this(left, right, NumericEvalContext.fromSQLConf(SQLConf.get)) override def toString: String = s"pmod($left, $right)" @@ -1199,6 +1238,11 @@ case class Pmod( copy(left = newLeft, right = newRight) } +object Pmod { + def apply(left: Expression, right: Expression, evalMode: EvalMode.Value): Pmod = + new Pmod(left, right, NumericEvalContext(evalMode)) +} + /** * A function that returns the least value of all parameters, skipping null values. * It takes at least 2 parameters, and returns null iff all parameters are null. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 26743ca6ff15..5fac0a93bf9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -39,7 +39,9 @@ import org.apache.spark.sql.types._ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic with CommutativeExpression { - protected override val evalMode: EvalMode.Value = EvalMode.LEGACY + override def evalMode: EvalMode.Value = EvalMode.LEGACY + + override val evalContext: NumericEvalContext = NumericEvalContext(evalMode) override def inputType: AbstractDataType = IntegralType @@ -86,7 +88,9 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic with CommutativeExpression { - protected override val evalMode: EvalMode.Value = EvalMode.LEGACY + override def evalMode: EvalMode.Value = EvalMode.LEGACY + + override val evalContext: NumericEvalContext = NumericEvalContext(evalMode) override def inputType: AbstractDataType = IntegralType @@ -133,7 +137,9 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic with CommutativeExpression { - protected override val evalMode: EvalMode.Value = EvalMode.LEGACY + override def evalMode: EvalMode.Value = EvalMode.LEGACY + + override val evalContext: NumericEvalContext = NumericEvalContext(evalMode) override def inputType: AbstractDataType = IntegralType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 649ce2478825..72be3031ace6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -285,34 +285,35 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper val n1 = makeNum(p1, s1) val n2 = makeNum(p2, s2) - val mulActual = Multiply( - Literal(Decimal(BigDecimal(n1), p1, s1)), - Literal(Decimal(BigDecimal(n2), p2, s2)) - ) - val mulExact = new java.math.BigDecimal(n1).multiply(new java.math.BigDecimal(n2)) - - val divActual = Divide( - Literal(Decimal(BigDecimal(n1), p1, s1)), - Literal(Decimal(BigDecimal(n2), p2, s2)) - ) - val divExact = new java.math.BigDecimal(n1) - .divide(new java.math.BigDecimal(n2), 100, RoundingMode.DOWN) - - val remActual = Remainder( - Literal(Decimal(BigDecimal(n1), p1, s1)), - Literal(Decimal(BigDecimal(n2), p2, s2)) - ) - val remExact = new java.math.BigDecimal(n1).remainder(new java.math.BigDecimal(n2)) - - val quotActual = IntegralDivide( - Literal(Decimal(BigDecimal(n1), p1, s1)), - Literal(Decimal(BigDecimal(n2), p2, s2)) - ) - val quotExact = - new java.math.BigDecimal(n1).divideToIntegralValue(new java.math.BigDecimal(n2)) - Seq(true, false).foreach { allowPrecLoss => withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> allowPrecLoss.toString) { + val mulActual = Multiply( + Literal(Decimal(BigDecimal(n1), p1, s1)), + Literal(Decimal(BigDecimal(n2), p2, s2)) + ) + val mulExact = new java.math.BigDecimal(n1).multiply(new java.math.BigDecimal(n2)) + + val divActual = Divide( + Literal(Decimal(BigDecimal(n1), p1, s1)), + Literal(Decimal(BigDecimal(n2), p2, s2)) + ) + val divExact = new java.math.BigDecimal(n1) + .divide(new java.math.BigDecimal(n2), 100, RoundingMode.DOWN) + + val remActual = Remainder( + Literal(Decimal(BigDecimal(n1), p1, s1)), + Literal(Decimal(BigDecimal(n2), p2, s2)) + ) + val remExact = new java.math.BigDecimal(n1).remainder(new java.math.BigDecimal(n2)) + + val quotActual = IntegralDivide( + Literal(Decimal(BigDecimal(n1), p1, s1)), + Literal(Decimal(BigDecimal(n2), p2, s2)) + ) + val quotExact = + new java.math.BigDecimal(n1).divideToIntegralValue(new java.math.BigDecimal(n2)) + + val mulType = Multiply(null, null).resultDecimalType(p1, s1, p2, s2) val mulResult = Decimal(mulExact.setScale(mulType.scale, RoundingMode.HALF_UP)) val mulExpected = @@ -483,7 +484,11 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } test("Remainder/Pmod: exception should contain SQL text context") { - Seq(("%", Remainder), ("pmod", Pmod)).foreach { case (symbol, exprBuilder) => + type BinaryOpFn = (Expression, Expression, EvalMode.Value) => BinaryArithmetic + Seq[(String, BinaryOpFn)]( + ("%", Remainder.apply), + ("pmod", Pmod.apply) + ).foreach { case (symbol, exprBuilder) => val query = s"1L $symbol 0L" val o = Origin( line = Some(1), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index c1c041509c35..6ee0029b6839 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -933,7 +933,8 @@ class BrokenColumnarAdd( left: ColumnarExpression, right: ColumnarExpression, failOnError: Boolean = false) - extends Add(left, right, EvalMode.fromBoolean(failOnError)) with ColumnarExpression { + extends Add(left, right, NumericEvalContext(EvalMode.fromBoolean(failOnError))) + with ColumnarExpression { override def supportsColumnar: Boolean = left.supportsColumnar && right.supportsColumnar diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala index 49997b5b0c18..592869968917 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala @@ -23,7 +23,7 @@ import test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.{JavaLo import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, EvalMode, Expression} +import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, EvalMode, Expression, NumericEvalContext} import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.classic.ClassicConversions._ @@ -109,7 +109,7 @@ object V2FunctionBenchmark extends SqlBasedBenchmark { left: Expression, right: Expression, override val nullable: Boolean) extends BinaryArithmetic { - protected override val evalMode: EvalMode.Value = EvalMode.LEGACY + override val evalContext: NumericEvalContext = NumericEvalContext(EvalMode.LEGACY) override def inputType: AbstractDataType = NumericType override def symbol: String = "+" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index f715353fd431..050a004a9353 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -1333,4 +1333,63 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } } } + + test("SPARK-53968 reading the view after allowPrecisionLoss is changed") { + import org.apache.spark.sql.internal.SQLConf + val partsTableName = "parts_tbl" + val ordersTableName = "orders_tbl" + val viewName = "view_spark_53968" + withTable(partsTableName, ordersTableName) { + spark.sql(s"""CREATE TABLE $partsTableName ( + | part_number STRING + |) USING PARQUET + |""".stripMargin) + spark.sql(s"INSERT INTO $partsTableName VALUES ('part1'), ('part2')") + + spark.sql(s"""CREATE TABLE $ordersTableName + |USING PARQUET AS + |SELECT * FROM VALUES + |('part1', CAST(100 AS DECIMAL(38,18)), CAST(NULL AS DECIMAL(38,18))), + |('part2', CAST(100 AS DECIMAL(38,18)), CAST(0 AS DECIMAL(38,18))), + |('part3', CAST(200.23 AS DECIMAL(38,18)), CAST(100 AS DECIMAL(38,18))) + |AS t(part_number, unit_price, shipping_price); + |""".stripMargin) + + Seq((true, false), (false, true)).foreach { case (oldValue, newValue) => + withView(viewName) { + withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> oldValue.toString) { + spark.sql(s""" + |CREATE OR REPLACE VIEW $viewName AS + |WITH order_details AS ( + | SELECT + | orders.part_number, + | orders.unit_price + | + COALESCE(orders.shipping_price, CAST(0 AS DECIMAL(38, 18))) + | AS total_price + | FROM $ordersTableName orders + |) + |SELECT + | od.total_price + |FROM order_details od LEFT JOIN $partsTableName pt + | ON pt.part_number = od.part_number + |ORDER BY od.total_price + """.stripMargin) + + val expectedResults = Seq( + Row(BigDecimal("100.00000000000000000")), + Row(BigDecimal("100.00000000000000000")), + Row(BigDecimal("300.23000000000000000"))) + + checkAnswer(spark.sql(s"SELECT * FROM $viewName"), expectedResults) + + // Re-run the query with new value of the config, we should get the same result. + withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> newValue.toString) { + + checkAnswer(spark.sql(s"SELECT * FROM $viewName"), expectedResults) + } + } + } + } + } + } }