diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index f8d0df1b6e470..69c760b5a00b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -162,6 +162,9 @@ case class BroadcastHashJoinExec( // Anti Join: Drop the row on the streamed side if it is a match on the build hashed.get(lookupKey) == null } + }).map(row => { + numOutputRows += 1 + row }) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index f5cfbbf5a65ee..07b35713fe5b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, SQLHadoopMapReduceCommitProtocol} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -842,6 +842,29 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils assert(createTableAsSelect.metrics("numOutputRows").value == 1) } } + + test("SPARK-41003: BHJ LeftAnti does not update numOutputRows when codegen is disabled") { + Seq(true, false).foreach { enableWholeStage => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> enableWholeStage.toString) { + withSQLConf(SQLConf.OPTIMIZE_NULL_AWARE_ANTI_JOIN.key -> "true") { + withTable("t1", "t2") { + spark.range(4).write.saveAsTable("t1") + spark.range(2).write.saveAsTable("t2") + val df = sql("SELECT * FROM t1 WHERE id NOT IN (SELECT id FROM t2)") + df.collect() + val plan = df.queryExecution.executedPlan + + val joins = plan.collect { + case s: BroadcastHashJoinExec => s + } + + assert(joins.size === 1) + testMetricsInSparkPlanOperator(joins.head, Map("numOutputRows" -> 2)) + } + } + } + } + } } case class CustomFileCommitProtocol(