diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala index a2daec0b1ade1..91c9457af7de3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp +import org.apache.spark.sql.internal.SQLConf /** @@ -33,6 +34,15 @@ class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mutable def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(bindReferences(expressions, inputSchema)) + private[this] val subExprEliminationEnabled = SQLConf.get.subexpressionEliminationEnabled + private[this] lazy val runtime = + new SubExprEvaluationRuntime(SQLConf.get.subexpressionEliminationCacheMaxEntries) + private[this] val exprs = if (subExprEliminationEnabled) { + runtime.proxyExpressions(expressions) + } else { + expressions + } + private[this] val buffer = new Array[Any](expressions.size) override def initialize(partitionIndex: Int): Unit = { @@ -76,11 +86,15 @@ class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mutable }.toArray override def apply(input: InternalRow): InternalRow = { + if (subExprEliminationEnabled) { + runtime.setInput(input) + } + var i = 0 while (i < validExprs.length) { - val (expr, ordinal) = validExprs(i) + val (_, ordinal) = validExprs(i) // Store the result into buffer first, to make the projection atomic (needed by aggregation) - buffer(ordinal) = expr.eval(input) + buffer(ordinal) = exprs(ordinal).eval(input) i += 1 } i = 0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala index 70789dac1d87a..0e71892db666b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -30,6 +31,15 @@ import org.apache.spark.sql.types._ */ class InterpretedSafeProjection(expressions: Seq[Expression]) extends Projection { + private[this] val subExprEliminationEnabled = SQLConf.get.subexpressionEliminationEnabled + private[this] lazy val runtime = + new SubExprEvaluationRuntime(SQLConf.get.subexpressionEliminationCacheMaxEntries) + private[this] val exprs = if (subExprEliminationEnabled) { + runtime.proxyExpressions(expressions) + } else { + expressions + } + private[this] val mutableRow = new SpecificInternalRow(expressions.map(_.dataType)) private[this] val exprsWithWriters = expressions.zipWithIndex.filter { @@ -49,7 +59,7 @@ class InterpretedSafeProjection(expressions: Seq[Expression]) extends Projection } } } - (e, f) + (exprs(i), f) } private def generateSafeValueConverter(dt: DataType): Any => Any = dt match { @@ -97,6 +107,10 @@ class InterpretedSafeProjection(expressions: Seq[Expression]) extends Projection } override def apply(row: InternalRow): InternalRow = { + if (subExprEliminationEnabled) { + runtime.setInput(row) + } + var i = 0 while (i < exprsWithWriters.length) { val (expr, writer) = exprsWithWriters(i) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala index c31310bc54023..8f030b45e5d3e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala @@ -80,4 +80,50 @@ class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { assert(errMsg.contains("MutableProjection cannot use UnsafeRow for output data types:")) } } + + test("SPARK-33473: subexpression elimination for interpreted MutableProjection") { + Seq("true", "false").foreach { enabled => + withSQLConf( + SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> enabled, + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString) { + val one = BoundReference(0, DoubleType, true) + val two = BoundReference(1, DoubleType, true) + + val mul = Multiply(one, two) + val mul2 = Multiply(mul, mul) + val sqrt = Sqrt(mul2) + val sum = Add(mul2, sqrt) + + val proj = MutableProjection.create(Seq(sum)) + val result = (d1: Double, d2: Double) => + ((d1 * d2) * (d1 * d2)) + Math.sqrt((d1 * d2) * (d1 * d2)) + + val inputRows = Seq( + InternalRow.fromSeq(Seq(1.0, 2.0)), + InternalRow.fromSeq(Seq(2.0, 3.0)), + InternalRow.fromSeq(Seq(1.0, null)), + InternalRow.fromSeq(Seq(null, 2.0)), + InternalRow.fromSeq(Seq(3.0, 4.0)), + InternalRow.fromSeq(Seq(null, null)) + ) + val expectedResults = Seq( + result(1.0, 2.0), + result(2.0, 3.0), + null, + null, + result(3.0, 4.0), + null + ) + + inputRows.zip(expectedResults).foreach { case (inputRow, expected) => + val projRow = proj.apply(inputRow) + if (expected != null) { + assert(projRow.getDouble(0) == expected) + } else { + assert(projRow.isNullAt(0)) + } + } + } + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index 4c9bcfe8f93a6..180665e653727 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -23,13 +23,14 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** * A test suite for generated projections */ -class GeneratedProjectionSuite extends SparkFunSuite { +class GeneratedProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { test("generated projections on wider table") { val N = 1000 @@ -246,4 +247,50 @@ class GeneratedProjectionSuite extends SparkFunSuite { val row2 = mutableProj(result) assert(result === row2) } + + test("SPARK-33473: subexpression elimination for interpreted SafeProjection") { + Seq("true", "false").foreach { enabled => + withSQLConf( + SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> enabled, + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString) { + val one = BoundReference(0, DoubleType, true) + val two = BoundReference(1, DoubleType, true) + + val mul = Multiply(one, two) + val mul2 = Multiply(mul, mul) + val sqrt = Sqrt(mul2) + val sum = Add(mul2, sqrt) + + val proj = SafeProjection.create(Seq(sum)) + val result = (d1: Double, d2: Double) => + ((d1 * d2) * (d1 * d2)) + Math.sqrt((d1 * d2) * (d1 * d2)) + + val inputRows = Seq( + InternalRow.fromSeq(Seq(1.0, 2.0)), + InternalRow.fromSeq(Seq(2.0, 3.0)), + InternalRow.fromSeq(Seq(1.0, null)), + InternalRow.fromSeq(Seq(null, 2.0)), + InternalRow.fromSeq(Seq(3.0, 4.0)), + InternalRow.fromSeq(Seq(null, null)) + ) + val expectedResults = Seq( + result(1.0, 2.0), + result(2.0, 3.0), + null, + null, + result(3.0, 4.0), + null + ) + + inputRows.zip(expectedResults).foreach { case (inputRow, expected) => + val projRow = proj.apply(inputRow) + if (expected != null) { + assert(projRow.getDouble(0) == expected) + } else { + assert(projRow.isNullAt(0)) + } + } + } + } + } }