diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index c3d831482704b..592ae04a055d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.command import java.net.URI -import scala.collection.mutable - import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext @@ -29,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{CTEInChildren, LogicalPlan, UnaryCommand} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} -import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, PartitionTaskStats} +import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.SerializableConfiguration @@ -55,12 +53,10 @@ trait DataWritingCommand extends UnaryCommand with CTEInChildren { DataWritingCommand.logicalPlanOutputWithNames(query, outputColumnNames) lazy val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics - lazy val partitionMetrics: mutable.Map[String, PartitionTaskStats] = - BasicWriteJobStatsTracker.partitionMetrics def basicWriteJobStatsTracker(hadoopConf: Configuration): BasicWriteJobStatsTracker = { val serializableHadoopConf = new SerializableConfiguration(hadoopConf) - new BasicWriteJobStatsTracker(serializableHadoopConf, metrics, partitionMetrics) + new BasicWriteJobStatsTracker(serializableHadoopConf, metrics) } def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index 1da22957359fc..62158eb0b8193 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -229,16 +229,15 @@ class BasicWriteTaskStatsTracker( class BasicWriteJobStatsTracker( serializableHadoopConf: SerializableConfiguration, @transient val driverSideMetrics: Map[String, SQLMetric], - @transient val driverSidePartitionMetrics: mutable.Map[String, PartitionTaskStats], taskCommitTimeMetric: SQLMetric) extends WriteJobStatsTracker { + val partitionMetrics: PartitionMetricsWriteInfo = new PartitionMetricsWriteInfo() + def this( serializableHadoopConf: SerializableConfiguration, - metrics: Map[String, SQLMetric], - partitionMetrics: mutable.Map[String, PartitionTaskStats]) = { - this(serializableHadoopConf, metrics - TASK_COMMIT_TIME, partitionMetrics, - metrics(TASK_COMMIT_TIME)) + metrics: Map[String, SQLMetric]) = { + this(serializableHadoopConf, metrics - TASK_COMMIT_TIME, metrics(TASK_COMMIT_TIME)) } override def newTaskInstance(): WriteTaskStatsTracker = { @@ -265,12 +264,7 @@ class BasicWriteJobStatsTracker( // Check if we know the mapping of the internal row to a partition path if (partitionsMap.contains(s._1)) { val path = partitionsMap(s._1) - val current = partitionMetrics(path) - driverSidePartitionMetrics(path) = BasicWritePartitionTaskStats( - current.numFiles + s._2.numFiles, - current.numBytes + s._2.numBytes, - current.numRows + s._2.numRows - ) + partitionMetrics.update(path, s._2.numBytes, s._2.numRows, s._2.numFiles) } }) } @@ -284,14 +278,8 @@ class BasicWriteJobStatsTracker( val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, driverSideMetrics.values.toList) - val partitionMetricsWriteInfo = new PartitionMetricsWriteInfo() - driverSidePartitionMetrics.foreach(entry => { - val key = entry._1 - val value = entry._2 - partitionMetricsWriteInfo.update(key, value.numBytes, value.numRows, value.numFiles) - }) SQLPartitionMetrics.postDriverMetricUpdates(sparkContext, executionId, - partitionMetricsWriteInfo) + partitionMetrics) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala index 2f9bdd73f3812..7a060a9bc8fe1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala @@ -128,8 +128,7 @@ trait FileWrite extends Write { val partitionMetrics: mutable.Map[String, PartitionTaskStats] = BasicWriteJobStatsTracker.partitionMetrics val serializableHadoopConf = new SerializableConfiguration(hadoopConf) - val statsTracker = new BasicWriteJobStatsTracker(serializableHadoopConf, metrics, - partitionMetrics) + val statsTracker = new BasicWriteJobStatsTracker(serializableHadoopConf, metrics) // TODO: after partitioning is supported in V2: // 1. filter out partition columns in `dataColumns`. // 2. Don't use Seq.empty for `partitionColumns`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index a493c08844992..ea8db3c99de92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -140,8 +140,7 @@ class FileStreamSink( private def basicWriteJobStatsTracker: BasicWriteJobStatsTracker = { val serializableHadoopConf = new SerializableConfiguration(hadoopConf) - new BasicWriteJobStatsTracker(serializableHadoopConf, BasicWriteJobStatsTracker.metrics, - BasicWriteJobStatsTracker.partitionMetrics) + new BasicWriteJobStatsTracker(serializableHadoopConf, BasicWriteJobStatsTracker.metrics) } override def addBatch(batchId: Long, data: DataFrame): Unit = {