From a07a4db2d63758f30dbdb014be6b438ef6f47e12 Mon Sep 17 00:00:00 2001 From: Steve Vaughan Jr Date: Thu, 28 Mar 2024 12:41:57 -0400 Subject: [PATCH] test: Switch to check posted event --- .../write/PartitionMetricsWriteInfo.java | 3 +- .../metric/SQLMetricsTestUtils.scala | 71 +++++++------------ 2 files changed, 29 insertions(+), 45 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/PartitionMetricsWriteInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/PartitionMetricsWriteInfo.java index 71dafb1847587..82455796444e1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/PartitionMetricsWriteInfo.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/PartitionMetricsWriteInfo.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.connector.write; +import java.io.Serializable; import java.util.Collections; import java.util.Map; import java.util.TreeMap; @@ -27,7 +28,7 @@ * This is patterned after {@code org.apache.spark.util.AccumulatorV2} *

*/ -public class PartitionMetricsWriteInfo { +public class PartitionMetricsWriteInfo implements Serializable { private final Map metrics = new TreeMap<>(); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 3849b824143c0..00580c73c65b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -23,16 +23,14 @@ import scala.collection.mutable import scala.collection.mutable.HashMap import org.apache.spark.TestUtils -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} +import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerTaskEnd} import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanInfo} -import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} -import org.apache.spark.sql.execution.datasources.V1WriteCommand +import org.apache.spark.sql.connector.write.SparkListenerSQLPartitionMetrics +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanInfo} import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SQLAppStatusStore} import org.apache.spark.sql.internal.SQLConf.WHOLESTAGE_CODEGEN_ENABLED import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.util.QueryExecutionListener trait SQLMetricsTestUtils extends SQLTestUtils { @@ -108,40 +106,28 @@ trait SQLMetricsTestUtils extends SQLTestUtils { assert(totalNumBytes > 0) } - private class CaptureWriteCommand extends QueryExecutionListener { + private class CapturePartitionMetrics extends SparkListener { - val v1WriteCommands: mutable.Buffer[V1WriteCommand] = mutable.Buffer[V1WriteCommand]() + val events: mutable.Buffer[SparkListenerSQLPartitionMetrics] = + mutable.Buffer[SparkListenerSQLPartitionMetrics]() - override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - if (qe.executedPlan.isInstanceOf[ExecutedCommandExec] || - qe.executedPlan.isInstanceOf[DataWritingCommandExec]) { - qe.optimizedPlan match { - case _: V1WriteCommand => - val executedPlanCmd = qe.executedPlan.asInstanceOf[DataWritingCommandExec].cmd - v1WriteCommands += executedPlanCmd.asInstanceOf[V1WriteCommand] - - // All other commands - case _ => - logDebug(f"Query execution data is not currently supported for query: " + - f"${qe.toString} with plan class ${qe.executedPlan.getClass.getName} " + - f" and executed plan : ${qe.executedPlan}") - } + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case metrics: SparkListenerSQLPartitionMetrics => events += metrics + case _ => } } - - override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} - } - protected def withQueryExecutionListener[L <: QueryExecutionListener] + protected def withSparkListener[L <: SparkListener] (spark: SparkSession, listener: L) (body: L => Unit): Unit = { - spark.listenerManager.register(listener) + spark.sparkContext.addSparkListener(listener) try { body(listener) } finally { - spark.listenerManager.unregister(listener) + spark.sparkContext.removeSparkListener(listener) } } @@ -149,8 +135,8 @@ trait SQLMetricsTestUtils extends SQLTestUtils { protected def testMetricsNonDynamicPartition( dataFormat: String, tableName: String): Unit = { - val listener = new CaptureWriteCommand() - withQueryExecutionListener(spark, listener) { _ => + val listener = new CapturePartitionMetrics() + withSparkListener(spark, listener) { _ => withTable(tableName) { Seq((1, 2)).toDF("i", "j") .write.format(dataFormat).mode("overwrite").saveAsTable(tableName) @@ -167,19 +153,16 @@ trait SQLMetricsTestUtils extends SQLTestUtils { } } - // Verify that there were 2 write command for the entire write process. This test creates the - // table and performs a repartitioning - assert(listener.v1WriteCommands.length == 2) - assert(listener.v1WriteCommands.forall( - v1WriteCommand => v1WriteCommand.partitionMetrics.isEmpty)) + // Verify that there are no partition metrics for the entire write process. + assert(listener.events.isEmpty) } protected def testMetricsDynamicPartition( provider: String, dataFormat: String, tableName: String): Unit = { - val listener = new CaptureWriteCommand() - withQueryExecutionListener(spark, listener) { _ => + val listener = new CapturePartitionMetrics() + withSparkListener(spark, listener) { _ => withTable(tableName) { withTempPath { dir => spark.sql( @@ -208,18 +191,18 @@ trait SQLMetricsTestUtils extends SQLTestUtils { } } - // Verify that there was a single write command for the entire write process - assert(listener.v1WriteCommands.length == 1) - val v1WriteCommand = listener.v1WriteCommands.head + // Verify that there a single event for partition metrics for the entire write process. This + // test creates the table and performs a repartitioning, but only 1 action actually results + // in collecting partition metrics. + assert(listener.events.length == 1) + val event = listener.events.head // Verify the number of partitions - assert(v1WriteCommand.partitionMetrics.keySet.size == 40) + assert(event.metrics.keySet.size == 40) // Verify the number of files per partition - assert(v1WriteCommand.partitionMetrics.values.forall( - partitionStats => partitionStats.numFiles == 1)) + event.metrics.values.forEach(partitionStats => assert(partitionStats.numFiles == 1)) // Verify the number of rows per partition - assert(v1WriteCommand.partitionMetrics.values.forall( - partitionStats => partitionStats.numRows == 2)) + event.metrics.values.forEach(partitionStats => assert(partitionStats.numRecords == 2)) } /**