diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index a56007f5d5d95..c9de8c7e1a9d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql.execution.command import org.apache.hadoop.conf.Configuration +import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker -import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.util.SerializableConfiguration /** @@ -73,4 +74,26 @@ object DataWritingCommand { attr.withName(outputName) } } + + /** + * When execute CTAS operators, Spark will use [[InsertIntoHadoopFsRelationCommand]] + * or [[InsertIntoHiveTable]] command to write data, they both inherit metrics from + * [[DataWritingCommand]], but after running [[InsertIntoHadoopFsRelationCommand]] + * or [[InsertIntoHiveTable]], we only update metrics in these two command through + * [[BasicWriteJobStatsTracker]], we also need to propogate metrics to the command + * that actually calls [[InsertIntoHadoopFsRelationCommand]] or [[InsertIntoHiveTable]]. + * + * @param sparkContext Current SparkContext. + * @param command Command to execute writing data. + * @param metrics Metrics of real DataWritingCommand. + */ + def propogateMetrics( + sparkContext: SparkContext, + command: DataWritingCommand, + metrics: Map[String, SQLMetric]): Unit = { + command.metrics.foreach { case (key, metric) => metrics(key).set(metric.value) } + SQLMetrics.postDriverMetricUpdates(sparkContext, + sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY), + metrics.values.toSeq) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 995d6273ea588..bb54457afdc78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -220,7 +220,7 @@ case class CreateDataSourceTableAsSelectCommand( catalogTable = if (tableExists) Some(table) else None) try { - dataSource.writeAndRead(mode, query, outputColumnNames, physicalPlan) + dataSource.writeAndRead(mode, query, outputColumnNames, physicalPlan, metrics) } catch { case ex: AnalysisException => logError(s"Failed to write to table ${table.identifier.unquotedString}", ex) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index ea0bc4fcd451b..b75152096c5b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -43,6 +43,7 @@ import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider} import org.apache.spark.sql.internal.SQLConf @@ -518,7 +519,8 @@ case class DataSource( mode: SaveMode, data: LogicalPlan, outputColumnNames: Seq[String], - physicalPlan: SparkPlan): BaseRelation = { + physicalPlan: SparkPlan, + metrics: Map[String, SQLMetric]): BaseRelation = { val outputColumns = DataWritingCommand.logicalPlanOutputWithNames(data, outputColumnNames) if (outputColumns.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { throw new AnalysisException("Cannot save interval data type into external storage.") @@ -546,6 +548,7 @@ case class DataSource( partitionColumns = resolvedPartCols, outputColumnNames = outputColumnNames) resolved.run(sparkSession, physicalPlan) + DataWritingCommand.propogateMetrics(sparkSession.sparkContext, resolved, metrics) // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring copy(userSpecifiedSchema = Some(outputColumns.toStructType.asNullable)).resolveRelation() case _ => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 2b703c06fa900..dd99368e3a87b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.execution.{FilterExec, RangeExec, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec import org.apache.spark.sql.functions._ @@ -782,4 +783,20 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils } } } + + test("SPARK-34567: Add metrics for CTAS operator") { + withTable("t") { + val df = sql("CREATE TABLE t USING PARQUET AS SELECT 1 as a") + val dataWritingCommandExec = + df.queryExecution.executedPlan.asInstanceOf[DataWritingCommandExec] + dataWritingCommandExec.executeCollect() + val createTableAsSelect = dataWritingCommandExec.cmd + assert(createTableAsSelect.metrics.contains("numFiles")) + assert(createTableAsSelect.metrics("numFiles").value == 1) + assert(createTableAsSelect.metrics.contains("numOutputBytes")) + assert(createTableAsSelect.metrics("numOutputBytes").value > 0) + assert(createTableAsSelect.metrics.contains("numOutputRows")) + assert(createTableAsSelect.metrics("numOutputRows").value == 1) + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index ccaa4502d9d2a..283c254b39602 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -55,6 +55,7 @@ trait CreateHiveTableAsSelectBase extends DataWritingCommand { val command = getWritingCommand(catalog, tableDesc, tableExists = true) command.run(sparkSession, child) + DataWritingCommand.propogateMetrics(sparkSession.sparkContext, command, metrics) } else { // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data @@ -69,6 +70,7 @@ trait CreateHiveTableAsSelectBase extends DataWritingCommand { val createdTableMeta = catalog.getTableMetadata(tableDesc.identifier) val command = getWritingCommand(catalog, createdTableMeta, tableExists = false) command.run(sparkSession, child) + DataWritingCommand.propogateMetrics(sparkSession.sparkContext, command, metrics) } catch { case NonFatal(e) => // drop the created table. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala index 4d6dafd598a2e..a2de43d737704 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite +import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.metric.SQLMetricsTestUtils +import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.test.TestHiveSingleton // Disable AQE because metric info is different with AQE on/off @@ -34,4 +36,29 @@ class SQLMetricsSuite extends SQLMetricsTestUtils with TestHiveSingleton testMetricsDynamicPartition("hive", "hive", "t1") } } + + test("SPARK-34567: Add metrics for CTAS operator") { + Seq(false, true).foreach { canOptimized => + withSQLConf(HiveUtils.CONVERT_METASTORE_CTAS.key -> canOptimized.toString) { + withTable("t") { + val df = sql(s"CREATE TABLE t STORED AS PARQUET AS SELECT 1 as a") + val dataWritingCommandExec = + df.queryExecution.executedPlan.asInstanceOf[DataWritingCommandExec] + dataWritingCommandExec.executeCollect() + val createTableAsSelect = dataWritingCommandExec.cmd + if (canOptimized) { + assert(createTableAsSelect.isInstanceOf[OptimizedCreateHiveTableAsSelectCommand]) + } else { + assert(createTableAsSelect.isInstanceOf[CreateHiveTableAsSelectCommand]) + } + assert(createTableAsSelect.metrics.contains("numFiles")) + assert(createTableAsSelect.metrics("numFiles").value == 1) + assert(createTableAsSelect.metrics.contains("numOutputBytes")) + assert(createTableAsSelect.metrics("numOutputBytes").value > 0) + assert(createTableAsSelect.metrics.contains("numOutputRows")) + assert(createTableAsSelect.metrics("numOutputRows").value == 1) + } + } + } + } }