From a07a4db2d63758f30dbdb014be6b438ef6f47e12 Mon Sep 17 00:00:00 2001
From: Steve Vaughan Jr
Date: Thu, 28 Mar 2024 12:41:57 -0400
Subject: [PATCH] test: Switch to check posted event
---
.../write/PartitionMetricsWriteInfo.java | 3 +-
.../metric/SQLMetricsTestUtils.scala | 71 +++++++------------
2 files changed, 29 insertions(+), 45 deletions(-)
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/PartitionMetricsWriteInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/PartitionMetricsWriteInfo.java
index 71dafb1847587..82455796444e1 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/PartitionMetricsWriteInfo.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/PartitionMetricsWriteInfo.java
@@ -17,6 +17,7 @@
package org.apache.spark.sql.connector.write;
+import java.io.Serializable;
import java.util.Collections;
import java.util.Map;
import java.util.TreeMap;
@@ -27,7 +28,7 @@
* This is patterned after {@code org.apache.spark.util.AccumulatorV2}
*
*/
-public class PartitionMetricsWriteInfo {
+public class PartitionMetricsWriteInfo implements Serializable {
private final Map metrics = new TreeMap<>();
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala
index 3849b824143c0..00580c73c65b0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala
@@ -23,16 +23,14 @@ import scala.collection.mutable
import scala.collection.mutable.HashMap
import org.apache.spark.TestUtils
-import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
+import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerTaskEnd}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanInfo}
-import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec}
-import org.apache.spark.sql.execution.datasources.V1WriteCommand
+import org.apache.spark.sql.connector.write.SparkListenerSQLPartitionMetrics
+import org.apache.spark.sql.execution.{SparkPlan, SparkPlanInfo}
import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SQLAppStatusStore}
import org.apache.spark.sql.internal.SQLConf.WHOLESTAGE_CODEGEN_ENABLED
import org.apache.spark.sql.test.SQLTestUtils
-import org.apache.spark.sql.util.QueryExecutionListener
trait SQLMetricsTestUtils extends SQLTestUtils {
@@ -108,40 +106,28 @@ trait SQLMetricsTestUtils extends SQLTestUtils {
assert(totalNumBytes > 0)
}
- private class CaptureWriteCommand extends QueryExecutionListener {
+ private class CapturePartitionMetrics extends SparkListener {
- val v1WriteCommands: mutable.Buffer[V1WriteCommand] = mutable.Buffer[V1WriteCommand]()
+ val events: mutable.Buffer[SparkListenerSQLPartitionMetrics] =
+ mutable.Buffer[SparkListenerSQLPartitionMetrics]()
- override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
- if (qe.executedPlan.isInstanceOf[ExecutedCommandExec] ||
- qe.executedPlan.isInstanceOf[DataWritingCommandExec]) {
- qe.optimizedPlan match {
- case _: V1WriteCommand =>
- val executedPlanCmd = qe.executedPlan.asInstanceOf[DataWritingCommandExec].cmd
- v1WriteCommands += executedPlanCmd.asInstanceOf[V1WriteCommand]
-
- // All other commands
- case _ =>
- logDebug(f"Query execution data is not currently supported for query: " +
- f"${qe.toString} with plan class ${qe.executedPlan.getClass.getName} " +
- f" and executed plan : ${qe.executedPlan}")
- }
+ override def onOtherEvent(event: SparkListenerEvent): Unit = {
+ event match {
+ case metrics: SparkListenerSQLPartitionMetrics => events += metrics
+ case _ =>
}
}
-
- override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
-
}
- protected def withQueryExecutionListener[L <: QueryExecutionListener]
+ protected def withSparkListener[L <: SparkListener]
(spark: SparkSession, listener: L)
(body: L => Unit): Unit = {
- spark.listenerManager.register(listener)
+ spark.sparkContext.addSparkListener(listener)
try {
body(listener)
}
finally {
- spark.listenerManager.unregister(listener)
+ spark.sparkContext.removeSparkListener(listener)
}
}
@@ -149,8 +135,8 @@ trait SQLMetricsTestUtils extends SQLTestUtils {
protected def testMetricsNonDynamicPartition(
dataFormat: String,
tableName: String): Unit = {
- val listener = new CaptureWriteCommand()
- withQueryExecutionListener(spark, listener) { _ =>
+ val listener = new CapturePartitionMetrics()
+ withSparkListener(spark, listener) { _ =>
withTable(tableName) {
Seq((1, 2)).toDF("i", "j")
.write.format(dataFormat).mode("overwrite").saveAsTable(tableName)
@@ -167,19 +153,16 @@ trait SQLMetricsTestUtils extends SQLTestUtils {
}
}
- // Verify that there were 2 write command for the entire write process. This test creates the
- // table and performs a repartitioning
- assert(listener.v1WriteCommands.length == 2)
- assert(listener.v1WriteCommands.forall(
- v1WriteCommand => v1WriteCommand.partitionMetrics.isEmpty))
+ // Verify that there are no partition metrics for the entire write process.
+ assert(listener.events.isEmpty)
}
protected def testMetricsDynamicPartition(
provider: String,
dataFormat: String,
tableName: String): Unit = {
- val listener = new CaptureWriteCommand()
- withQueryExecutionListener(spark, listener) { _ =>
+ val listener = new CapturePartitionMetrics()
+ withSparkListener(spark, listener) { _ =>
withTable(tableName) {
withTempPath { dir =>
spark.sql(
@@ -208,18 +191,18 @@ trait SQLMetricsTestUtils extends SQLTestUtils {
}
}
- // Verify that there was a single write command for the entire write process
- assert(listener.v1WriteCommands.length == 1)
- val v1WriteCommand = listener.v1WriteCommands.head
+ // Verify that there a single event for partition metrics for the entire write process. This
+ // test creates the table and performs a repartitioning, but only 1 action actually results
+ // in collecting partition metrics.
+ assert(listener.events.length == 1)
+ val event = listener.events.head
// Verify the number of partitions
- assert(v1WriteCommand.partitionMetrics.keySet.size == 40)
+ assert(event.metrics.keySet.size == 40)
// Verify the number of files per partition
- assert(v1WriteCommand.partitionMetrics.values.forall(
- partitionStats => partitionStats.numFiles == 1))
+ event.metrics.values.forEach(partitionStats => assert(partitionStats.numFiles == 1))
// Verify the number of rows per partition
- assert(v1WriteCommand.partitionMetrics.values.forall(
- partitionStats => partitionStats.numRows == 2))
+ event.metrics.values.forEach(partitionStats => assert(partitionStats.numRecords == 2))
}
/**